bert attention mask (#1934)
* bert attention mask * Allow for using None as a mask. * Revert part of the changes so that the proper default mask applies. * Cosmetic change. * Another cosmetic tweak. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
parent
24d54d0ff9
commit
4a52aeb437
|
@ -126,7 +126,7 @@ fn main() -> Result<()> {
|
|||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
||||
let ys = model.forward(&token_ids, &token_type_ids, None)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
|
@ -163,11 +163,19 @@ fn main() -> Result<()> {
|
|||
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let attention_mask = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_attention_mask().to_vec();
|
||||
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
let attention_mask = Tensor::stack(&attention_mask, 0)?;
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
println!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
||||
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
|
|
|
@ -230,10 +230,8 @@ impl BertSelfAttention {
|
|||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||
xs.contiguous()
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertSelfAttention {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query_layer = self.query.forward(hidden_states)?;
|
||||
let key_layer = self.key.forward(hidden_states)?;
|
||||
|
@ -245,6 +243,7 @@ impl Module for BertSelfAttention {
|
|||
|
||||
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||
let attention_scores = attention_scores.broadcast_add(attention_mask)?;
|
||||
let attention_probs = {
|
||||
let _enter_sm = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
|
||||
|
@ -307,12 +306,10 @@ impl BertAttention {
|
|||
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertAttention {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let self_outputs = self.self_attention.forward(hidden_states)?;
|
||||
let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
|
||||
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
|
||||
Ok(attention_output)
|
||||
}
|
||||
|
@ -398,12 +395,10 @@ impl BertLayer {
|
|||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertLayer {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let attention_output = self.attention.forward(hidden_states)?;
|
||||
let attention_output = self.attention.forward(hidden_states, attention_mask)?;
|
||||
// TODO: Support cross-attention?
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||
// TODO: Support something similar to `apply_chunking_to_forward`?
|
||||
|
@ -429,15 +424,13 @@ impl BertEncoder {
|
|||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
Ok(BertEncoder { layers, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertEncoder {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||
for layer in self.layers.iter() {
|
||||
hidden_states = layer.forward(&hidden_states)?
|
||||
hidden_states = layer.forward(&hidden_states, attention_mask)?
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
|
@ -481,10 +474,32 @@ impl BertModel {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
token_type_ids: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
|
||||
let sequence_output = self.encoder.forward(&embedding_output)?;
|
||||
let attention_mask = match attention_mask {
|
||||
Some(attention_mask) => attention_mask.clone(),
|
||||
None => input_ids.ones_like()?,
|
||||
};
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
||||
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||
Ok(sequence_output)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
|
||||
let attention_mask = match attention_mask.rank() {
|
||||
3 => attention_mask.unsqueeze(1)?,
|
||||
2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
|
||||
_ => candle::bail!("Wrong shape for input_ids or attention_mask"),
|
||||
};
|
||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||
// torch.finfo(dtype).min
|
||||
(attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?)
|
||||
}
|
||||
|
|
|
@ -55,11 +55,21 @@ impl Model {
|
|||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let attention_mask: Vec<Tensor> = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_attention_mask().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
let attention_mask = Tensor::stack(&attention_mask, 0)?;
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
console_log!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
|
||||
let embeddings = self
|
||||
.bert
|
||||
.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
|
||||
console_log!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
|
|
Loading…
Reference in New Issue