-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
40 lines (31 loc) · 1.31 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from tokenizers import Tokenizer
import torch
from model import Seq2SeqTransformer, config
import argparse
tokenizer = Tokenizer.from_file('./model/tokenizer.json')
special_tokens = {
'<unk>': 0,
'<pad>': 1,
'<s-en>': 2, # bos english
'<s-hi>': 3, # bos hindi
'<s-te>': 4, # bos telugu
'</s>': 5, # eos
}
model = Seq2SeqTransformer(config)
state_dict = torch.load('./model/best_model.pt',map_location='cpu')
model.load_state_dict(state_dict)
def translate(input_sentence,language='hi',d=True,t=1.0):
input_ids = f"<s-en>{input_sentence.strip()}</s>"
input_ids = tokenizer.encode(input_ids).ids
input_ids = torch.tensor(input_ids,dtype=torch.long).unsqueeze(0)
bos = special_tokens[f"<s-{language}>"]
outputs = model.generate(input_ids,deterministic=d,bos=bos,temperature=t)
translation = tokenizer.decode(outputs.numpy())
return translation
parser = argparse.ArgumentParser()
parser.add_argument('-l',default='hi',required=True,help="hi:hindi,te:telugu")
parser.add_argument('--text',required=True,help="english text to translate")
parser.add_argument('-s',default=True,action='store_true',help='do_sample')
parser.add_argument('-t',default=1.0,type=float,help='temperature')
args = parser.parse_args()
print(translate(input_sentence=args.text,language=args.l,d=not args.s,t=args.t))