原理

卷积的工作:

BN的工作:

把卷积带入BN:

融合后的新卷积:

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from torch import cuda
import torch.nn as nn
from time import time
from copy import deepcopy

torch.manual_seed(1)
torch.cuda.manual_seed(1)

# 原样返回,占位用的,免得报错
class DummyModule(nn.Module):
def __init__(self):
super(DummyModule, self).__init__()

def forward(self, x):
return x

# 核心函数,把conv和bn进行融合
def fuse(conv, bn):
w = conv.weight
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var + bn.eps)

beta = bn.weight
gamma = bn.bias

if conv.bias is not None:
b = conv.bias
else:
b = mean.new_zeros(mean.shape)

w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
b = (b - mean)/var_sqrt * beta + gamma

fused_conv = nn.Conv2d(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
conv.dilation,
conv.groups,
bias=True,
padding_mode=conv.padding_mode
)
fused_conv.weight = nn.Parameter(w)
fused_conv.bias = nn.Parameter(b)
return fused_conv

# 对整个网络的bn进行融合
def fuse_module(m):
children = list(m.named_children())
conv = None
conv_name = None
# 遍历整个网路的层
for name, child in children:
# 如果该层为bn且上一层为conv,则融合
if isinstance(child, nn.BatchNorm2d) and conv:
bc = fuse(conv, child)
m._modules[conv_name] = bc
m._modules[name] = DummyModule()
conv = None
# 如果该层为conv,记录该层
elif isinstance(child, nn.Conv2d):
conv = child
conv_name = name
else:
# 递归
fuse_module(child)

# 验证融合后网络与原网络的差异
def validate(net, cuda=torch.cuda.is_available()):
net.eval()
fused_net = deepcopy(net)
fused_net.eval()
fuse_module(fused_net)

error = 0
origin_time = 0
fused_time = 0

if cuda:
net.cuda()
fused_net.cuda()
n = 1
with torch.no_grad():
for _ in range(n):
x = torch.randn(size=(32, 3, 224, 224))
if cuda:
x = x.cuda()

torch.cuda.synchronize()
start = time()
out_origin = net(x)
torch.cuda.synchronize()
end = time()
origin_time += end - start

torch.cuda.synchronize()
start = time()
out_fused = fused_net(x)
torch.cuda.synchronize()
end = time()
fused_time += end - start

error += (out_origin - out_fused).abs().max().item()
print(f"origin time: {origin_time / n}s fused time: {fused_time / n}s error:{error / n}")

if __name__ == '__main__':
import torchvision
net = torchvision.models.mobilenet_v2(True)
net.eval()
validate(net, cuda=False)

测试

这个融合优化属于经济上净赚的事情,精度理论上无损(实际上有损,但是很小),速度有20%-30%的提升,尤其是BN层特别多的情况。

以下为在CPU上测速的情况:

ResNetBeforeAfter
ResNet180.0880.076
ResNet340.1570.124
ResNet500.2750.185
ResNet1520.7280.406

引用

PyTorch 卷积与BatchNorm的融合 - 知乎 (zhihu.com)