模型层融合conv与bn
本文主要介绍了卷积层 conv 和 批归一化层 batch normalization 融合的原理。
1. 批归一化 Batch Normalization
批归一化(Batch Normalization)因其可以加速神经网络训练、使网络训练更稳定,而且还有一定的正则化效果,所以得到了非常广泛的应用。但是,在推理阶段,BN层一般是可以完全融合到前面的卷积层的,而且丝毫不影响性能。
Batch Normalization 的思想非常简单,一句话概括就是,对一个神经元(或者一个卷积核)的输出减去统计得到的均值除以标准差,然后乘以一个可学习的系数,再加上一个偏置,这个过程就完成了。
在训练过程中, 主要执行如下四个步骤:
其中 $\epsilon$ 为一个非常小的常数, 例如 0.0001, 主要是为了避免除零错误。 而 $\gamma$ 和 $\beta$ 则是可学习参数, 在训练过程中,和其他卷积核的参数一样, 通过梯度下降来学习。 在训练过程中,为保持稳定,一般使用滑动平均法更新均值和方差,滑动平均就是在更新当前值的时候,以一定比例保存之前的数值,以均值 $\mu$ 为例,以一定比例 $\theta$ (例如这里0.99)保存之前的均值,当前只更新 $(1-\theta)$ 倍(也就是0.001倍)的本Batch 的均值,计算方法如下:
标准差的滑动平均计算方法也一样。
在推理(测试)阶段,我们不太会对一个batch图像进行预测,一般是对单张图像测试。因此,通过前面公式计算 $\mu$ 和 $\sigma$ 就不可能。其实对于预测阶段时所使用的均值和方差,其实也是来源于训练集。比如我们在模型训练时我们就记录下每个 batch 下的均值和方差,待训练完毕后,我们求整个训练样本的均值和方差期望值(滑动均值和方差),作为我们进行预测时进行BN的的均值和方差。
具体的 batchnorm 的一维 python 实现可以参考如下代码:
import numpy as np
def batchnorm_forward(x, gamma, beta, bn_param):
mode = bn_param['mode']
eps = bn_param.get('eps', 1e-5)
momentum = bn_param.get('momentum', 0.9)
N, D = s.shape # N is batch_size * H * W, D is channels
running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))
out, cahce = None, None
if mode == 'train':
batch_mean = np.mean(x, axis=0, keepdims=True)
batch_var = np.var(x, axis=0, keepdims=True)
x_norm = (inp - batch_mean) / np.sqrt(batch_var + eps)
out = x_norm * gamma + beta
# store variables in cache
cache = (x, x_norm, gamma, beta, eps, batch_mean, batch_var)
# update running_mean & running var
running_mean = momentum * running_mean + (1 - momentum) * batch_mean
running_var = momentum * running_var + (1 - momentum) * batch_var
elif mode == 'test':
x_norm = (x - running_mean) / np.sqrt(running_var + eps)
out = x_norm * gamma + beta
bn_param['running_mean'] = running_mean
bn_param['running_var'] = running_var
return out, cache
2. conv 与 Batchnorm 的融合
网络完成训练后,在 inference 阶段,为了加速运算,通常将卷积层和BN层进行融合:
(1)卷积层可以抽象为如下公式:
(2)bn 层可以抽象为如下公式:
其中 $\mu$ 和 $\sigma$ 是整个训练集的均值和方差(滑动), 这是两个常量。$\gamma$ 和 $\beta$ 是学习完成的参数, 也是两个常量。
(3)融合两层:将 conv 层的公式带入到 BN 层的公式
融合后相当于:
通过公式可以看出, 我们可以将 Batch Normalization 层融合到卷积层中,相当于对卷积核进行一定的修改,没有增加卷积的计算量,同时整个 Batch Normalization 层的计算量都省去了。
Pytorch 提供了相关的 conv 和 bn 融合的代码: https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/fusion.py
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
bn_w = torch.ones_like(bn_rm)
if bn_b is None:
bn_b = torch.zeros_like(bn_rm)
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) # 注意一下,这里直接取了倒数 rsqrt
conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)
3. 实际测试
In [1]: import torch
In [2]: import torchvision
In [3]: from torch.nn.utils.fusion import fuse_conv_bn_weights
In [4]: resnet18 = torchvision.models.resnet18(pretrained=True)
In [5]: resnet18.eval()
In [6]: conv_bn = torch.nn.Sequential(
...: resnet18.conv1,
...: resnet18.bn1
...: )
In [7]: bn_var, bn_mean = resnet18.bn1.running_var, resnet18.bn1.running_mean
In [8]: bn_eps = resnet18.bn1.eps
In [9]: bn_weight, bn_bias = resnet18.bn1.weight, resnet18.bn1.bias
In [10]: conv_weight, conv_bias = resnet18.conv1.weight, resnet18.conv1.bias
In [11]: fused_w, fused_b = fuse_conv_bn_weights(conv_weight, conv_bias, mean, var, eps, weight, bias)
In [12]: fused_conv = torch.nn.Conv2d(
...: resnet18.conv1.in_channels,
...: resnet18.conv1.out_channels,
...: kernel_size=resnet18.conv1.kernel_size,
...: stride=resnet18.conv1.stride,
...: padding=resnet18.conv1.padding,
...: bias=True
...: )
In [13]: fused.weight.copy_(fused_w)
In [14]: fused.bias.copy_(fused_b)
In [15]: torch.set_grad_enabled(False)
In [16]: x = torch.randn(16, 3, 256, 256)
In [17]: res1 = conv_bn.forward(x)
In [18]: res1
Out[18]:
tensor([[[[ 3.7646e-01, 7.3849e-01, -1.1688e-03, ..., 4.6396e-01,
5.1061e-01, 7.1462e-02],
[ 1.0769e+00, 1.2775e+00, 5.5155e-01, ..., 1.2121e+00,
-7.6443e-02, 4.3850e-01],
[ 5.2745e-02, -3.7510e-01, -3.4277e-01, ..., -2.5618e-01,
-7.3687e-01, -2.6337e-01],
...,
[ 3.1750e-01, 2.7688e-01, -7.8872e-01, ..., 4.0573e-01,
2.3306e-01, -6.2662e-01],
[ 7.0026e-01, -1.9506e-01, 5.6528e-01, ..., -1.3667e-01,
-5.3668e-03, 5.2011e-01],
[ 4.6626e-02, 5.6441e-01, 5.7992e-01, ..., -1.9146e-01,
3.9299e-01, 2.9972e-01]]]])
In [19]: res2 = fused(x)
In [20]: res2
Out[20]:
tensor([[[[ 3.7646e-01, 7.3850e-01, -1.1688e-03, ..., 4.6396e-01,
5.1061e-01, 7.1462e-02],
[ 1.0769e+00, 1.2775e+00, 5.5155e-01, ..., 1.2121e+00,
-7.6443e-02, 4.3850e-01],
[ 5.2745e-02, -3.7510e-01, -3.4277e-01, ..., -2.5618e-01,
-7.3687e-01, -2.6337e-01],
...,
[ 3.1750e-01, 2.7688e-01, -7.8872e-01, ..., 4.0573e-01,
2.3306e-01, -6.2662e-01],
[ 7.0026e-01, -1.9506e-01, 5.6528e-01, ..., -1.3667e-01,
-5.3669e-03, 5.2011e-01],
[ 4.6626e-02, 5.6441e-01, 5.7992e-01, ..., -1.9146e-01,
3.9299e-01, 2.9972e-01]]]])
运行代码,会发现融合 conv 和 bn 层之后推理结果是一样,所以是等效替换。另外也可以对比前后推理时间的差异,会发现融合后推理时间会减少。
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!