Cadam Optimizer

This commit is contained in:
l00619700 2023-11-13 17:24:27 +08:00
parent 711911f9fc
commit 0211eea70d
5 changed files with 142 additions and 4 deletions

View File

@ -55,7 +55,7 @@ LLaMA2-13B/34B, Aquila-7B, Baichuan2-7B/13B 等模型即将上线...
* [序列并行](#jump5)
* [ZeRO-1/2](#jump6)
* [倒三角加速](#jump7)
* [融合优化器](#jump8)
* [高性能优化器](#jump8)
* [前向网络合并](#jump9)
* [梯度累积](#jump9)
@ -874,9 +874,88 @@ SP将TP中的allreduce操作拆分为了reduce-scatter和allgather操作
在AscendSpeed中可以通过使能 `--triangle-attn` 标志打开倒三角加速。
### <span id="jump8"> 融合优化器 </span>
### <span id="jump8"> 高性能优化器 </span>
融合优化器通过减小算子下发降低计算时间在AscendSpeed中通过 `--optimizer` 可以选择不同的优化器,具体的,选择 `--optimizer adam` 可以节省更多内存,选择 `--optimizer fused_adam` 则可以运行的更快。
除此之外,我们还提供一些自研的低内存优化器。如:
Cadam是基于谷歌提出的[Lion优化器](https://arxiv.org/abs/2302.06675)并采用与其beta参数相同的极简形式。这样在去掉二阶动量的同时可以得到类似于Adam的参数更新效果。然后对一阶动量进行量化和逐行压缩采用昇腾亲和的FP16进行计算同时结合Scale和clip操作防止溢出。
<div align=center>
<img src="sources/images/cadam.png" height="300px" width="600px"></div>
欲使用Cadam优化器需要在脚本中指定如下参数`--optimizer cadam`并将能够用Adam正常平稳训练的模型的学习率`lr`和最小学习率`min-lr`缩小3-10倍`weight_decay`同步放大3-10倍`--adam-beta1 0.965`。
部分大模型的测试结果如下表所示:
<table>
<thead>
<tr>
<th>模型</th>
<th>优化器</th>
<th>性能</th>
<th>平均压缩HBM</th>
<th>BoolQ</th>
<th>PIQA</th>
<th>HellaSwag</th>
</tr>
</thead>
<tbody>
<tr>
<td rowspan="2"> LLama-7B </td>
<td> Adam </td>
<td> 5.65s/iteration </td>
<td> -- </td>
<td> 38.78% </td>
<td> 52.88% </td>
<td> 26.69% </td>
</tr>
<tr>
<td> Cadam </td>
<td> 5.68s/iteration </td>
<td> ↓ 21.8% </td>
<td> 41.50% </td>
<td> 55.11% </td>
<td> 26.59% </td>
</tr>
<tr>
<td rowspan="2"> LLama2-7B </td>
<td> Adam </td>
<td> 7.09s/iteration </td>
<td> -- </td>
<td> 40.36% </td>
<td> 51.20% </td>
<td> 25.13% </td>
</tr>
<tr>
<td> Cadam </td>
<td> 7.12s/iteration </td>
<td> ↓ 17.4% </td>
<td> 44.16% </td>
<td> 51.20% </td>
<td> 26.29% </td>
</tr>
<tr>
<td rowspan="2"> Bloom-7B </td>
<td> Adam </td>
<td> 4.10s/iteration </td>
<td> -- </td>
<td> 37.83% </td>
<td> 49.24% </td>
<td> 25.31% </td>
</tr>
<tr>
<td> Cadam </td>
<td> 4.25s/iteration </td>
<td> ↓ 10.1% </td>
<td> 37.83% </td>
<td> 51.16% </td>
<td> 25.25% </td>
</tr>
</tbody>
</table>
### <span id="jump9"> 前向网络合并 </span>
对于llama以及一些在FFN中没有bias的模型FFN中的线性层可以合并计算以减少TP中的通信量通过设置 `--mlp-layer-fusion` 标志,可以使用该特性。

View File

@ -634,7 +634,7 @@ def _add_training_args(parser):
help='Create separate groups for MoE params.'
'This is necessary for techniques like ZeRO.')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd', 'fused_adam'],
choices=['adam', 'sgd', 'fused_adam', 'cadam'],
help='Optimizer function')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],

View File

@ -21,6 +21,7 @@ from ascendspeed.model import LayerNorm
from ascendspeed.model.fused_layer_norm import MixedFusedLayerNorm
from ascendspeed.model.lora_utils import is_enable_lora
from ascendspeed.optimizer.adam import AdamW as Adam
from ascendspeed.optimizer.cadam import CAdamW as CAdam
from ascendspeed.error_utils import check_equal
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
@ -141,7 +142,11 @@ def get_megatron_optimizer(model):
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps)
elif args.optimizer == 'cadam':
optimizer = CAdam(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
beta=args.adam_beta1)
else:
raise Exception('{} optimizer is not supported.'.format(
args.optimizer))

View File

@ -0,0 +1,54 @@
import torch
from torch.optim.optimizer import Optimizer
from ascendspeed.error_utils import check_divisible_by_zero
class CAdamW(Optimizer):
def __init__(self, params, lr=5e-6, beta=0.965, weight_decay=3e-2, eps=1e-8):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= beta < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(beta))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, eps=eps)
super(CAdamW, self).__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
beta = group['beta']
group_name = group['params']
lr, weight_decay, eps = group['lr'], group['weight_decay'], group['eps']
for _, p in enumerate(group['params']):
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
state = self.state[p]
if len(state) == 0:
state['step'] = 0
device = p.device
state["exp_avg"] = torch.zeros(list(p.shape), dtype=torch.float16).to(device)
state["qs_val_1"] = torch.ones(list(p.shape)[:-1] + [1], dtype=torch.float32).to(device)
state['step'] += 1
grad = p.grad
if group_name == 'no_weight_decay_params':
lr = 0.5 * lr
weight_decay = 0
p.mul_(1 - lr * weight_decay)
exp_avg_q = state['exp_avg']
qs_val_1 = state["qs_val_1"]
exp_avg = exp_avg_q.float() * qs_val_1
exp_avg.mul_(beta).add_(grad, alpha=1 - beta)
qs_val_1.copy_((torch.max(torch.abs(exp_avg), dim=-1, keepdim=True)[0]) / 65503.0 + eps)
exp_avg_q.copy_((exp_avg * qs_val_1.reciprocal()).clamp_(-65503.0, 65503.0).half())
p.add_(torch.sign(exp_avg), alpha=-lr)
return loss

BIN
sources/images/cadam.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 106 KiB