296 lines
8.9 KiB
Python
296 lines
8.9 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
|
|
from modellink.core.parallel_state import (
|
|
get_tensor_model_parallel_group,
|
|
get_tensor_model_parallel_world_size,
|
|
get_tensor_model_parallel_rank
|
|
)
|
|
from modellink.core.tensor_parallel.utils import split_tensor_along_last_dim
|
|
from modellink.error_utils import check_divisible
|
|
|
|
|
|
def _reduce(input_):
|
|
"""All-reduce the the input tensor across model parallel group."""
|
|
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if get_tensor_model_parallel_world_size() == 1:
|
|
return input_
|
|
|
|
# All-reduce.
|
|
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
|
|
|
|
return input_
|
|
|
|
|
|
def _split_along_last_dim(input_):
|
|
"""Split the tensor along its last dimension and keep the
|
|
corresponding slice."""
|
|
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
|
|
# Split along last dimension.
|
|
input_list = split_tensor_along_last_dim(input_, world_size)
|
|
|
|
# Note: torch.split does not create contiguous tensors by default.
|
|
rank = get_tensor_model_parallel_rank()
|
|
output = input_list[rank].contiguous()
|
|
|
|
return output
|
|
|
|
|
|
def _gather_along_last_dim(input_):
|
|
"""Gather tensors and concatinate along the last dimension."""
|
|
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
|
|
# Size and dimension.
|
|
last_dim = input_.dim() - 1
|
|
rank = get_tensor_model_parallel_rank()
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
tensor_list[rank] = input_
|
|
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
|
|
|
|
# Note: torch.cat already creates a contiguous tensor.
|
|
output = torch.cat(tensor_list, dim=last_dim).contiguous()
|
|
|
|
return output
|
|
|
|
|
|
def _gather_along_first_dim(input_):
|
|
"""Gather tensors and concatinate along the first dimension."""
|
|
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
|
|
dim_size = list(input_.size())
|
|
dim_size[0] = dim_size[0] * world_size
|
|
|
|
output = torch.empty(dim_size, dtype=input_.dtype,
|
|
device=torch.cuda.current_device())
|
|
torch.distributed._all_gather_base(output, input_.contiguous(),
|
|
group=get_tensor_model_parallel_group())
|
|
|
|
return output
|
|
|
|
|
|
def _reduce_scatter_along_first_dim(input_):
|
|
"""Reduce-scatter the input tensor across model parallel group."""
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
|
|
dim_size = list(input_.size())
|
|
check_divisible(dim_size[0], world_size,
|
|
error_info="{} % {}, First dimension of the tensor should be divisible by tensor parallel size")
|
|
|
|
dim_size[0] = dim_size[0] // world_size
|
|
|
|
output = torch.empty(dim_size, dtype=input_.dtype,
|
|
device=torch.cuda.current_device())
|
|
|
|
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
|
|
group=get_tensor_model_parallel_group())
|
|
return output
|
|
|
|
|
|
def _split_along_first_dim(input_):
|
|
"""Split the tensor along its first dimension and keep the
|
|
corresponding slice."""
|
|
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
|
|
# Split along first dimension.
|
|
dim_size = input_.size()[0]
|
|
check_divisible(dim_size, world_size,
|
|
error_info="{} % {}, First dimension of the tensor should be divisible by tensor parallel size")
|
|
local_dim_size = dim_size // world_size
|
|
rank = get_tensor_model_parallel_rank()
|
|
dim_offset = rank * local_dim_size
|
|
|
|
output = input_[dim_offset:dim_offset + local_dim_size].contiguous()
|
|
|
|
return output
|
|
|
|
|
|
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
|
|
"""Split the input and keep only the corresponding chuck to the rank."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return _split_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return _split_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _gather_along_first_dim(grad_output)
|
|
|
|
|
|
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
|
|
"""Gather the input from sequence parallel region and concatinate."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_, tensor_parallel_output_grad=True):
|
|
return _gather_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_, tensor_parallel_output_grad=True):
|
|
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
|
|
return _gather_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
|
|
|
|
# If the computation graph after the gather operation is
|
|
# in the tensor parallel mode, output gradients need to reduce
|
|
# scattered and whereas if the computation is duplicated,
|
|
# output gradients need to be scattered.
|
|
if tensor_parallel_output_grad:
|
|
return _reduce_scatter_along_first_dim(grad_output), None
|
|
else:
|
|
return _split_along_first_dim(grad_output), None
|
|
|
|
|
|
class _CopyToModelParallelRegion(torch.autograd.Function):
|
|
"""Pass the input to the model parallel region."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return input_
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return input_
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _reduce(grad_output)
|
|
|
|
|
|
class _ReduceFromModelParallelRegion(torch.autograd.Function):
|
|
"""All-reduce the input from the model parallel region."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return _reduce(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return _reduce(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
|
|
class _ScatterToModelParallelRegion(torch.autograd.Function):
|
|
"""Split the input and keep only the corresponding chuck to the rank."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return _split_along_last_dim(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return _split_along_last_dim(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _gather_along_last_dim(grad_output)
|
|
|
|
|
|
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
|
|
"""Reduce scatter the input from the model parallel region."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return _reduce_scatter_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return _reduce_scatter_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _gather_along_first_dim(grad_output)
|
|
|
|
|
|
|
|
class _GatherFromModelParallelRegion(torch.autograd.Function):
|
|
"""Gather the input from model parallel region and concatinate."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return _gather_along_last_dim(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return _gather_along_last_dim(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _split_along_last_dim(grad_output)
|
|
|
|
|
|
# -----------------
|
|
# Helper functions.
|
|
# -----------------
|
|
|
|
def copy_to_tensor_model_parallel_region(input_):
|
|
return _CopyToModelParallelRegion.apply(input_)
|
|
|
|
|
|
def reduce_from_tensor_model_parallel_region(input_):
|
|
return _ReduceFromModelParallelRegion.apply(input_)
|
|
|
|
|
|
def scatter_to_tensor_model_parallel_region(input_):
|
|
return _ScatterToModelParallelRegion.apply(input_)
|
|
|
|
|
|
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
|
|
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
|
|
|
|
|
|
def gather_from_tensor_model_parallel_region(input_):
|
|
return _GatherFromModelParallelRegion.apply(input_)
|
|
|
|
|
|
def reduce_scatter_to_sequence_parallel_region(input_):
|
|
return _ReduceScatterToSequenceParallelRegion.apply(input_)
|
|
|
|
|
|
def scatter_to_sequence_parallel_region(input_):
|
|
return _ScatterToSequenceParallelRegion.apply(input_)
|