# Sandiway Fong (c) University of Arizona
import concurrent.futures                 # ThreadPoolExecutor
import nltk
from nltk.corpus import ptb
import re
from datetime import datetime
from time import time

## GLOBALS

wts = []                        # (word tag) list
rules = []                      # syntax rule 'lhs -> rhs' list
mc_wts = []                     # most common items for wts (ordered)
mc_rules = []                              # most common items for rules

## GRAMMAR SYMBOL PROCESSING

def trim_num(s):
    """removes index etc. =1-4 =3 -1"""
    m = re.match(r'(.+?)([=-][0-9]+)+', s)
    if m:
        return m.group(1)
    else:
        return s

def trim_punc(s):
    """delete non-words"""
    if re.match(r'\W+$', s):
        return ''
    else:
        return s

# Overcome miscellaneous nltk grammar problems:
def misc(s):
    m = re.match(r"(.+)\$$", s)
    if m:
        return m.group(1) + 'S'                 # e.g. WP$ maps to WPS
    else:
        m = re.match(r"-(.+)-", s)
        if m:
            return m.group(1)               # e.g. -NONE- maps to NONE
        else:
            return s
        
def wordtag(p):
    """tuple (word, tag) with cleaned-up pos tag"""
    return (p.rhs()[0], misc(trim_num(p.lhs().symbol())))

def p_str2(p):
    """production to string lhs -> rhs with cleaned-up symbols"""
    return misc(trim_num(p.lhs().symbol())) + ' -> ' + ' '.join([misc(trim_punc(trim_num(nt.symbol()))) for nt in p.rhs() if not re.match(r"\W+$", nt.symbol())])

### LEXICAL RULES

class WordError(Exception):
        def __init__(self, word, message="word not found!"):
                self.word = word
                self.message = message
                super().__init__(self.message)

def find_wt(word):
    """look up word to find its (word tag)"""
    if word in overrides:
        return overrides[word]
    else:
        for (wt, c) in mc_wts:               # global: most common
            if wt[0] == word:
                return wt                 # first one!
        print(f"Error: {word} not found!")
        raise WordError(word)

def make_lexrule(wt):                            
    """tuple (word tag) maps to rule tag -> word"""
    if wt[0].startswith("'"):
        return wt[1] + ' -> "' + wt[0] + '"'   # tag -> "'word"
    else:
        return wt[1] + " -> '" + wt[0] + "'"   # tag -> 'word'

def lexrules(words):
    """list of words to list of lexical rules tag -> word"""
    try:
        wordtags = set([find_wt(word) for word in words])
        return [make_lexrule(wt) for wt in wordtags]
    except WordError as e:
        print(e.word, e.message)
        return None

def word_in_lexrules(word, lexrules):
    for rule in lexrules:
        if re.search(r"\b"+word+r"\b", rule):
            return True
    return False

def find_lexrule(word, lexrules):
    for rule in lexrules:
        if re.search(r"\b"+word+r"\b", rule):
            return rule

def lexrules2(words, lex):
    return [find_lexrule(w, lex) if word_in_lexrules(w, lex) else make_lexrule(find_wt(w)) for w in words]

### SYNTAX RULES

def rule_blocked(r):
    return r in blockedrules

def find_rule(rule):
    return mc_rules.index(rule)                # global: mc_rules

# stop can be a rule or a list of rules
def stopping(stop, trees, printparses, oneline, g, printgrammar):
    if stop != None:
        if type(stop) == 'str':
            if stop in g:
                print("FOUND {}!".format(stop))
                print_parses(trees, printparses, oneline)
                print_grammar(g, printgrammar)
                return True
        else:
            notfound = 0
            for r in stop:
                if r not in g:
                    notfound += 1
            if notfound == 0:
                print("FOUND ALL {}!".format(stop))
                print_parses(trees, printparses, oneline)
                print_grammar(g, printgrammar)
                return True
    else:
        print_parses(trees, printparses, oneline)
        print_grammar(g, printgrammar)
        return True
    return False

### PRINTING

def print_grammar(g, flag):
    if flag:
        for r in g:
            print(r)                

def print_parses(trees, flag, oneline):
    if flag:
        for tree in trees:
            if oneline:
                tree.pprint(margin=1000)
            else:
                print(tree)

def find_scfg(s, parses=1, ec=False, printparses=True, oneline=True, printgrammar=False, lex=None, stop=None):
    """finds the smallest cfg that can parse sentence s (string)"""
    words = s.split()
    if not lex:
        lex = lexrules(words)
    else:
        lex = lexrules2(words, lex)
        
    if ec:
        lex.append('NONE ->')        # EC rule
    print(lex)
    
    start = time()
    print('Begin CFG grow')

    blocked = 0
    syntax = []
    for l in range(len(mc_rules)):
        found = 0
        if rule_blocked(mc_rules[l]):
            l += 1
            blocked += 1
            continue
        syntax.append(mc_rules[l])
        g = syntax + lex
        cfg = nltk.CFG.fromstring(g)
        p = nltk.ChartParser(cfg)
        trees = []
        for tree in p.parse(words):
            trees.append(tree)
            found += 1
        if found:
            if found >= parses:
                print(f"Parses: {found:,}, # syntax rules: {l+1:,} (minus {blocked:,} blocked), # lexical rules: {len(lex):,}")
                parses = found + 1
                if stopping(stop, trees, printparses, oneline, g, printgrammar):
                    break
    print(f"grow CFG time: {time() - start:.2f} (s)")

def pos_rule(p):
    """POS -> word"""
    return(len(p.rhs()) == 1 and isinstance(p.rhs()[0], str))

def make_lexicalrules(ps):
    global wts
    global mc_wts
    wts = [wordtag(p) for p in ps if pos_rule(p)]
    mc_wts = nltk.FreqDist(wts).most_common()
    print(f"Lexical rules: {len(wts):,}")

def make_syntaxrules(ps):
    global rules
    global mc_rules
    rules = [p_str2(p) for p in ps if not pos_rule(p)]
    mc_rules = [r for (r, c) in nltk.FreqDist(rules).most_common() if c > 1]
    print(f"Syntax rules: {len(rules):,}")
    
overrides = {'that': ('that', 'WDT')}

blockedrules = ['S -> NONE', 'NP -> NP NP', 'NP-SBJ -> NP NP', 'NP -> DT', 'NP -> NP NP-ADV', 'PP -> NONE', 'SBAR -> NONE']

print("Creates lexical and syntactic rules from ptb: stored in mc_wts, mc_rules")
print("find_scfg(string) finds the smallest CFG that can parse string.")
print("  rules are added to the grammar in order of frequency: high > low.")
print("  optional parameters:")
print("    parses=1, ec=False, printparses=True, oneline=True, printgrammar=False, lex=None, stop=None.")
print("    parses: # of parses for string the smallest CFG must produce.")
print("    ec: True, permits -NONE- -> ''")
print("    printparses: True, prints the parses found.")
print("    oneline: True, prints each parse on one line. False, pprint()")
print("    printgrammar: True, prints the CFG found.")
print("    lex: list of lexical rules to use (overrides default)")
print("         example lexical rule: \"DT -> 'the'\"")
print("    stop: don't stop adding to the grammar until stop rule(s) are in.")
print("         example: 'VP -> VBD NP ADVP-TMP'")

print(datetime.now().strftime("%H:%M:%S"))
start = time()

ps = [p for t in ptb.parsed_sents() for p in t.productions()]
print(f"Productions: {len(ps):,}")

with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
    executor.submit(make_lexicalrules, ps)
    executor.submit(make_syntaxrules, ps)

print(f"Total time: {time()-start:.0f} (s)")
