Add support for gemma-2. (#2425)
* Add gemma-2. * Support a couple more models. * Sliding window support. * Example + readme updates. * Update the main readme.
This commit is contained in:
parent
69fdcfe96a
commit
c1b9e07e35
|
@ -65,7 +65,7 @@ We also provide a some command line based examples using state of the art models
|
|||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level
|
||||
- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM
|
||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
|
||||
- [Gemma v1 and v2](./candle-examples/examples/gemma/): 2b and 7b+/9b general LLMs from Google Deepmind.
|
||||
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
||||
Griffin based models from Google that mix attention with a RNN like state.
|
||||
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
|
||||
|
@ -208,7 +208,7 @@ If you have an addition to this list, please submit a pull request.
|
|||
- StarCoder, StarCoder2.
|
||||
- Phi 1, 1.5, 2, and 3.
|
||||
- Mamba, Minimal Mamba
|
||||
- Gemma 2b and 7b.
|
||||
- Gemma v1 2b and 7b+, v2 2b and 9b.
|
||||
- Mistral 7b v0.1.
|
||||
- Mixtral 8x7b v0.1.
|
||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
||||
|
|
|
@ -1,27 +1,27 @@
|
|||
# candle-gemma: 2b and 7b LLMs from Google DeepMind
|
||||
|
||||
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
|
||||
models published by Google Deepmind with a 2b and a 7b variant.
|
||||
|
||||
In order to use the example below, you have to accept the license on the
|
||||
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
|
||||
your access token via the [HuggingFace cli login
|
||||
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
|
||||
models published by Google Deepmind with a 2b and a 7b variant for the first
|
||||
version, and a 2b and a 9b variant for v2.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
|
||||
fn count_primes(max_n: usize) -> usize {
|
||||
let mut primes = vec![true; max_n];
|
||||
for i in 2..=max_n {
|
||||
if primes[i] {
|
||||
for j in i * i..max_n {
|
||||
primes[j] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
primes.len()
|
||||
}
|
||||
$ cargo run --example gemma --features cuda -r -- \
|
||||
--prompt "Here is a proof that square root of 2 is not rational: "
|
||||
|
||||
Here is a proof that square root of 2 is not rational:
|
||||
|
||||
Let us assume it to be rational. Then, we can write √2 = p/q where q ≠ 0 and p and q are integers with no common factors other than 1. Squaring both sides gives us (p/q)^2 = 2 or p^2/q^2 = 2. This implies that p^2 is divisible by 2, which means that p must be even. Let us write p = 2m where m is an integer. Substituting this in the above equation we get:
|
||||
|
||||
(p^2)/q^2 = 2 or (4m^2)/q^2 = 2 or q^2/2m^2 = 1 which implies that q^2 must be divisible by 2, and hence q is even. This contradicts our assumption that p and q have no common factors other than 1. Hence we conclude that √2 cannot be rational.
|
||||
```
|
||||
|
||||
## Access restrictions
|
||||
|
||||
In order to use the v1 examples, you have to accept the license on the
|
||||
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
|
||||
your access token via the [HuggingFace cli login
|
||||
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,8 @@ extern crate accelerate_src;
|
|||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::gemma::{Config, Model};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
|
@ -38,6 +39,46 @@ enum Which {
|
|||
CodeInstruct2B,
|
||||
#[value(name = "code-7b-it")]
|
||||
CodeInstruct7B,
|
||||
#[value(name = "2-2b")]
|
||||
BaseV2_2B,
|
||||
#[value(name = "2-2b-it")]
|
||||
InstructV2_2B,
|
||||
#[value(name = "2-9b")]
|
||||
BaseV2_9B,
|
||||
#[value(name = "2-9b-it")]
|
||||
InstructV2_9B,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn is_v1(&self) -> bool {
|
||||
match self {
|
||||
Self::Base2B
|
||||
| Self::Base7B
|
||||
| Self::Instruct2B
|
||||
| Self::Instruct7B
|
||||
| Self::InstructV1_1_2B
|
||||
| Self::InstructV1_1_7B
|
||||
| Self::CodeBase2B
|
||||
| Self::CodeBase7B
|
||||
| Self::CodeInstruct2B
|
||||
| Self::CodeInstruct7B => true,
|
||||
Self::BaseV2_2B | Self::InstructV2_2B | Self::BaseV2_9B | Self::InstructV2_9B => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
|
@ -191,7 +232,7 @@ struct Args {
|
|||
repeat_last_n: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "2b")]
|
||||
#[arg(long, default_value = "2-2b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
|
@ -239,6 +280,10 @@ fn main() -> Result<()> {
|
|||
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
||||
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
||||
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
||||
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
|
@ -263,7 +308,6 @@ fn main() -> Result<()> {
|
|||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
@ -273,7 +317,15 @@ fn main() -> Result<()> {
|
|||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(args.use_flash_attn, &config, vb)?;
|
||||
let model = if args.which.is_v1() {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
} else {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
|
|
@ -0,0 +1,449 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
|
||||
|
||||
fn default_max_position_embeddings() -> usize {
|
||||
4096
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub attention_bias: bool,
|
||||
pub head_dim: usize,
|
||||
pub hidden_activation: Activation,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f64,
|
||||
pub vocab_size: usize,
|
||||
pub final_logit_softcapping: Option<f64>,
|
||||
pub attn_logit_softcapping: Option<f64>,
|
||||
pub query_pre_attn_scalar: usize,
|
||||
// TODO: Handle the sliding window in the attention mask.
|
||||
pub sliding_window: Option<usize>,
|
||||
|
||||
#[serde(default = "default_max_position_embeddings")]
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
weight: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(dim, "weight")?;
|
||||
Ok(Self { weight, eps })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&(&self.weight + 1.0)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.head_dim;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
|
||||
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_activation,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
attn_logit_softcapping: Option<f64>,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = cfg.head_dim;
|
||||
let bias = cfg.attention_bias;
|
||||
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
attn_logit_softcapping: cfg.attn_logit_softcapping,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_output = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
let q = query_states.transpose(1, 2)?;
|
||||
let k = key_states.transpose(1, 2)?;
|
||||
let v = value_states.transpose(1, 2)?;
|
||||
let scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
|
||||
} else {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match self.attn_logit_softcapping {
|
||||
None => attn_weights,
|
||||
Some(sc) => ((attn_weights / sc)?.tanh()? * sc)?,
|
||||
};
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, ()))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
input_layernorm: RmsNorm,
|
||||
pre_feedforward_layernorm: RmsNorm,
|
||||
post_feedforward_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let pre_feedforward_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("pre_feedforward_layernorm"),
|
||||
)?;
|
||||
let post_feedforward_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_feedforward_layernorm"),
|
||||
)?;
|
||||
let post_attention_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
pre_feedforward_layernorm,
|
||||
post_feedforward_layernorm,
|
||||
post_attention_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = xs.apply(&self.post_attention_layernorm)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.pre_feedforward_layernorm)?;
|
||||
let xs = xs.apply(&self.mlp)?;
|
||||
let xs = xs.apply(&self.post_feedforward_layernorm)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
final_logit_softcapping: Option<f64>,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
hidden_size: usize,
|
||||
sliding_window: Option<usize>,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer =
|
||||
DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
final_logit_softcapping: cfg.final_logit_softcapping,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
hidden_size: cfg.hidden_size,
|
||||
sliding_window: cfg.sliding_window,
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = match self.sliding_window {
|
||||
None => (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect(),
|
||||
Some(sliding_window) => (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let xs = self.embed_tokens.forward(input_ids)?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
let logits = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)?;
|
||||
let logits = match self.final_logit_softcapping {
|
||||
None => logits,
|
||||
Some(sc) => ((logits / sc)?.tanh()? * sc)?,
|
||||
};
|
||||
|
||||
Ok(logits)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ pub mod eva2;
|
|||
pub mod falcon;
|
||||
pub mod flux;
|
||||
pub mod gemma;
|
||||
pub mod gemma2;
|
||||
pub mod glm4;
|
||||
pub mod hiera;
|
||||
pub mod jina_bert;
|
||||
|
|
Loading…
Reference in New Issue