665 lines
26 KiB
Python
665 lines
26 KiB
Python
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
|
# pylint: disable=global-statement
|
|
|
|
"""Model and data parallel groups."""
|
|
from typing import Optional
|
|
import torch
|
|
|
|
from modellink.error_utils import (
|
|
ensure_valid,
|
|
ensure_var_is_none,
|
|
ensure_var_is_not_none,
|
|
check_equal
|
|
)
|
|
from .utils import GlobalMemoryBuffer
|
|
from ..global_vars import get_args
|
|
|
|
# Intra-layer model parallel group that the current rank belongs to.
|
|
_TENSOR_MODEL_PARALLEL_GROUP = None
|
|
# Inter-layer model parallel group that the current rank belongs to.
|
|
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
|
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
|
|
_MODEL_PARALLEL_GROUP = None
|
|
# Embedding group.
|
|
_EMBEDDING_GROUP = None
|
|
# Position embedding group.
|
|
_POSITION_EMBEDDING_GROUP = None
|
|
# Data parallel group that the current rank belongs to.
|
|
_DATA_PARALLEL_GROUP = None
|
|
_DATA_PARALLEL_GROUP_GLOO = None
|
|
# FP8 amax reduction group.
|
|
_AMAX_REDUCTION_GROUP = None
|
|
|
|
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
|
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
|
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
|
|
_PIPELINE_PREV_GROUP = None
|
|
_PIPELINE_NEXT_GROUP = None
|
|
# These values enable us to change the mpu sizes on the fly.
|
|
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
|
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
|
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
|
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
|
|
|
# A list of ranks that have a copy of the embedding.
|
|
_EMBEDDING_GLOBAL_RANKS = None
|
|
|
|
# A list of ranks that have a copy of the position embedding.
|
|
_POSITION_EMBEDDING_GLOBAL_RANKS = None
|
|
|
|
# A list of global ranks for each pipeline group to ease calculation of the source
|
|
# rank when broadcasting from the first or last pipeline stage.
|
|
_PIPELINE_GLOBAL_RANKS = None
|
|
|
|
# A list of global ranks for each data parallel group to ease calculation of the source
|
|
# rank when broadcasting weights from src to all other data parallel ranks
|
|
_DATA_PARALLEL_GLOBAL_RANKS = None
|
|
|
|
# Memory buffers to avoid dynamic memory allocation
|
|
_GLOBAL_MEMORY_BUFFER = None
|
|
|
|
|
|
def initialize_model_parallel(
|
|
tensor_model_parallel_size: int = 1,
|
|
pipeline_model_parallel_size: int = 1,
|
|
virtual_pipeline_model_parallel_size: Optional[int] = None,
|
|
pipeline_model_parallel_split_rank: Optional[int] = None,
|
|
use_fp8: bool = False,
|
|
) -> None:
|
|
"""
|
|
Initialize model data parallel groups.
|
|
|
|
Arguments:
|
|
tensor_model_parallel_size (int, default = 1):
|
|
The number of GPUs to split individual tensors across.
|
|
|
|
pipeline_model_parallel_size (int, default = 1):
|
|
The number of tensor parallel GPU groups to split the
|
|
Transformer layers across. For example, if
|
|
tensor_model_parallel_size is 4 and
|
|
pipeline_model_parallel_size is 2, the model will be split
|
|
into 2 groups of 4 GPUs.
|
|
|
|
virtual_pipeline_model_parallel_size (int, optional):
|
|
The number of stages that each pipeline group will have,
|
|
interleaving as necessary. If None, no interleaving is
|
|
performed. For example, if tensor_model_parallel_size is 1,
|
|
pipeline_model_parallel_size is 4,
|
|
virtual_pipeline_model_parallel_size is 2, and there are
|
|
16 transformer layers in the model, the model will be
|
|
split into 8 stages with two layers each and each GPU
|
|
would get 2 stages as such (layer number starting with 1):
|
|
|
|
GPU 0: [1, 2] [9, 10]
|
|
GPU 1: [3, 4] [11, 12]
|
|
GPU 2: [5, 6] [13, 14]
|
|
GPU 3: [7, 8] [15, 16]
|
|
|
|
pipeline_model_parallel_split_rank (int, optional):
|
|
For models with both an encoder and decoder, the rank in
|
|
pipeline to switch between encoder and decoder (i.e. the
|
|
first rank of the decoder). This allows the user to set
|
|
the pipeline parallel size of the encoder and decoder
|
|
independently. For example, if
|
|
pipeline_model_parallel_size is 8 and
|
|
pipeline_model_parallel_split_rank is 3, then ranks 0-2
|
|
will be the encoder and ranks 3-7 will be the decoder.
|
|
|
|
use_fp8 (bool, default = False):
|
|
Construct GPU groups needed for FP8 training, namely for
|
|
amax reduction across the product of the data-parallel and
|
|
tensor-parallel groups.
|
|
|
|
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
|
|
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
|
the model pipeline. The present function will
|
|
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
|
|
and 8 data-parallel groups as:
|
|
8 data_parallel groups:
|
|
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
|
|
8 tensor model-parallel groups:
|
|
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
|
|
4 pipeline model-parallel groups:
|
|
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
|
|
Note that for efficiency, the caller should make sure adjacent ranks
|
|
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
|
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
|
ranks 8 to 15 belong to the second box.
|
|
"""
|
|
ensure_valid(not use_fp8, error_message="FP8 not supported by AscendSpeed")
|
|
if torch.distributed.get_rank() == 0:
|
|
print('> initializing tensor model parallel with size {}'.format(
|
|
tensor_model_parallel_size))
|
|
print('> initializing pipeline model parallel with size {}'.format(
|
|
pipeline_model_parallel_size))
|
|
# Get world size and rank. Ensure some consistencies.
|
|
ensure_valid(torch.distributed.is_initialized())
|
|
world_size: int = torch.distributed.get_world_size()
|
|
|
|
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
|
|
raise RuntimeError(
|
|
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
|
|
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size "
|
|
f"({pipeline_model_parallel_size})"
|
|
)
|
|
|
|
data_parallel_size: int = world_size // (
|
|
tensor_model_parallel_size * pipeline_model_parallel_size
|
|
)
|
|
|
|
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
|
|
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
|
|
|
if virtual_pipeline_model_parallel_size is not None:
|
|
if not pipeline_model_parallel_size > 2:
|
|
raise RuntimeError(
|
|
"pipeline-model-parallel size should be greater than 2 with interleaved schedule"
|
|
)
|
|
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
|
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
|
|
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
|
|
|
|
if pipeline_model_parallel_split_rank is not None:
|
|
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
|
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
|
|
|
|
rank = torch.distributed.get_rank()
|
|
|
|
# Build the data-parallel groups.
|
|
global _DATA_PARALLEL_GROUP
|
|
global _DATA_PARALLEL_GROUP_GLOO
|
|
global _DATA_PARALLEL_GLOBAL_RANKS
|
|
ensure_var_is_none(_DATA_PARALLEL_GROUP, error_message='data parallel group is already initialized')
|
|
all_data_parallel_group_ranks = []
|
|
args = get_args()
|
|
for i in range(pipeline_model_parallel_size):
|
|
start_rank = i * num_pipeline_model_parallel_groups
|
|
end_rank = (i + 1) * num_pipeline_model_parallel_groups
|
|
for j in range(tensor_model_parallel_size):
|
|
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
|
|
all_data_parallel_group_ranks.append(list(ranks))
|
|
group = torch.distributed.new_group(ranks)
|
|
if args.use_distributed_optimizer:
|
|
group_gloo = torch.distributed.new_group(ranks, backend="gloo")
|
|
if rank in ranks:
|
|
_DATA_PARALLEL_GROUP = group
|
|
_DATA_PARALLEL_GLOBAL_RANKS = ranks
|
|
if args.use_distributed_optimizer:
|
|
_DATA_PARALLEL_GROUP_GLOO = group_gloo
|
|
|
|
# Build the model-parallel groups.
|
|
global _MODEL_PARALLEL_GROUP
|
|
ensure_var_is_none(_MODEL_PARALLEL_GROUP, error_message='model parallel group is already initialized')
|
|
for i in range(data_parallel_size):
|
|
ranks = [
|
|
data_parallel_group_ranks[i]
|
|
for data_parallel_group_ranks in all_data_parallel_group_ranks
|
|
]
|
|
group = torch.distributed.new_group(ranks)
|
|
if rank in ranks:
|
|
_MODEL_PARALLEL_GROUP = group
|
|
|
|
# Build the tensor model-parallel groups.
|
|
global _TENSOR_MODEL_PARALLEL_GROUP
|
|
ensure_var_is_none(_TENSOR_MODEL_PARALLEL_GROUP, error_message='tensor model parallel' \
|
|
' group is already initialized')
|
|
for i in range(num_tensor_model_parallel_groups):
|
|
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
|
|
group = torch.distributed.new_group(ranks)
|
|
if rank in ranks:
|
|
_TENSOR_MODEL_PARALLEL_GROUP = group
|
|
|
|
# Build the pipeline model-parallel groups and embedding groups
|
|
# (first and last rank in each pipeline model-parallel group).
|
|
global _PIPELINE_MODEL_PARALLEL_GROUP
|
|
global _PIPELINE_GLOBAL_RANKS
|
|
global _PIPELINE_PREV_GROUP
|
|
global _PIPELINE_NEXT_GROUP
|
|
ensure_var_is_none(_PIPELINE_MODEL_PARALLEL_GROUP, error_message='pipeline model parallel' \
|
|
' group is already initialized')
|
|
global _EMBEDDING_GROUP
|
|
global _EMBEDDING_GLOBAL_RANKS
|
|
ensure_var_is_none(_EMBEDDING_GROUP, error_message='embedding group is already initialized')
|
|
global _POSITION_EMBEDDING_GROUP
|
|
global _POSITION_EMBEDDING_GLOBAL_RANKS
|
|
ensure_var_is_none(_POSITION_EMBEDDING_GROUP, error_message='position embedding' \
|
|
' group is already initialized')
|
|
for i in range(num_pipeline_model_parallel_groups):
|
|
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
|
group = torch.distributed.new_group(ranks)
|
|
if rank in ranks:
|
|
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
|
_PIPELINE_GLOBAL_RANKS = ranks
|
|
for j in iter(range(len(ranks))):
|
|
ranks_ = [ranks[j], ranks[(j + 1) % len(ranks)]] if world_size != 1 else [ranks[j]]
|
|
group = torch.distributed.new_group(ranks_)
|
|
if rank == ranks[j]:
|
|
_PIPELINE_NEXT_GROUP = group
|
|
if rank == ranks[(j + 1) % len(ranks)]:
|
|
_PIPELINE_PREV_GROUP = group
|
|
# Setup embedding group (to exchange gradients between
|
|
# first and last stages).
|
|
if len(ranks) > 1:
|
|
embedding_ranks = [ranks[0], ranks[-1]]
|
|
position_embedding_ranks = [ranks[0]]
|
|
if pipeline_model_parallel_split_rank is not None:
|
|
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
|
|
embedding_ranks = [
|
|
ranks[0],
|
|
ranks[pipeline_model_parallel_split_rank],
|
|
ranks[-1],
|
|
]
|
|
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
|
|
position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]
|
|
else:
|
|
embedding_ranks = ranks
|
|
position_embedding_ranks = ranks
|
|
|
|
group = torch.distributed.new_group(embedding_ranks)
|
|
if rank in embedding_ranks:
|
|
_EMBEDDING_GROUP = group
|
|
if rank in ranks:
|
|
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
|
|
|
|
group = torch.distributed.new_group(position_embedding_ranks)
|
|
if rank in position_embedding_ranks:
|
|
_POSITION_EMBEDDING_GROUP = group
|
|
if rank in ranks:
|
|
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
|
|
|
|
# Initialize global memory buffer
|
|
# This isn't really "parallel state" but there isn't another good place to
|
|
# put this. If we end up with a more generic initialization of megatron-core
|
|
# we could stick it there
|
|
_set_global_memory_buffer()
|
|
|
|
|
|
def is_unitialized():
|
|
"""Useful for code segments that may be accessed with or without mpu initialization"""
|
|
return _DATA_PARALLEL_GROUP is None
|
|
|
|
|
|
def model_parallel_is_initialized():
|
|
"""Check if model and data parallel groups are initialized."""
|
|
if (
|
|
_TENSOR_MODEL_PARALLEL_GROUP is None
|
|
or _PIPELINE_MODEL_PARALLEL_GROUP is None
|
|
or _DATA_PARALLEL_GROUP is None
|
|
):
|
|
return False
|
|
return True
|
|
|
|
|
|
def get_model_parallel_group():
|
|
"""Get the model parallel group the caller rank belongs to."""
|
|
ensure_var_is_not_none(_MODEL_PARALLEL_GROUP, error_message='model parallel group is not initialized')
|
|
return _MODEL_PARALLEL_GROUP
|
|
|
|
|
|
def get_tensor_model_parallel_group():
|
|
"""Get the tensor model parallel group the caller rank belongs to."""
|
|
ensure_var_is_not_none(_TENSOR_MODEL_PARALLEL_GROUP, error_message='intra_layer_model' \
|
|
' parallel group is not initialized')
|
|
return _TENSOR_MODEL_PARALLEL_GROUP
|
|
|
|
|
|
def get_pipeline_model_parallel_group():
|
|
"""Get the pipeline model parallel group the caller rank belongs to."""
|
|
ensure_var_is_not_none(_PIPELINE_MODEL_PARALLEL_GROUP, error_message='pipeline_model' \
|
|
' parallel group is not initialized')
|
|
return _PIPELINE_MODEL_PARALLEL_GROUP
|
|
|
|
|
|
def get_data_parallel_group():
|
|
"""Get the data parallel group the caller rank belongs to."""
|
|
ensure_var_is_not_none(_DATA_PARALLEL_GROUP, error_message='data parallel group is not initialized')
|
|
return _DATA_PARALLEL_GROUP
|
|
|
|
|
|
def get_data_parallel_group_gloo():
|
|
"""Get the data parallel group-gloo the caller rank belongs to."""
|
|
ensure_var_is_not_none(_DATA_PARALLEL_GROUP_GLOO, error_message='data parallel' \
|
|
' group-gloo is not initialized')
|
|
return _DATA_PARALLEL_GROUP_GLOO
|
|
|
|
|
|
def get_embedding_group():
|
|
"""Get the embedding group the caller rank belongs to."""
|
|
ensure_var_is_not_none(_EMBEDDING_GROUP, error_message='embedding group is not initialized')
|
|
return _EMBEDDING_GROUP
|
|
|
|
|
|
def get_position_embedding_group():
|
|
"""Get the position embedding group the caller rank belongs to."""
|
|
ensure_var_is_not_none(_POSITION_EMBEDDING_GROUP, error_message='position embedding' \
|
|
' group is not initialized')
|
|
return _POSITION_EMBEDDING_GROUP
|
|
|
|
|
|
def set_tensor_model_parallel_world_size(world_size):
|
|
"""Set the tensor model parallel size"""
|
|
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
|
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
|
|
|
|
|
|
def set_pipeline_model_parallel_world_size(world_size):
|
|
"""Set the pipeline model parallel size"""
|
|
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
|
|
|
|
|
|
def get_tensor_model_parallel_world_size():
|
|
"""Return world size for the tensor model parallel group."""
|
|
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
|
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
|
|
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
|
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
|
|
|
|
|
|
def get_model_parallel_world_size():
|
|
check_equal(get_pipeline_model_parallel_world_size(), 1, error_info="legacy get_model_parallel_world_size" \
|
|
" is only supported if PP is disabled")
|
|
return get_tensor_model_parallel_world_size()
|
|
|
|
|
|
def get_pipeline_model_parallel_world_size():
|
|
"""Return world size for the pipeline model parallel group."""
|
|
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
|
|
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
|
|
|
|
|
|
def set_tensor_model_parallel_rank(rank):
|
|
"""Set tensor model parallel rank."""
|
|
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
|
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
|
|
|
|
|
|
def set_pipeline_model_parallel_rank(rank):
|
|
"""Set pipeline model parallel rank."""
|
|
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
|
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
|
|
|
|
|
|
def set_pipeline_model_parallel_split_rank(rank):
|
|
"""Set pipeline model parallel split rank."""
|
|
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
|
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
|
|
|
|
|
|
def get_tensor_model_parallel_rank():
|
|
"""Return my rank for the tensor model parallel group."""
|
|
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
|
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
|
|
return _MPU_TENSOR_MODEL_PARALLEL_RANK
|
|
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
|
|
|
|
|
def get_model_parallel_rank():
|
|
check_equal(get_pipeline_model_parallel_world_size(), 1, error_info="legacy get_model_parallel_rank" \
|
|
" is only supported if PP is disabled")
|
|
return get_tensor_model_parallel_rank()
|
|
|
|
|
|
def get_pipeline_model_parallel_rank():
|
|
"""Return my rank for the pipeline model parallel group."""
|
|
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
|
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
|
|
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
|
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
|
|
|
|
|
|
def get_pipeline_model_parallel_split_rank():
|
|
"""Return pipeline model parallel split rank."""
|
|
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
|
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
|
|
|
|
|
def is_pipeline_first_stage(ignore_virtual=False):
|
|
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
|
|
if not ignore_virtual:
|
|
if (
|
|
get_virtual_pipeline_model_parallel_world_size() is not None
|
|
and get_virtual_pipeline_model_parallel_rank() != 0
|
|
):
|
|
return False
|
|
return get_pipeline_model_parallel_rank() == 0
|
|
|
|
|
|
def is_pipeline_last_stage(ignore_virtual=False):
|
|
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
|
|
if not ignore_virtual:
|
|
virtual_pipeline_model_parallel_world_size = (
|
|
get_virtual_pipeline_model_parallel_world_size()
|
|
)
|
|
if virtual_pipeline_model_parallel_world_size is not None \
|
|
and get_virtual_pipeline_model_parallel_rank() != (
|
|
virtual_pipeline_model_parallel_world_size - 1):
|
|
return False
|
|
return get_pipeline_model_parallel_rank() == (
|
|
get_pipeline_model_parallel_world_size() - 1)
|
|
|
|
|
|
|
|
def is_rank_in_embedding_group(ignore_virtual=False):
|
|
"""Return true if current rank is in embedding group, False otherwise."""
|
|
rank = torch.distributed.get_rank()
|
|
global _EMBEDDING_GLOBAL_RANKS
|
|
if ignore_virtual:
|
|
return rank in _EMBEDDING_GLOBAL_RANKS
|
|
if rank in _EMBEDDING_GLOBAL_RANKS:
|
|
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
|
|
return is_pipeline_first_stage(ignore_virtual=False)
|
|
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
|
|
return is_pipeline_last_stage(ignore_virtual=False)
|
|
else:
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_rank_in_position_embedding_group():
|
|
"""Return true if current rank is in position embedding group, False otherwise."""
|
|
rank = torch.distributed.get_rank()
|
|
global _POSITION_EMBEDDING_GLOBAL_RANKS
|
|
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
|
|
|
|
|
|
def is_pipeline_stage_before_split(rank=None):
|
|
"""
|
|
Return True if pipeline stage executes encoder block for a model
|
|
with both encoder and decoder.
|
|
"""
|
|
if get_pipeline_model_parallel_world_size() == 1:
|
|
return True
|
|
if rank is None:
|
|
rank = get_pipeline_model_parallel_rank()
|
|
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
|
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
|
|
return True
|
|
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_pipeline_stage_after_split(rank=None):
|
|
"""
|
|
Return True if pipeline stage executes decoder block for a model
|
|
with both encoder and decoder.
|
|
"""
|
|
if get_pipeline_model_parallel_world_size() == 1:
|
|
return True
|
|
if rank is None:
|
|
rank = get_pipeline_model_parallel_rank()
|
|
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
|
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
|
|
return True
|
|
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_pipeline_stage_at_split():
|
|
"""
|
|
Return true if pipeline stage executes decoder block and next
|
|
stage executes encoder block for a model with both encoder and
|
|
decoder.
|
|
"""
|
|
rank = get_pipeline_model_parallel_rank()
|
|
return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
|
|
|
|
|
|
|
|
def get_virtual_pipeline_model_parallel_rank():
|
|
"""Return the virtual pipeline-parallel rank."""
|
|
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
|
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
|
|
|
|
|
def set_virtual_pipeline_model_parallel_rank(rank):
|
|
"""Set the virtual pipeline-parallel rank."""
|
|
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
|
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
|
|
|
|
|
|
def get_virtual_pipeline_model_parallel_world_size():
|
|
"""Return the virtual pipeline-parallel world size."""
|
|
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
|
|
|
|
def set_virtual_pipeline_model_parallel_world_size(world_size):
|
|
"""Set the virtual pipeline-parallel world size"""
|
|
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
|
|
|
|
|
|
def get_tensor_model_parallel_src_rank():
|
|
"""
|
|
Calculate the global rank corresponding to the first local rank
|
|
in the tensor model parallel group.
|
|
"""
|
|
global_rank = torch.distributed.get_rank()
|
|
local_world_size = get_tensor_model_parallel_world_size()
|
|
return (global_rank // local_world_size) * local_world_size
|
|
|
|
|
|
def get_data_parallel_src_rank():
|
|
"""
|
|
Calculate the global rank corresponding to the first local rank
|
|
in the data parallel group.
|
|
"""
|
|
ensure_var_is_not_none(_DATA_PARALLEL_GLOBAL_RANKS, error_message="Data parallel group is not initialized")
|
|
return _DATA_PARALLEL_GLOBAL_RANKS[0]
|
|
|
|
|
|
def get_pipeline_model_parallel_first_rank():
|
|
"""
|
|
Return the global rank of the first process in the pipeline for the
|
|
current tensor parallel group
|
|
"""
|
|
ensure_var_is_not_none(_PIPELINE_GLOBAL_RANKS, error_message="Pipeline parallel group is not initialized")
|
|
return _PIPELINE_GLOBAL_RANKS[0]
|
|
|
|
|
|
def get_pipeline_model_parallel_last_rank():
|
|
"""
|
|
Return the global rank of the last process in the pipeline for the
|
|
current tensor parallel group
|
|
"""
|
|
ensure_var_is_not_none(_PIPELINE_GLOBAL_RANKS, error_message="Pipeline parallel group is not initialized")
|
|
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
|
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
|
|
|
|
|
def get_pipeline_model_parallel_next_rank():
|
|
"""Return the global rank that follows the caller in the pipeline"""
|
|
ensure_var_is_not_none(_PIPELINE_GLOBAL_RANKS, error_message="Pipeline parallel group is not initialized")
|
|
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
|
world_size = get_pipeline_model_parallel_world_size()
|
|
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
|
|
|
|
|
def get_pipeline_model_parallel_prev_rank():
|
|
"""Return the global rank that preceeds the caller in the pipeline"""
|
|
ensure_var_is_not_none(_PIPELINE_GLOBAL_RANKS, error_message="Pipeline parallel group is not initialized")
|
|
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
|
world_size = get_pipeline_model_parallel_world_size()
|
|
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
|
|
|
|
|
def get_pipeline_model_parallel_prev_rank_group():
|
|
ensure_var_is_not_none(_PIPELINE_PREV_GROUP)
|
|
return _PIPELINE_PREV_GROUP
|
|
|
|
|
|
def get_pipeline_model_parallel_next_rank_group():
|
|
ensure_var_is_not_none(_PIPELINE_NEXT_GROUP)
|
|
return _PIPELINE_NEXT_GROUP
|
|
|
|
|
|
def get_data_parallel_world_size():
|
|
"""Return world size for the data parallel group."""
|
|
return torch.distributed.get_world_size(group=get_data_parallel_group())
|
|
|
|
|
|
def get_data_parallel_rank():
|
|
"""Return my rank for the data parallel group."""
|
|
return torch.distributed.get_rank(group=get_data_parallel_group())
|
|
|
|
|
|
def _set_global_memory_buffer():
|
|
"""Initialize global buffer"""
|
|
global _GLOBAL_MEMORY_BUFFER
|
|
ensure_var_is_none(_GLOBAL_MEMORY_BUFFER, error_message='global memory buffer is already initialized')
|
|
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
|
|
|
|
|
|
def get_global_memory_buffer():
|
|
"""Return the global GlobalMemoryBuffer object"""
|
|
ensure_var_is_not_none(_GLOBAL_MEMORY_BUFFER, error_message='global memory buffer is not initialized')
|
|
return _GLOBAL_MEMORY_BUFFER
|
|
|
|
|
|
def destroy_global_memory_buffer():
|
|
"""Sets the global memory buffer to None"""
|
|
global _GLOBAL_MEMORY_BUFFER
|
|
_GLOBAL_MEMORY_BUFFER = None
|
|
|
|
|
|
def destroy_model_parallel():
|
|
"""Set the groups to none."""
|
|
global _MODEL_PARALLEL_GROUP
|
|
_MODEL_PARALLEL_GROUP = None
|
|
global _TENSOR_MODEL_PARALLEL_GROUP
|
|
_TENSOR_MODEL_PARALLEL_GROUP = None
|
|
global _PIPELINE_MODEL_PARALLEL_GROUP
|
|
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
|
global _DATA_PARALLEL_GROUP
|
|
_DATA_PARALLEL_GROUP = None
|
|
global _PIPELINE_NEXT_GROUP
|
|
_PIPELINE_NEXT_GROUP = None
|
|
global _PIPELINE_PREV_GROUP
|
|
_PIPELINE_PREV_GROUP = None
|
|
global _EMBEDDING_GROUP
|
|
_EMBEDDING_GROUP = None
|
|
global _POSITION_EMBEDDING_GROUP
|
|
_POSITION_EMBEDDING_GROUP = None
|
|
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
|
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
|
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
|
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
|
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
|
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
|
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
|
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
|
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
|
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
|
global _GLOBAL_MEMORY_BUFFER
|
|
_GLOBAL_MEMORY_BUFFER = None
|
|
|
|
|