-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtwoStreamModel.py
24 lines (21 loc) · 1.07 KB
/
twoStreamModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from flow_resnet import *
from objectAttentionModelConvLSTM import *
import torch.nn as nn
class twoStreamAttentionModel(nn.Module):
def __init__(self, flowModel='', frameModel='', stackSize=5, memSize=512, num_classes=61):
super(twoStreamAttentionModel, self).__init__()
self.flowModel = flow_resnet34(False, channels=2*stackSize, num_classes=num_classes)
if flowModel != '':
self.flowModel.load_state_dict(torch.load(flowModel))
self.frameModel = attentionModel(num_classes, memSize)
if frameModel != '':
self.frameModel.load_state_dict(torch.load(frameModel))
self.fc2 = nn.Linear(512 * 2, num_classes, bias=True)
self.dropout = nn.Dropout(0.5)
self.classifier = nn.Sequential(self.dropout, self.fc2)
def forward(self, inputVariableFlow, inputVariableFrame):
_, flowFeats = self.flowModel(inputVariableFlow)
_, rgbFeats = self.frameModel(inputVariableFrame)
twoStreamFeats = torch.cat((flowFeats, rgbFeats), 1)
return self.classifier(twoStreamFeats)