Adds support for stella_en_v5 embedding model -400M variant (#2608)

* Adds support for stella_en_v5 embedding model -400M variant

* Unified stella

* WIP: Unified Stella

* Combined stella for both 1.5B and 400M variants

* Cargo fmt for the CI

* removed redundant stella-400m model and example after merge into stella-en-v5

* cargo fmt --all

---------

Co-authored-by: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
iskng 2024-11-29 00:01:08 -08:00 committed by GitHub
parent 54e7fc3c97
commit 4f59ed38b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 556 additions and 113 deletions

View File

@ -21,7 +21,7 @@ Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling
The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example.
```bash
$ cargo run --example stella-en-v5 --release --features <metal | cuda>
$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 1.5b
>
> Score: 0.8178786
@ -37,9 +37,29 @@ $ cargo run --example stella-en-v5 --release --features <metal | cuda>
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types >
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
>
$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 400m
>
> Score: 0.8397539
> Query: What are some ways to reduce stress?
> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending
> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent
> stress from building up.
>
>
>
> Score: 0.809545
> Query: What are the benefits of drinking green tea?
> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
>
```
## Supported options:
- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
- `Stella_en_v5` has 2 model variants published - a 1.5B variant and 400M variant. This is enabled through the flag `--which`. E.g. `--which 400m` or `--which 1.5b`.
- `Stella_en_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option.

View File

@ -212,6 +212,14 @@ impl EncodeTask {
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "1.5b")]
Large,
#[value(name = "400m")]
Small,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -219,6 +227,9 @@ struct Args {
#[arg(long)]
cpu: bool,
#[arg(long)]
which: Which,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -250,24 +261,33 @@ struct Args {
// Tokenizer creation is super critical in our case.
// We are going to be `padding: Left` for each batch
fn create_tokenizer(tokenizer_file: &Path) -> Result<Tokenizer> {
fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result<Tokenizer> {
let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
pad_id
} else {
return Err(anyhow!(
"Tokenizer doesn't contain expected `<|endoftext|>` token"
));
};
// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_id,
pad_token: "<|endoftext|>".to_string(),
..Default::default()
}));
if which == Which::Large {
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
pad_id
} else {
return Err(anyhow!(
"Tokenizer doesn't contain expected `<|endoftext|>` token"
));
};
// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_id,
pad_token: "<|endoftext|>".to_string(),
..Default::default()
}));
} else {
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Right,
..Default::default()
}));
}
Ok(tokenizer)
}
@ -298,7 +318,19 @@ fn main() -> Result<()> {
Some(d) => d,
None => EmbedDim::Dim1024,
};
let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string()));
let (repo, cfg) = match args.which {
Which::Large => (
"dunzhang/stella_en_1.5B_v5",
Config::new_1_5_b_v5(embed_dim.embed_dim()),
),
Which::Small => (
"dunzhang/stella_en_400M_v5",
Config::new_400_m_v5(embed_dim.embed_dim()),
),
};
let repo = api.repo(Repo::model(repo.to_string()));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
@ -330,7 +362,7 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed());
// Initializing the tokenizer which would require us to add padding to the `left` for batch encoding
let tokenizer = create_tokenizer(tokenizer_filename.as_path())?;
let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?;
let start = std::time::Instant::now();
@ -343,11 +375,7 @@ fn main() -> Result<()> {
let embed_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? };
let model = EmbeddingModel::new(
&Config::new_1_5_b_v5(embed_dim.embed_dim()),
base_vb,
embed_vb,
)?;
let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -16,33 +16,49 @@
//!
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
use candle::{DType, Device, IndexOp, Module, Result, Tensor};
use candle_nn::{Activation, VarBuilder};
use candle::{DType, Device, Error, IndexOp, Module, Result, Tensor, D};
use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder};
use std::sync::Arc;
// internal representation for identifying which model is being used
#[derive(Debug, Copy, Clone, PartialEq, serde::Deserialize)]
pub enum ModelVariant {
Large, // 1.5B
Small, // 400M
}
impl Default for ModelVariant {
fn default() -> Self {
Self::Large
}
}
// Same as `qwen2` family of models with the exception being the `embed_head`
// The final `output` causal modelling head is swapped with a learned `dense` layer, `embed_head`
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub variant: ModelVariant,
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub max_position_embeddings: usize,
pub max_window_layers: usize,
pub tie_word_embeddings: bool,
pub rope_theta: f64,
pub rms_norm_eps: f64,
pub hidden_act: Activation,
pub embed_head: EmbedHead,
pub norm_eps: f64, // RMSNorm for 1.5B || LayerNorm for 400M
pub activation_fn: Activation, // Silu for 1.5B || Gelu for 400M
// Unique to 1.5B
pub num_key_value_heads: usize,
// Unique to 400M
pub type_vocab_size: usize,
pub scaling_factor: f64,
}
// Excerpt from `stella` model card:
// `Stella_en_1.5B_v5` models have been trained on [MRL](https://arxiv.org/abs/2205.13147) enabling multiple output dimensions
// Embed head represents the config for various embedding dims supported
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)]
pub struct EmbedHead {
pub in_features: usize,
pub out_features: usize,
@ -68,9 +84,9 @@ impl Default for EmbedDim {
}
impl EmbedDim {
pub fn config(&self) -> EmbedHead {
pub fn config(&self, in_features: usize) -> EmbedHead {
EmbedHead {
in_features: 1536,
in_features,
out_features: match &self {
Self::Dim256 => 256,
Self::Dim768 => 768,
@ -91,7 +107,8 @@ impl Config {
// Representing config.json at https://huggingface.co/dunzhang/stella_en_1.5B_v5/blob/main/config.json
// Removed `sliding_window` related config which is basically being carried forward from `qwen2` but not used here
Self {
hidden_act: candle_nn::Activation::Silu,
variant: ModelVariant::Large,
activation_fn: candle_nn::Activation::Silu,
vocab_size: 151646,
hidden_size: 1536,
intermediate_size: 8960,
@ -99,11 +116,30 @@ impl Config {
num_attention_heads: 12,
num_key_value_heads: 2,
max_position_embeddings: 131072,
max_window_layers: 21,
tie_word_embeddings: false,
rope_theta: 1000000.,
rms_norm_eps: 1e-06,
embed_head: embed_dim.config(),
norm_eps: 1e-06,
embed_head: embed_dim.config(1536),
..Default::default()
}
}
/// Initialize new `stella_en_400M_v5`
pub fn new_400_m_v5(embed_dim: EmbedDim) -> Self {
Self {
variant: ModelVariant::Small,
vocab_size: 30528,
hidden_size: 1024,
intermediate_size: 4096,
num_hidden_layers: 24,
num_attention_heads: 16,
max_position_embeddings: 8192,
type_vocab_size: 2,
norm_eps: 1e-12,
scaling_factor: 2.0,
rope_theta: 160000.0,
activation_fn: Activation::Gelu,
embed_head: embed_dim.config(1024),
..Default::default()
}
}
}
@ -117,27 +153,57 @@ struct RotaryEmbedding {
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings;
// Factoring in `scaling factor` for `400M` variant
let max_seq_len = if cfg.scaling_factor == 0. {
cfg.max_position_embeddings
} else {
((cfg.max_position_embeddings as f64) * cfg.scaling_factor) as usize
};
// let rot_dim = if cfg.variant == ModelVariant::Small { dim / 2 } else { dim };
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.map(|i| {
// Scaled rope_theta for 400M variant
let rope_theta = if cfg.scaling_factor == 0. {
cfg.rope_theta
} else {
cfg.rope_theta * cfg.scaling_factor
};
let mut freq = 1. / rope_theta.powf(i as f64 / dim as f64);
if cfg.scaling_factor != 0. {
freq /= cfg.scaling_factor.powf(2.0 / (dim as f64))
}
freq as f32
})
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
// Calculate position embeddings with scaled sequence length
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
// if cfg.variant == ModelVariant::Small {
// freqs = Tensor::cat(&[&freqs, &freqs], 1)?
// }
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
// TODO: re-visit this
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, 0, seq_len)?;
let sin = self.sin.narrow(0, 0, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
@ -147,8 +213,9 @@ impl RotaryEmbedding {
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
variant: ModelVariant,
gate_proj: Linear,
up_proj: Linear,
up_proj: Option<Linear>, // `up_proj` only for 1.5B variant
down_proj: Linear,
act_fn: Activation,
}
@ -157,31 +224,65 @@ impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
let (gate_proj, up_proj, down_proj) = match cfg.variant {
ModelVariant::Large => (
linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?,
Some(linear_no_bias(
hidden_sz,
intermediate_sz,
vb.pp("up_proj"),
)?),
linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?,
),
ModelVariant::Small => (
linear_no_bias(hidden_sz, intermediate_sz * 2, vb.pp("up_gate_proj"))?,
None,
linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))?,
),
};
Ok(Self {
variant: cfg.variant,
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
act_fn: cfg.activation_fn,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
let up = self.gate_proj.forward(xs)?;
let (lhs, rhs) = match self.variant {
ModelVariant::Large => {
let lhs = up.apply(&self.act_fn)?;
let rhs = xs.apply(self.up_proj.as_ref().unwrap())?;
(lhs, rhs)
}
ModelVariant::Small => {
// Get the dimensions
let (_batch_size, _seq_len, hidden_dim) = up.dims3()?;
let split_size = hidden_dim / 2;
// Split along the last dimension (hidden_dim)
let up_states = up.narrow(2, 0, split_size)?;
let gate = up.narrow(2, split_size, split_size)?.apply(&self.act_fn)?;
(up_states, gate)
}
};
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
qkv_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
@ -189,6 +290,7 @@ struct Attention {
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
variant: ModelVariant,
}
impl Attention {
@ -196,16 +298,47 @@ impl Attention {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let num_kv_groups = if num_kv_heads > 0 {
num_heads / num_kv_heads
} else {
0
};
let head_dim = hidden_sz / num_heads;
let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
let (qkv_proj, o_proj) = match cfg.variant {
ModelVariant::Large => {
// The 1.5B variant comes with separate `q, k, v` layers, let's merge it and standardize
// Weights
let q_w = vb
.pp("q_proj")
.get((num_heads * head_dim, hidden_sz), "weight")?;
let k_w = vb
.pp("k_proj")
.get((num_kv_heads * head_dim, hidden_sz), "weight")?;
let v_w = vb
.pp("v_proj")
.get((num_kv_heads * head_dim, hidden_sz), "weight")?;
// Biases
let q_b = vb.pp("q_proj").get(num_heads * head_dim, "bias")?;
let k_b = vb.pp("k_proj").get(num_kv_heads * head_dim, "bias")?;
let v_b = vb.pp("v_proj").get(num_kv_heads * head_dim, "bias")?;
let qkv_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?;
let qkv_b = Tensor::cat(&[&q_b, &k_b, &v_b], 0)?;
(
Linear::from_weights(qkv_w, Some(qkv_b)),
linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?,
)
}
ModelVariant::Small => (
linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?,
linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?,
),
};
Ok(Self {
q_proj,
k_proj,
v_proj,
qkv_proj,
o_proj,
num_heads,
num_kv_heads,
@ -213,45 +346,90 @@ impl Attention {
head_dim,
hidden_size: hidden_sz,
rotary_emb,
variant: cfg.variant,
})
}
fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let qkv = self.qkv_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let n_kv_heads = match self.variant {
ModelVariant::Large => self.num_kv_heads,
ModelVariant::Small => self.num_heads,
};
let (query_states, key_states, value_states) = match self.variant {
ModelVariant::Large => {
let q_sz = self.num_heads * self.head_dim;
let kv_sz = n_kv_heads * self.head_dim;
let q = qkv.narrow(D::Minus1, 0, q_sz)?.reshape((
b_sz,
q_len,
self.num_heads,
self.head_dim,
))?;
let k = qkv.narrow(D::Minus1, q_sz, kv_sz)?.reshape((
b_sz,
q_len,
n_kv_heads,
self.head_dim,
))?;
let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)?.reshape((
b_sz,
q_len,
n_kv_heads,
self.head_dim,
))?;
(q, k, v)
}
ModelVariant::Small => {
// Split into Q, K, V and reshape to match PyTorch shapes
let qkv = qkv.reshape((b_sz, q_len, 3, self.num_heads, self.head_dim))?;
(
qkv.i((.., .., 0, .., ..))?,
qkv.i((.., .., 1, .., ..))?,
qkv.i((.., .., 2, .., ..))?,
)
}
};
let query_states = query_states.transpose(1, 2)?.contiguous()?;
let key_states = key_states.transpose(1, 2)?.contiguous()?;
let value_states = value_states.transpose(1, 2)?.contiguous()?;
let (query_states, key_states) = self
.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states)?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
// The 1.5B is expected to have grouped query attention
let (key_states, value_states) = if self.variant == ModelVariant::Large {
(
crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?,
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?,
)
} else {
(key_states, value_states)
};
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?;
let attn_weights = (attn_weights * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
@ -260,70 +438,282 @@ impl Attention {
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
enum NormType {
Layer(LayerNorm),
Rms(RmsNorm),
}
impl DecoderLayer {
#[derive(Debug, Clone)]
struct Layer {
variant: ModelVariant,
attention: Attention,
mlp: MLP,
// For 1.5B: this is `input_layernorm`
// For 400M: this is `output_layernorm`
layernorm: NormType,
post_attention_layernorm: NormType,
}
impl Layer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
let attention = Attention::new(
rotary_emb,
cfg,
vb.pp(if cfg.variant == ModelVariant::Large {
"self_attn"
} else {
"attention"
}),
)?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let (layernorm, post_attention_layernorm) = match cfg.variant {
ModelVariant::Large => (
NormType::Rms(RmsNorm::new(
cfg.hidden_size,
cfg.norm_eps,
vb.pp("input_layernorm"),
)?),
NormType::Rms(RmsNorm::new(
cfg.hidden_size,
cfg.norm_eps,
vb.pp("post_attention_layernorm"),
)?),
),
ModelVariant::Small => (
NormType::Layer(layer_norm(
cfg.hidden_size,
candle_nn::LayerNormConfig {
eps: cfg.norm_eps,
..Default::default()
},
vb.pp("mlp_ln"),
)?),
NormType::Layer(layer_norm(
cfg.hidden_size,
candle_nn::LayerNormConfig {
eps: cfg.norm_eps,
..Default::default()
},
vb.pp("attn_ln"),
)?),
),
};
Ok(Self {
self_attn,
variant: cfg.variant,
attention,
mlp,
input_layernorm,
layernorm,
post_attention_layernorm,
})
}
fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
// Here, the application of normalizations and activation calculations differ
// For Large [1.5B]:
// residual = x
// state = other_layernorm(xs)
// state = attention(state)
// state += residual
// residual = state
// state = mlp(attention_layernorm(state))
// -> residual + state
// For Small [400M]:
// residual = x;
// state = attention(x)
// state += residual
// state = attention_layernorm(state)
// residual = state
// state = mlp(state)
// state += residual
// -> other_layernorm(state)
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
match self.variant {
ModelVariant::Large => {
let (attn_ln, input_ln) = if let (NormType::Rms(attn_ln), NormType::Rms(input_ln)) =
(&self.post_attention_layernorm, &self.layernorm)
{
(attn_ln, input_ln)
} else {
return Err(candle::error::Error::Msg(
"Stella 1.5B expects RMSNorm".to_string(),
));
};
let xs = input_ln.forward(xs)?;
let xs = (self.attention.forward(&xs, attention_mask)? + residual)?;
let residual = &xs;
let xs = xs.apply(attn_ln)?.apply(&self.mlp)?;
residual + xs
}
ModelVariant::Small => {
let (attn_ln, output_ln) =
if let (NormType::Layer(attn_ln), NormType::Layer(input_ln)) =
(&self.post_attention_layernorm, &self.layernorm)
{
(attn_ln, input_ln)
} else {
return Err(candle::error::Error::Msg(
"Stella 400M expects RMSNorm".to_string(),
));
};
let xs = (self.attention.forward(xs, attention_mask)? + residual)?;
let xs = attn_ln.forward(&xs)?;
let residual = &xs;
let xs = (self.mlp.forward(&xs)? + residual)?;
output_ln.forward(&xs)
}
}
}
}
#[derive(Debug, Clone)]
pub struct Embeddings {
variant: ModelVariant,
// For 1.5B: this is the `embed_tokens`
// For 400M: this is the `word_embeddings`
embeddings: candle_nn::Embedding,
// folloing are specifically for 400M
token_type_embeddings: Option<candle_nn::Embedding>,
layer_norm: Option<LayerNorm>,
position_ids: Option<Tensor>,
}
impl Embeddings {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let (embeddings, token_type_embeddings, layer_norm, position_ids) = match cfg.variant {
ModelVariant::Large => (
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?,
None,
None,
None,
),
ModelVariant::Small => {
let vb = vb.pp("embeddings");
let weight = vb.pp("LayerNorm").get_with_hints(
cfg.hidden_size,
"weight",
candle_nn::Init::Const(1.0),
)?;
let bias = vb.pp("LayerNorm").get_with_hints(
cfg.hidden_size,
"bias",
candle_nn::Init::Const(0.0),
)?;
let dev = bias.device().clone();
let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps);
(
candle_nn::embedding(
cfg.vocab_size,
cfg.hidden_size,
vb.pp("word_embeddings"),
)?,
Some(candle_nn::embedding(
cfg.type_vocab_size,
cfg.hidden_size,
vb.pp("token_type_embeddings"),
)?),
Some(layer_norm),
Some(Tensor::arange(
0u32,
cfg.max_position_embeddings as u32,
&dev,
)?),
)
}
};
Ok(Self {
variant: cfg.variant,
embeddings,
token_type_embeddings,
layer_norm,
position_ids,
})
}
}
impl Module for Embeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let embd = self.embeddings.forward(xs)?;
// For 1.5B just forward the embeddings
if self.variant == ModelVariant::Large {
return Ok(embd);
}
let (token_type_embed, layer_norm, pos_ids) =
if let (Some(token_type_embd), Some(layer_norm), Some(position_ids)) = (
&self.token_type_embeddings,
&self.layer_norm,
&self.position_ids,
) {
(token_type_embd, layer_norm, position_ids)
} else {
return Err(Error::Msg(
"Stella 400M requires `token_type_embeddings`, `layer_norm` and `position_ids`"
.to_string(),
));
};
let (batch_size, seq_length) = xs.dims2()?;
let pos_ids = pos_ids
.as_ref()
.narrow(0, 0, seq_length)?
.expand((batch_size, seq_length))?;
layer_norm.forward(&embd.add(&token_type_embed.forward(&pos_ids.zeros_like()?)?)?)
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
embeddings: Embeddings,
layers: Vec<Layer>,
norm: Option<RmsNorm>,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let vb_m = match cfg.variant {
ModelVariant::Large => vb.pp("model"),
ModelVariant::Small => vb.pp("new"),
};
// let embed_tokens =
// candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let embeddings = Embeddings::new(cfg, vb_m.clone())?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
let vb_l = match cfg.variant {
ModelVariant::Large => vb_m.pp("layers"),
ModelVariant::Small => vb_m.pp("encoder").pp("layer"),
};
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
let layer = Layer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let norm = match cfg.variant {
ModelVariant::Large => Some(RmsNorm::new(
cfg.hidden_size,
cfg.norm_eps,
vb_m.pp("norm"),
)?),
ModelVariant::Small => None,
};
Ok(Self {
embed_tokens,
embeddings,
layers,
norm,
// sliding_window: 0,
device: vb.device().clone(),
dtype: vb.dtype(),
})
@ -352,15 +742,20 @@ impl Model {
Some(self.prepare_attention_mask(mask)?)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
let mut xs = self.embeddings.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref())?
}
xs.apply(&self.norm)
if let Some(n) = &self.norm {
xs.apply(n)
} else {
Ok(xs)
}
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct EmbeddingModel {
base_model: Model,
lm_head: Linear,