bert attention mask (#1934)

* bert attention mask

* Allow for using None as a mask.

* Revert part of the changes so that the proper default mask applies.

* Cosmetic change.

* Another cosmetic tweak.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Zheng Li 2024-08-01 14:26:19 +08:00 committed by GitHub
parent 24d54d0ff9
commit 4a52aeb437
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 20 deletions

View File

@ -126,7 +126,7 @@ fn main() -> Result<()> {
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids, &token_type_ids)?;
let ys = model.forward(&token_ids, &token_type_ids, None)?;
if idx == 0 {
println!("{ys}");
}
@ -163,11 +163,19 @@ fn main() -> Result<()> {
Ok(Tensor::new(tokens.as_slice(), device)?)
})
.collect::<Result<Vec<_>>>()?;
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Ok(Tensor::new(tokens.as_slice(), device)?)
})
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;
let token_type_ids = token_ids.zeros_like()?;
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids, &token_type_ids)?;
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;

View File

@ -230,10 +230,8 @@ impl BertSelfAttention {
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
xs.contiguous()
}
}
impl Module for BertSelfAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let query_layer = self.query.forward(hidden_states)?;
let key_layer = self.key.forward(hidden_states)?;
@ -245,6 +243,7 @@ impl Module for BertSelfAttention {
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_scores = attention_scores.broadcast_add(attention_mask)?;
let attention_probs = {
let _enter_sm = self.span_softmax.enter();
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
@ -307,12 +306,10 @@ impl BertAttention {
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
}
impl Module for BertAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let self_outputs = self.self_attention.forward(hidden_states)?;
let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
Ok(attention_output)
}
@ -398,12 +395,10 @@ impl BertLayer {
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
}
impl Module for BertLayer {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let attention_output = self.attention.forward(hidden_states)?;
let attention_output = self.attention.forward(hidden_states, attention_mask)?;
// TODO: Support cross-attention?
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
// TODO: Support something similar to `apply_chunking_to_forward`?
@ -429,15 +424,13 @@ impl BertEncoder {
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span })
}
}
impl Module for BertEncoder {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
for layer in self.layers.iter() {
hidden_states = layer.forward(&hidden_states)?
hidden_states = layer.forward(&hidden_states, attention_mask)?
}
Ok(hidden_states)
}
@ -481,10 +474,32 @@ impl BertModel {
})
}
pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
pub fn forward(
&self,
input_ids: &Tensor,
token_type_ids: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
let sequence_output = self.encoder.forward(&embedding_output)?;
let attention_mask = match attention_mask {
Some(attention_mask) => attention_mask.clone(),
None => input_ids.ones_like()?,
};
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
Ok(sequence_output)
}
}
fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
let attention_mask = match attention_mask.rank() {
3 => attention_mask.unsqueeze(1)?,
2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
_ => candle::bail!("Wrong shape for input_ids or attention_mask"),
};
let attention_mask = attention_mask.to_dtype(dtype)?;
// torch.finfo(dtype).min
(attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?)
}

View File

@ -55,11 +55,21 @@ impl Model {
Tensor::new(tokens.as_slice(), device)
})
.collect::<Result<Vec<_>, _>>()?;
let attention_mask: Vec<Tensor> = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<Result<Vec<_>, _>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;
let token_type_ids = token_ids.zeros_like()?;
console_log!("running inference on batch {:?}", token_ids.shape());
let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
let embeddings = self
.bert
.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
console_log!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;