Depthwise-STFT卷积复现

发布 : 2020-02-23 分类 : 深度学习 浏览 :

image.png
Depthwise-STFT based separable Convolutional Neural Networks》论文复现。该论文提出了使用STFT替换Depthwise结构的卷积层,达到提速的效果。这里是未经优化的复现了文中提到的Depthwise-STFT卷积操作(卷积操作部分可以优化,预计计算了减少8倍),由于未优化,速度并没有得到明显的提升。

但是该文提供的改进思路值得学习。

  • 改进思路

image.png

  • 论文中达到的效果:

image.png

Depthwise-STFT Conv

卷积操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import cv2


class STFTConv(nn.Module):
def __init__(self, in_c, out_c, kernel_size, stride=1, padding=0):
super(STFTConv, self).__init__()
self.in_c = in_c
self.out_c = out_c
self.stride = stride
self.padding = padding
self.kernel_size = kernel_size

self.Y = self.define_Y(kernel_size)
n = kernel_size[0]
a = 1 / n
self.V = np.array([[a, 0], [0, a], [a, a], [a, -a]])
w1_r, w1_i = self.kernel_fn(self.V[0], self.Y)
w2_r, w2_i = self.kernel_fn(self.V[1], self.Y)
w3_r, w3_i = self.kernel_fn(self.V[2], self.Y)
w4_r, w4_i = self.kernel_fn(self.V[3], self.Y)

w1_r = torch.FloatTensor(w1_r).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1])
self.w1_r = nn.Parameter(w1_r, requires_grad=False)
w1_i = torch.FloatTensor(w1_i).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1])
self.w1_i = nn.Parameter(w1_i, requires_grad=False)

w2_r = torch.FloatTensor(w2_r).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1])
self.w2_r = nn.Parameter(w2_r, requires_grad=False)
w2_i = torch.FloatTensor(w2_i).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1])
self.w2_i = nn.Parameter(w2_i, requires_grad=False)

w3_r = torch.FloatTensor(w3_r).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1])
self.w3_r = nn.Parameter(w3_r, requires_grad=False)
w3_i = torch.FloatTensor(w3_i).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1])
self.w3_i = nn.Parameter(w3_i, requires_grad=False)

w4_r = torch.FloatTensor(w4_r).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1])
self.w4_r = nn.Parameter(w4_r, requires_grad=False)
w4_i = torch.FloatTensor(w4_i).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1])
self.w4_i = nn.Parameter(w4_i, requires_grad=False)

def forward(self, X):
c1_r = F.conv2d(X, self.w1_r, stride=self.stride, padding=self.padding)
c1_i = F.conv2d(X, self.w1_i, stride=self.stride, padding=self.padding)

c2_r = F.conv2d(X, self.w2_r, stride=self.stride, padding=self.padding)
c2_i = F.conv2d(X, self.w2_i, stride=self.stride, padding=self.padding)

c3_r = F.conv2d(X, self.w3_r, stride=self.stride, padding=self.padding)
c3_i = F.conv2d(X, self.w3_i, stride=self.stride, padding=self.padding)

c4_r = F.conv2d(X, self.w4_r, stride=self.stride, padding=self.padding)
c4_i = F.conv2d(X, self.w4_i, stride=self.stride, padding=self.padding)
c = torch.cat((c1_r, c1_i, c2_r, c2_i, c3_r, c3_i, c4_r, c4_i), dim=1)
return torch.abs(c)

@staticmethod
def define_Y(kernel_size):
assert len(kernel_size) % 2 == 0 # 2D
w, h = kernel_size
Y = []
for i in range(w):
yi = []
for j in range(h):
yi.append([[i], [j]]) # 列向量
Y.append(yi)
Y = np.array(Y) # 从1开始
return Y

def kernel_fn(self, v, Y):
w = v.dot(Y)
w = np.squeeze(w, axis=2)
return np.cos(2 * np.pi * w), -np.sin(2 * np.pi * w)


def _get_stft_kernels(size, v):
assert len(size) % 2 == 0 # 2D
h, w = size
Y = []
for i in range(w):
yi = []
for j in range(h):
yi.append([[i], [j]]) # 列向量
Y.append(yi)
Y = np.array(Y) + 1 # 从1开始

def kernel_fn():
w = v.dot(Y)
w = np.squeeze(w, axis=2)
return np.cos(2 * np.pi * w), -np.sin(2 * np.pi * w)

return kernel_fn()

Depthwise-STFT

image.png

1
2
3
4
5
6
7
8
9
class DepthwiseSTFT(nn.Module):
def __init__(self, in_c, out_c, kernel_size=(3, 3), stride=1, padding=0):
super(DepthwiseSTFT, self).__init__()
self.stft = STFTConv(in_c, in_c, kernel_size, stride=stride, padding=padding)
self.conv = nn.Conv2d(in_c * 8, out_c, kernel_size=(1, 1), stride=1)

def forward(self, X):
t = self.stft(X)
return self.conv(t)

Block

image.png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

class Block1(nn.Module):
def __init__(self, in_c, out_c):
super(Block1, self).__init__()
self.stft_3 = DepthwiseSTFT(in_c, in_c, kernel_size=(3, 3), padding=1)
self.stft_5 = DepthwiseSTFT(in_c, in_c, kernel_size=(5, 5), padding=2)
self.conv = nn.Conv2d(in_c * 2, out_c, kernel_size=(1, 1))

def forward(self, X):
d1 = self.stft_3(X)
d2 = self.stft_5(X)
d = torch.cat((d1, d2), dim=1)
return self.conv(d)


class Block2(nn.Module):
def __init__(self, in_c, b, out_c):
super(Block2, self).__init__()
self.b = b
self.conv1 = nn.Conv2d(in_c, self.b, kernel_size=(1, 1))
self.stft_3 = DepthwiseSTFT(self.b, self.b, kernel_size=(3, 3), padding=1)
self.stft_5 = DepthwiseSTFT(self.b, self.b, kernel_size=(5, 5), padding=2)
self.conv2 = nn.Conv2d(in_c, out_c, kernel_size=(1, 1))
self.conv = nn.Conv2d(out_c + self.b * 2, out_c, kernel_size=(1, 1))

def forward(self, X):
c1 = self.conv1(X)
c2 = self.conv2(X)
d1 = self.stft_3(c1)
d2 = self.stft_5(c1)
d = torch.cat((d1, d2, c2), dim=1)

return self.conv(d)

网络整体架构

image.png

本文作者 : HeoLis
原文链接 : http://ishero.net/Depthwise-STFT%E5%8D%B7%E7%A7%AF%E5%A4%8D%E7%8E%B0.html
版权声明 : 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明出处!

学习、记录、分享、获得

微信扫一扫, 向我投食

微信扫一扫, 向我投食