-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel.py
114 lines (83 loc) · 3.41 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np
from torch.autograd import Variable
class HighwayNetwork(nn.Module):
def __init__(self, input_size,activation='ReLU'):
super(HighwayNetwork, self).__init__()
#transform gate
self.trans_gate = nn.Sequential(
nn.Linear(input_size,input_size),
nn.Sigmoid())
#highway
if activation== 'ReLU':
self.activation = nn.ReLU()
self.h_layer = nn.Sequential(
nn.Linear(input_size,input_size),
self.activation)
#self.trans_gate[0].weight.data.uniform_(-0.05,0.05)
#self.h_layer[0].weight.data.uniform_(-0.05,0.05)
self.trans_gate[0].bias.data.fill_(-2)
#self.h_layer[0].bias.data.fill_(0)
def forward(self,x):
t = self.trans_gate(x)
h = self.h_layer(x)
z = torch.mul(t,h)+torch.mul(1-t,x)
return z
class LM(nn.Module):
def __init__(self,word_vocab,char_vocab,max_len,embed_dim,out_channels,kernels,hidden_size,batch_size):
super(LM, self).__init__()
self.word_vocab = word_vocab
self.char_vocab = char_vocab
#Embedding layer
self.embed = nn.Embedding(len(char_vocab)+1, embed_dim,padding_idx=0)
#CNN layer
self.cnns = []
for kernel in kernels:
self.cnns.append(nn.Sequential(
nn.Conv2d(1,out_channels*kernel,kernel_size=(kernel,embed_dim)),
nn.Tanh(),
nn.MaxPool2d((max_len-kernel+1,1))))
self.cnns = nn.ModuleList(self.cnns)
#highway layer
input_size = np.asscalar(out_channels*np.sum(kernels))
self.highway = HighwayNetwork(input_size)
self.highway2 = HighwayNetwork(input_size)
#lstm layer
self.lstm = nn.LSTM(input_size,hidden_size,2,batch_first=True,dropout=0.5)
#output layer
self.linear = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(hidden_size,len(word_vocab)))
#self.init_weight()
def init_weight(self):
#self.embed.weight.data.uniform_(-0.05,0.05)
for cnn in self.cnns:
cnn[0].weight.data.uniform_(-0.05,0.05)
cnn[0].bias.data.fill_(0)
self.linear[1].weight.data.uniform_(-0.05,0.05)
self.linear[1].bias.data.fill_(0)
self.lstm.weight_hh_l0.data.uniform_(-0.05,0.05)
self.lstm.weight_hh_l1.data.uniform_(-0.05,0.05)
self.lstm.weight_ih_l0.data.uniform_(-0.05,0.05)
self.lstm.weight_ih_l1.data.uniform_(-0.05,0.05)
self.lstm.bias_hh_l0.data.fill_(0)
self.lstm.bias_hh_l1.data.fill_(0)
self.lstm.bias_ih_l0.data.fill_(0)
self.lstm.bias_ih_l1.data.fill_(0)
def forward(self,x,h):
batch_size = x.shape[0]
seq_len = x.shape[1]
x = x.contiguous().view(-1,x.shape[2])
x = self.embed(x)
x = x.contiguous().view(x.shape[0],1,x.shape[1],x.shape[2])
y = [cnn(x).squeeze() for cnn in self.cnns]
w = torch.cat(y,1)
w = self.highway(w)
w = self.highway2(w)
w = w.contiguous().view(batch_size,seq_len,-1)
out, h = self.lstm(w,h)
out = out.contiguous().view(batch_size*seq_len,-1)
out = self.linear(out)
return out,h