ModelLink/modellink/arguments.py

136 lines
5.7 KiB
Python

# coding=utf-8
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
def extra_args_provider_decorator(extra_args_provider):
@wraps(extra_args_provider)
def wrapper(parser):
if extra_args_provider is not None:
parser = extra_args_provider(parser)
parser = process_args(parser)
return parser
return wrapper
def parse_args_decorator(parse_args):
@wraps(parse_args)
def wrapper(extra_args_provider=None, ignore_unknown_args=False):
decorated_provider = extra_args_provider_decorator(extra_args_provider)
return parse_args(decorated_provider, ignore_unknown_args)
return wrapper
def process_args(parser):
parser.conflict_handler = 'resolve'
parser = _add_network_size_args(parser)
parser = _add_lora_args(parser)
parser = _add_data_args(parser)
parser = _add_moe_args(parser)
parser = _add_num_layer_allocation(parser)
parser = _add_dataset_args(parser)
return parser
def _add_lora_args(parser):
group = parser.add_argument_group(title='lora')
group.add_argument('--lora-target-modules', nargs='+', type=str, default=[],
help='Lora target modules.')
group.add_argument('--lora-load', type=str, default=None,
help='Directory containing a lora model checkpoint.')
group.add_argument('--lora-r', type=int, default=16,
help='Lora r.')
group.add_argument('--lora-alpha', type=int, default=32,
help='Lora alpha.')
group.add_argument('--lora-modules-to-save', nargs='+', type=str, default=None,
help='Lora modules to save.')
group.add_argument('--lora-register-forward-hook', nargs='+', type=str,
default=['word_embeddings', 'input_layernorm'],
help='Lora register forward hook.')
return parser
def _add_moe_args(parser):
group = parser.add_argument_group(title='moe')
group.add_argument('--moe-router-topk', type=int, default=2,
help='Number of experts to route to for each token. The default is 2.')
group.add_argument('--moe-router-load-balancing-type', type=str,
choices=['aux_loss', ],
default='aux_loss',
help='Determines the load balancing strategy for the router. "aux_loss" corresponds '
'to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds '
'to the balancing algorithm used in S-BASE, and "None" implies no load balancing. '
'The default is "aux_loss".')
group.add_argument('--expert-interval', type=int, default=1,
help='Use experts in every "expert-interval" layers')
group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0,
help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.')
group.add_argument('--moe-z-loss-coeff', type=float, default=0.0,
help='Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended.')
group.add_argument('--moe-train-capacity-factor', type=float, default=1.0,
help='The capacity of the MoE expert at training time')
group.add_argument('--noisy_gate_policy', type=str, default=None,
help="noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.")
return parser
def _add_data_args(parser):
group = parser.add_argument_group(title='data')
group.add_argument('--is-instruction-dataset', action='store_true', help='use instruction dataset or not')
group.add_argument('--variable-seq-lengths', action='store_true', help='Use variable seq lengths or not.')
group.add_argument("--tokenizer-kwargs", type=str, nargs='+', default=None,
help="Kwargs of the huggingface tokenizer.")
group.add_argument('--tokenizer-padding-side', type=str, default='right',
help="tokenizer padding side")
return parser
def _add_num_layer_allocation(parser):
group = parser.add_argument_group(title='num_layer_allocation')
group.add_argument('--num-layer-list',
type=str, help='a list of number of layers, '
'seperated by comma; e.g., 4,4,4,4')
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network_size_args')
group.add_argument('--padded-vocab-size',
type=int,
default=None,
help='set padded vocab size')
group.add_argument('--embed-layernorm',
action='store_true',
default=False,
help='set padded vocab size'
)
return parser
def _add_dataset_args(parser):
group = parser.add_argument_group(title='dataset_args')
group.add_argument('--no-shared-storage',
action='store_true',
default=False,
help='if no shared storage, set it'
)
return parser