294 lines
11 KiB
Python
294 lines
11 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig
|
|
|
|
|
|
class UniXcoder(nn.Module):
|
|
def __init__(self, model_name):
|
|
"""
|
|
Build UniXcoder.
|
|
|
|
Parameters:
|
|
|
|
* `model_name`- huggingface model card name. e.g. microsoft/unixcoder-base
|
|
"""
|
|
super(UniXcoder, self).__init__()
|
|
self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
|
self.config = RobertaConfig.from_pretrained(model_name)
|
|
self.config.is_decoder = True
|
|
self.model = RobertaModel.from_pretrained(model_name, config=self.config)
|
|
|
|
self.register_buffer(
|
|
"bias",
|
|
torch.tril(torch.ones((1024, 1024), dtype=torch.uint8)).view(1, 1024, 1024),
|
|
)
|
|
self.lm_head = nn.Linear(
|
|
self.config.hidden_size, self.config.vocab_size, bias=False
|
|
)
|
|
self.lm_head.weight = self.model.embeddings.word_embeddings.weight
|
|
self.lsm = nn.LogSoftmax(dim=-1)
|
|
|
|
self.tokenizer.add_tokens(["<mask0>"], special_tokens=True)
|
|
|
|
def tokenize(self, inputs, mode="<encoder-only>", max_length=512, padding=False):
|
|
"""
|
|
Convert string to token ids
|
|
|
|
Parameters:
|
|
|
|
* `inputs`- list of input strings.
|
|
* `max_length`- The maximum total source sequence length after tokenization.
|
|
* `padding`- whether to pad source sequence length to max_length.
|
|
* `mode`- which mode the sequence will use. i.e. <encoder-only>, <decoder-only>, <encoder-decoder>
|
|
"""
|
|
assert mode in ["<encoder-only>", "<decoder-only>", "<encoder-decoder>"]
|
|
assert max_length < 1024
|
|
|
|
tokenizer = self.tokenizer
|
|
|
|
tokens_ids = []
|
|
for x in inputs:
|
|
tokens = tokenizer.tokenize(x)
|
|
if mode == "<encoder-only>":
|
|
tokens = tokens[: max_length - 4]
|
|
tokens = (
|
|
[tokenizer.cls_token, mode, tokenizer.sep_token]
|
|
+ tokens
|
|
+ [tokenizer.sep_token]
|
|
)
|
|
elif mode == "<decoder-only>":
|
|
tokens = tokens[-(max_length - 3) :]
|
|
tokens = [tokenizer.cls_token, mode, tokenizer.sep_token] + tokens
|
|
else:
|
|
tokens = tokens[: max_length - 5]
|
|
tokens = (
|
|
[tokenizer.cls_token, mode, tokenizer.sep_token]
|
|
+ tokens
|
|
+ [tokenizer.sep_token]
|
|
)
|
|
|
|
tokens_id = tokenizer.convert_tokens_to_ids(tokens)
|
|
if padding:
|
|
tokens_id = tokens_id + [self.config.pad_token_id] * (
|
|
max_length - len(tokens_id)
|
|
)
|
|
tokens_ids.append(tokens_id)
|
|
return tokens_ids
|
|
|
|
def decode(self, source_ids):
|
|
"""Convert token ids to string"""
|
|
predictions = []
|
|
for x in source_ids:
|
|
prediction = []
|
|
for y in x:
|
|
t = y.cpu().numpy()
|
|
t = list(t)
|
|
if 0 in t:
|
|
t = t[: t.index(0)]
|
|
text = self.tokenizer.decode(t, clean_up_tokenization_spaces=False)
|
|
prediction.append(text)
|
|
predictions.append(prediction)
|
|
return predictions
|
|
|
|
def forward(self, source_ids):
|
|
"""Obtain token embeddings and sentence embeddings"""
|
|
mask = source_ids.ne(self.config.pad_token_id)
|
|
token_embeddings = self.model(
|
|
source_ids, attention_mask=mask.unsqueeze(1) * mask.unsqueeze(2)
|
|
)[0]
|
|
sentence_embeddings = (token_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(
|
|
-1
|
|
).unsqueeze(-1)
|
|
return token_embeddings, sentence_embeddings
|
|
|
|
def generate(
|
|
self, source_ids, decoder_only=True, eos_id=None, beam_size=5, max_length=64
|
|
):
|
|
"""Generate sequence given context (source_ids)"""
|
|
|
|
# Set encoder mask attention matrix: bidirectional for <encoder-decoder>, unirectional for <decoder-only>
|
|
if decoder_only:
|
|
mask = self.bias[:, : source_ids.size(-1), : source_ids.size(-1)]
|
|
else:
|
|
mask = source_ids.ne(self.config.pad_token_id)
|
|
mask = mask.unsqueeze(1) * mask.unsqueeze(2)
|
|
|
|
if eos_id is None:
|
|
eos_id = self.config.eos_token_id
|
|
|
|
device = source_ids.device
|
|
|
|
# Decoding using beam search
|
|
preds = []
|
|
zero = torch.LongTensor(1).fill_(0).to(device)
|
|
source_len = list(source_ids.ne(1).sum(-1).cpu().numpy())
|
|
length = source_ids.size(-1)
|
|
encoder_output = self.model(source_ids, attention_mask=mask)
|
|
for i in range(source_ids.shape[0]):
|
|
context = [
|
|
[x[i : i + 1, :, : source_len[i]].repeat(beam_size, 1, 1, 1) for x in y]
|
|
for y in encoder_output.past_key_values
|
|
]
|
|
beam = Beam(beam_size, eos_id, device)
|
|
input_ids = beam.getCurrentState().clone()
|
|
context_ids = source_ids[i : i + 1, : source_len[i]].repeat(beam_size, 1)
|
|
out = encoder_output.last_hidden_state[i : i + 1, : source_len[i]].repeat(
|
|
beam_size, 1, 1
|
|
)
|
|
for _ in range(max_length):
|
|
if beam.done():
|
|
break
|
|
if _ == 0:
|
|
hidden_states = out[:, -1, :]
|
|
out = self.lsm(self.lm_head(hidden_states)).data
|
|
beam.advance(out)
|
|
input_ids.data.copy_(
|
|
input_ids.data.index_select(0, beam.getCurrentOrigin())
|
|
)
|
|
input_ids = beam.getCurrentState().clone()
|
|
else:
|
|
length = context_ids.size(-1) + input_ids.size(-1)
|
|
out = self.model(
|
|
input_ids,
|
|
attention_mask=self.bias[
|
|
:, context_ids.size(-1) : length, :length
|
|
],
|
|
past_key_values=context,
|
|
).last_hidden_state
|
|
hidden_states = out[:, -1, :]
|
|
out = self.lsm(self.lm_head(hidden_states)).data
|
|
beam.advance(out)
|
|
input_ids.data.copy_(
|
|
input_ids.data.index_select(0, beam.getCurrentOrigin())
|
|
)
|
|
input_ids = torch.cat(
|
|
(input_ids, beam.getCurrentState().clone()), -1
|
|
)
|
|
hyp = beam.getHyp(beam.getFinal())
|
|
pred = beam.buildTargetTokens(hyp)[:beam_size]
|
|
pred = [
|
|
torch.cat(
|
|
[x.view(-1) for x in p] + [zero] * (max_length - len(p))
|
|
).view(1, -1)
|
|
for p in pred
|
|
]
|
|
preds.append(torch.cat(pred, 0).unsqueeze(0))
|
|
|
|
preds = torch.cat(preds, 0)
|
|
|
|
return preds
|
|
|
|
|
|
class Beam(object):
|
|
def __init__(self, size, eos, device):
|
|
self.size = size
|
|
self.device = device
|
|
# The score for each translation on the beam.
|
|
self.scores = torch.FloatTensor(size).zero_().to(device)
|
|
# The backpointers at each time-step.
|
|
self.prevKs = []
|
|
# The outputs at each time-step.
|
|
self.nextYs = [torch.LongTensor(size).fill_(0).to(device)]
|
|
# Has EOS topped the beam yet.
|
|
self._eos = eos
|
|
self.eosTop = False
|
|
# Time and k pair for finished.
|
|
self.finished = []
|
|
|
|
def getCurrentState(self):
|
|
"Get the outputs for the current timestep."
|
|
batch = self.nextYs[-1].view(-1, 1)
|
|
return batch
|
|
|
|
def getCurrentOrigin(self):
|
|
"Get the backpointers for the current timestep."
|
|
return self.prevKs[-1]
|
|
|
|
def advance(self, wordLk):
|
|
"""
|
|
Given prob over words for every last beam `wordLk` and attention
|
|
`attnOut`: Compute and update the beam search.
|
|
|
|
Parameters:
|
|
|
|
* `wordLk`- probs of advancing from the last step (K x words)
|
|
* `attnOut`- attention at the last step
|
|
|
|
Returns: True if beam search is complete.
|
|
"""
|
|
numWords = wordLk.size(1)
|
|
|
|
# Sum the previous scores.
|
|
if len(self.prevKs) > 0:
|
|
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
|
|
|
|
# Don't let EOS have children.
|
|
for i in range(self.nextYs[-1].size(0)):
|
|
if self.nextYs[-1][i] == self._eos:
|
|
beamLk[i] = -1e20
|
|
else:
|
|
beamLk = wordLk[0]
|
|
flatBeamLk = beamLk.view(-1)
|
|
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
|
|
|
|
self.scores = bestScores
|
|
|
|
# bestScoresId is flattened beam x word array, so calculate which
|
|
# word and beam each score came from
|
|
prevK = torch.div(bestScoresId, numWords, rounding_mode="floor")
|
|
self.prevKs.append(prevK)
|
|
self.nextYs.append((bestScoresId - prevK * numWords))
|
|
|
|
for i in range(self.nextYs[-1].size(0)):
|
|
if self.nextYs[-1][i] == self._eos:
|
|
s = self.scores[i]
|
|
self.finished.append((s, len(self.nextYs) - 1, i))
|
|
|
|
# End condition is when top-of-beam is EOS and no global score.
|
|
if self.nextYs[-1][0] == self._eos:
|
|
self.eosTop = True
|
|
|
|
def done(self):
|
|
return self.eosTop and len(self.finished) >= self.size
|
|
|
|
def getFinal(self):
|
|
if len(self.finished) == 0:
|
|
self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
|
|
self.finished.sort(key=lambda a: -a[0])
|
|
if len(self.finished) != self.size:
|
|
unfinished = []
|
|
for i in range(self.nextYs[-1].size(0)):
|
|
if self.nextYs[-1][i] != self._eos:
|
|
s = self.scores[i]
|
|
unfinished.append((s, len(self.nextYs) - 1, i))
|
|
unfinished.sort(key=lambda a: -a[0])
|
|
self.finished += unfinished[: self.size - len(self.finished)]
|
|
return self.finished[: self.size]
|
|
|
|
def getHyp(self, beam_res):
|
|
"""
|
|
Walk back to construct the full hypothesis.
|
|
"""
|
|
hyps = []
|
|
for _, timestep, k in beam_res:
|
|
hyp = []
|
|
for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
|
|
hyp.append(self.nextYs[j + 1][k])
|
|
k = self.prevKs[j][k]
|
|
hyps.append(hyp[::-1])
|
|
return hyps
|
|
|
|
def buildTargetTokens(self, preds):
|
|
sentence = []
|
|
for pred in preds:
|
|
tokens = []
|
|
for tok in pred:
|
|
if tok == self._eos:
|
|
break
|
|
tokens.append(tok)
|
|
sentence.append(tokens)
|
|
return sentence
|