《Transformer分析之模型训练内存计算的一个简单公式》介绍了Transformer算法内存占用的一个公式,但不是特别严谨,遗漏了很多细节,同时也没有与实际训练进行比较。
本文以 GPT-2 训练作为例子得出内存占用计算公式,GPT-2 和 Transformer 论文描述的实现是有一些不同的。
这个公式有很多假设,比如没有自动混合精度计算,也没有使用 checkpoints 这样的优化技术,因为它们对公式的形成有很大不同,这些务必要牢记。
本文不光有公式,据说最终的数值也和实际的训练占用进行过比较,但原文没有提供代码,所以我简单写了一个,代码部分在后面描述,本文主要聊聊公式以及我的理解。
首先看 GPT-2 的一些超参数:
• L = 12 # blocks数量• N = 12 # 注意力头个数• E = 768 # embedding 维度• B = 8 # 批次大小• T = 1024 # 序列长度• TOKS = 50257 # 词表大小• param_bytes = 4 # float32 全精度训练• bytes_to_gigs = 1_000_000_000 # 单位转换整体的公式:
model_params = (TOKS*E)+ L*( 4*E**2 + 2*E*4*E + 4*E)act_params = B*T*(2*TOKS+L*(14*E + N*T ))backprop_model_params = 3*model_paramsbackprop_act_params = act_params可以看出训练过程内存占用包含几部分:
• 模型参数本身内存占用• 激活值占用• 优化器adam和梯度,前者是模型参数大小的二倍,后者一倍• backprop_act_params,我认为计算以及理解是有问题的原始公式:
total_params = model_params+act_params+backprop_model_params+backprop_act_params=4*model_params+2*act_paramsgigabytes_used = total_params*param_bytes/bytes_to_gigsbackprop_act_params从名称上看好像是反向传播的激活占用,且内存消耗等同于前向传播激活值,一方面激活值在前向传播过程中已经存储了,所以不应该需要,我理解这个值是反向过程中的一个临时存储以及遗漏部分激活计算,但整个值预估的偏大一点!所以我乘了一个系数。
公式变更为:
total_params = model_params+act_params+backprop_model_params+backprop_act_params=4*model_params+1.3*act_paramsgigabytes_used = total_params*param_bytes/bytes_to_gigs1:模型参数
(TOKS*E) [embedding layer ]+ L [number of blocks]*( 4*E**2 [Q,K,V matrices and the linear projection after Attention] + 2*E*4*E [the MLP layer that projects up to 4*E hidden neurons and then back down again] + 4*E [Two layer norms and their scale and bias terms])稍微补充说明:
• embedding layer 和 head_params(最后softmax计算logits)参数共享,前者(vocab_size, d_model),后者(d_model, vocab_size)• 忽略了位置嵌入层• layer norms 的bits 占用还是很大的,所以 MLP 和注意力机制中需要 4*E整个公式还是很清晰的,此处就不多解释了,以前文章聊了很多了。
2:激活值
B[batch]*T[seq. length]*(2*TOKS [one hot vectors at input and output]+L[number of blocks]*(3E [K,Q,V projections] + N*T [Attention Heads softmax weightings] + E [value vector] + E [linear projection] + E [residual connection] + E [LayerNorm] +4E [MLP activation]+E [MLP projection down]+E[residual]+E[LayerNorm] ))它说 2*TOKS 是输入输出的独热向量,应该是有问题的,输入实际上是 (B,T,E),输出是 (B,T,TOKS),前者是每个token有一个 E 大小的隐藏维度,后者是每个 token 的 logits,对应的是 TOKS 大小的维度,不过对于理解问题不大!
关于 Attention Heads softmax weightings 计算权重的时候,每个头都需要计算,所以最终是 BTNT,没有问题。
而 value vector 是权重乘以 V 得出最终的权重值,也没有问题,但最后的残差是否要再额外内存,我觉得不需要,直接和权重值相加即可。
同时 MLP 中激活函数的内存占用也是不小的,它遗漏了,只是描述了两个线性层的内存计算。
整体看来,这个公式的可信度较高,描述了模型的大部分细节,但激活部分实际上还是缺少了很多内容,也没有考虑自动混合计算等,所以该公式只能作为计算内存消耗的一个参考。
而实际训练过程中内存占用多少呢?下文聊!
系列文章:
• Transformer分析之算力分析• Transformer分析之参数量分析• Transformer分析之从吞吐量的角度理解模型训练时长• Transformer分析之从算力的角度理解模型训练时长• Transformer分析之模型训练需要占用多少内存(上)• Transformer分析之模型训练需要占用多少内存(下)• Transformer分析之模型训练内存计算的一个简单公式