Get some embeddings out.
This commit is contained in:
parent
54850e7525
commit
f379b8feae
|
@ -1,7 +1,9 @@
|
|||
#![allow(dead_code)]
|
||||
// The tokenizer.json and weights should be retrieved from:
|
||||
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
|
||||
|
||||
use anyhow::Error as E;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor};
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use clap::Parser;
|
||||
use std::collections::HashMap;
|
||||
|
||||
|
@ -40,7 +42,7 @@ impl<'a> VarBuilder<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
|
||||
let s: Shape = s.into();
|
||||
match &self.safetensors {
|
||||
None => Tensor::zeros(s, self.dtype, &self.device),
|
||||
|
@ -71,7 +73,7 @@ enum HiddenAct {
|
|||
}
|
||||
|
||||
impl HiddenAct {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
|
@ -164,7 +166,8 @@ impl Embedding {
|
|||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
Tensor::embedding(indexes, &self.embeddings)
|
||||
let values = Tensor::embedding(indexes, &self.embeddings)?;
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -281,11 +284,15 @@ impl BertEmbeddings {
|
|||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let seq_len = input_ids.shape().r1()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||
let mut embeddings = (input_embeddings + token_type_embeddings)?;
|
||||
if let Some(position_embeddings) = &self.position_embeddings {
|
||||
embeddings = (&embeddings + position_embeddings.forward(&embeddings))?
|
||||
// TODO: Proper absolute positions?
|
||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||
let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?;
|
||||
embeddings = (&embeddings + position_embeddings.forward(&position_ids)?)?
|
||||
}
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
let embeddings = self.dropout.forward(&embeddings)?;
|
||||
|
@ -326,7 +333,8 @@ impl BertSelfAttention {
|
|||
new_x_shape.pop();
|
||||
new_x_shape.push(self.num_attention_heads);
|
||||
new_x_shape.push(self.attention_head_size);
|
||||
xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)
|
||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
|
@ -425,7 +433,8 @@ impl BertIntermediate {
|
|||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
self.intermediate_act.forward(&hidden_states)
|
||||
let ys = self.intermediate_act.forward(&hidden_states)?;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -534,8 +543,8 @@ impl BertModel {
|
|||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
|
||||
let embedding_output = self.embeddings.forward(input_ids, position_ids)?;
|
||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
|
||||
let sequence_output = self.encoder.forward(&embedding_output)?;
|
||||
Ok(sequence_output)
|
||||
}
|
||||
|
@ -555,7 +564,7 @@ struct Args {
|
|||
weights: String,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
|
@ -579,9 +588,9 @@ fn main() -> anyhow::Result<()> {
|
|||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], &device)?;
|
||||
let position_ids: Vec<_> = (0..tokens.len() as u32).collect();
|
||||
let position_ids = Tensor::new(&position_ids[..], &device)?.unsqueeze(0)?;
|
||||
let ys = model.forward(&token_ids, &position_ids)?;
|
||||
println!("{token_ids}");
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
||||
println!("{ys}");
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue