ModelLink/tasks/inference/text_generation/generation.py

321 lines
15 KiB
Python

# coding=utf-8
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generation utilities."""
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.core import parallel_state
from tasks.inference.text_generation.utils import pad_batch, top_k_logits
from .forward_step import ForwardStep
from .beam_utils import BeamHypotheses
from .communication import broadcast_from_last_pipeline_stage
def beam_search(model, tokens, **kwargs):
beam_size = kwargs.pop("beam_size", 1)
stop_token = kwargs.pop("stop_token", 0)
num_return_gen = kwargs.pop("num_return_gen", 1)
length_penalty = kwargs.pop("length_penalty", 1.0)
args = get_args()
if args.micro_batch_size > 1:
raise NotImplementedError("The input prompt nums should not greater than 1 "
"(i.e. micro_batch_size must be 1) in beam search mode.")
# ==========================
# Pad tokens
# ==========================
prompt_length, context_lengths, tokens = _pad_tokens(args, tokens, beam_size, num_return_gen)
final_sequence_length = args.max_length_ori
# ==========================
# Forward step
# ==========================
forward_step = ForwardStep(model, beam_size, final_sequence_length)
# ==========================
# Build BeamHypotheses
# ==========================
beam_hyp = BeamHypotheses(beam_size, length_penalty)
done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device())
scores = torch.zeros(beam_size, dtype=torch.float32, device=torch.cuda.current_device()).unsqueeze(1)
scores_size_tensor, tokens_size_tensor = None, None
output_scores, output_tokens = None, None
# ==========================
# Run inference
# ==========================
with torch.no_grad():
tokens = tokens.repeat(beam_size, 1)
batch_size, seq_length = tokens.size()
attention_mask = torch.tril(torch.ones(
(args.micro_batch_size, seq_length, seq_length), device=tokens.device)).view(
args.micro_batch_size, 1, seq_length, seq_length)
attention_mask = (attention_mask < 0.5)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=tokens.device)
position_ids = position_ids.unsqueeze(0).expand_as(tokens)
context_length, done, scores, tokens = yield from forward_loop(args,
attention_mask=attention_mask,
beam_hyp=beam_hyp,
beam_size=beam_size,
done=done,
final_sequence_length=final_sequence_length,
forward_step=forward_step,
num_return_gen=num_return_gen,
position_ids=position_ids,
prompt_length=prompt_length,
context_lengths=context_lengths,
scores=scores,
stop_token=stop_token,
tokens=tokens)
output_scores, output_tokens = _beam_search_post_process(beam_hyp=beam_hyp,
beam_size=beam_size,
done=done,
num_return_gen=num_return_gen,
output_scores=output_scores,
output_tokens=output_tokens,
context_length=context_length,
prompt_length=prompt_length,
scores=scores,
scores_size_tensor=scores_size_tensor,
tokens=tokens,
tokens_size_tensor=tokens_size_tensor)
yield output_tokens, context_lengths, torch.exp(output_scores)
def forward_loop(args, **kwargs):
attention_mask = kwargs.pop("attention_mask")
beam_hyp = kwargs.pop("beam_hyp")
beam_size = kwargs.pop("beam_size")
done = kwargs.pop("done")
final_sequence_length = kwargs.pop("final_sequence_length")
forward_step = kwargs.pop("forward_step")
num_return_gen = kwargs.pop("num_return_gen")
position_ids = kwargs.pop("position_ids")
prompt_length = kwargs.pop("prompt_length")
context_lengths = kwargs.pop("context_lengths")
scores = kwargs.pop("scores")
stop_token = kwargs.pop("stop_token")
tokens = kwargs.pop("tokens")
context_length = None
for context_length in range(prompt_length, final_sequence_length):
logits = forward_step(tokens, position_ids, attention_mask)
if parallel_state.is_pipeline_last_stage():
vocab_size = logits.size(2)
best_beam_ids, best_scores, best_words = _beam_candidates_with_sampling(args,
beam_size=beam_size,
context_length=context_length,
logits=logits,
prompt_length=prompt_length,
scores=scores,
stop_token=stop_token,
vocab_size=vocab_size)
done, scores, tokens = _beam_search_process(beam_hyp=beam_hyp,
beam_size=beam_size,
best_beam_ids=best_beam_ids,
best_scores=best_scores,
best_words=best_words,
context_length=context_length,
done=done,
prompt_length=prompt_length,
scores=scores,
stop_token=stop_token,
tokens=tokens)
done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
if done:
break
tokens = broadcast_from_last_pipeline_stage(tokens.size(), torch.int64, tokens)
yield tokens[:num_return_gen], context_lengths, torch.exp(scores[:num_return_gen])
output_info = (context_length, done, scores, tokens)
return output_info
def _beam_search_post_process(**kwargs):
beam_hyp = kwargs.pop("beam_hyp")
beam_size = kwargs.pop("beam_size")
context_length = kwargs.pop("context_length")
done = kwargs.pop("done")
num_return_gen = kwargs.pop("num_return_gen")
output_scores = kwargs.pop("output_scores")
output_tokens = kwargs.pop("output_tokens")
prompt_length = kwargs.pop("prompt_length")
scores = kwargs.pop("scores")
scores_size_tensor = kwargs.pop("scores_size_tensor")
tokens = kwargs.pop("tokens")
tokens_size_tensor = kwargs.pop("tokens_size_tensor")
if parallel_state.is_pipeline_last_stage():
if not done:
for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(),
scores[beam_id].squeeze(),
context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
num_return_gen = min(num_return_gen, len(sorted_hyps))
output_scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
output_tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
output_scores = torch.stack(output_scores, dim=0)
output_tokens = torch.stack(output_tokens, dim=0)
scores_size_tensor = torch.tensor(output_scores.shape,
dtype=torch.int64,
device=torch.cuda.current_device())
tokens_size_tensor = torch.tensor(output_tokens.shape,
dtype=torch.int64,
device=torch.cuda.current_device())
scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor)
tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor)
output_scores = broadcast_from_last_pipeline_stage(tuple(scores_size_tensor),
torch.float32,
output_scores)
output_tokens = broadcast_from_last_pipeline_stage(tuple(tokens_size_tensor),
torch.int64,
output_tokens)
return output_scores, output_tokens
def _beam_search_process(**kwargs):
beam_hyp = kwargs.pop("beam_hyp")
beam_size = kwargs.pop("beam_size")
best_beam_ids = kwargs.pop("best_beam_ids")
best_scores = kwargs.pop("best_scores")
best_words = kwargs.pop("best_words")
context_length = kwargs.pop("context_length")
done = kwargs.pop("done")
prompt_length = kwargs.pop("prompt_length")
scores = kwargs.pop("scores")
stop_token = kwargs.pop("stop_token")
tokens = kwargs.pop("tokens")
next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
zip(best_words, best_scores, best_beam_ids)
):
if token_id.item() == stop_token:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
tokens[beam_id].clone(),
beam_score,
context_length + 1 - prompt_length
)
else:
# add next predicted token since it is not eos_token
next_beams.append((token_id, beam_score, beam_id))
if len(next_beams) == beam_size:
break
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches, :]
tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
return done, scores, tokens
def _beam_candidates_with_sampling(args, **kwargs):
beam_size = kwargs.pop("beam_size")
context_length = kwargs.pop("context_length")
logits = kwargs.pop("logits")
prompt_length = kwargs.pop("prompt_length")
scores = kwargs.pop("scores")
vocab_size = kwargs.pop("vocab_size")
stop_token = kwargs.pop("stop_token")
try:
logits = logits[:, context_length - 1, :] / args.text_generation_config["temperature"]
except ZeroDivisionError:
logits = logits[:, context_length - 1, :] * 10000
if args.text_generation_config["top_k"] > 1 and (0.0 < args.text_generation_config["top_p"] <= 1.0):
logits = top_k_logits(logits,
top_k=args.text_generation_config["top_k"],
top_p=args.text_generation_config["top_p"])
log_probs = F.log_softmax(logits, dim=1)
new_scores = log_probs + scores
if context_length == prompt_length:
indices, sorted_scores = _beam_candidates_at_beginning(args, beam_size, new_scores)
else:
indices, sorted_scores = _beam_candidates_at_later(args, beam_size, new_scores)
best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[: 2 * beam_size]
return best_beam_ids, best_scores, best_words
def _beam_candidates_at_later(args, beam_size, new_scores):
if args.text_generation_config['greedy']:
sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)
else:
accumulate_logits = torch.exp(new_scores)
accumulate_logits_sum = accumulate_logits.sum()
if accumulate_logits_sum > 1e-5 and accumulate_logits_sum < 1.0:
indices = torch.multinomial(accumulate_logits.view(-1), num_samples=2 * beam_size)
sorted_scores = torch.gather(new_scores.view(-1), dim=0, index=indices)
else:
sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)
return indices, sorted_scores
def _beam_candidates_at_beginning(args, beam_size, new_scores):
if args.text_generation_config['greedy']:
sorted_scores, indices = torch.sort(new_scores[0, :], descending=True)
else:
accumulate_logits = torch.exp(new_scores[0, :])
accumulate_logits_sum = accumulate_logits.sum()
if accumulate_logits_sum > 1e-5 and accumulate_logits_sum < 1.0:
indices = torch.multinomial(accumulate_logits, num_samples=2 * beam_size)
sorted_scores = torch.gather(new_scores[0, :], dim=0, index=indices)
else:
sorted_scores, indices = torch.sort(new_scores[0, :], descending=True)
return indices, sorted_scores
def _pad_tokens(args, tokens, beam_size, num_return_gen):
tokens, lengths = pad_batch(tokens, args)
prompt_length = lengths.min().item()
lengths = lengths.repeat(min(beam_size, num_return_gen)).cpu().numpy().tolist()
return prompt_length, lengths, tokens