136 lines
5.7 KiB
Python
136 lines
5.7 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import wraps
|
|
|
|
|
|
def extra_args_provider_decorator(extra_args_provider):
|
|
@wraps(extra_args_provider)
|
|
def wrapper(parser):
|
|
if extra_args_provider is not None:
|
|
parser = extra_args_provider(parser)
|
|
parser = process_args(parser)
|
|
return parser
|
|
|
|
return wrapper
|
|
|
|
|
|
def parse_args_decorator(parse_args):
|
|
@wraps(parse_args)
|
|
def wrapper(extra_args_provider=None, ignore_unknown_args=False):
|
|
decorated_provider = extra_args_provider_decorator(extra_args_provider)
|
|
return parse_args(decorated_provider, ignore_unknown_args)
|
|
|
|
return wrapper
|
|
|
|
|
|
def process_args(parser):
|
|
parser.conflict_handler = 'resolve'
|
|
parser = _add_network_size_args(parser)
|
|
parser = _add_lora_args(parser)
|
|
parser = _add_data_args(parser)
|
|
parser = _add_moe_args(parser)
|
|
parser = _add_num_layer_allocation(parser)
|
|
parser = _add_dataset_args(parser)
|
|
return parser
|
|
|
|
|
|
def _add_lora_args(parser):
|
|
group = parser.add_argument_group(title='lora')
|
|
|
|
group.add_argument('--lora-target-modules', nargs='+', type=str, default=[],
|
|
help='Lora target modules.')
|
|
group.add_argument('--lora-load', type=str, default=None,
|
|
help='Directory containing a lora model checkpoint.')
|
|
group.add_argument('--lora-r', type=int, default=16,
|
|
help='Lora r.')
|
|
group.add_argument('--lora-alpha', type=int, default=32,
|
|
help='Lora alpha.')
|
|
group.add_argument('--lora-modules-to-save', nargs='+', type=str, default=None,
|
|
help='Lora modules to save.')
|
|
group.add_argument('--lora-register-forward-hook', nargs='+', type=str,
|
|
default=['word_embeddings', 'input_layernorm'],
|
|
help='Lora register forward hook.')
|
|
|
|
return parser
|
|
|
|
|
|
def _add_moe_args(parser):
|
|
group = parser.add_argument_group(title='moe')
|
|
group.add_argument('--moe-router-topk', type=int, default=2,
|
|
help='Number of experts to route to for each token. The default is 2.')
|
|
group.add_argument('--moe-router-load-balancing-type', type=str,
|
|
choices=['aux_loss', ],
|
|
default='aux_loss',
|
|
help='Determines the load balancing strategy for the router. "aux_loss" corresponds '
|
|
'to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds '
|
|
'to the balancing algorithm used in S-BASE, and "None" implies no load balancing. '
|
|
'The default is "aux_loss".')
|
|
group.add_argument('--expert-interval', type=int, default=1,
|
|
help='Use experts in every "expert-interval" layers')
|
|
group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0,
|
|
help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.')
|
|
group.add_argument('--moe-z-loss-coeff', type=float, default=0.0,
|
|
help='Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended.')
|
|
group.add_argument('--moe-train-capacity-factor', type=float, default=1.0,
|
|
help='The capacity of the MoE expert at training time')
|
|
group.add_argument('--noisy_gate_policy', type=str, default=None,
|
|
help="noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.")
|
|
return parser
|
|
|
|
|
|
def _add_data_args(parser):
|
|
group = parser.add_argument_group(title='data')
|
|
group.add_argument('--is-instruction-dataset', action='store_true', help='use instruction dataset or not')
|
|
group.add_argument('--variable-seq-lengths', action='store_true', help='Use variable seq lengths or not.')
|
|
group.add_argument("--tokenizer-kwargs", type=str, nargs='+', default=None,
|
|
help="Kwargs of the huggingface tokenizer.")
|
|
group.add_argument('--tokenizer-padding-side', type=str, default='right',
|
|
help="tokenizer padding side")
|
|
|
|
return parser
|
|
|
|
|
|
def _add_num_layer_allocation(parser):
|
|
group = parser.add_argument_group(title='num_layer_allocation')
|
|
group.add_argument('--num-layer-list',
|
|
type=str, help='a list of number of layers, '
|
|
'seperated by comma; e.g., 4,4,4,4')
|
|
return parser
|
|
|
|
|
|
def _add_network_size_args(parser):
|
|
group = parser.add_argument_group(title='network_size_args')
|
|
group.add_argument('--padded-vocab-size',
|
|
type=int,
|
|
default=None,
|
|
help='set padded vocab size')
|
|
group.add_argument('--embed-layernorm',
|
|
action='store_true',
|
|
default=False,
|
|
help='set padded vocab size'
|
|
)
|
|
return parser
|
|
|
|
|
|
def _add_dataset_args(parser):
|
|
group = parser.add_argument_group(title='dataset_args')
|
|
group.add_argument('--no-shared-storage',
|
|
action='store_true',
|
|
default=False,
|
|
help='if no shared storage, set it'
|
|
)
|
|
return parser
|