688 lines
25 KiB
Python
688 lines
25 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
import operator
|
|
from functools import reduce
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from deepspeed.accelerator import get_accelerator
|
|
import torch.distributed as dist
|
|
from modellink import core
|
|
from modellink.core import tensor_parallel
|
|
from modellink import get_args
|
|
from modellink.core import ModelParallelConfig
|
|
from modellink.core.parallel_state import (
|
|
get_pipeline_model_parallel_group,
|
|
get_pipeline_model_parallel_next_rank,
|
|
get_pipeline_model_parallel_prev_rank,
|
|
get_pipeline_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
get_pipeline_model_parallel_prev_rank_group,
|
|
get_pipeline_model_parallel_next_rank_group
|
|
)
|
|
|
|
# Types
|
|
Shape = Union[List[int], torch.Size]
|
|
|
|
|
|
|
|
def _communicate(
|
|
*,
|
|
tensor_send_next: Optional[torch.Tensor],
|
|
tensor_send_prev: Optional[torch.Tensor],
|
|
recv_prev: bool,
|
|
recv_next: bool,
|
|
tensor_shape: Shape,
|
|
config: ModelParallelConfig,
|
|
wait_on_reqs: bool = True
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Communicate tensors between stages. Used as helper method in other
|
|
communication methods that are used in megatron/schedules.py.
|
|
|
|
Arguments:
|
|
tensor_send_next (torch.Tensor, optional):
|
|
Tensor to send to next rank (no tensor sent if None)
|
|
|
|
tensor_send_prev (torch.Tensor, optional):
|
|
Tensor to send to prev rank (no tensor sent if None)
|
|
|
|
recv_prev (boolean, required):
|
|
whether tensor should be received from previous rank.
|
|
|
|
recv_next (boolean, required):
|
|
whether tensor should be received from next rank.
|
|
|
|
tensor_shape (List[int] or torch.Size, required):
|
|
shape of tensor to receive (this method assumes that all
|
|
tensors sent and received in a single function call are
|
|
the same shape).
|
|
|
|
wait_on_reqs (boolean, optional, default=False):
|
|
For non-batched p2p communication, wait on each request
|
|
before returning.
|
|
|
|
Returns:
|
|
tuple containing
|
|
|
|
- tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
|
|
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
|
|
|
|
"""
|
|
|
|
# Create placeholder tensors for receive in forward and backward directions
|
|
# if needed.
|
|
tensor_recv_prev = None
|
|
tensor_recv_next = None
|
|
args = get_args()
|
|
|
|
tensor_shape = tensor_shape if args.optimized_pipeline and (recv_prev or recv_next) \
|
|
else (args.seq_length, args.micro_batch_size, args.hidden_size)
|
|
|
|
if args.sequence_parallel:
|
|
seq_length = args.seq_length // get_tensor_model_parallel_world_size()
|
|
tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
|
|
|
|
if not config.variable_seq_lengths:
|
|
recv_prev_shape = tensor_shape
|
|
recv_next_shape = tensor_shape
|
|
else:
|
|
recv_prev_shape, recv_next_shape = _communicate_shapes(
|
|
tensor_send_next, tensor_send_prev, recv_prev, recv_next, config
|
|
)
|
|
recv_prev_shape_origin = recv_prev_shape
|
|
recv_next_shape_origin = recv_next_shape
|
|
if args.scatter_gather_tensors_in_pipeline and not config.sequence_parallel:
|
|
recv_prev_shape = reduce(operator.mul, recv_prev_shape, 1) // \
|
|
get_tensor_model_parallel_world_size()
|
|
recv_next_shape = reduce(operator.mul, recv_next_shape, 1) // \
|
|
get_tensor_model_parallel_world_size()
|
|
|
|
if recv_prev:
|
|
if config.pipeline_dtype is None:
|
|
raise RuntimeError("pipeline_dtype must be provided if recv_prev is True")
|
|
if tensor_shape is None:
|
|
raise RuntimeError(
|
|
"tensor_shape must be specified if recv_prev is True. "
|
|
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
|
|
)
|
|
tensor_recv_prev = torch.empty(
|
|
recv_prev_shape,
|
|
requires_grad=True,
|
|
device=get_accelerator().current_device(),
|
|
dtype=config.pipeline_dtype,
|
|
)
|
|
if recv_next:
|
|
if config.pipeline_dtype is None:
|
|
raise RuntimeError("dtype must be provided if recv_next is True")
|
|
if tensor_shape is None:
|
|
raise RuntimeError(
|
|
"tensor_shape must be specified if recv_next is True. "
|
|
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
|
|
)
|
|
tensor_recv_next = torch.empty(
|
|
recv_next_shape,
|
|
requires_grad=True,
|
|
device=get_accelerator().current_device(),
|
|
dtype=config.pipeline_dtype,
|
|
)
|
|
|
|
# Split tensor into smaller chunks if using scatter-gather optimization.
|
|
if args.scatter_gather_tensors_in_pipeline and not config.sequence_parallel:
|
|
if tensor_send_next is not None:
|
|
tensor_send_next = tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_next)
|
|
|
|
if tensor_send_prev is not None:
|
|
tensor_send_prev = tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_prev)
|
|
|
|
# Send tensors in both the forward and backward directions as appropriate.
|
|
if config.use_ring_exchange_p2p:
|
|
def _ring_exchange_wrapper(**kwargs):
|
|
torch.distributed.ring_exchange(**kwargs)
|
|
return []
|
|
|
|
p2p_func = _ring_exchange_wrapper
|
|
elif config.batch_p2p_comm:
|
|
if not wait_on_reqs:
|
|
raise Exception("Wait_on_reqs should be true")
|
|
p2p_func = _batched_p2p_ops
|
|
else:
|
|
p2p_func = _p2p_ops
|
|
|
|
reqs = p2p_func(
|
|
tensor_send_prev=tensor_send_prev,
|
|
tensor_recv_prev=tensor_recv_prev,
|
|
tensor_send_next=tensor_send_next,
|
|
tensor_recv_next=tensor_recv_next,
|
|
group=get_pipeline_model_parallel_group(),
|
|
)
|
|
|
|
if wait_on_reqs and len(reqs) > 0:
|
|
for req in reqs:
|
|
req.wait()
|
|
reqs = None
|
|
|
|
# To protect against race condition when using batch_isend_irecv().
|
|
# User should assert that we have a modern enough PyTorch to not need this
|
|
get_accelerator().synchronize()
|
|
|
|
# If using scatter-gather optimization, gather smaller chunks.
|
|
if args.scatter_gather_tensors_in_pipeline and not config.sequence_parallel:
|
|
if recv_prev:
|
|
tensor_recv_prev = tensor_parallel.gather_split_1d_tensor(
|
|
tensor_recv_prev).view(recv_prev_shape_origin).requires_grad_()
|
|
|
|
if recv_next:
|
|
tensor_recv_next = tensor_parallel.gather_split_1d_tensor(
|
|
tensor_recv_next).view(recv_next_shape_origin).requires_grad_()
|
|
|
|
return tensor_recv_prev, tensor_recv_next, reqs
|
|
|
|
|
|
def async_communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
|
|
args = get_args()
|
|
|
|
# Create placeholder tensors for receive in forward and backward directions
|
|
# if needed.
|
|
tensor_recv_prev = None
|
|
tensor_recv_next = None
|
|
|
|
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
|
|
|
|
if args.sequence_parallel:
|
|
seq_length = args.seq_length // get_tensor_model_parallel_world_size()
|
|
tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
|
|
|
|
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
|
|
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
|
|
get_tensor_model_parallel_world_size()
|
|
else:
|
|
tensor_chunk_shape = tensor_shape
|
|
dtype = args.params_dtype
|
|
if args.fp32_residual_connection:
|
|
dtype = torch.float
|
|
if recv_prev:
|
|
tensor_recv_prev = torch.empty(tensor_chunk_shape,
|
|
requires_grad=True,
|
|
device=get_accelerator().current_device_name(),
|
|
dtype=dtype)
|
|
if recv_next:
|
|
tensor_recv_next = torch.empty(tensor_chunk_shape,
|
|
requires_grad=True,
|
|
device=get_accelerator().current_device_name(),
|
|
dtype=dtype)
|
|
|
|
# Split tensor into smaller chunks if using scatter-gather optimization.
|
|
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
|
|
if tensor_send_next is not None:
|
|
tensor_send_next = tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_next)
|
|
|
|
if tensor_send_prev is not None:
|
|
tensor_send_prev = tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_prev)
|
|
|
|
ops = []
|
|
if tensor_send_prev is not None:
|
|
torch.distributed.isend(tensor_send_prev,
|
|
get_pipeline_model_parallel_prev_rank(),
|
|
group=get_pipeline_model_parallel_prev_rank_group())
|
|
if tensor_recv_prev is not None:
|
|
ops.append(torch.distributed.irecv(tensor_recv_prev,
|
|
get_pipeline_model_parallel_prev_rank(),
|
|
group=get_pipeline_model_parallel_prev_rank_group()))
|
|
if tensor_send_next is not None:
|
|
torch.distributed.isend(tensor_send_next,
|
|
get_pipeline_model_parallel_next_rank(),
|
|
group=get_pipeline_model_parallel_next_rank_group())
|
|
if tensor_recv_next is not None:
|
|
ops.append(torch.distributed.irecv(tensor_recv_next,
|
|
get_pipeline_model_parallel_next_rank(),
|
|
group=get_pipeline_model_parallel_next_rank_group()))
|
|
return tensor_recv_prev, tensor_recv_next, ops
|
|
|
|
|
|
def recv_gather(tensor_recv):
|
|
args = get_args()
|
|
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
|
|
|
|
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
|
|
tensor_recv = tensor_parallel.gather_split_1d_tensor(
|
|
tensor_recv).view(tensor_shape).requires_grad_()
|
|
|
|
return tensor_recv
|
|
|
|
|
|
def recv_forward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
|
|
""" Receive tensor from previous rank in pipeline (forward receive).
|
|
|
|
|
|
See _communicate for argument details.
|
|
"""
|
|
|
|
if core.parallel_state.is_pipeline_first_stage():
|
|
input_tensor = None
|
|
else:
|
|
if config.timers is not None:
|
|
config.timers('forward-recv', log_level=2).start()
|
|
input_tensor, _, _ = _communicate(
|
|
tensor_send_next=None,
|
|
tensor_send_prev=None,
|
|
recv_prev=True,
|
|
recv_next=False,
|
|
tensor_shape=tensor_shape,
|
|
config=config,
|
|
)
|
|
if config.timers is not None:
|
|
config.timers('forward-recv').stop()
|
|
return input_tensor
|
|
|
|
|
|
def recv_backward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
|
|
"""Receive tensor from next rank in pipeline (backward receive).
|
|
|
|
See _communicate for argument details.
|
|
"""
|
|
if core.parallel_state.is_pipeline_last_stage():
|
|
output_tensor_grad = None
|
|
else:
|
|
if config.timers is not None:
|
|
config.timers('backward-recv', log_level=2).start()
|
|
_, output_tensor_grad, _ = _communicate(
|
|
tensor_send_next=None,
|
|
tensor_send_prev=None,
|
|
recv_prev=False,
|
|
recv_next=True,
|
|
tensor_shape=tensor_shape,
|
|
config=config,
|
|
)
|
|
if config.timers is not None:
|
|
config.timers('backward-recv').stop()
|
|
return output_tensor_grad
|
|
|
|
|
|
def send_forward(output_tensor: torch.Tensor, config: ModelParallelConfig) -> None:
|
|
"""Send tensor to next rank in pipeline (forward send).
|
|
|
|
See _communicate for argument details.
|
|
"""
|
|
|
|
if not core.parallel_state.is_pipeline_last_stage():
|
|
if config.timers is not None:
|
|
config.timers('forward-send', log_level=2).start()
|
|
_communicate(
|
|
tensor_send_next=output_tensor,
|
|
tensor_send_prev=None,
|
|
recv_prev=False,
|
|
recv_next=False,
|
|
tensor_shape=None,
|
|
config=config,
|
|
)
|
|
if config.timers is not None:
|
|
config.timers('forward-send').stop()
|
|
|
|
|
|
def send_backward(input_tensor_grad: torch.Tensor, config: ModelParallelConfig) -> None:
|
|
"""Send tensor to previous rank in pipeline (backward send).
|
|
|
|
See _communicate for argument details.
|
|
"""
|
|
if not core.parallel_state.is_pipeline_first_stage():
|
|
if config.timers is not None:
|
|
config.timers('backward-send', log_level=2).start()
|
|
_communicate(
|
|
tensor_send_next=None,
|
|
tensor_send_prev=input_tensor_grad,
|
|
recv_prev=False,
|
|
recv_next=False,
|
|
tensor_shape=None,
|
|
config=config,
|
|
)
|
|
if config.timers is not None:
|
|
config.timers('backward-send').stop()
|
|
|
|
|
|
def send_forward_recv_backward(
|
|
output_tensor: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
|
|
) -> torch.Tensor:
|
|
"""Batched send and recv with next rank in pipeline.
|
|
|
|
See _communicate for argument details.
|
|
"""
|
|
if core.parallel_state.is_pipeline_last_stage():
|
|
output_tensor_grad = None
|
|
else:
|
|
if config.timers is not None:
|
|
config.timers('forward-send-backward-recv', log_level=2).start()
|
|
_, output_tensor_grad, _ = _communicate(
|
|
tensor_send_next=output_tensor,
|
|
tensor_send_prev=None,
|
|
recv_prev=False,
|
|
recv_next=True,
|
|
tensor_shape=tensor_shape,
|
|
config=config,
|
|
)
|
|
if config.timers is not None:
|
|
config.timers('forward-send-backward-recv').stop()
|
|
return output_tensor_grad
|
|
|
|
|
|
def send_backward_recv_forward(
|
|
input_tensor_grad: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
|
|
) -> torch.Tensor:
|
|
"""Batched send and recv with previous rank in pipeline.
|
|
|
|
See _communicate for argument details.
|
|
"""
|
|
if core.parallel_state.is_pipeline_first_stage():
|
|
input_tensor = None
|
|
else:
|
|
if config.timers is not None:
|
|
config.timers('backward-send-forward-recv', log_level=2).start()
|
|
input_tensor, _, _ = _communicate(
|
|
tensor_send_next=None,
|
|
tensor_send_prev=input_tensor_grad,
|
|
recv_prev=True,
|
|
recv_next=False,
|
|
tensor_shape=tensor_shape,
|
|
config=config,
|
|
)
|
|
if config.timers is not None:
|
|
config.timers('backward-send-forward-recv').stop()
|
|
return input_tensor
|
|
|
|
|
|
def send_forward_recv_forward(
|
|
output_tensor: torch.Tensor,
|
|
recv_prev: bool,
|
|
tensor_shape: Shape,
|
|
config: ModelParallelConfig,
|
|
overlap_p2p_comm: bool = False,
|
|
) -> torch.Tensor:
|
|
"""Batched recv from previous rank and send to next rank in pipeline.
|
|
|
|
See _communicate for argument details.
|
|
"""
|
|
if config.timers is not None:
|
|
config.timers('forward-send-forward-recv', log_level=2).start()
|
|
input_tensor, _, wait_handles = _communicate(
|
|
tensor_send_next=output_tensor,
|
|
tensor_send_prev=None,
|
|
recv_prev=recv_prev,
|
|
recv_next=False,
|
|
tensor_shape=tensor_shape,
|
|
wait_on_reqs=(not overlap_p2p_comm),
|
|
config=config,
|
|
)
|
|
if config.timers is not None:
|
|
config.timers('forward-send-forward-recv').stop()
|
|
if overlap_p2p_comm:
|
|
return input_tensor, wait_handles
|
|
return input_tensor
|
|
|
|
|
|
def send_backward_recv_backward(
|
|
input_tensor_grad: torch.Tensor,
|
|
recv_next: bool,
|
|
tensor_shape: Shape,
|
|
config: ModelParallelConfig,
|
|
overlap_p2p_comm: bool = False,
|
|
) -> torch.Tensor:
|
|
"""Batched recv from next rank and send to previous rank in pipeline.
|
|
|
|
See _communicate for argument details.
|
|
"""
|
|
if config.timers is not None:
|
|
config.timers('backward-send-backward-recv', log_level=2).start()
|
|
_, output_tensor_grad, wait_handles = _communicate(
|
|
tensor_send_next=None,
|
|
tensor_send_prev=input_tensor_grad,
|
|
recv_prev=False,
|
|
recv_next=recv_next,
|
|
tensor_shape=tensor_shape,
|
|
wait_on_reqs=(not overlap_p2p_comm),
|
|
config=config,
|
|
)
|
|
if config.timers is not None:
|
|
config.timers('backward-send-backward-recv').stop()
|
|
if overlap_p2p_comm:
|
|
return output_tensor_grad, wait_handles
|
|
return output_tensor_grad
|
|
|
|
|
|
def send_forward_backward_recv_forward_backward(
|
|
output_tensor: torch.Tensor,
|
|
input_tensor_grad: torch.Tensor,
|
|
recv_prev: bool,
|
|
recv_next: bool,
|
|
tensor_shape: Shape,
|
|
config: ModelParallelConfig,
|
|
) -> torch.Tensor:
|
|
"""Batched send and recv with previous and next ranks in pipeline.
|
|
|
|
See _communicate for argument details.
|
|
"""
|
|
if config.timers is not None:
|
|
config.timers('forward-backward-send-forward-backward-recv', log_level=2).start()
|
|
input_tensor, output_tensor_grad, _ = _communicate(
|
|
tensor_send_next=output_tensor,
|
|
tensor_send_prev=input_tensor_grad,
|
|
recv_prev=recv_prev,
|
|
recv_next=recv_next,
|
|
tensor_shape=tensor_shape,
|
|
config=config,
|
|
)
|
|
if config.timers is not None:
|
|
config.timers('forward-backward-send-forward-backward-recv').stop()
|
|
return input_tensor, output_tensor_grad
|
|
|
|
|
|
def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next, config):
|
|
"""Communicate tensor shapes between stages. Used to communicate
|
|
tensor shapes before the actual tensor communication happens.
|
|
This is required when the sequence lengths across micro batches
|
|
are not uniform.
|
|
|
|
Takes the following arguments:
|
|
tensor_send_next: tensor to send to next rank (no tensor sent if
|
|
set to None).
|
|
tensor_send_prev: tensor to send to prev rank (no tensor sent if
|
|
set to None).
|
|
recv_prev: boolean for whether tensor should be received from
|
|
previous rank.
|
|
recv_next: boolean for whether tensor should be received from
|
|
next rank.
|
|
Returns:
|
|
(recv_prev_shape, recv_next_shape)
|
|
"""
|
|
|
|
recv_prev_shape_tensor = None
|
|
recv_next_shape_tensor = None
|
|
send_prev_shape_tensor = None
|
|
send_next_shape_tensor = None
|
|
if recv_prev:
|
|
recv_prev_shape_tensor = torch.empty((3),
|
|
device=get_accelerator().current_device(),
|
|
dtype=torch.int64)
|
|
if recv_next:
|
|
recv_next_shape_tensor = torch.empty((3),
|
|
device=get_accelerator().current_device(),
|
|
dtype=torch.int64)
|
|
if tensor_send_prev is not None:
|
|
send_prev_shape_tensor = torch.tensor(tensor_send_prev.size(),
|
|
device=get_accelerator().current_device(),
|
|
dtype=torch.int64)
|
|
if tensor_send_next is not None:
|
|
send_next_shape_tensor = torch.tensor(tensor_send_next.size(),
|
|
device=get_accelerator().current_device(),
|
|
dtype=torch.int64)
|
|
|
|
if config.use_ring_exchange_p2p:
|
|
torch.distributed.ring_exchange(
|
|
tensor_send_prev=send_prev_shape_tensor,
|
|
tensor_recv_prev=recv_prev_shape_tensor,
|
|
tensor_send_next=send_next_shape_tensor,
|
|
tensor_recv_next=recv_next_shape_tensor,
|
|
group=get_pipeline_model_parallel_group(),
|
|
)
|
|
else:
|
|
ops = []
|
|
if send_prev_shape_tensor is not None:
|
|
send_prev_op = torch.distributed.P2POp(
|
|
torch.distributed.isend,
|
|
send_prev_shape_tensor,
|
|
get_pipeline_model_parallel_prev_rank(),
|
|
)
|
|
ops.append(send_prev_op)
|
|
if recv_prev_shape_tensor is not None:
|
|
recv_prev_op = torch.distributed.P2POp(
|
|
torch.distributed.irecv,
|
|
recv_prev_shape_tensor,
|
|
get_pipeline_model_parallel_prev_rank(),
|
|
)
|
|
ops.append(recv_prev_op)
|
|
if recv_next_shape_tensor is not None:
|
|
recv_next_op = torch.distributed.P2POp(
|
|
torch.distributed.irecv,
|
|
recv_next_shape_tensor,
|
|
get_pipeline_model_parallel_next_rank(),
|
|
)
|
|
ops.append(recv_next_op)
|
|
if send_next_shape_tensor is not None:
|
|
send_next_op = torch.distributed.P2POp(
|
|
torch.distributed.isend,
|
|
send_next_shape_tensor,
|
|
get_pipeline_model_parallel_next_rank(),
|
|
)
|
|
ops.append(send_next_op)
|
|
|
|
if len(ops) > 0:
|
|
reqs = torch.distributed.batch_isend_irecv(ops)
|
|
for req in reqs:
|
|
req.wait()
|
|
|
|
# To protect against race condition when using batch_isend_irecv().
|
|
# should take this out once the bug with batch_isend_irecv is resolved.
|
|
get_accelerator().synchronize()
|
|
|
|
recv_prev_shape = [0, 0, 0]
|
|
if recv_prev_shape_tensor is not None:
|
|
recv_prev_shape = recv_prev_shape_tensor.tolist()
|
|
|
|
recv_next_shape = [0, 0, 0]
|
|
if recv_next_shape_tensor is not None:
|
|
recv_next_shape = recv_next_shape_tensor.tolist()
|
|
|
|
return recv_prev_shape, recv_next_shape
|
|
|
|
|
|
def _batched_p2p_ops(
|
|
*,
|
|
tensor_send_prev: Optional[torch.Tensor],
|
|
tensor_recv_prev: Optional[torch.Tensor],
|
|
tensor_send_next: Optional[torch.Tensor],
|
|
tensor_recv_next: Optional[torch.Tensor],
|
|
group: torch.distributed.ProcessGroup
|
|
):
|
|
ops = []
|
|
if tensor_send_prev is not None:
|
|
send_prev_op = torch.distributed.P2POp(
|
|
torch.distributed.isend,
|
|
tensor_send_prev,
|
|
get_pipeline_model_parallel_prev_rank(),
|
|
group)
|
|
ops.append(send_prev_op)
|
|
if tensor_recv_prev is not None:
|
|
recv_prev_op = torch.distributed.P2POp(
|
|
torch.distributed.irecv,
|
|
tensor_recv_prev,
|
|
get_pipeline_model_parallel_prev_rank(),
|
|
group,
|
|
)
|
|
ops.append(recv_prev_op)
|
|
if tensor_send_next is not None:
|
|
send_next_op = torch.distributed.P2POp(
|
|
torch.distributed.isend,
|
|
tensor_send_next,
|
|
get_pipeline_model_parallel_next_rank(),
|
|
group,
|
|
)
|
|
ops.append(send_next_op)
|
|
if tensor_recv_next is not None:
|
|
recv_next_op = torch.distributed.P2POp(
|
|
torch.distributed.irecv,
|
|
tensor_recv_next,
|
|
get_pipeline_model_parallel_next_rank(),
|
|
group,
|
|
)
|
|
ops.append(recv_next_op)
|
|
|
|
if get_pipeline_model_parallel_rank() % 2 == 1:
|
|
ops.reverse()
|
|
|
|
if len(ops) > 0:
|
|
reqs = torch.distributed.batch_isend_irecv(ops)
|
|
else:
|
|
reqs = []
|
|
return reqs
|
|
|
|
|
|
def _p2p_ops(
|
|
*,
|
|
tensor_send_prev: Optional[torch.Tensor],
|
|
tensor_recv_prev: Optional[torch.Tensor],
|
|
tensor_send_next: Optional[torch.Tensor],
|
|
tensor_recv_next: Optional[torch.Tensor],
|
|
group: torch.distributed.ProcessGroup
|
|
):
|
|
reqs = []
|
|
rank = get_pipeline_model_parallel_rank()
|
|
if get_pipeline_model_parallel_rank() % 2 == 0:
|
|
if tensor_send_prev is not None:
|
|
send_prev_req = torch.distributed.isend(
|
|
tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
|
|
)
|
|
reqs.append(send_prev_req)
|
|
|
|
if tensor_recv_prev is not None:
|
|
recv_prev_req = torch.distributed.irecv(
|
|
tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
|
|
)
|
|
reqs.append(recv_prev_req)
|
|
|
|
if tensor_recv_next is not None:
|
|
recv_next_req = torch.distributed.irecv(
|
|
tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
|
|
)
|
|
reqs.append(recv_next_req)
|
|
|
|
if tensor_send_next is not None:
|
|
send_next_req = torch.distributed.isend(
|
|
tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
|
|
)
|
|
reqs.append(send_next_req)
|
|
|
|
|
|
else:
|
|
if tensor_send_prev is not None:
|
|
send_prev_req = torch.distributed.isend(
|
|
tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
|
|
)
|
|
reqs.append(send_prev_req)
|
|
|
|
if tensor_recv_prev is not None:
|
|
recv_prev_req = torch.distributed.irecv(
|
|
tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
|
|
)
|
|
reqs.append(recv_prev_req)
|
|
|
|
if tensor_recv_next is not None:
|
|
recv_next_req = torch.distributed.irecv(
|
|
tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
|
|
)
|
|
reqs.append(recv_next_req)
|
|
|
|
if tensor_send_next is not None:
|
|
send_next_req = torch.distributed.isend(
|
|
tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
|
|
)
|
|
reqs.append(send_next_req)
|
|
|
|
return reqs
|