| | import math |
| | import pickle |
| | import re |
| | import string |
| |
|
| | from nltk.tokenize import word_tokenize |
| | from nltk.tokenize.treebank import TreebankWordDetokenizer |
| |
|
| |
|
| | class TrueCaser(object): |
| | def __init__(self, dist_file_path): |
| | with open(dist_file_path, "rb") as distributions_file: |
| | pickle_dict = pickle.load(distributions_file) |
| | self.uni_dist = pickle_dict["uni_dist"] |
| | self.backward_bi_dist = pickle_dict["backward_bi_dist"] |
| | self.forward_bi_dist = pickle_dict["forward_bi_dist"] |
| | self.trigram_dist = pickle_dict["trigram_dist"] |
| | self.word_casing_lookup = pickle_dict["word_casing_lookup"] |
| | self.detknzr = TreebankWordDetokenizer() |
| |
|
| | def get_score(self, prev_token, possible_token, next_token): |
| | pseudo_count = 5.0 |
| |
|
| | |
| | numerator = self.uni_dist[possible_token] + pseudo_count |
| | denominator = 0 |
| | for alternativeToken in self.word_casing_lookup[possible_token.lower()]: |
| | denominator += self.uni_dist[alternativeToken] + pseudo_count |
| |
|
| | unigram_score = numerator / denominator |
| |
|
| | |
| | bigram_backward_score = 1 |
| | if prev_token is not None: |
| | key = prev_token + "_" + possible_token |
| | numerator = self.backward_bi_dist[key] + pseudo_count |
| | denominator = 0 |
| | for alternativeToken in self.word_casing_lookup[possible_token.lower()]: |
| | key = prev_token + "_" + alternativeToken |
| | denominator += self.backward_bi_dist[key] + pseudo_count |
| |
|
| | bigram_backward_score = numerator / denominator |
| |
|
| | |
| | bigram_forward_score = 1 |
| | if next_token is not None: |
| | next_token = next_token.lower() |
| | key = possible_token + "_" + next_token |
| | numerator = self.forward_bi_dist[key] + pseudo_count |
| | denominator = 0 |
| | for alternativeToken in self.word_casing_lookup[possible_token.lower()]: |
| | key = alternativeToken + "_" + next_token |
| | denominator += self.forward_bi_dist[key] + pseudo_count |
| |
|
| | bigram_forward_score = numerator / denominator |
| |
|
| | |
| | trigram_score = 1 |
| | if prev_token is not None and next_token is not None: |
| | next_token = next_token.lower() |
| | trigram_key = prev_token + "_" + possible_token + "_" + next_token |
| | numerator = self.trigram_dist[trigram_key] + pseudo_count |
| | denominator = 0 |
| | for alternativeToken in self.word_casing_lookup[possible_token.lower()]: |
| | trigram_key = prev_token + "_" + alternativeToken + "_" + next_token |
| | denominator += self.trigram_dist[trigram_key] + pseudo_count |
| |
|
| | trigram_score = numerator / denominator |
| |
|
| | result = ( |
| | math.log(unigram_score) |
| | + math.log(bigram_backward_score) |
| | + math.log(bigram_forward_score) |
| | + math.log(trigram_score) |
| | ) |
| |
|
| | return result |
| |
|
| | @staticmethod |
| | def first_token_case(raw): |
| | return raw.capitalize() |
| |
|
| | @staticmethod |
| | def upper_replacement(match): |
| | return '. ' + match.group(0)[-1].upper() |
| |
|
| | def get_true_case(self, sentence, out_of_vocabulary_token_option="title"): |
| | tokens = word_tokenize(sentence) |
| | tokens_true_case = self.get_true_case_from_tokens(tokens, out_of_vocabulary_token_option) |
| | text = self.detknzr.detokenize(tokens_true_case) |
| | text = re.sub(r' \. .', self.upper_replacement, text) |
| | return text |
| |
|
| | def get_true_case_from_tokens(self, tokens, out_of_vocabulary_token_option="title"): |
| | tokens_true_case = [] |
| |
|
| | if not len(tokens): |
| | return tokens_true_case |
| |
|
| | for token_idx, token in enumerate(tokens): |
| | if token in string.punctuation or token.isdigit(): |
| | tokens_true_case.append(token) |
| | continue |
| |
|
| | token = token.lower() |
| | if token not in self.word_casing_lookup: |
| | if out_of_vocabulary_token_option == "title": |
| | tokens_true_case.append(token.title()) |
| | elif out_of_vocabulary_token_option == "capitalize": |
| | tokens_true_case.append(token.capitalize()) |
| | elif out_of_vocabulary_token_option == "lower": |
| | tokens_true_case.append(token.lower()) |
| | else: |
| | tokens_true_case.append(token) |
| | continue |
| |
|
| | if len(self.word_casing_lookup[token]) == 1: |
| | tokens_true_case.append(list(self.word_casing_lookup[token])[0]) |
| | continue |
| |
|
| | prev_token = tokens_true_case[token_idx - 1] if token_idx > 0 else None |
| | next_token = tokens[token_idx + 1] if token_idx < len(tokens) - 1 else None |
| |
|
| | best_token = None |
| | highest_score = float("-inf") |
| |
|
| | for possible_token in self.word_casing_lookup[token]: |
| | score = self.get_score(prev_token, possible_token, next_token) |
| |
|
| | if score > highest_score: |
| | best_token = possible_token |
| | highest_score = score |
| |
|
| | tokens_true_case.append(best_token) |
| |
|
| | tokens_true_case[0] = self.first_token_case(tokens_true_case[0]) |
| | return tokens_true_case |
| |
|