-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathngram.py
More file actions
65 lines (52 loc) · 1.85 KB
/
ngram.py
File metadata and controls
65 lines (52 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import logging
from typing import Generator, MutableMapping
import dbm
import random
from fastapi.logger import logger
from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordDetokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_gram(input_text: str, n: int = 3, append: bool = False) -> tuple[str]:
tokens = word_tokenize(input_text)
if len(tokens) < n:
if append:
tokens = ["."] * (n - len(tokens)) + tokens
else:
raise ValueError(f"Input text must contain at least {n} tokens.")
return tuple(tokens[-n:])
def generate_text(
current_gram: tuple[str],
model: MutableMapping[str, list[str]] | MutableMapping[bytes, bytes],
num_tokens: int = 50,
) -> str:
output = []
for next_word in token_generator(current_gram, model, num_tokens):
if next_word is None:
break
output.append(next_word)
sentence = TreebankWordDetokenizer().detokenize(output)
return sentence
def token_generator(
current_gram: tuple[str],
model: dict[str, list[str]] | MutableMapping[bytes, bytes],
num_tokens: int = 50,
) -> Generator[str, None]:
i = 0
while i < num_tokens:
current_ngram_bytes = bytearray('\x00'.join(current_gram), 'utf-8')
possibilities = (
list(model[current_ngram_bytes].decode('utf-8').split('\x00'))
if current_ngram_bytes in model
else None
)
if not possibilities:
logger.warning("No possibilities found, stopping generation.")
yield None
break
next_word = random.choice(possibilities)
i += 1
current_gram = (*current_gram[1:], next_word)
yield next_word
def load_model(db_path: str) -> MutableMapping[bytes, bytes]:
return dbm.open(db_path, "r")