Various scripts for playing around with natural language processing/generation

syntax_aware_generate.py 6.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import nltk
  2. import operator
  3. import os
  4. import pickle
  5. import random
  6. import re
  7. import codecs
  8. import sys
  9. from nltk.tree import Tree
  10. from collections import defaultdict
  11. from tqdm import tqdm
  12. from stat_parser import Parser
  13. syntaxes = defaultdict(set)
  14. SYNTAXES_FILE = 'syntaxes.p'
  15. CFDS_FILE = 'cfds.p'
  16. def tree_hash(self):
  17. return hash(tuple(self.leaves()))
  18. Tree.__hash__ = tree_hash
  19. # NOTE: to me: I need to replace nltk parse and tokenization with spacy because it is much faster and less detailed
  20. # which is actually a plus. The problem is that spacy does not create a syntax tree like nltk does. However, it does
  21. # create a dependency tree, which might be good enough for splitting into chunks that can be swapped out between
  22. # corpora. Shitty bus wifi makes it hard to download spacy data and look up the docs.
  23. def generate(filename, word_limit=None):
  24. global syntaxes
  25. parser = Parser()
  26. if not os.path.exists(SYNTAXES_FILE):
  27. # sents = nltk.corpus.gutenberg.sents('results.txt')
  28. # NOTE: results.txt is a big file of raw text not included in source control, provide your own corpus.
  29. with codecs.open(filename, encoding='utf-8') as corpus:
  30. sents = nltk.sent_tokenize(corpus.read())
  31. if word_limit:
  32. sents = [sent for sent in sents if len(sent) < word_limit]
  33. sent_limit = min(1500, len(sents))
  34. sents[0:sent_limit]
  35. for sent in tqdm(sents):
  36. try:
  37. parsed = parser.parse(sent)
  38. except TypeError:
  39. pass
  40. syntax_signature(parsed, save=True)
  41. with open(SYNTAXES_FILE, 'wb+') as pickle_file:
  42. pickle.dump(syntaxes, pickle_file)
  43. else:
  44. with open(SYNTAXES_FILE, 'rb+') as pickle_file:
  45. syntaxes = pickle.load(pickle_file)
  46. if not os.path.exists(CFDS_FILE):
  47. with codecs.open(filename, encoding='utf-8') as corpus:
  48. cfds = [make_cfd(corpus.read(), i, exclude_punctuation=False, case_insensitive=True) for i in range(2, 5)]
  49. with open(CFDS_FILE, 'wb+') as pickle_file:
  50. pickle.dump(cfds, pickle_file)
  51. else:
  52. with open(CFDS_FILE, 'rb+') as pickle_file:
  53. cfds = pickle.load(pickle_file)
  54. sents = nltk.corpus.gutenberg.sents('austen-emma.txt')
  55. if word_limit:
  56. sents = [sent for sent in sents if len(sent) < word_limit]
  57. sent = random.choice(sents)
  58. parsed = parser.parse(' '.join(sent))
  59. print(parsed)
  60. print(' '.join(parsed.leaves()))
  61. replaced_tree = tree_replace(parsed, cfds, [])
  62. print('=' * 30)
  63. print(' '.join(replaced_tree.leaves()))
  64. print(replaced_tree)
  65. def list_to_string(l):
  66. return str(l).replace(" ", "").replace("'", "")
  67. def syntax_signature(tree, save=False):
  68. return list_to_string(syntax_signature_recurse(tree, save=save))
  69. def syntax_signature_recurse(tree, save=False):
  70. global syntaxes
  71. if type(tree) is Tree:
  72. label = tree.label()
  73. if label == ',':
  74. label = 'COMMA'
  75. children = [syntax_signature_recurse(child, save=save) for child in tree if type(child) is Tree]
  76. if not children:
  77. if save:
  78. syntaxes[label].add(tree)
  79. return label
  80. else:
  81. if save:
  82. syntaxes[list_to_string([label, children])].add(tree)
  83. return [label, children]
  84. else:
  85. raise ValueError('Not a nltk.tree.Tree: {}'.format(tree))
  86. def tree_replace(tree, cfds, preceding_children=[]):
  87. condition_search = ' '.join([' '.join(child.leaves()) for child in preceding_children]).lower()
  88. sig = syntax_signature(tree)
  89. if sig in syntaxes:
  90. matching_fragments = tuple(syntaxes[sig])
  91. if len(matching_fragments) > 1 and condition_search:
  92. matching_leaves = [' '.join(frag.leaves()) for frag in matching_fragments]
  93. most_common = get_most_common(condition_search, cfds)
  94. candidates = list(set(matching_leaves).intersection(set(most_common)))
  95. if candidates:
  96. return Tree(tree.label(), [random.choice(candidates)])
  97. # find the first element of get_most_common that is also in this list of matching_leaves
  98. return random.choice(matching_fragments)
  99. else:
  100. children = [tree_replace(child, cfds, preceding_children + tree[0:i])
  101. for i, child in enumerate(tree) if type(child) is Tree]
  102. if not children:
  103. # unable to replace this leaf
  104. return tree
  105. else:
  106. return Tree(tree.label(), children)
  107. # TODO: this part should definitely be in a different class or module. I need to be able to resuse this method
  108. # among all of my nlp expirements. See notes in this repo for more detail.
  109. def make_cfd(text, n, cfd=None, exclude_punctuation=True, case_insensitive=True):
  110. if not cfd:
  111. cfd = {}
  112. if exclude_punctuation:
  113. nopunct = re.compile('^\w+$')
  114. sentences = nltk.sent_tokenize(text)
  115. for sent in sentences:
  116. sent = nltk.word_tokenize(sent)
  117. if case_insensitive:
  118. sent = [word.lower() for word in sent]
  119. if exclude_punctuation:
  120. sent = [word for word in sent if nopunct.match(word)]
  121. for i in range(len(sent) - (n - 1)):
  122. condition = ' '.join(sent[i:(i + n) - 1])
  123. sample = sent[(i + n) - 1]
  124. if condition in cfd:
  125. if sample in cfd[condition]:
  126. cfd[condition][sample] += 1
  127. else:
  128. cfd[condition].update({sample: 1})
  129. else:
  130. cfd[condition] = {sample: 1}
  131. return cfd
  132. def get_most_common(search, cfds, most_common=None):
  133. if not most_common:
  134. most_common = list()
  135. words = search.split(' ')
  136. for i in reversed(range(len(cfds))):
  137. n = i + 2
  138. if len(words) >= (n - 1):
  139. query = ' '.join(words[len(words) - (n - 1):])
  140. if query in cfds[i]:
  141. most_common.extend([entry[0] for entry in sorted(cfds[i][query].items(),
  142. key=operator.itemgetter(1),
  143. reverse=True)
  144. if entry[0] not in most_common])
  145. return most_common
  146. if __name__ == '__main__':
  147. generate(sys.argv[1])