2016-11-28 21:07:19 +00:00
|
|
|
import nltk
|
|
|
|
import operator
|
2016-05-02 02:34:43 +00:00
|
|
|
import os
|
|
|
|
import pickle
|
|
|
|
import random
|
2016-11-28 21:07:19 +00:00
|
|
|
import re
|
|
|
|
import codecs
|
2017-03-14 05:03:23 +00:00
|
|
|
import sys
|
2016-05-02 02:34:43 +00:00
|
|
|
from nltk.tree import Tree
|
|
|
|
from collections import defaultdict
|
|
|
|
from tqdm import tqdm
|
|
|
|
from stat_parser import Parser
|
|
|
|
|
|
|
|
syntaxes = defaultdict(set)
|
|
|
|
SYNTAXES_FILE = 'syntaxes.p'
|
2016-11-28 21:07:19 +00:00
|
|
|
CFDS_FILE = 'cfds.p'
|
2016-05-02 02:34:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
def tree_hash(self):
|
|
|
|
return hash(tuple(self.leaves()))
|
|
|
|
|
2016-11-28 21:07:19 +00:00
|
|
|
|
2016-05-02 02:34:43 +00:00
|
|
|
Tree.__hash__ = tree_hash
|
|
|
|
|
|
|
|
|
2016-11-28 21:07:19 +00:00
|
|
|
# NOTE: to me: I need to replace nltk parse and tokenization with spacy because it is much faster and less detailed
|
|
|
|
# which is actually a plus. The problem is that spacy does not create a syntax tree like nltk does. However, it does
|
|
|
|
# create a dependency tree, which might be good enough for splitting into chunks that can be swapped out between
|
|
|
|
# corpora. Shitty bus wifi makes it hard to download spacy data and look up the docs.
|
|
|
|
|
|
|
|
|
2017-03-14 05:03:23 +00:00
|
|
|
def generate(filename):
|
2016-05-02 02:34:43 +00:00
|
|
|
global syntaxes
|
|
|
|
parser = Parser()
|
|
|
|
if not os.path.exists(SYNTAXES_FILE):
|
2016-11-28 21:07:19 +00:00
|
|
|
# sents = nltk.corpus.gutenberg.sents('results.txt')
|
|
|
|
# NOTE: results.txt is a big file of raw text not included in source control, provide your own corpus.
|
2017-03-14 05:03:23 +00:00
|
|
|
with codecs.open(filename, encoding='utf-8') as corpus:
|
2016-11-28 21:07:19 +00:00
|
|
|
sents = nltk.sent_tokenize(corpus.read())
|
|
|
|
sents = [sent for sent in sents if len(sent) < 150][0:1500]
|
|
|
|
for sent in tqdm(sents):
|
|
|
|
try:
|
|
|
|
parsed = parser.parse(sent)
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
syntax_signature(parsed, save=True)
|
2016-05-02 02:34:43 +00:00
|
|
|
with open(SYNTAXES_FILE, 'wb+') as pickle_file:
|
|
|
|
pickle.dump(syntaxes, pickle_file)
|
|
|
|
else:
|
|
|
|
with open(SYNTAXES_FILE, 'rb+') as pickle_file:
|
|
|
|
syntaxes = pickle.load(pickle_file)
|
2016-11-28 21:07:19 +00:00
|
|
|
|
|
|
|
if not os.path.exists(CFDS_FILE):
|
2017-03-14 05:03:23 +00:00
|
|
|
with codecs.open(filename, encoding='utf-8') as corpus:
|
2016-11-28 21:07:19 +00:00
|
|
|
cfds = [make_cfd(corpus.read(), i, exclude_punctuation=False, case_insensitive=True) for i in range(2, 5)]
|
|
|
|
with open(CFDS_FILE, 'wb+') as pickle_file:
|
|
|
|
pickle.dump(cfds, pickle_file)
|
|
|
|
else:
|
|
|
|
with open(CFDS_FILE, 'rb+') as pickle_file:
|
|
|
|
cfds = pickle.load(pickle_file)
|
|
|
|
|
2016-05-02 02:34:43 +00:00
|
|
|
sents = nltk.corpus.gutenberg.sents('austen-emma.txt')
|
2016-11-28 21:07:19 +00:00
|
|
|
sents = [sent for sent in sents if len(sent) < 50]
|
2016-05-02 02:34:43 +00:00
|
|
|
sent = random.choice(sents)
|
|
|
|
parsed = parser.parse(' '.join(sent))
|
|
|
|
print(parsed)
|
|
|
|
print(' '.join(parsed.leaves()))
|
2016-11-28 21:07:19 +00:00
|
|
|
replaced_tree = tree_replace(parsed, cfds, [])
|
|
|
|
print('=' * 30)
|
2016-05-02 02:34:43 +00:00
|
|
|
print(' '.join(replaced_tree.leaves()))
|
|
|
|
print(replaced_tree)
|
|
|
|
|
|
|
|
|
|
|
|
def list_to_string(l):
|
|
|
|
return str(l).replace(" ", "").replace("'", "")
|
|
|
|
|
|
|
|
|
|
|
|
def syntax_signature(tree, save=False):
|
|
|
|
return list_to_string(syntax_signature_recurse(tree, save=save))
|
|
|
|
|
|
|
|
|
|
|
|
def syntax_signature_recurse(tree, save=False):
|
|
|
|
global syntaxes
|
|
|
|
if type(tree) is Tree:
|
|
|
|
label = tree.label()
|
|
|
|
if label == ',':
|
|
|
|
label = 'COMMA'
|
|
|
|
children = [syntax_signature_recurse(child, save=save) for child in tree if type(child) is Tree]
|
|
|
|
if not children:
|
|
|
|
if save:
|
|
|
|
syntaxes[label].add(tree)
|
|
|
|
return label
|
|
|
|
else:
|
|
|
|
if save:
|
|
|
|
syntaxes[list_to_string([label, children])].add(tree)
|
|
|
|
return [label, children]
|
|
|
|
else:
|
|
|
|
raise ValueError('Not a nltk.tree.Tree: {}'.format(tree))
|
|
|
|
|
|
|
|
|
2016-11-28 21:07:19 +00:00
|
|
|
def tree_replace(tree, cfds, preceding_children=[]):
|
|
|
|
condition_search = ' '.join([' '.join(child.leaves()) for child in preceding_children]).lower()
|
2016-05-02 02:34:43 +00:00
|
|
|
sig = syntax_signature(tree)
|
|
|
|
if sig in syntaxes:
|
2016-11-28 21:07:19 +00:00
|
|
|
matching_fragments = tuple(syntaxes[sig])
|
|
|
|
if len(matching_fragments) > 1 and condition_search:
|
|
|
|
matching_leaves = [' '.join(frag.leaves()) for frag in matching_fragments]
|
|
|
|
most_common = get_most_common(condition_search, cfds)
|
|
|
|
candidates = list(set(matching_leaves).intersection(set(most_common)))
|
|
|
|
if candidates:
|
|
|
|
return Tree(tree.label(), [random.choice(candidates)])
|
|
|
|
# find the first element of get_most_common that is also in this list of matching_leaves
|
|
|
|
return random.choice(matching_fragments)
|
2016-05-02 02:34:43 +00:00
|
|
|
else:
|
2016-11-28 21:07:19 +00:00
|
|
|
children = [tree_replace(child, cfds, preceding_children + tree[0:i])
|
|
|
|
for i, child in enumerate(tree) if type(child) is Tree]
|
2016-05-02 02:34:43 +00:00
|
|
|
if not children:
|
|
|
|
# unable to replace this leaf
|
|
|
|
return tree
|
|
|
|
else:
|
|
|
|
return Tree(tree.label(), children)
|
|
|
|
|
|
|
|
|
2016-11-28 21:07:19 +00:00
|
|
|
# TODO: this part should definitely be in a different class or module. I need to be able to resuse this method
|
|
|
|
# among all of my nlp expirements. See notes in this repo for more detail.
|
|
|
|
def make_cfd(text, n, cfd=None, exclude_punctuation=True, case_insensitive=True):
|
|
|
|
if not cfd:
|
|
|
|
cfd = {}
|
|
|
|
if exclude_punctuation:
|
|
|
|
nopunct = re.compile('^\w+$')
|
|
|
|
sentences = nltk.sent_tokenize(text)
|
|
|
|
for sent in sentences:
|
|
|
|
sent = nltk.word_tokenize(sent)
|
|
|
|
if case_insensitive:
|
|
|
|
sent = [word.lower() for word in sent]
|
|
|
|
if exclude_punctuation:
|
|
|
|
sent = [word for word in sent if nopunct.match(word)]
|
|
|
|
for i in range(len(sent) - (n - 1)):
|
|
|
|
condition = ' '.join(sent[i:(i + n) - 1])
|
|
|
|
sample = sent[(i + n) - 1]
|
|
|
|
if condition in cfd:
|
|
|
|
if sample in cfd[condition]:
|
|
|
|
cfd[condition][sample] += 1
|
|
|
|
else:
|
|
|
|
cfd[condition].update({sample: 1})
|
|
|
|
else:
|
|
|
|
cfd[condition] = {sample: 1}
|
|
|
|
return cfd
|
|
|
|
|
|
|
|
|
|
|
|
def get_most_common(search, cfds, most_common=None):
|
|
|
|
if not most_common:
|
|
|
|
most_common = list()
|
|
|
|
words = search.split(' ')
|
|
|
|
for i in reversed(range(len(cfds))):
|
|
|
|
n = i + 2
|
|
|
|
if len(words) >= (n - 1):
|
|
|
|
query = ' '.join(words[len(words) - (n - 1):])
|
|
|
|
if query in cfds[i]:
|
|
|
|
most_common.extend([entry[0] for entry in sorted(cfds[i][query].items(),
|
|
|
|
key=operator.itemgetter(1),
|
|
|
|
reverse=True)
|
|
|
|
if entry[0] not in most_common])
|
|
|
|
return most_common
|
|
|
|
|
|
|
|
|
2016-05-02 02:34:43 +00:00
|
|
|
if __name__ == '__main__':
|
2017-03-14 05:03:23 +00:00
|
|
|
generate(sys.argv[1])
|