模型层融合conv与bn

本文主要介绍了卷积层 conv 和 批归一化层 batch normalization 融合的原理。

1. 批归一化 Batch Normalization

​ 批归一化(Batch Normalization)因其可以加速神经网络训练、使网络训练更稳定,而且还有一定的正则化效果,所以得到了非常广泛的应用。但是,在推理阶段,BN层一般是可以完全融合到前面的卷积层的,而且丝毫不影响性能。

​ Batch Normalization 的思想非常简单,一句话概括就是,对一个神经元(或者一个卷积核)的输出减去统计得到的均值除以标准差,然后乘以一个可学习的系数,再加上一个偏置,这个过程就完成了。

​ 在训练过程中, 主要执行如下四个步骤:

其中 $\epsilon$ 为一个非常小的常数, 例如 0.0001, 主要是为了避免除零错误。 而 $\gamma$ 和 $\beta$ 则是可学习参数, 在训练过程中,和其他卷积核的参数一样, 通过梯度下降来学习。 在训练过程中,为保持稳定,一般使用滑动平均法更新均值和方差,滑动平均就是在更新当前值的时候,以一定比例保存之前的数值,以均值 $\mu$ 为例,以一定比例 $\theta$ (例如这里0.99)保存之前的均值,当前只更新 $(1-\theta)$ 倍(也就是0.001倍)的本Batch 的均值,计算方法如下:

标准差的滑动平均计算方法也一样。

​ 在推理(测试)阶段,我们不太会对一个batch图像进行预测,一般是对单张图像测试。因此,通过前面公式计算 $\mu$ 和 $\sigma$ 就不可能。其实对于预测阶段时所使用的均值和方差,其实也是来源于训练集。比如我们在模型训练时我们就记录下每个 batch 下的均值和方差,待训练完毕后,我们求整个训练样本的均值和方差期望值(滑动均值和方差),作为我们进行预测时进行BN的的均值和方差。

具体的 batchnorm 的一维 python 实现可以参考如下代码:

import numpy as np

def batchnorm_forward(x, gamma, beta, bn_param):
    mode = bn_param['mode']
    eps = bn_param.get('eps', 1e-5) 
    momentum = bn_param.get('momentum', 0.9)
    
    N, D = s.shape  # N is batch_size * H * W, D is channels
    running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
    running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))
    
    out, cahce = None, None
    if mode == 'train':    
        batch_mean = np.mean(x, axis=0, keepdims=True)
        batch_var = np.var(x, axis=0, keepdims=True) 
        x_norm = (inp - batch_mean) / np.sqrt(batch_var + eps)
        out = x_norm * gamma + beta
        # store variables in cache
        cache = (x, x_norm, gamma, beta, eps, batch_mean, batch_var)
        # update running_mean & running var
        running_mean = momentum * running_mean + (1 - momentum) * batch_mean
        running_var = momentum * running_var + (1 - momentum) * batch_var
    elif mode == 'test':
        x_norm = (x - running_mean) / np.sqrt(running_var + eps)
        out = x_norm * gamma + beta
        
    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var
    return out, cache

2. conv 与 Batchnorm 的融合

​ 网络完成训练后,在 inference 阶段,为了加速运算,通常将卷积层和BN层进行融合:

(1)卷积层可以抽象为如下公式:

(2)bn 层可以抽象为如下公式:

其中 $\mu$ 和 $\sigma$ 是整个训练集的均值和方差(滑动), 这是两个常量。$\gamma$ 和 $\beta$ 是学习完成的参数, 也是两个常量。

(3)融合两层:将 conv 层的公式带入到 BN 层的公式

融合后相当于:

​ 通过公式可以看出, 我们可以将 Batch Normalization 层融合到卷积层中,相当于对卷积核进行一定的修改,没有增加卷积的计算量,同时整个 Batch Normalization 层的计算量都省去了。

Pytorch 提供了相关的 conv 和 bn 融合的代码: https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/fusion.py

def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
    if conv_b is None:
        conv_b = torch.zeros_like(bn_rm)
    if bn_w is None:
        bn_w = torch.ones_like(bn_rm)
    if bn_b is None:
        bn_b = torch.zeros_like(bn_rm)
    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)  # 注意一下,这里直接取了倒数 rsqrt

    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b

    return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)

3. 实际测试

In [1]: import torch
In [2]: import torchvision
In [3]: from torch.nn.utils.fusion import fuse_conv_bn_weights

In [4]: resnet18 = torchvision.models.resnet18(pretrained=True)
In [5]: resnet18.eval()
In [6]: conv_bn = torch.nn.Sequential(
   ...:     resnet18.conv1, 
   ...:     resnet18.bn1 
   ...: )                                                                                                                         

In [7]: bn_var, bn_mean = resnet18.bn1.running_var, resnet18.bn1.running_mean
In [8]: bn_eps = resnet18.bn1.eps
In [9]: bn_weight, bn_bias = resnet18.bn1.weight, resnet18.bn1.bias
In [10]: conv_weight, conv_bias = resnet18.conv1.weight, resnet18.conv1.bias
In [11]: fused_w, fused_b = fuse_conv_bn_weights(conv_weight, conv_bias, mean, var, eps, weight, bias)
In [12]: fused_conv = torch.nn.Conv2d(
    ...:             resnet18.conv1.in_channels,
    ...:             resnet18.conv1.out_channels,
    ...:             kernel_size=resnet18.conv1.kernel_size,
    ...:             stride=resnet18.conv1.stride,
    ...:             padding=resnet18.conv1.padding,
    ...:             bias=True
    ...:         )

In [13]: fused.weight.copy_(fused_w)
In [14]: fused.bias.copy_(fused_b)

In [15]: torch.set_grad_enabled(False) 
In [16]: x = torch.randn(16, 3, 256, 256)
In [17]: res1 = conv_bn.forward(x)
In [18]: res1
Out[18]: 
tensor([[[[ 3.7646e-01,  7.3849e-01, -1.1688e-03,  ...,  4.6396e-01,
            5.1061e-01,  7.1462e-02],
          [ 1.0769e+00,  1.2775e+00,  5.5155e-01,  ...,  1.2121e+00,
           -7.6443e-02,  4.3850e-01],
          [ 5.2745e-02, -3.7510e-01, -3.4277e-01,  ..., -2.5618e-01,
           -7.3687e-01, -2.6337e-01],
          ...,
          [ 3.1750e-01,  2.7688e-01, -7.8872e-01,  ...,  4.0573e-01,
            2.3306e-01, -6.2662e-01],
          [ 7.0026e-01, -1.9506e-01,  5.6528e-01,  ..., -1.3667e-01,
           -5.3668e-03,  5.2011e-01],
          [ 4.6626e-02,  5.6441e-01,  5.7992e-01,  ..., -1.9146e-01,
            3.9299e-01,  2.9972e-01]]]])


In [19]: res2 = fused(x)
In [20]: res2
Out[20]: 
tensor([[[[ 3.7646e-01,  7.3850e-01, -1.1688e-03,  ...,  4.6396e-01,
            5.1061e-01,  7.1462e-02],
          [ 1.0769e+00,  1.2775e+00,  5.5155e-01,  ...,  1.2121e+00,
           -7.6443e-02,  4.3850e-01],
          [ 5.2745e-02, -3.7510e-01, -3.4277e-01,  ..., -2.5618e-01,
           -7.3687e-01, -2.6337e-01],
          ...,
          [ 3.1750e-01,  2.7688e-01, -7.8872e-01,  ...,  4.0573e-01,
            2.3306e-01, -6.2662e-01],
          [ 7.0026e-01, -1.9506e-01,  5.6528e-01,  ..., -1.3667e-01,
           -5.3669e-03,  5.2011e-01],
          [ 4.6626e-02,  5.6441e-01,  5.7992e-01,  ..., -1.9146e-01,
            3.9299e-01,  2.9972e-01]]]])

​ 运行代码,会发现融合 conv 和 bn 层之后推理结果是一样,所以是等效替换。另外也可以对比前后推理时间的差异,会发现融合后推理时间会减少。


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!