diff --git a/syntax_aware_generate.py b/syntax_aware_generate.py index a9c38c6..1c7545a 100644 --- a/syntax_aware_generate.py +++ b/syntax_aware_generate.py @@ -1,7 +1,10 @@ +import nltk +import operator import os import pickle import random -import nltk +import re +import codecs from nltk.tree import Tree from collections import defaultdict from tqdm import tqdm @@ -9,38 +12,61 @@ from stat_parser import Parser syntaxes = defaultdict(set) SYNTAXES_FILE = 'syntaxes.p' +CFDS_FILE = 'cfds.p' def tree_hash(self): return hash(tuple(self.leaves())) + Tree.__hash__ = tree_hash +# NOTE: to me: I need to replace nltk parse and tokenization with spacy because it is much faster and less detailed +# which is actually a plus. The problem is that spacy does not create a syntax tree like nltk does. However, it does +# create a dependency tree, which might be good enough for splitting into chunks that can be swapped out between +# corpora. Shitty bus wifi makes it hard to download spacy data and look up the docs. + + def generate(): global syntaxes parser = Parser() if not os.path.exists(SYNTAXES_FILE): - sents = nltk.corpus.gutenberg.sents('melville-moby_dick.txt') - sents = sents[0:100] - for sent in tqdm(sents): - try: - parsed = parser.parse(' '.join(sent)) - except TypeError: - pass - syntax_signature(parsed, save=True) + # sents = nltk.corpus.gutenberg.sents('results.txt') + # NOTE: results.txt is a big file of raw text not included in source control, provide your own corpus. + with codecs.open('results.txt', encoding='utf-8') as corpus: + sents = nltk.sent_tokenize(corpus.read()) + sents = [sent for sent in sents if len(sent) < 150][0:1500] + for sent in tqdm(sents): + try: + parsed = parser.parse(sent) + except TypeError: + pass + syntax_signature(parsed, save=True) with open(SYNTAXES_FILE, 'wb+') as pickle_file: pickle.dump(syntaxes, pickle_file) else: with open(SYNTAXES_FILE, 'rb+') as pickle_file: syntaxes = pickle.load(pickle_file) + + if not os.path.exists(CFDS_FILE): + # corpus = nltk.corpus.gutenberg.raw('results.txt') + with codecs.open('results.txt', encoding='utf-8') as corpus: + cfds = [make_cfd(corpus.read(), i, exclude_punctuation=False, case_insensitive=True) for i in range(2, 5)] + with open(CFDS_FILE, 'wb+') as pickle_file: + pickle.dump(cfds, pickle_file) + else: + with open(CFDS_FILE, 'rb+') as pickle_file: + cfds = pickle.load(pickle_file) + sents = nltk.corpus.gutenberg.sents('austen-emma.txt') + sents = [sent for sent in sents if len(sent) < 50] sent = random.choice(sents) parsed = parser.parse(' '.join(sent)) print(parsed) print(' '.join(parsed.leaves())) - replaced_tree = tree_replace(parsed) - print('='*30) + replaced_tree = tree_replace(parsed, cfds, []) + print('=' * 30) print(' '.join(replaced_tree.leaves())) print(replaced_tree) @@ -72,12 +98,22 @@ def syntax_signature_recurse(tree, save=False): raise ValueError('Not a nltk.tree.Tree: {}'.format(tree)) -def tree_replace(tree): +def tree_replace(tree, cfds, preceding_children=[]): + condition_search = ' '.join([' '.join(child.leaves()) for child in preceding_children]).lower() sig = syntax_signature(tree) if sig in syntaxes: - return random.choice(tuple(syntaxes[sig])) + matching_fragments = tuple(syntaxes[sig]) + if len(matching_fragments) > 1 and condition_search: + matching_leaves = [' '.join(frag.leaves()) for frag in matching_fragments] + most_common = get_most_common(condition_search, cfds) + candidates = list(set(matching_leaves).intersection(set(most_common))) + if candidates: + return Tree(tree.label(), [random.choice(candidates)]) + # find the first element of get_most_common that is also in this list of matching_leaves + return random.choice(matching_fragments) else: - children = [tree_replace(child) for child in tree if type(child) is Tree] + children = [tree_replace(child, cfds, preceding_children + tree[0:i]) + for i, child in enumerate(tree) if type(child) is Tree] if not children: # unable to replace this leaf return tree @@ -85,5 +121,48 @@ def tree_replace(tree): return Tree(tree.label(), children) +# TODO: this part should definitely be in a different class or module. I need to be able to resuse this method +# among all of my nlp expirements. See notes in this repo for more detail. +def make_cfd(text, n, cfd=None, exclude_punctuation=True, case_insensitive=True): + if not cfd: + cfd = {} + if exclude_punctuation: + nopunct = re.compile('^\w+$') + sentences = nltk.sent_tokenize(text) + for sent in sentences: + sent = nltk.word_tokenize(sent) + if case_insensitive: + sent = [word.lower() for word in sent] + if exclude_punctuation: + sent = [word for word in sent if nopunct.match(word)] + for i in range(len(sent) - (n - 1)): + condition = ' '.join(sent[i:(i + n) - 1]) + sample = sent[(i + n) - 1] + if condition in cfd: + if sample in cfd[condition]: + cfd[condition][sample] += 1 + else: + cfd[condition].update({sample: 1}) + else: + cfd[condition] = {sample: 1} + return cfd + + +def get_most_common(search, cfds, most_common=None): + if not most_common: + most_common = list() + words = search.split(' ') + for i in reversed(range(len(cfds))): + n = i + 2 + if len(words) >= (n - 1): + query = ' '.join(words[len(words) - (n - 1):]) + if query in cfds[i]: + most_common.extend([entry[0] for entry in sorted(cfds[i][query].items(), + key=operator.itemgetter(1), + reverse=True) + if entry[0] not in most_common]) + return most_common + + if __name__ == '__main__': generate()