VarBuilder path creation (#131)

* Use a struct for the safetensor+routing.

* Group the path and the var-builder together.

* Fix for the empty path case.
This commit is contained in:
Laurent Mazare 2023-07-10 22:37:34 +01:00 committed by GitHub
parent 1aa7fbbc33
commit b46c28a2ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 196 additions and 200 deletions

View File

@ -109,14 +109,14 @@ impl Config {
}
}
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
let bias = vb.get(size2, &format!("{p}.bias"))?;
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
@ -135,17 +135,11 @@ impl Dropout {
}
}
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (
vb.get(size, &format!("{p}.weight")),
vb.get(size, &format!("{p}.bias")),
) {
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (
vb.get(size, &format!("{p}.gamma")),
vb.get(size, &format!("{p}.beta")),
) {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
return Err(err.into());
@ -167,33 +161,29 @@ struct BertEmbeddings {
}
impl BertEmbeddings {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let word_embeddings = embedding(
config.vocab_size,
config.hidden_size,
&format!("{p}.word_embeddings"),
vb,
vb.pp("word_embeddings"),
)?;
let position_embeddings = embedding(
config.max_position_embeddings,
config.hidden_size,
&format!("{p}.position_embeddings"),
vb,
vb.pp("position_embeddings"),
)?;
let token_type_embeddings = embedding(
config.type_vocab_size,
config.hidden_size,
&format!("{p}.token_type_embeddings"),
vb,
vb.pp("token_type_embeddings"),
)?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
vb.pp("LayerNorm"),
)?;
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
let position_ids = Tensor::new(&position_ids[..], vb.device())?.unsqueeze(0)?;
let token_type_ids = position_ids.zeros_like()?;
Ok(Self {
word_embeddings,
@ -233,14 +223,14 @@ struct BertSelfAttention {
}
impl BertSelfAttention {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;
let dropout = Dropout::new(config.hidden_dropout_prob);
let hidden_size = config.hidden_size;
let query = linear(hidden_size, all_head_size, &format!("{p}.query"), vb)?;
let value = linear(hidden_size, all_head_size, &format!("{p}.value"), vb)?;
let key = linear(hidden_size, all_head_size, &format!("{p}.key"), vb)?;
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
Ok(Self {
query,
key,
@ -289,18 +279,12 @@ struct BertSelfOutput {
}
impl BertSelfOutput {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(
config.hidden_size,
config.hidden_size,
&format!("{p}.dense"),
vb,
)?;
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
vb.pp("LayerNorm"),
)?;
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
@ -324,9 +308,9 @@ struct BertAttention {
}
impl BertAttention {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::load(&format!("{p}.self"), vb, config)?;
let self_output = BertSelfOutput::load(&format!("{p}.output"), vb, config)?;
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
Ok(Self {
self_attention,
self_output,
@ -347,13 +331,8 @@ struct BertIntermediate {
}
impl BertIntermediate {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(
config.hidden_size,
config.intermediate_size,
&format!("{p}.dense"),
vb,
)?;
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
Ok(Self {
dense,
intermediate_act: config.hidden_act,
@ -375,18 +354,12 @@ struct BertOutput {
}
impl BertOutput {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(
config.intermediate_size,
config.hidden_size,
&format!("{p}.dense"),
vb,
)?;
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
vb.pp("LayerNorm"),
)?;
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
@ -411,10 +384,10 @@ struct BertLayer {
}
impl BertLayer {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
let attention = BertAttention::load(&format!("{p}.attention"), vb, config)?;
let intermediate = BertIntermediate::load(&format!("{p}.intermediate"), vb, config)?;
let output = BertOutput::load(&format!("{p}.output"), vb, config)?;
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let attention = BertAttention::load(vb.pp("attention"), config)?;
let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
let output = BertOutput::load(vb.pp("output"), config)?;
Ok(Self {
attention,
intermediate,
@ -441,12 +414,9 @@ struct BertEncoder {
}
impl BertEncoder {
fn load(p: &str, vb: &VarBuilder, config: &Config) -> Result<Self> {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| {
let p = format!("{p}.layer.{index}");
BertLayer::load(&p, vb, config)
})
.map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
Ok(BertEncoder { layers })
}
@ -469,17 +439,17 @@ struct BertModel {
}
impl BertModel {
fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let (embeddings, encoder) = match (
BertEmbeddings::load("embeddings", vb, config),
BertEncoder::load("encoder", vb, config),
BertEmbeddings::load(vb.pp("embeddings"), config),
BertEncoder::load(vb.pp("encoder"), config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
BertEmbeddings::load(&format!("{model_type}.embeddings"), vb, config),
BertEncoder::load(&format!("{model_type}.encoder"), vb, config),
BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
) {
(embeddings, encoder)
} else {
@ -493,7 +463,7 @@ impl BertModel {
Ok(Self {
embeddings,
encoder,
device: vb.device.clone(),
device: vb.device().clone(),
})
}
@ -576,7 +546,7 @@ impl Args {
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let model = BertModel::load(&vb, &config)?;
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
}

View File

@ -169,7 +169,7 @@ fn main() -> Result<()> {
let vb = VarBuilder::from_safetensors(weights, DTYPE, &device);
let config = Config::falcon7b();
config.validate()?;
let model = Falcon::load(&vb, config)?;
let model = Falcon::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);

View File

@ -4,27 +4,21 @@ use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
const MAX_SEQ_LEN: usize = 5000;
fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = if bias {
Some(vb.get(size2, &format!("{p}.bias"))?)
Some(vb.get(size2, "bias")?)
} else {
None
};
Ok(Linear::new(weight, bias))
}
fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (
vb.get(size, &format!("{p}.weight")),
vb.get(size, &format!("{p}.bias")),
) {
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (
vb.get(size, &format!("{p}.gamma")),
vb.get(size, &format!("{p}.beta")),
) {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
return Err(err.into());
@ -50,8 +44,8 @@ impl Dropout {
}
}
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
@ -164,14 +158,14 @@ struct FalconRotaryEmbedding {
}
impl FalconRotaryEmbedding {
fn load(vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(device: &Device, cfg: &Config) -> Result<Self> {
let head_dim = cfg.head_dim();
let inv_freq: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
.collect();
Ok(Self {
inv_freq: Tensor::new(inv_freq.as_slice(), &vb.device)?,
inv_freq: Tensor::new(inv_freq.as_slice(), device)?,
cache: None,
})
}
@ -237,9 +231,9 @@ struct FalconAttention {
}
impl FalconAttention {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let maybe_rotary = if cfg.rotary() {
let rotary = FalconRotaryEmbedding::load(vb, cfg)?;
let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?;
Some(rotary)
} else {
None
@ -251,20 +245,8 @@ impl FalconAttention {
} else {
3 * hidden_size
};
let query_key_value = linear(
hidden_size,
qkv_out_dim,
cfg.bias,
&format!("{p}.query_key_value"),
vb,
)?;
let dense = linear(
hidden_size,
hidden_size,
cfg.bias,
&format!("{p}.dense"),
vb,
)?;
let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?;
let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?;
Ok(Self {
query_key_value,
dense,
@ -367,11 +349,11 @@ struct FalconMlp {
}
impl FalconMlp {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let b = cfg.bias;
let dense_h_to_4h = linear(h, 4 * h, b, &format!("{p}.dense_h_to_4h"), vb)?;
let dense_4h_to_h = linear(4 * h, h, b, &format!("{p}.dense_4h_to_h"), vb)?;
let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?;
let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?;
let dropout = Dropout::new(cfg.hidden_dropout);
Ok(Self {
dense_h_to_4h,
@ -397,23 +379,21 @@ struct FalconDecoderLayer {
}
impl FalconDecoderLayer {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let mlp = FalconMlp::load(&format!("{p}.mlp"), vb, cfg)?;
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?;
let inp_layernorm = layer_norm(
cfg.hidden_size,
cfg.layer_norm_epsilon,
&format!("{p}.input_layernorm"),
vb,
vb.pp("input_layernorm"),
)?;
let self_attention = FalconAttention::load(&format!("{p}.self_attention"), vb, cfg)?;
let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?;
let post_attention_layernorm = if cfg.parallel_attn {
None
} else {
let ln = layer_norm(
cfg.hidden_size,
cfg.layer_norm_epsilon,
&format!("{p}.post_attention_layernorm"),
vb,
vb.pp("post_attention_layernorm"),
)?;
Some(ln)
};
@ -480,23 +460,21 @@ impl Falcon {
&self.config
}
pub fn load(vb: &VarBuilder, cfg: Config) -> Result<Self> {
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
let word_embeddings = embedding(
cfg.vocab_size,
cfg.hidden_size,
"transformer.word_embeddings",
vb,
vb.pp("transformer.word_embeddings"),
)?;
let blocks = (0..cfg.num_hidden_layers)
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
.map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg))
.collect::<Result<Vec<_>>>()?;
let ln_f = layer_norm(
cfg.hidden_size,
cfg.layer_norm_epsilon,
"transformer.ln_f",
vb,
vb.pp("transformer.ln_f"),
)?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
Ok(Self {
word_embeddings,
blocks,

View File

@ -38,19 +38,19 @@ impl Config {
}
}
fn embedding(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn linear(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
let bias = vb.get(size2, &format!("{p}.bias"))?;
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
fn linear_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
}
@ -59,14 +59,10 @@ fn conv1d(
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
p: &str,
vb: &VarBuilder,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get(
(out_channels, in_channels, kernel_size),
&format!("{p}.weight"),
)?;
let bias = vb.get(out_channels, &format!("{p}.bias"))?;
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
let bias = vb.get(out_channels, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
@ -75,13 +71,9 @@ fn conv1d_no_bias(
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
p: &str,
vb: &VarBuilder,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get(
(out_channels, in_channels, kernel_size),
&format!("{p}.weight"),
)?;
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
Ok(Conv1d::new(weight, None, config))
}
@ -100,9 +92,9 @@ impl Dropout {
}
}
fn layer_norm(size: usize, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, &format!("{p}.weight"))?;
let bias = vb.get(size, &format!("{p}.bias"))?;
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
Ok(LayerNorm::new(weight, bias, 1e-5))
}
@ -116,11 +108,11 @@ struct MultiHeadAttention {
}
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let query = linear(n_state, n_state, &format!("{p}.q_proj"), vb)?;
let value = linear(n_state, n_state, &format!("{p}.v_proj"), vb)?;
let key = linear_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
let out = linear(n_state, n_state, &format!("{p}.out_proj"), vb)?;
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
let out = linear(n_state, n_state, vb.pp("out_proj"))?;
Ok(Self {
query,
key,
@ -179,21 +171,20 @@ struct ResidualAttentionBlock {
}
impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?;
let attn_ln = layer_norm(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
let cross_attn =
MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?;
let cross_attn_ln = layer_norm(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
Some((cross_attn, cross_attn_ln))
} else {
None
};
let n_mlp = n_state * 4;
let mlp_linear1 = linear(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
let mlp_linear2 = linear(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
let mlp_ln = layer_norm(n_state, &format!("{p}.final_layer_norm"), vb)?;
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
Ok(Self {
attn,
attn_ln,
@ -245,7 +236,7 @@ pub struct AudioEncoder {
}
impl AudioEncoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
@ -257,22 +248,15 @@ impl AudioEncoder {
padding: 1,
stride: 2,
};
let conv1 = conv1d(
cfg.num_mel_bins,
n_state,
3,
cfg1,
&format!("{p}.conv1"),
vb,
)?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?;
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
let blocks = (0..cfg.encoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb)
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln_post = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
Ok(Self {
conv1,
conv2,
@ -306,23 +290,22 @@ pub struct TextDecoder {
}
impl TextDecoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
let token_embedding = embedding(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?;
let positional_embedding =
vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?;
let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
let blocks = (0..cfg.decoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb)
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln = layer_norm(n_state, &format!("{p}.layer_norm"), vb)?;
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
let mask: Vec<_> = (0..n_ctx)
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), &vb.device)?;
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
Ok(Self {
token_embedding,
@ -361,8 +344,8 @@ pub struct Whisper {
impl Whisper {
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
let encoder = AudioEncoder::load("model.encoder", vb, &config)?;
let decoder = TextDecoder::load("model.decoder", vb, &config)?;
let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
Ok(Self {
encoder,
decoder,

View File

@ -1,53 +1,118 @@
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use std::collections::HashMap;
use std::sync::Arc;
pub struct VarBuilder<'a> {
safetensors: Option<(HashMap<String, usize>, Vec<SafeTensors<'a>>)>,
struct SafeTensorWithRouting<'a> {
routing: HashMap<String, usize>,
safetensors: Vec<SafeTensors<'a>>,
}
struct TensorData<'a> {
// TODO: Make this part generic, probably via some Box<dyn> to avoid too much generics.
safetensors: Option<SafeTensorWithRouting<'a>>,
pub dtype: DType,
pub device: Device,
}
impl<'a> VarBuilder<'a> {
pub fn from_safetensors(
safetensors: Vec<SafeTensors<'a>>,
dtype: DType,
device: &Device,
) -> Self {
impl<'a> TensorData<'a> {
fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, device: &Device) -> Self {
let mut routing = HashMap::new();
for (index, sf) in safetensors.iter().enumerate() {
for k in sf.names() {
routing.insert(k.to_string(), index);
}
}
let safetensors = SafeTensorWithRouting {
routing,
safetensors,
};
Self {
safetensors: Some((routing, safetensors)),
safetensors: Some(safetensors),
device: device.clone(),
dtype,
}
}
pub fn zeros(dtype: DType, device: Device) -> Self {
fn zeros(dtype: DType, device: &Device) -> Self {
Self {
safetensors: None,
device,
device: device.clone(),
dtype,
}
}
}
#[derive(Clone)]
pub struct VarBuilder<'a> {
data: Arc<TensorData<'a>>,
path: Vec<String>,
}
impl<'a> VarBuilder<'a> {
/// Create a `VarBuilder` accessing data frome the safetensors storage. The initial path is
/// set to the root path and sub-paths can be created via the `push_prefix` method.
pub fn from_safetensors(st: Vec<SafeTensors<'a>>, dtype: DType, device: &Device) -> Self {
let data = TensorData::from_safetensors(st, dtype, device);
Self {
data: Arc::new(data),
path: vec![],
}
}
pub fn zeros(dtype: DType, device: &Device) -> Self {
let data = TensorData::zeros(dtype, device);
Self {
data: Arc::new(data),
path: vec![],
}
}
pub fn push_prefix(&self, s: &str) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());
Self {
data: self.data.clone(),
path,
}
}
/// Short alias for `push_prefix`.
pub fn pp(&self, s: &str) -> Self {
self.push_prefix(s)
}
pub fn device(&self) -> &Device {
&self.data.device
}
pub fn dtype(&self) -> DType {
self.data.dtype
}
}
impl<'a> VarBuilder<'a> {
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
let data = self.data.as_ref();
let s: Shape = s.into();
match &self.safetensors {
None => Tensor::zeros(s, self.dtype, &self.device),
Some((routing, safetensors)) => {
match &self.data.safetensors {
None => Tensor::zeros(s, data.dtype, &data.device),
Some(SafeTensorWithRouting {
routing,
safetensors,
}) => {
let path = if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
};
// Unwrap or 0 just to let the proper error flow.
let index = routing.get(tensor_name).unwrap_or(&0);
let index = routing.get(&path).unwrap_or(&0);
let tensor = safetensors[*index]
.tensor(tensor_name, &self.device)?
.to_dtype(self.dtype)?;
.tensor(&path, &data.device)?
.to_dtype(data.dtype)?;
if *tensor.shape() != s {
let msg = format!("shape mismatch for {tensor_name}");
Err(candle::Error::UnexpectedShape {
msg,
msg: format!("shape mismatch for {path}"),
expected: s,
got: tensor.shape().clone(),
})?