ModelLink/modellink/core/pipeline_parallel/p2p_communication.py

688 lines
25 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import operator
from functools import reduce
from typing import List, Optional, Tuple, Union
import torch
from deepspeed.accelerator import get_accelerator
import torch.distributed as dist
from modellink import core
from modellink.core import tensor_parallel
from modellink import get_args
from modellink.core import ModelParallelConfig
from modellink.core.parallel_state import (
get_pipeline_model_parallel_group,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_pipeline_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_pipeline_model_parallel_prev_rank_group,
get_pipeline_model_parallel_next_rank_group
)
# Types
Shape = Union[List[int], torch.Size]
def _communicate(
*,
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
wait_on_reqs: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Arguments:
tensor_send_next (torch.Tensor, optional):
Tensor to send to next rank (no tensor sent if None)
tensor_send_prev (torch.Tensor, optional):
Tensor to send to prev rank (no tensor sent if None)
recv_prev (boolean, required):
whether tensor should be received from previous rank.
recv_next (boolean, required):
whether tensor should be received from next rank.
tensor_shape (List[int] or torch.Size, required):
shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
wait_on_reqs (boolean, optional, default=False):
For non-batched p2p communication, wait on each request
before returning.
Returns:
tuple containing
- tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
"""
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
args = get_args()
tensor_shape = tensor_shape if args.optimized_pipeline and (recv_prev or recv_next) \
else (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.sequence_parallel:
seq_length = args.seq_length // get_tensor_model_parallel_world_size()
tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
if not config.variable_seq_lengths:
recv_prev_shape = tensor_shape
recv_next_shape = tensor_shape
else:
recv_prev_shape, recv_next_shape = _communicate_shapes(
tensor_send_next, tensor_send_prev, recv_prev, recv_next, config
)
recv_prev_shape_origin = recv_prev_shape
recv_next_shape_origin = recv_next_shape
if args.scatter_gather_tensors_in_pipeline and not config.sequence_parallel:
recv_prev_shape = reduce(operator.mul, recv_prev_shape, 1) // \
get_tensor_model_parallel_world_size()
recv_next_shape = reduce(operator.mul, recv_next_shape, 1) // \
get_tensor_model_parallel_world_size()
if recv_prev:
if config.pipeline_dtype is None:
raise RuntimeError("pipeline_dtype must be provided if recv_prev is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_prev is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_prev = torch.empty(
recv_prev_shape,
requires_grad=True,
device=get_accelerator().current_device(),
dtype=config.pipeline_dtype,
)
if recv_next:
if config.pipeline_dtype is None:
raise RuntimeError("dtype must be provided if recv_next is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_next is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_next = torch.empty(
recv_next_shape,
requires_grad=True,
device=get_accelerator().current_device(),
dtype=config.pipeline_dtype,
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline and not config.sequence_parallel:
if tensor_send_next is not None:
tensor_send_next = tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
if config.use_ring_exchange_p2p:
def _ring_exchange_wrapper(**kwargs):
torch.distributed.ring_exchange(**kwargs)
return []
p2p_func = _ring_exchange_wrapper
elif config.batch_p2p_comm:
if not wait_on_reqs:
raise Exception("Wait_on_reqs should be true")
p2p_func = _batched_p2p_ops
else:
p2p_func = _p2p_ops
reqs = p2p_func(
tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=get_pipeline_model_parallel_group(),
)
if wait_on_reqs and len(reqs) > 0:
for req in reqs:
req.wait()
reqs = None
# To protect against race condition when using batch_isend_irecv().
# User should assert that we have a modern enough PyTorch to not need this
get_accelerator().synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline and not config.sequence_parallel:
if recv_prev:
tensor_recv_prev = tensor_parallel.gather_split_1d_tensor(
tensor_recv_prev).view(recv_prev_shape_origin).requires_grad_()
if recv_next:
tensor_recv_next = tensor_parallel.gather_split_1d_tensor(
tensor_recv_next).view(recv_next_shape_origin).requires_grad_()
return tensor_recv_prev, tensor_recv_next, reqs
def async_communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.sequence_parallel:
seq_length = args.seq_length // get_tensor_model_parallel_world_size()
tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_prev:
tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=get_accelerator().current_device_name(),
dtype=dtype)
if recv_next:
tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=get_accelerator().current_device_name(),
dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
if tensor_send_next is not None:
tensor_send_next = tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_prev)
ops = []
if tensor_send_prev is not None:
torch.distributed.isend(tensor_send_prev,
get_pipeline_model_parallel_prev_rank(),
group=get_pipeline_model_parallel_prev_rank_group())
if tensor_recv_prev is not None:
ops.append(torch.distributed.irecv(tensor_recv_prev,
get_pipeline_model_parallel_prev_rank(),
group=get_pipeline_model_parallel_prev_rank_group()))
if tensor_send_next is not None:
torch.distributed.isend(tensor_send_next,
get_pipeline_model_parallel_next_rank(),
group=get_pipeline_model_parallel_next_rank_group())
if tensor_recv_next is not None:
ops.append(torch.distributed.irecv(tensor_recv_next,
get_pipeline_model_parallel_next_rank(),
group=get_pipeline_model_parallel_next_rank_group()))
return tensor_recv_prev, tensor_recv_next, ops
def recv_gather(tensor_recv):
args = get_args()
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
tensor_recv = tensor_parallel.gather_split_1d_tensor(
tensor_recv).view(tensor_shape).requires_grad_()
return tensor_recv
def recv_forward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_first_stage():
input_tensor = None
else:
if config.timers is not None:
config.timers('forward-recv', log_level=2).start()
input_tensor, _, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
config=config,
)
if config.timers is not None:
config.timers('forward-recv').stop()
return input_tensor
def recv_backward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
"""Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_last_stage():
output_tensor_grad = None
else:
if config.timers is not None:
config.timers('backward-recv', log_level=2).start()
_, output_tensor_grad, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
config=config,
)
if config.timers is not None:
config.timers('backward-recv').stop()
return output_tensor_grad
def send_forward(output_tensor: torch.Tensor, config: ModelParallelConfig) -> None:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
if not core.parallel_state.is_pipeline_last_stage():
if config.timers is not None:
config.timers('forward-send', log_level=2).start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
tensor_shape=None,
config=config,
)
if config.timers is not None:
config.timers('forward-send').stop()
def send_backward(input_tensor_grad: torch.Tensor, config: ModelParallelConfig) -> None:
"""Send tensor to previous rank in pipeline (backward send).
See _communicate for argument details.
"""
if not core.parallel_state.is_pipeline_first_stage():
if config.timers is not None:
config.timers('backward-send', log_level=2).start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
tensor_shape=None,
config=config,
)
if config.timers is not None:
config.timers('backward-send').stop()
def send_forward_recv_backward(
output_tensor: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
) -> torch.Tensor:
"""Batched send and recv with next rank in pipeline.
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_last_stage():
output_tensor_grad = None
else:
if config.timers is not None:
config.timers('forward-send-backward-recv', log_level=2).start()
_, output_tensor_grad, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
config=config,
)
if config.timers is not None:
config.timers('forward-send-backward-recv').stop()
return output_tensor_grad
def send_backward_recv_forward(
input_tensor_grad: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
) -> torch.Tensor:
"""Batched send and recv with previous rank in pipeline.
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_first_stage():
input_tensor = None
else:
if config.timers is not None:
config.timers('backward-send-forward-recv', log_level=2).start()
input_tensor, _, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
config=config,
)
if config.timers is not None:
config.timers('backward-send-forward-recv').stop()
return input_tensor
def send_forward_recv_forward(
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
overlap_p2p_comm: bool = False,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
if config.timers is not None:
config.timers('forward-send-forward-recv', log_level=2).start()
input_tensor, _, wait_handles = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
tensor_shape=tensor_shape,
wait_on_reqs=(not overlap_p2p_comm),
config=config,
)
if config.timers is not None:
config.timers('forward-send-forward-recv').stop()
if overlap_p2p_comm:
return input_tensor, wait_handles
return input_tensor
def send_backward_recv_backward(
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
overlap_p2p_comm: bool = False,
) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline.
See _communicate for argument details.
"""
if config.timers is not None:
config.timers('backward-send-backward-recv', log_level=2).start()
_, output_tensor_grad, wait_handles = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
tensor_shape=tensor_shape,
wait_on_reqs=(not overlap_p2p_comm),
config=config,
)
if config.timers is not None:
config.timers('backward-send-backward-recv').stop()
if overlap_p2p_comm:
return output_tensor_grad, wait_handles
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
) -> torch.Tensor:
"""Batched send and recv with previous and next ranks in pipeline.
See _communicate for argument details.
"""
if config.timers is not None:
config.timers('forward-backward-send-forward-backward-recv', log_level=2).start()
input_tensor, output_tensor_grad, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
)
if config.timers is not None:
config.timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad
def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next, config):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
This is required when the sequence lengths across micro batches
are not uniform.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
Returns:
(recv_prev_shape, recv_next_shape)
"""
recv_prev_shape_tensor = None
recv_next_shape_tensor = None
send_prev_shape_tensor = None
send_next_shape_tensor = None
if recv_prev:
recv_prev_shape_tensor = torch.empty((3),
device=get_accelerator().current_device(),
dtype=torch.int64)
if recv_next:
recv_next_shape_tensor = torch.empty((3),
device=get_accelerator().current_device(),
dtype=torch.int64)
if tensor_send_prev is not None:
send_prev_shape_tensor = torch.tensor(tensor_send_prev.size(),
device=get_accelerator().current_device(),
dtype=torch.int64)
if tensor_send_next is not None:
send_next_shape_tensor = torch.tensor(tensor_send_next.size(),
device=get_accelerator().current_device(),
dtype=torch.int64)
if config.use_ring_exchange_p2p:
torch.distributed.ring_exchange(
tensor_send_prev=send_prev_shape_tensor,
tensor_recv_prev=recv_prev_shape_tensor,
tensor_send_next=send_next_shape_tensor,
tensor_recv_next=recv_next_shape_tensor,
group=get_pipeline_model_parallel_group(),
)
else:
ops = []
if send_prev_shape_tensor is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
send_prev_shape_tensor,
get_pipeline_model_parallel_prev_rank(),
)
ops.append(send_prev_op)
if recv_prev_shape_tensor is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
recv_prev_shape_tensor,
get_pipeline_model_parallel_prev_rank(),
)
ops.append(recv_prev_op)
if recv_next_shape_tensor is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
recv_next_shape_tensor,
get_pipeline_model_parallel_next_rank(),
)
ops.append(recv_next_op)
if send_next_shape_tensor is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
send_next_shape_tensor,
get_pipeline_model_parallel_next_rank(),
)
ops.append(send_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
# should take this out once the bug with batch_isend_irecv is resolved.
get_accelerator().synchronize()
recv_prev_shape = [0, 0, 0]
if recv_prev_shape_tensor is not None:
recv_prev_shape = recv_prev_shape_tensor.tolist()
recv_next_shape = [0, 0, 0]
if recv_next_shape_tensor is not None:
recv_next_shape = recv_next_shape_tensor.tolist()
return recv_prev_shape, recv_next_shape
def _batched_p2p_ops(
*,
tensor_send_prev: Optional[torch.Tensor],
tensor_recv_prev: Optional[torch.Tensor],
tensor_send_next: Optional[torch.Tensor],
tensor_recv_next: Optional[torch.Tensor],
group: torch.distributed.ProcessGroup
):
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_prev,
get_pipeline_model_parallel_prev_rank(),
group)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
get_pipeline_model_parallel_prev_rank(),
group,
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_next,
get_pipeline_model_parallel_next_rank(),
group,
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
get_pipeline_model_parallel_next_rank(),
group,
)
ops.append(recv_next_op)
if get_pipeline_model_parallel_rank() % 2 == 1:
ops.reverse()
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
else:
reqs = []
return reqs
def _p2p_ops(
*,
tensor_send_prev: Optional[torch.Tensor],
tensor_recv_prev: Optional[torch.Tensor],
tensor_send_next: Optional[torch.Tensor],
tensor_recv_next: Optional[torch.Tensor],
group: torch.distributed.ProcessGroup
):
reqs = []
rank = get_pipeline_model_parallel_rank()
if get_pipeline_model_parallel_rank() % 2 == 0:
if tensor_send_prev is not None:
send_prev_req = torch.distributed.isend(
tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
)
reqs.append(send_prev_req)
if tensor_recv_prev is not None:
recv_prev_req = torch.distributed.irecv(
tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
)
reqs.append(recv_prev_req)
if tensor_recv_next is not None:
recv_next_req = torch.distributed.irecv(
tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
)
reqs.append(recv_next_req)
if tensor_send_next is not None:
send_next_req = torch.distributed.isend(
tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
)
reqs.append(send_next_req)
else:
if tensor_send_prev is not None:
send_prev_req = torch.distributed.isend(
tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
)
reqs.append(send_prev_req)
if tensor_recv_prev is not None:
recv_prev_req = torch.distributed.irecv(
tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
)
reqs.append(recv_prev_req)
if tensor_recv_next is not None:
recv_next_req = torch.distributed.irecv(
tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
)
reqs.append(recv_next_req)
if tensor_send_next is not None:
send_next_req = torch.distributed.isend(
tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
)
reqs.append(send_next_req)
return reqs