-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathConvOffset.py
112 lines (105 loc) · 5.19 KB
/
ConvOffset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
import torch.nn.functional as F
class Conv2D_Offset(nn.Module):
def __init__(self,mode='bilinear', padding_mode='border', align_corners=True,device = "cuda:7"):
'''
input tensor : NCHW
:param input_tensor:
'''
super(Conv2D_Offset,self).__init__()
self.mode = mode
self.paading_mode = padding_mode
self.align_corners = align_corners
self.device = device
def getgrid(self,row,column,dimension,batch_size):
grid = torch.randn((batch_size,2*9*dimension,row,column),requires_grad=False,device= self.device)
x_axis = [i for i in range(row)]
y_axis = [i for i in range(column)]
row_axis = torch.reshape(torch.tensor(column*x_axis),(row,column))
row_axis = torch.transpose(row_axis,1,0)
column_axis = torch.reshape(torch.tensor(row * y_axis), (row, column))
for i in range(18*dimension):
if i % 2 == 0:
grid[:,i,:,:] = column_axis
else:
grid[:,i,:,:] = row_axis
return grid/(((row*column)**0.5)/2) - 1
def forward(self,x):
row = x.shape[2]
column = x.shape[3]
dimension = x.shape[1]
batch_size = x.shape[0]
input_shape = x.shape
offset_output = torch.zeros((batch_size, 9 * dimension, row, column),requires_grad=False,device=self.device)
grid = self.getgrid(row,column,dimension,batch_size)
# grid = nn.Parameter(torch.FloatTensor(grid))
# dim_count = 0
# input_dim_count = 0
for i in range(batch_size*9*dimension):
batch_num = i//(9*dimension)
dim_count = i - 9*dimension*batch_num
input_dim_count = dim_count//9
offset = grid[batch_num,dim_count*2:(dim_count*2+2),:,:].unsqueeze(dim=0)
offset = offset.permute(0,2,3,1)
tmp = (F.grid_sample(x[batch_num,input_dim_count,:,:].unsqueeze(dim = 0).unsqueeze(dim = 0), offset, mode=self.mode, padding_mode=self.paading_mode, align_corners=self.align_corners)).squeeze(dim =0 )
offset_output[batch_num, dim_count, :, :] = tmp
# offset_output[batch_num,dim_count,:,:] = (F.grid_sample(x, offset, mode=self.mode, padding_mode=self.paading_mode, align_corners=self.align_corners)).squeeze()
offset_output = offset_output.permute(0, 2, 3, 1)
offset_output = torch.reshape(offset_output,(batch_size,-1,3*row,3*column))
# offset_output = offset_output.permute(0,2,3,1)
# offset = torch.reshape()
return offset_output
# b = Conv2D_Offset()
# a = torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18],dtype=torch.float)
# a = torch.reshape(a,(1,2,3,3))
# b.forward(a)
class Conv2D_Offset(nn.Module):
def __init__(self,mode='bilinear', padding_mode='border', align_corners=True,device = "cuda:7"):
'''
input tensor : NCHW
:param input_tensor:
'''
super(Conv2D_Offset,self).__init__()
self.mode = mode
self.paading_mode = padding_mode
self.align_corners = align_corners
self.device = device
def getgrid(self,row,column,dimension,batch_size):
grid = torch.randn((batch_size,2*9*dimension,row,column),requires_grad=False,device= self.device)
x_axis = [i for i in range(row)]
y_axis = [i for i in range(column)]
row_axis = torch.reshape(torch.tensor(column*x_axis),(row,column))
row_axis = torch.transpose(row_axis,1,0)
column_axis = torch.reshape(torch.tensor(row * y_axis), (row, column))
for i in range(18*dimension):
if i % 2 == 0:
grid[:,i,:,:] = column_axis
else:
grid[:,i,:,:] = row_axis
return grid/(((row*column)**0.5)/2) - 1
def forward(self,x):
row = x.shape[2]
column = x.shape[3]
dimension = x.shape[1]
batch_size = x.shape[0]
input_shape = x.shape
offset_output = torch.zeros((batch_size, 9 * dimension, row, column),requires_grad=False,device=self.device)
grid = self.getgrid(row,column,dimension,batch_size)
# grid = nn.Parameter(torch.FloatTensor(grid))
# dim_count = 0
# input_dim_count = 0
for i in range(batch_size*9*dimension):
batch_num = i//(9*dimension)
dim_count = i - 9*dimension*batch_num
input_dim_count = dim_count//9
offset = grid[batch_num,dim_count*2:(dim_count*2+2),:,:].unsqueeze(dim=0)
offset = offset.permute(0,2,3,1)
tmp = (F.grid_sample(x[batch_num,input_dim_count,:,:].unsqueeze(dim = 0).unsqueeze(dim = 0), offset, mode=self.mode, padding_mode=self.paading_mode, align_corners=self.align_corners)).squeeze(dim =0 )
offset_output[batch_num, dim_count, :, :] = tmp
# offset_output[batch_num,dim_count,:,:] = (F.grid_sample(x, offset, mode=self.mode, padding_mode=self.paading_mode, align_corners=self.align_corners)).squeeze()
offset_output = offset_output.permute(0, 2, 3, 1)
offset_output = torch.reshape(offset_output,(batch_size,-1,3*row,3*column))
# offset_output = offset_output.permute(0,2,3,1)
# offset = torch.reshape()
return offset_output