-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathutils.py
107 lines (87 loc) · 2.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import numpy as np
import bitarray
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer
def decode(self, token_ids, **kwargs):
filtered_tokens = self.convert_ids_to_tokens(token_ids)
text = self.convert_tokens_to_string(filtered_tokens)
return text
GPT2Tokenizer.decode = decode
def _convert_token_to_id(self, token):
return self.encoder.get(token, 0)
GPT2Tokenizer._convert_token_to_id = _convert_token_to_id
def limit_past(past):
past = list(past)
for i in range(len(past)):
past[i] = past[i][:, :, :, -1022:]
return past
def kl(q, logq, logp):
res = q*(logq-logp)/0.69315
res[q==0] = 0
return res.sum().item() # in bits
def entropy(q, logq):
res = q*logq/0.69315
res[q==0] = 0
return -res.sum().item() # in bits
# e.g. [0, 1, 1, 1] looks like 1110=14
def bits2int(bits):
res = 0
for i, bit in enumerate(bits):
res += bit*(2**i)
return res
def int2bits(inp, num_bits):
if num_bits == 0:
return []
strlist = ('{0:0%db}'%num_bits).format(inp)
return [int(strval) for strval in reversed(strlist)]
def is_sent_finish(token_idx, enc):
token = enc.decoder[token_idx]
return '.' in token or '!' in token or '?' in token
def num_same_from_beg(bits1, bits2):
assert len(bits1) == len(bits2)
for i in range(len(bits1)):
if bits1[i] != bits2[i]:
break
return i
def encode_context(raw_text, enc):
context_tokens = [enc.encoder['<|endoftext|>']] + enc.encode(raw_text)
return context_tokens
# Use gpt2-medium for 345M param model
# Use gpt2-large for 774M param model
def get_model(seed=1234, model_name='gpt2'):
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enc = GPT2Tokenizer.from_pretrained(model_name)
enc.unk_token = None
enc.bos_token = None
enc.eos_token = None
model = GPT2LMHeadModel.from_pretrained(model_name)
model.to(device)
model.eval()
#model.double()
return enc, model
enc32_itoc = ['\0', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '.', ',', "'", '!', ' ']
enc32_ctoi = {k: v for v, k in enumerate(enc32_itoc)}
def enc32(text):
bits = []
for c in text:
bits.extend(int2bits(enc32_ctoi[c], 5))
return bits
def dec32(bits):
text = ''
for i in range(0, len(bits), 5):
c = enc32_itoc[bits2int(bits[i:i+5])]
if c == '\0':
break
text += c
return text
# message should be bit string
# encoded should be text string
def expansion_ratio(message, encoded):
message_bits = len(message)
encoded_ba = bitarray.bitarray()
encoded_ba.frombytes(encoded.encode('utf-8'))
encoded_bits = len(encoded_ba.tolist())
return encoded_bits/message_bits