Cadam Optimizer
This commit is contained in:
parent
711911f9fc
commit
0211eea70d
83
README.md
83
README.md
|
@ -55,7 +55,7 @@ LLaMA2-13B/34B, Aquila-7B, Baichuan2-7B/13B 等模型即将上线...
|
|||
* [序列并行](#jump5)
|
||||
* [ZeRO-1/2](#jump6)
|
||||
* [倒三角加速](#jump7)
|
||||
* [融合优化器](#jump8)
|
||||
* [高性能优化器](#jump8)
|
||||
* [前向网络合并](#jump9)
|
||||
* [梯度累积](#jump9)
|
||||
|
||||
|
@ -874,9 +874,88 @@ SP将TP中的allreduce操作拆分为了reduce-scatter和allgather操作,这
|
|||
在AscendSpeed中可以通过使能 `--triangle-attn` 标志打开倒三角加速。
|
||||
|
||||
|
||||
### <span id="jump8"> 融合优化器 </span>
|
||||
### <span id="jump8"> 高性能优化器 </span>
|
||||
融合优化器通过减小算子下发降低计算时间,在AscendSpeed中,通过 `--optimizer` 可以选择不同的优化器,具体的,选择 `--optimizer adam` 可以节省更多内存,选择 `--optimizer fused_adam` 则可以运行的更快。
|
||||
|
||||
除此之外,我们还提供一些自研的低内存优化器。如:
|
||||
|
||||
Cadam是基于谷歌提出的[Lion优化器](https://arxiv.org/abs/2302.06675),并采用与其beta参数相同的极简形式。这样在去掉二阶动量的同时,可以得到类似于Adam的参数更新效果。然后对一阶动量进行量化和逐行压缩,采用昇腾亲和的FP16进行计算,同时结合Scale和clip操作防止溢出。
|
||||
|
||||
<div align=center>
|
||||
<img src="sources/images/cadam.png" height="300px" width="600px"></div>
|
||||
|
||||
欲使用Cadam优化器,需要在脚本中指定如下参数`--optimizer cadam`,并将能够用Adam正常平稳训练的模型的学习率`lr`和最小学习率`min-lr`缩小3-10倍,`weight_decay`同步放大3-10倍,`--adam-beta1 0.965`。
|
||||
|
||||
部分大模型的测试结果如下表所示:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>模型</th>
|
||||
<th>优化器</th>
|
||||
<th>性能</th>
|
||||
<th>平均压缩HBM</th>
|
||||
<th>BoolQ</th>
|
||||
<th>PIQA</th>
|
||||
<th>HellaSwag</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td rowspan="2"> LLama-7B </td>
|
||||
<td> Adam </td>
|
||||
<td> 5.65s/iteration </td>
|
||||
<td> -- </td>
|
||||
<td> 38.78% </td>
|
||||
<td> 52.88% </td>
|
||||
<td> 26.69% </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td> Cadam </td>
|
||||
<td> 5.68s/iteration </td>
|
||||
<td> ↓ 21.8% </td>
|
||||
<td> 41.50% </td>
|
||||
<td> 55.11% </td>
|
||||
<td> 26.59% </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2"> LLama2-7B </td>
|
||||
<td> Adam </td>
|
||||
<td> 7.09s/iteration </td>
|
||||
<td> -- </td>
|
||||
<td> 40.36% </td>
|
||||
<td> 51.20% </td>
|
||||
<td> 25.13% </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td> Cadam </td>
|
||||
<td> 7.12s/iteration </td>
|
||||
<td> ↓ 17.4% </td>
|
||||
<td> 44.16% </td>
|
||||
<td> 51.20% </td>
|
||||
<td> 26.29% </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2"> Bloom-7B </td>
|
||||
<td> Adam </td>
|
||||
<td> 4.10s/iteration </td>
|
||||
<td> -- </td>
|
||||
<td> 37.83% </td>
|
||||
<td> 49.24% </td>
|
||||
<td> 25.31% </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td> Cadam </td>
|
||||
<td> 4.25s/iteration </td>
|
||||
<td> ↓ 10.1% </td>
|
||||
<td> 37.83% </td>
|
||||
<td> 51.16% </td>
|
||||
<td> 25.25% </td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
||||
### <span id="jump9"> 前向网络合并 </span>
|
||||
对于llama以及一些在FFN中没有bias的模型,FFN中的线性层可以合并计算以减少TP中的通信量,通过设置 `--mlp-layer-fusion` 标志,可以使用该特性。
|
||||
|
||||
|
|
|
@ -634,7 +634,7 @@ def _add_training_args(parser):
|
|||
help='Create separate groups for MoE params.'
|
||||
'This is necessary for techniques like ZeRO.')
|
||||
group.add_argument('--optimizer', type=str, default='adam',
|
||||
choices=['adam', 'sgd', 'fused_adam'],
|
||||
choices=['adam', 'sgd', 'fused_adam', 'cadam'],
|
||||
help='Optimizer function')
|
||||
group.add_argument('--dataloader-type', type=str, default=None,
|
||||
choices=['single', 'cyclic'],
|
||||
|
|
|
@ -21,6 +21,7 @@ from ascendspeed.model import LayerNorm
|
|||
from ascendspeed.model.fused_layer_norm import MixedFusedLayerNorm
|
||||
from ascendspeed.model.lora_utils import is_enable_lora
|
||||
from ascendspeed.optimizer.adam import AdamW as Adam
|
||||
from ascendspeed.optimizer.cadam import CAdamW as CAdam
|
||||
from ascendspeed.error_utils import check_equal
|
||||
|
||||
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
|
||||
|
@ -141,7 +142,11 @@ def get_megatron_optimizer(model):
|
|||
weight_decay=args.weight_decay,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
eps=args.adam_eps)
|
||||
|
||||
elif args.optimizer == 'cadam':
|
||||
optimizer = CAdam(param_groups,
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay,
|
||||
beta=args.adam_beta1)
|
||||
else:
|
||||
raise Exception('{} optimizer is not supported.'.format(
|
||||
args.optimizer))
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from ascendspeed.error_utils import check_divisible_by_zero
|
||||
|
||||
|
||||
class CAdamW(Optimizer):
|
||||
def __init__(self, params, lr=5e-6, beta=0.965, weight_decay=3e-2, eps=1e-8):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= beta < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(beta))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, eps=eps)
|
||||
super(CAdamW, self).__init__(params, defaults)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
beta = group['beta']
|
||||
group_name = group['params']
|
||||
lr, weight_decay, eps = group['lr'], group['weight_decay'], group['eps']
|
||||
for _, p in enumerate(group['params']):
|
||||
if p.grad is None:
|
||||
continue
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError('AdamW does not support sparse gradients')
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
device = p.device
|
||||
state["exp_avg"] = torch.zeros(list(p.shape), dtype=torch.float16).to(device)
|
||||
state["qs_val_1"] = torch.ones(list(p.shape)[:-1] + [1], dtype=torch.float32).to(device)
|
||||
state['step'] += 1
|
||||
grad = p.grad
|
||||
if group_name == 'no_weight_decay_params':
|
||||
lr = 0.5 * lr
|
||||
weight_decay = 0
|
||||
p.mul_(1 - lr * weight_decay)
|
||||
exp_avg_q = state['exp_avg']
|
||||
qs_val_1 = state["qs_val_1"]
|
||||
exp_avg = exp_avg_q.float() * qs_val_1
|
||||
exp_avg.mul_(beta).add_(grad, alpha=1 - beta)
|
||||
qs_val_1.copy_((torch.max(torch.abs(exp_avg), dim=-1, keepdim=True)[0]) / 65503.0 + eps)
|
||||
exp_avg_q.copy_((exp_avg * qs_val_1.reciprocal()).clamp_(-65503.0, 65503.0).half())
|
||||
p.add_(torch.sign(exp_avg), alpha=-lr)
|
||||
return loss
|
Binary file not shown.
After Width: | Height: | Size: 106 KiB |
Loading…
Reference in New Issue