105 lines
5.4 KiB
Python
105 lines
5.4 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import megatron
|
|
from .model import (
|
|
GPTModel, parallel_transformer_init, seq_length_wrapper,
|
|
norm_wrapper, SwitchMLP, state_dict_for_save_checkpoint_wrapper,
|
|
core_attention_wrapper, core_attention_forward, FlashSelfAttention,
|
|
ParallelAttention_wrapper
|
|
)
|
|
from .core import (vocab_embedding_wrapper, initialize_model_parallel_decorator,
|
|
destroy_model_parallel_decorator, get_expert_parallel_group,
|
|
get_expert_parallel_rank, get_expert_model_parallel_rank,
|
|
get_expert_parallel_world_size, get_expert_model_parallel_world_size,
|
|
set_expert_model_parallel_rank, set_expert_model_parallel_world_size,
|
|
_build_generic_dataset, _build_document_sample_shuffle_indices)
|
|
from .data import build_pretraining_data_loader
|
|
from .tokenizer import build_tokenizer
|
|
from .arguments import parse_args_decorator
|
|
from .training import get_model_wrapper
|
|
from .utils import ALL_MODULE_WRAPPER_CLASSNAMES
|
|
from .checkpointing import _load_base_checkpoint_wrapper, load_checkpoint_wrapper
|
|
from .initialize import _compile_dependencies
|
|
|
|
|
|
def exe_adaptor():
|
|
import megatron
|
|
megatron.utils.ALL_MODULE_WRAPPER_CLASSNAMES = ALL_MODULE_WRAPPER_CLASSNAMES
|
|
megatron.initialize.parse_args = parse_args_decorator(megatron.initialize.parse_args)
|
|
megatron.initialize._compile_dependencies = _compile_dependencies
|
|
megatron.arguments.parse_args = parse_args_decorator(megatron.arguments.parse_args)
|
|
megatron.global_vars.build_tokenizer = build_tokenizer
|
|
|
|
|
|
import megatron.training
|
|
megatron.training.get_model = get_model_wrapper(megatron.training.get_model)
|
|
megatron.training.build_pretraining_data_loader = build_pretraining_data_loader
|
|
|
|
megatron.model.GPTModel = GPTModel
|
|
megatron.model.transformer.SwitchMLP = SwitchMLP
|
|
megatron.model.transformer.ParallelTransformer.__init__ = parallel_transformer_init
|
|
megatron.model.transformer.ParallelTransformer.state_dict_for_save_checkpoint \
|
|
= state_dict_for_save_checkpoint_wrapper(
|
|
megatron.model.transformer.ParallelTransformer.state_dict_for_save_checkpoint)
|
|
megatron.model.language_model.TransformerLanguageModel.forward = (seq_length_wrapper(
|
|
megatron.model.language_model.TransformerLanguageModel.forward))
|
|
|
|
megatron.model.transformer.ParallelAttention.__init__ = ParallelAttention_wrapper(
|
|
megatron.model.transformer.ParallelAttention.__init__)
|
|
megatron.model.transformer.CoreAttention.__init__ = core_attention_wrapper(
|
|
megatron.model.transformer.CoreAttention.__init__)
|
|
megatron.model.transformer.CoreAttention.forward = core_attention_forward
|
|
megatron.model.transformer.FlashSelfAttention = FlashSelfAttention
|
|
|
|
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward = vocab_embedding_wrapper(
|
|
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward)
|
|
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__ = norm_wrapper(
|
|
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__)
|
|
|
|
set_moe_attr()
|
|
megatron.core.parallel_state.initialize_model_parallel = initialize_model_parallel_decorator(
|
|
megatron.core.parallel_state.initialize_model_parallel)
|
|
megatron.core.parallel_state.destroy_model_parallel = destroy_model_parallel_decorator(
|
|
megatron.core.parallel_state.destroy_model_parallel)
|
|
megatron.core.mpu = megatron.core.parallel_state
|
|
|
|
megatron.checkpointing._load_base_checkpoint = _load_base_checkpoint_wrapper(
|
|
megatron.checkpointing._load_base_checkpoint)
|
|
megatron.training.load_checkpoint = load_checkpoint_wrapper(
|
|
megatron.checkpointing.load_checkpoint)
|
|
|
|
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
|
|
from megatron.core.datasets.gpt_dataset import GPTDataset
|
|
GPTDataset._build_document_sample_shuffle_indices = _build_document_sample_shuffle_indices
|
|
BlendedMegatronDatasetBuilder._build_generic_dataset = _build_generic_dataset
|
|
|
|
|
|
def set_moe_attr():
|
|
setattr(megatron.core.parallel_state,
|
|
"get_expert_parallel_group", get_expert_parallel_group)
|
|
setattr(megatron.core.parallel_state,
|
|
"get_expert_parallel_rank", get_expert_parallel_rank)
|
|
setattr(megatron.core.parallel_state,
|
|
"get_expert_model_parallel_rank", get_expert_model_parallel_rank)
|
|
setattr(megatron.core.parallel_state,
|
|
"get_expert_parallel_world_size", get_expert_parallel_world_size)
|
|
setattr(megatron.core.parallel_state,
|
|
"get_expert_model_parallel_world_size", get_expert_model_parallel_world_size)
|
|
setattr(megatron.core.parallel_state,
|
|
"set_expert_model_parallel_rank", set_expert_model_parallel_rank)
|
|
setattr(megatron.core.parallel_state,
|
|
"set_expert_model_parallel_world_size", set_expert_model_parallel_world_size)
|