ModelLink/modellink/core/parallel_state.py

665 lines
26 KiB
Python

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# pylint: disable=global-statement
"""Model and data parallel groups."""
from typing import Optional
import torch
from modellink.error_utils import (
ensure_valid,
ensure_var_is_none,
ensure_var_is_not_none,
check_equal
)
from .utils import GlobalMemoryBuffer
from ..global_vars import get_args
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Position embedding group.
_POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_GLOO = None
# FP8 amax reduction group.
_AMAX_REDUCTION_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
_PIPELINE_PREV_GROUP = None
_PIPELINE_NEXT_GROUP = None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
use_fp8: bool = False,
) -> None:
"""
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size (int, default = 1):
The number of GPUs to split individual tensors across.
pipeline_model_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
Transformer layers across. For example, if
tensor_model_parallel_size is 4 and
pipeline_model_parallel_size is 2, the model will be split
into 2 groups of 4 GPUs.
virtual_pipeline_model_parallel_size (int, optional):
The number of stages that each pipeline group will have,
interleaving as necessary. If None, no interleaving is
performed. For example, if tensor_model_parallel_size is 1,
pipeline_model_parallel_size is 4,
virtual_pipeline_model_parallel_size is 2, and there are
16 transformer layers in the model, the model will be
split into 8 stages with two layers each and each GPU
would get 2 stages as such (layer number starting with 1):
GPU 0: [1, 2] [9, 10]
GPU 1: [3, 4] [11, 12]
GPU 2: [5, 6] [13, 14]
GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_split_rank (int, optional):
For models with both an encoder and decoder, the rank in
pipeline to switch between encoder and decoder (i.e. the
first rank of the decoder). This allows the user to set
the pipeline parallel size of the encoder and decoder
independently. For example, if
pipeline_model_parallel_size is 8 and
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
use_fp8 (bool, default = False):
Construct GPU groups needed for FP8 training, namely for
amax reduction across the product of the data-parallel and
tensor-parallel groups.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
ensure_valid(not use_fp8, error_message="FP8 not supported by AscendSpeed")
if torch.distributed.get_rank() == 0:
print('> initializing tensor model parallel with size {}'.format(
tensor_model_parallel_size))
print('> initializing pipeline model parallel with size {}'.format(
pipeline_model_parallel_size))
# Get world size and rank. Ensure some consistencies.
ensure_valid(torch.distributed.is_initialized())
world_size: int = torch.distributed.get_world_size()
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
raise RuntimeError(
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size "
f"({pipeline_model_parallel_size})"
)
data_parallel_size: int = world_size // (
tensor_model_parallel_size * pipeline_model_parallel_size
)
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size > 2:
raise RuntimeError(
"pipeline-model-parallel size should be greater than 2 with interleaved schedule"
)
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
if pipeline_model_parallel_split_rank is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
rank = torch.distributed.get_rank()
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GROUP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS
ensure_var_is_none(_DATA_PARALLEL_GROUP, error_message='data parallel group is already initialized')
all_data_parallel_group_ranks = []
args = get_args()
for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
if args.use_distributed_optimizer:
group_gloo = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GLOBAL_RANKS = ranks
if args.use_distributed_optimizer:
_DATA_PARALLEL_GROUP_GLOO = group_gloo
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
ensure_var_is_none(_MODEL_PARALLEL_GROUP, error_message='model parallel group is already initialized')
for i in range(data_parallel_size):
ranks = [
data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks
]
group = torch.distributed.new_group(ranks)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
ensure_var_is_none(_TENSOR_MODEL_PARALLEL_GROUP, error_message='tensor model parallel' \
' group is already initialized')
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
global _PIPELINE_PREV_GROUP
global _PIPELINE_NEXT_GROUP
ensure_var_is_none(_PIPELINE_MODEL_PARALLEL_GROUP, error_message='pipeline model parallel' \
' group is already initialized')
global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
ensure_var_is_none(_EMBEDDING_GROUP, error_message='embedding group is already initialized')
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
ensure_var_is_none(_POSITION_EMBEDDING_GROUP, error_message='position embedding' \
' group is already initialized')
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
for j in iter(range(len(ranks))):
ranks_ = [ranks[j], ranks[(j + 1) % len(ranks)]] if world_size != 1 else [ranks[j]]
group = torch.distributed.new_group(ranks_)
if rank == ranks[j]:
_PIPELINE_NEXT_GROUP = group
if rank == ranks[(j + 1) % len(ranks)]:
_PIPELINE_PREV_GROUP = group
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank is not None:
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
embedding_ranks = [
ranks[0],
ranks[pipeline_model_parallel_split_rank],
ranks[-1],
]
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]
else:
embedding_ranks = ranks
position_embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
group = torch.distributed.new_group(position_embedding_ranks)
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer()
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if (
_TENSOR_MODEL_PARALLEL_GROUP is None
or _PIPELINE_MODEL_PARALLEL_GROUP is None
or _DATA_PARALLEL_GROUP is None
):
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
ensure_var_is_not_none(_MODEL_PARALLEL_GROUP, error_message='model parallel group is not initialized')
return _MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
ensure_var_is_not_none(_TENSOR_MODEL_PARALLEL_GROUP, error_message='intra_layer_model' \
' parallel group is not initialized')
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
ensure_var_is_not_none(_PIPELINE_MODEL_PARALLEL_GROUP, error_message='pipeline_model' \
' parallel group is not initialized')
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
ensure_var_is_not_none(_DATA_PARALLEL_GROUP, error_message='data parallel group is not initialized')
return _DATA_PARALLEL_GROUP
def get_data_parallel_group_gloo():
"""Get the data parallel group-gloo the caller rank belongs to."""
ensure_var_is_not_none(_DATA_PARALLEL_GROUP_GLOO, error_message='data parallel' \
' group-gloo is not initialized')
return _DATA_PARALLEL_GROUP_GLOO
def get_embedding_group():
"""Get the embedding group the caller rank belongs to."""
ensure_var_is_not_none(_EMBEDDING_GROUP, error_message='embedding group is not initialized')
return _EMBEDDING_GROUP
def get_position_embedding_group():
"""Get the position embedding group the caller rank belongs to."""
ensure_var_is_not_none(_POSITION_EMBEDDING_GROUP, error_message='position embedding' \
' group is not initialized')
return _POSITION_EMBEDDING_GROUP
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_model_parallel_world_size():
check_equal(get_pipeline_model_parallel_world_size(), 1, error_info="legacy get_model_parallel_world_size" \
" is only supported if PP is disabled")
return get_tensor_model_parallel_world_size()
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_rank(rank):
"""Set pipeline model parallel rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_TENSOR_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_model_parallel_rank():
check_equal(get_pipeline_model_parallel_world_size(), 1, error_info="legacy get_model_parallel_rank" \
" is only supported if PP is disabled")
return get_tensor_model_parallel_rank()
def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_pipeline_model_parallel_split_rank():
"""Return pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
if (
get_virtual_pipeline_model_parallel_world_size() is not None
and get_virtual_pipeline_model_parallel_rank() != 0
):
return False
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
virtual_pipeline_model_parallel_world_size = (
get_virtual_pipeline_model_parallel_world_size()
)
if virtual_pipeline_model_parallel_world_size is not None \
and get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1):
return False
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1)
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _EMBEDDING_GLOBAL_RANKS
if ignore_virtual:
return rank in _EMBEDDING_GLOBAL_RANKS
if rank in _EMBEDDING_GLOBAL_RANKS:
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
return is_pipeline_first_stage(ignore_virtual=False)
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
return is_pipeline_last_stage(ignore_virtual=False)
else:
return True
return False
def is_rank_in_position_embedding_group():
"""Return true if current rank is in position embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
def is_pipeline_stage_before_split(rank=None):
"""
Return True if pipeline stage executes encoder block for a model
with both encoder and decoder.
"""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""
Return True if pipeline stage executes decoder block for a model
with both encoder and decoder.
"""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""
Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder.
"""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def set_virtual_pipeline_model_parallel_rank(rank):
"""Set the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_virtual_pipeline_model_parallel_world_size():
"""Return the virtual pipeline-parallel world size."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def set_virtual_pipeline_model_parallel_world_size(world_size):
"""Set the virtual pipeline-parallel world size"""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_tensor_model_parallel_src_rank():
"""
Calculate the global rank corresponding to the first local rank
in the tensor model parallel group.
"""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_data_parallel_src_rank():
"""
Calculate the global rank corresponding to the first local rank
in the data parallel group.
"""
ensure_var_is_not_none(_DATA_PARALLEL_GLOBAL_RANKS, error_message="Data parallel group is not initialized")
return _DATA_PARALLEL_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_first_rank():
"""
Return the global rank of the first process in the pipeline for the
current tensor parallel group
"""
ensure_var_is_not_none(_PIPELINE_GLOBAL_RANKS, error_message="Pipeline parallel group is not initialized")
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
"""
Return the global rank of the last process in the pipeline for the
current tensor parallel group
"""
ensure_var_is_not_none(_PIPELINE_GLOBAL_RANKS, error_message="Pipeline parallel group is not initialized")
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
ensure_var_is_not_none(_PIPELINE_GLOBAL_RANKS, error_message="Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline"""
ensure_var_is_not_none(_PIPELINE_GLOBAL_RANKS, error_message="Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_pipeline_model_parallel_prev_rank_group():
ensure_var_is_not_none(_PIPELINE_PREV_GROUP)
return _PIPELINE_PREV_GROUP
def get_pipeline_model_parallel_next_rank_group():
ensure_var_is_not_none(_PIPELINE_NEXT_GROUP)
return _PIPELINE_NEXT_GROUP
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group())
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
ensure_var_is_none(_GLOBAL_MEMORY_BUFFER, error_message='global memory buffer is already initialized')
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def get_global_memory_buffer():
"""Return the global GlobalMemoryBuffer object"""
ensure_var_is_not_none(_GLOBAL_MEMORY_BUFFER, error_message='global memory buffer is not initialized')
return _GLOBAL_MEMORY_BUFFER
def destroy_global_memory_buffer():
"""Sets the global memory buffer to None"""
global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _PIPELINE_NEXT_GROUP
_PIPELINE_NEXT_GROUP = None
global _PIPELINE_PREV_GROUP
_PIPELINE_PREV_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None