其中 μ 为 x 的均值,σ 为 x 的方差,γ 和 β 是可训练的模型参数,γ 是缩放参数,新分布的方差 γ2 ; β 是平移系数,新分布的均值为 β 。 ε 为一个小数,添加到方差上,避免分母为0。
def layerNorm(feature):
size = feature.shape
alpha = torch.nn.Parameter(torch.ones(size[-1]))
beta = torch.nn.Parameter(torch.ones(size[-1]))
input_dtype = feature.dtype
feature = torch.nn.Parameter(feature.to(torch.float32))
mean = feature.mean(-1, keepdim=True)
std = feature.std(-1, keepdim=True)
feature = alpha * (feature - mean)
return (feature / (std + 1e-6) + beta).to(input_dtype)
对于layerNorm和RMSNorm,layerNorm包含缩放和平移两部分,RMSNorm去除了平移部分,只保留了缩放部分。
def RMSNorm(feature):
size = feature.shape
weight = torch.nn.Parameter(torch.ones(size[-1]))
input_dtype = feature.dtype
feature = torch.nn.Parameter(feature.to(torch.float32))
variance = feature.pow(2).mean(-1, keepdim=True)
feature = feature * torch.rsqrt(variance + 1e-6)
return weight * feature.to(input_dtype)
RMSNorm 相比一般的 layerNorm,减少了计算均值和平移系数的部分,训练速度更快,效果基本相当,甚至有所提升。
DeepNorm 是由微软提出的一种 Normalization 方法。主要对 Transformer 结构中的残差链接做修正。
DeepNorm 可以缓解模型参数爆炸式更新的问题,把模型参数更新限制在一个常数域范围内,使得模型训练过程可以更稳定。模型规模可以达到 1000 层。
DeepNorm 兼具 PreLN 的训练稳定和 PostLN 的效果性能。
在 transformer 的原始结构中,采用了 PostLN 结构,即在残差连接之后 layerNorm,如上图(a)所示。在 LLM 训练过程中发现,PostLN 的输出层附近的梯度过大会造成训练的不稳定性。在 LLM 很少单独使用 PostLN,如在 GLM-130B 中采用 PostLN 与 PreLN 结合的方式。
PreLN 将 layerNorm 放置在残差连接的过程中,如上图(a)所示。PreLN 在每层的梯度范数近似相等,有利于提升训练稳定性。相比 PostLN,使用 PreLN 的深层 transforme 的训练更稳定,但是性能有一定损害。为了提升训练稳定性,很多大模型都采用了 PreLN。
阅读量:1029
点赞量:0
收藏量:0