bert attention mask (#1934)
* bert attention mask * Allow for using None as a mask. * Revert part of the changes so that the proper default mask applies. * Cosmetic change. * Another cosmetic tweak. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
parent
24d54d0ff9
commit
4a52aeb437
|
@ -126,7 +126,7 @@ fn main() -> Result<()> {
|
||||||
println!("Loaded and encoded {:?}", start.elapsed());
|
println!("Loaded and encoded {:?}", start.elapsed());
|
||||||
for idx in 0..args.n {
|
for idx in 0..args.n {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
let ys = model.forward(&token_ids, &token_type_ids, None)?;
|
||||||
if idx == 0 {
|
if idx == 0 {
|
||||||
println!("{ys}");
|
println!("{ys}");
|
||||||
}
|
}
|
||||||
|
@ -163,11 +163,19 @@ fn main() -> Result<()> {
|
||||||
Ok(Tensor::new(tokens.as_slice(), device)?)
|
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let attention_mask = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tokens| {
|
||||||
|
let tokens = tokens.get_attention_mask().to_vec();
|
||||||
|
Ok(Tensor::new(tokens.as_slice(), device)?)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||||
|
let attention_mask = Tensor::stack(&attention_mask, 0)?;
|
||||||
let token_type_ids = token_ids.zeros_like()?;
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
println!("running inference on batch {:?}", token_ids.shape());
|
println!("running inference on batch {:?}", token_ids.shape());
|
||||||
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
|
||||||
println!("generated embeddings {:?}", embeddings.shape());
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||||
|
|
|
@ -230,10 +230,8 @@ impl BertSelfAttention {
|
||||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||||
xs.contiguous()
|
xs.contiguous()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for BertSelfAttention {
|
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let query_layer = self.query.forward(hidden_states)?;
|
let query_layer = self.query.forward(hidden_states)?;
|
||||||
let key_layer = self.key.forward(hidden_states)?;
|
let key_layer = self.key.forward(hidden_states)?;
|
||||||
|
@ -245,6 +243,7 @@ impl Module for BertSelfAttention {
|
||||||
|
|
||||||
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
||||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||||
|
let attention_scores = attention_scores.broadcast_add(attention_mask)?;
|
||||||
let attention_probs = {
|
let attention_probs = {
|
||||||
let _enter_sm = self.span_softmax.enter();
|
let _enter_sm = self.span_softmax.enter();
|
||||||
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
|
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
|
||||||
|
@ -307,12 +306,10 @@ impl BertAttention {
|
||||||
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for BertAttention {
|
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let self_outputs = self.self_attention.forward(hidden_states)?;
|
let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
|
||||||
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
|
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
|
||||||
Ok(attention_output)
|
Ok(attention_output)
|
||||||
}
|
}
|
||||||
|
@ -398,12 +395,10 @@ impl BertLayer {
|
||||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for BertLayer {
|
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let attention_output = self.attention.forward(hidden_states)?;
|
let attention_output = self.attention.forward(hidden_states, attention_mask)?;
|
||||||
// TODO: Support cross-attention?
|
// TODO: Support cross-attention?
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||||
// TODO: Support something similar to `apply_chunking_to_forward`?
|
// TODO: Support something similar to `apply_chunking_to_forward`?
|
||||||
|
@ -429,15 +424,13 @@ impl BertEncoder {
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||||
Ok(BertEncoder { layers, span })
|
Ok(BertEncoder { layers, span })
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for BertEncoder {
|
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let mut hidden_states = hidden_states.clone();
|
let mut hidden_states = hidden_states.clone();
|
||||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||||
for layer in self.layers.iter() {
|
for layer in self.layers.iter() {
|
||||||
hidden_states = layer.forward(&hidden_states)?
|
hidden_states = layer.forward(&hidden_states, attention_mask)?
|
||||||
}
|
}
|
||||||
Ok(hidden_states)
|
Ok(hidden_states)
|
||||||
}
|
}
|
||||||
|
@ -481,10 +474,32 @@ impl BertModel {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
pub fn forward(
|
||||||
|
&self,
|
||||||
|
input_ids: &Tensor,
|
||||||
|
token_type_ids: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
|
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
|
||||||
let sequence_output = self.encoder.forward(&embedding_output)?;
|
let attention_mask = match attention_mask {
|
||||||
|
Some(attention_mask) => attention_mask.clone(),
|
||||||
|
None => input_ids.ones_like()?,
|
||||||
|
};
|
||||||
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||||
|
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
||||||
|
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||||
Ok(sequence_output)
|
Ok(sequence_output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
|
||||||
|
let attention_mask = match attention_mask.rank() {
|
||||||
|
3 => attention_mask.unsqueeze(1)?,
|
||||||
|
2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
|
||||||
|
_ => candle::bail!("Wrong shape for input_ids or attention_mask"),
|
||||||
|
};
|
||||||
|
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||||
|
// torch.finfo(dtype).min
|
||||||
|
(attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?)
|
||||||
|
}
|
||||||
|
|
|
@ -55,11 +55,21 @@ impl Model {
|
||||||
Tensor::new(tokens.as_slice(), device)
|
Tensor::new(tokens.as_slice(), device)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
let attention_mask: Vec<Tensor> = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tokens| {
|
||||||
|
let tokens = tokens.get_attention_mask().to_vec();
|
||||||
|
Tensor::new(tokens.as_slice(), device)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||||
|
let attention_mask = Tensor::stack(&attention_mask, 0)?;
|
||||||
let token_type_ids = token_ids.zeros_like()?;
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
console_log!("running inference on batch {:?}", token_ids.shape());
|
console_log!("running inference on batch {:?}", token_ids.shape());
|
||||||
let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
|
let embeddings = self
|
||||||
|
.bert
|
||||||
|
.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
|
||||||
console_log!("generated embeddings {:?}", embeddings.shape());
|
console_log!("generated embeddings {:?}", embeddings.shape());
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||||
|
|
Loading…
Reference in New Issue