ModelLink/tasks/inference/text_generation/utils.py

444 lines
17 KiB
Python

# coding=utf-8
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating text."""
import time
import math
import torch
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_tokenizer
from megatron.core import parallel_state
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.core.pipeline_parallel.p2p_communication import recv_forward, send_forward
from megatron.core.distributed import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.core.utils import get_model_config
from tasks.inference.text_generation.communication import broadcast_tensor
from tasks.finetune.lora.utils import is_enable_lora, get_lora_model_classes
def get_batch(context_tokens):
"""Generate batch from context tokens."""
tokenizer = get_tokenizer()
# Move to GPU.
tokens = context_tokens.contiguous().to(torch.cuda.current_device())
# Get the attention mask and position ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.pad_token_id,
False, False, False)
return tokens, attention_mask, position_ids
def pad_batch(batch, args):
max_context_length = torch.LongTensor([max(len(val) for val in batch)]).cuda()
micro_batch_size = torch.LongTensor([args.micro_batch_size]).cuda()
torch.distributed.all_reduce(max_context_length, op=torch.distributed.ReduceOp.MAX)
torch.distributed.all_reduce(micro_batch_size, op=torch.distributed.ReduceOp.MAX)
tokenizer = get_tokenizer()
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
context_lengths = [len(val) for val in batch]
if args.text_generation_config['max_new_tokens'] > 0:
max_length = max_context_length[0].item() + args.text_generation_config['max_new_tokens']
else:
max_length = args.text_generation_config['max_length']
# set fused_operator_contiguous_num = 64
max_length_padded = math.ceil(max_length / 64) * 64
for i, tokens in enumerate(batch):
if context_lengths[i] < max_length_padded:
tokens.extend([pad_id] * (max_length_padded - context_lengths[i]))
context_tokens_tensor = torch.LongTensor(batch).cuda()
context_length_tensor = torch.LongTensor(context_lengths).cuda()
context_tokens_tensor = broadcast_tensor(
(micro_batch_size[0].item(), max_length_padded),
torch.int64,
tensor=context_tokens_tensor,
rank=args.master_rank
)
context_length_tensor = broadcast_tensor((
micro_batch_size[0].item()),
torch.int64,
tensor=context_length_tensor,
rank=args.master_rank
)
args.seq_length = context_tokens_tensor.shape[1]
args.max_position_embeddings = args.seq_length
args.max_length_ori = max_length
return context_tokens_tensor, context_length_tensor
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
"""
This function has been mostly taken from huggingface conversational ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer
-learning-2d818ac26313
"""
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(
logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] \
= sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def greedy_search_or_sampling(model, context_tokens, model_latencies=None, single_token_latency=None):
args = get_args()
model_latencies = [] if model_latencies is None else model_latencies
context_tokens_tensor, context_length_tensor = pad_batch(context_tokens, args)
context_length = context_length_tensor.min().item()
batch_token_iterator = sample_sequence_batch(
model,
context_tokens_tensor,
context_length_tensor,
model_latencies=model_latencies
)
yield from _post_process(
batch_token_iterator,
context_length,
context_length_tensor
)
def _post_process(batch_token_iterator, context_length, context_lengths):
for tokens, _, log_probs in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], context_lengths.cpu().numpy().tolist(), log_probs
else:
yield None, None, None
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
def forward_step(model, tokens, **kwargs):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
position_ids = kwargs.pop("position_ids")
attention_mask = kwargs.pop("attention_mask")
tokentype_ids = kwargs.pop("tokentype_ids")
model_latencies = kwargs.pop("model_latencies", None)
model_latencies = [] if model_latencies is None else model_latencies
torch.cuda.synchronize()
t0 = time.time()
args = get_args()
orig_seq_length = args.seq_length
args.micro_batch_size = tokens.shape[0]
config = get_model_config(model)
tensor_shapes = [args.seq_length, args.micro_batch_size, args.hidden_size]
input_tensor = recv_forward(tensor_shapes, config)
_unwrap_and_set_input_tensor(args, input_tensor, model)
output_tensor = model(
input_ids=tokens,
position_ids=position_ids,
attention_mask=attention_mask,
tokentype_ids=tokentype_ids
)
output_tensor = _check_forward_output(output_tensor)
send_forward(output_tensor, config)
args.seq_length = orig_seq_length
torch.cuda.synchronize()
model_latencies.append(time.time() - t0)
return output_tensor
def _check_forward_output(output_tensor):
if isinstance(output_tensor, (list, tuple)):
if output_tensor[0] is not None:
output_tensor = output_tensor[0]
else:
raise ValueError("Please make sure that the output of the model is 'Tensor' or '[Tensor, ...]'")
return output_tensor
def _unwrap_and_set_input_tensor(args, input_tensor, model):
# Forward pass through the model.
unwrap_classes = (torchDDP, LocalDDP, Float16Module)
if is_enable_lora():
unwrap_classes += get_lora_model_classes()
unwrapped_model = unwrap_model(model, unwrap_classes)
if hasattr(unwrapped_model, 'set_input_tensor'):
unwrapped_model.set_input_tensor(input_tensor)
def sample_sequence_batch(model, context_tokens, context_lengths, type_ids=None, model_latencies=None):
model_latencies = [] if model_latencies is None else model_latencies
args = get_args()
tokenizer = get_tokenizer()
tokens, attention_mask, position_ids = get_batch(context_tokens)
model.eval()
with torch.no_grad():
counter = 0
layer_past = None
batch_size = tokens.size(0)
max_length = args.max_length_ori
context_length = context_lengths.min().item()
is_done = torch.zeros([batch_size]).byte().to(torch.cuda.current_device())
while context_length < max_length:
if args.text_generation_config['recompute']:
logits = _recompute_forward(model,
attention_mask=attention_mask,
context_length=context_length,
position_ids=position_ids,
tokens=tokens,
type_ids=type_ids)
else:
logits = _disable_recompute_forward(model,
attention_mask=attention_mask,
batch_size=batch_size,
context_length=context_length,
counter=counter,
layer_past=layer_past,
model_latencies=model_latencies,
position_ids=position_ids,
tokens=tokens,
type_ids=type_ids)
group, next_log_probs, prev, src, started, vocab_size = _sample_and_synchronize(
args, batch_size, (context_length, context_lengths), logits, tokens
)
if counter == 0:
log_probs_seq = _init_log_probs(args, batch_size, max_length, vocab_size)
output_log_probs = _get_log_probs(args, context_length, log_probs_seq, next_log_probs, (group, src))
done = _is_done(is_done, prev, started, tokenizer)
yield tokens, max_length, output_log_probs
context_length += 1
counter += 1
if done:
break
def _is_done(is_done, prev, started, tokenizer):
if parallel_state.is_pipeline_last_stage():
done_token = (prev == tokenizer.eos_token_id).byte() & started.byte()
is_done = is_done | done_token
done = torch.all(is_done)
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
else:
done = torch.ByteTensor([0]).cuda()
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
return done
def _get_log_probs(args, context_length, log_probs_seq, next_log_probs, comm_info):
group, src = comm_info
output_log_probs = None
if log_probs_seq is None:
return output_log_probs
if args.text_generation_config['return_output_log_probs']:
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
torch.distributed.broadcast(next_log_probs, src, group)
log_probs_seq[:, context_length, :] = next_log_probs
output_log_probs = log_probs_seq[:, :context_length + 1, :]
return output_log_probs
def _init_log_probs(args, batch_size, max_length, vocab_size):
log_probs_seq = None
if args.text_generation_config['return_output_log_probs']:
log_probs_seq = torch.zeros(
(batch_size, max_length, int(vocab_size))
).to(torch.cuda.current_device())
return log_probs_seq
def _sample_and_synchronize(args, batch_size, context_length_info, logits, tokens):
log_probs = None
prev = None
started = None
context_length, context_lengths = context_length_info
if parallel_state.is_pipeline_last_stage():
vocab_size = torch.Tensor([logits.size(1)]).to(torch.cuda.current_device())
log_probs = F.softmax(logits, dim=-1)
logits, prev = _sample_strategy(args, logits)
started = context_lengths <= context_length
new_tokens = switch(
tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_pipeline_model_parallel_group()
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
torch.distributed.broadcast(new_tokens, src, group)
torch.distributed.broadcast(vocab_size, src, group)
else:
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_pipeline_model_parallel_group()
new_tokens = torch.empty_like(tokens[:, context_length])
vocab_size = torch.empty_like(torch.Tensor([0])).to(torch.cuda.current_device())
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
torch.distributed.broadcast(new_tokens, src, group)
torch.distributed.broadcast(vocab_size, src, group)
tokens[:, context_length] = new_tokens
if log_probs is None:
log_probs = torch.empty([batch_size, int(vocab_size)],
dtype=torch.float32,
device=torch.cuda.current_device())
res = (group, log_probs, prev, src, started, vocab_size)
return res
def _sample_strategy(args, logits):
if args.text_generation_config['greedy']:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
logits /= args.text_generation_config["temperature"]
logits = top_k_logits(logits,
top_k=args.text_generation_config["top_k"],
top_p=args.text_generation_config["top_p"])
logits = F.softmax(logits, dim=-1)
prev = torch.multinomial(logits, num_samples=1).view(-1)
return logits, prev
def _disable_recompute_forward(model, **kwargs):
attention_mask = kwargs.pop("attention_mask")
batch_size = kwargs.pop("batch_size")
context_length = kwargs.pop("context_length")
counter = kwargs.pop("counter")
layer_past = kwargs.pop("layer_past")
model_latencies = kwargs.pop("model_latencies")
position_ids = kwargs.pop("position_ids")
tokens = kwargs.pop("tokens")
type_ids = kwargs.pop("type_ids")
types2use = None
logits = None
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
if type_ids is not None:
types2use = type_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
output = forward_step(model,
tokens2use,
position_ids=positions2use,
attention_mask=attention_mask,
tokentype_ids=types2use,
model_latencies=model_latencies)
if parallel_state.is_pipeline_last_stage():
if output is None:
raise ValueError("In pipeline_last_stage group, the forward output should not be None")
logits = output[:, -1].view(batch_size, -1).contiguous()
return logits
def _recompute_forward(model, **kwargs):
attention_mask = kwargs.pop("attention_mask")
context_length = kwargs.pop("context_length")
position_ids = kwargs.pop("position_ids")
tokens = kwargs.pop("tokens")
type_ids = kwargs.pop("type_ids")
logits = None
output = forward_step(model,
tokens,
position_ids=position_ids,
attention_mask=attention_mask,
tokentype_ids=type_ids)
if parallel_state.is_pipeline_last_stage():
if output is None:
raise ValueError("In pipeline_last_stage group, the forward output should not be None")
logits = output[:, context_length - 1, :]
return logits