火炬:'1.9.0+cu111'
Tensorflow-gpu:'2.5.0'
我遇到了一个奇怪的事情,当使用tensorflow 2.5的Batch Normal层和Pytorch 1.9的BatchNorm2d层计算同一个Tensor时,结果相差很大(TensorFlow接近1,Pytorch接近0)。一开始以为是momentum和epsilon的区别,后来改成一样,结果是一样的。
from torch import nn
import torch
x = torch.ones((20, 100, 35, 45))
a = nn.Sequential(
# nn.Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), padding=0, bias=True),
nn.BatchNorm2d(100)
)
b = a(x)
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import *
x = tf.ones((20, 35, 45, 100))
a = keras.models.Sequential([
# Conv2D(128, (1, 1), (1, 1), padding='same', use_bias=True),
BatchNormalization()
])
b = a(x)
批量归一化在训练和推理中的工作方式不同,
在训练期间(即使用fit()
或调用带有参数的层/模型时training=True
),层使用当前输入批次的均值和标准差对其输出进行归一化。也就是说,对于每个被归一化的通道,该层返回
gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta
在哪里:
在推理过程中(即当使用evaluate()
或predict()
或当使用参数调用层/模型时training=False
(这是默认值),层使用它在训练期间看到的批次的均值和标准差的移动平均值对其输出进行归一化。即说,它返回
gamma * (batch - self.moving_mean) / sqrt(self.moving_var + epsilon) + beta.
self.moving_mean
和self.moving_var
是不可训练的变量,每次在训练模式下调用层时都会更新,例如:
moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
moving_var = moving_var * momentum + var(batch) * (1 - momentum)
参考:https : //www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization
如果你在eval
模式下运行 pytorch batchnorm ,你会得到接近的结果(其余的差异来自不同的内部实现、参数选择等),
from torch import nn
import torch
x = torch.ones((1, 2, 2, 2))
a = nn.Sequential(
# nn.Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), padding=0, bias=True),
nn.BatchNorm2d(2)
)
a.eval()
b = a(x)
print(b)
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import *
x = tf.ones((1, 2, 2, 2))
a = keras.models.Sequential([
# Conv2D(128, (1, 1), (1, 1), padding='same', use_bias=True),
BatchNormalization()
])
b = a(x)
print(b)
out:
tensor([[[[1.0000, 1.0000],
[1.0000, 1.0000]],
[[1.0000, 1.0000],
[1.0000, 1.0000]]]], grad_fn=<NativeBatchNormBackward>)
tf.Tensor(
[[[[0.9995004 0.9995004]
[0.9995004 0.9995004]]
[[0.9995004 0.9995004]
[0.9995004 0.9995004]]]], shape=(1, 2, 2, 2), dtype=float32)
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句