原理
卷积的工作:
BN的工作:
把卷积带入BN:
融合后的新卷积:
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
| import torch from torch import cuda import torch.nn as nn from time import time from copy import deepcopy
torch.manual_seed(1) torch.cuda.manual_seed(1)
class DummyModule(nn.Module): def __init__(self): super(DummyModule, self).__init__()
def forward(self, x): return x
def fuse(conv, bn): w = conv.weight mean = bn.running_mean var_sqrt = torch.sqrt(bn.running_var + bn.eps)
beta = bn.weight gamma = bn.bias
if conv.bias is not None: b = conv.bias else: b = mean.new_zeros(mean.shape)
w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1]) b = (b - mean)/var_sqrt * beta + gamma
fused_conv = nn.Conv2d( conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, bias=True, padding_mode=conv.padding_mode ) fused_conv.weight = nn.Parameter(w) fused_conv.bias = nn.Parameter(b) return fused_conv
def fuse_module(m): children = list(m.named_children()) conv = None conv_name = None for name, child in children: if isinstance(child, nn.BatchNorm2d) and conv: bc = fuse(conv, child) m._modules[conv_name] = bc m._modules[name] = DummyModule() conv = None elif isinstance(child, nn.Conv2d): conv = child conv_name = name else: fuse_module(child)
def validate(net, cuda=torch.cuda.is_available()): net.eval() fused_net = deepcopy(net) fused_net.eval() fuse_module(fused_net)
error = 0 origin_time = 0 fused_time = 0
if cuda: net.cuda() fused_net.cuda() n = 1 with torch.no_grad(): for _ in range(n): x = torch.randn(size=(32, 3, 224, 224)) if cuda: x = x.cuda()
torch.cuda.synchronize() start = time() out_origin = net(x) torch.cuda.synchronize() end = time() origin_time += end - start
torch.cuda.synchronize() start = time() out_fused = fused_net(x) torch.cuda.synchronize() end = time() fused_time += end - start
error += (out_origin - out_fused).abs().max().item() print(f"origin time: {origin_time / n}s fused time: {fused_time / n}s error:{error / n}")
if __name__ == '__main__': import torchvision net = torchvision.models.mobilenet_v2(True) net.eval() validate(net, cuda=False)
|
测试
这个融合优化属于经济上净赚的事情,精度理论上无损(实际上有损,但是很小),速度有20%-30%的提升,尤其是BN层特别多的情况。
以下为在CPU上测速的情况:
ResNet | Before | After |
---|
ResNet18 | 0.088 | 0.076 |
ResNet34 | 0.157 | 0.124 |
ResNet50 | 0.275 | 0.185 |
ResNet152 | 0.728 | 0.406 |
引用
PyTorch 卷积与BatchNorm的融合 - 知乎 (zhihu.com)