Get some embeddings out.

This commit is contained in:
laurent 2023-07-03 16:11:16 +01:00
parent 54850e7525
commit f379b8feae
1 changed files with 23 additions and 14 deletions

View File

@ -1,7 +1,9 @@
#![allow(dead_code)]
// The tokenizer.json and weights should be retrieved from:
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
use anyhow::Error as E;
use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor};
use anyhow::{Error as E, Result};
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use clap::Parser;
use std::collections::HashMap;
@ -40,7 +42,7 @@ impl<'a> VarBuilder<'a> {
}
}
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
let s: Shape = s.into();
match &self.safetensors {
None => Tensor::zeros(s, self.dtype, &self.device),
@ -71,7 +73,7 @@ enum HiddenAct {
}
impl HiddenAct {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Gelu => xs.gelu(),
Self::Relu => xs.relu(),
@ -164,7 +166,8 @@ impl Embedding {
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
Tensor::embedding(indexes, &self.embeddings)
let values = Tensor::embedding(indexes, &self.embeddings)?;
Ok(values)
}
}
@ -281,11 +284,15 @@ impl BertEmbeddings {
}
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
let seq_len = input_ids.shape().r1()?;
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
let mut embeddings = (input_embeddings + token_type_embeddings)?;
if let Some(position_embeddings) = &self.position_embeddings {
embeddings = (&embeddings + position_embeddings.forward(&embeddings))?
// TODO: Proper absolute positions?
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?;
embeddings = (&embeddings + position_embeddings.forward(&position_ids)?)?
}
let embeddings = self.layer_norm.forward(&embeddings)?;
let embeddings = self.dropout.forward(&embeddings)?;
@ -326,7 +333,8 @@ impl BertSelfAttention {
new_x_shape.pop();
new_x_shape.push(self.num_attention_heads);
new_x_shape.push(self.attention_head_size);
xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
Ok(xs)
}
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
@ -425,7 +433,8 @@ impl BertIntermediate {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
self.intermediate_act.forward(&hidden_states)
let ys = self.intermediate_act.forward(&hidden_states)?;
Ok(ys)
}
}
@ -534,8 +543,8 @@ impl BertModel {
})
}
fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
let embedding_output = self.embeddings.forward(input_ids, position_ids)?;
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
let sequence_output = self.encoder.forward(&embedding_output)?;
Ok(sequence_output)
}
@ -555,7 +564,7 @@ struct Args {
weights: String,
}
fn main() -> anyhow::Result<()> {
fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
@ -579,9 +588,9 @@ fn main() -> anyhow::Result<()> {
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?;
let position_ids: Vec<_> = (0..tokens.len() as u32).collect();
let position_ids = Tensor::new(&position_ids[..], &device)?.unsqueeze(0)?;
let ys = model.forward(&token_ids, &position_ids)?;
println!("{token_ids}");
let token_type_ids = token_ids.zeros_like()?;
let ys = model.forward(&token_ids, &token_type_ids)?;
println!("{ys}");
Ok(())
}