Overhaul syntax gen: dual cfd & signature approach

Still a WIP. Needs a faster tagging method (e.g. spacy).
This commit is contained in:
Tyler Hallada 2016-11-28 16:07:19 -05:00
parent 3ace25b6e2
commit 82a209c771

View File

@ -1,7 +1,10 @@
import nltk
import operator
import os import os
import pickle import pickle
import random import random
import nltk import re
import codecs
from nltk.tree import Tree from nltk.tree import Tree
from collections import defaultdict from collections import defaultdict
from tqdm import tqdm from tqdm import tqdm
@ -9,38 +12,61 @@ from stat_parser import Parser
syntaxes = defaultdict(set) syntaxes = defaultdict(set)
SYNTAXES_FILE = 'syntaxes.p' SYNTAXES_FILE = 'syntaxes.p'
CFDS_FILE = 'cfds.p'
def tree_hash(self): def tree_hash(self):
return hash(tuple(self.leaves())) return hash(tuple(self.leaves()))
Tree.__hash__ = tree_hash Tree.__hash__ = tree_hash
# NOTE: to me: I need to replace nltk parse and tokenization with spacy because it is much faster and less detailed
# which is actually a plus. The problem is that spacy does not create a syntax tree like nltk does. However, it does
# create a dependency tree, which might be good enough for splitting into chunks that can be swapped out between
# corpora. Shitty bus wifi makes it hard to download spacy data and look up the docs.
def generate(): def generate():
global syntaxes global syntaxes
parser = Parser() parser = Parser()
if not os.path.exists(SYNTAXES_FILE): if not os.path.exists(SYNTAXES_FILE):
sents = nltk.corpus.gutenberg.sents('melville-moby_dick.txt') # sents = nltk.corpus.gutenberg.sents('results.txt')
sents = sents[0:100] # NOTE: results.txt is a big file of raw text not included in source control, provide your own corpus.
for sent in tqdm(sents): with codecs.open('results.txt', encoding='utf-8') as corpus:
try: sents = nltk.sent_tokenize(corpus.read())
parsed = parser.parse(' '.join(sent)) sents = [sent for sent in sents if len(sent) < 150][0:1500]
except TypeError: for sent in tqdm(sents):
pass try:
syntax_signature(parsed, save=True) parsed = parser.parse(sent)
except TypeError:
pass
syntax_signature(parsed, save=True)
with open(SYNTAXES_FILE, 'wb+') as pickle_file: with open(SYNTAXES_FILE, 'wb+') as pickle_file:
pickle.dump(syntaxes, pickle_file) pickle.dump(syntaxes, pickle_file)
else: else:
with open(SYNTAXES_FILE, 'rb+') as pickle_file: with open(SYNTAXES_FILE, 'rb+') as pickle_file:
syntaxes = pickle.load(pickle_file) syntaxes = pickle.load(pickle_file)
if not os.path.exists(CFDS_FILE):
# corpus = nltk.corpus.gutenberg.raw('results.txt')
with codecs.open('results.txt', encoding='utf-8') as corpus:
cfds = [make_cfd(corpus.read(), i, exclude_punctuation=False, case_insensitive=True) for i in range(2, 5)]
with open(CFDS_FILE, 'wb+') as pickle_file:
pickle.dump(cfds, pickle_file)
else:
with open(CFDS_FILE, 'rb+') as pickle_file:
cfds = pickle.load(pickle_file)
sents = nltk.corpus.gutenberg.sents('austen-emma.txt') sents = nltk.corpus.gutenberg.sents('austen-emma.txt')
sents = [sent for sent in sents if len(sent) < 50]
sent = random.choice(sents) sent = random.choice(sents)
parsed = parser.parse(' '.join(sent)) parsed = parser.parse(' '.join(sent))
print(parsed) print(parsed)
print(' '.join(parsed.leaves())) print(' '.join(parsed.leaves()))
replaced_tree = tree_replace(parsed) replaced_tree = tree_replace(parsed, cfds, [])
print('='*30) print('=' * 30)
print(' '.join(replaced_tree.leaves())) print(' '.join(replaced_tree.leaves()))
print(replaced_tree) print(replaced_tree)
@ -72,12 +98,22 @@ def syntax_signature_recurse(tree, save=False):
raise ValueError('Not a nltk.tree.Tree: {}'.format(tree)) raise ValueError('Not a nltk.tree.Tree: {}'.format(tree))
def tree_replace(tree): def tree_replace(tree, cfds, preceding_children=[]):
condition_search = ' '.join([' '.join(child.leaves()) for child in preceding_children]).lower()
sig = syntax_signature(tree) sig = syntax_signature(tree)
if sig in syntaxes: if sig in syntaxes:
return random.choice(tuple(syntaxes[sig])) matching_fragments = tuple(syntaxes[sig])
if len(matching_fragments) > 1 and condition_search:
matching_leaves = [' '.join(frag.leaves()) for frag in matching_fragments]
most_common = get_most_common(condition_search, cfds)
candidates = list(set(matching_leaves).intersection(set(most_common)))
if candidates:
return Tree(tree.label(), [random.choice(candidates)])
# find the first element of get_most_common that is also in this list of matching_leaves
return random.choice(matching_fragments)
else: else:
children = [tree_replace(child) for child in tree if type(child) is Tree] children = [tree_replace(child, cfds, preceding_children + tree[0:i])
for i, child in enumerate(tree) if type(child) is Tree]
if not children: if not children:
# unable to replace this leaf # unable to replace this leaf
return tree return tree
@ -85,5 +121,48 @@ def tree_replace(tree):
return Tree(tree.label(), children) return Tree(tree.label(), children)
# TODO: this part should definitely be in a different class or module. I need to be able to resuse this method
# among all of my nlp expirements. See notes in this repo for more detail.
def make_cfd(text, n, cfd=None, exclude_punctuation=True, case_insensitive=True):
if not cfd:
cfd = {}
if exclude_punctuation:
nopunct = re.compile('^\w+$')
sentences = nltk.sent_tokenize(text)
for sent in sentences:
sent = nltk.word_tokenize(sent)
if case_insensitive:
sent = [word.lower() for word in sent]
if exclude_punctuation:
sent = [word for word in sent if nopunct.match(word)]
for i in range(len(sent) - (n - 1)):
condition = ' '.join(sent[i:(i + n) - 1])
sample = sent[(i + n) - 1]
if condition in cfd:
if sample in cfd[condition]:
cfd[condition][sample] += 1
else:
cfd[condition].update({sample: 1})
else:
cfd[condition] = {sample: 1}
return cfd
def get_most_common(search, cfds, most_common=None):
if not most_common:
most_common = list()
words = search.split(' ')
for i in reversed(range(len(cfds))):
n = i + 2
if len(words) >= (n - 1):
query = ' '.join(words[len(words) - (n - 1):])
if query in cfds[i]:
most_common.extend([entry[0] for entry in sorted(cfds[i][query].items(),
key=operator.itemgetter(1),
reverse=True)
if entry[0] not in most_common])
return most_common
if __name__ == '__main__': if __name__ == '__main__':
generate() generate()