Add support for tiny-llama-1.1b. (#1512)
This commit is contained in:
parent
a0facd0e67
commit
1fb2dd905c
|
@ -33,6 +33,8 @@ enum Which {
|
||||||
V2,
|
V2,
|
||||||
#[value(name = "solar-10.7b")]
|
#[value(name = "solar-10.7b")]
|
||||||
Solar10_7B,
|
Solar10_7B,
|
||||||
|
#[value(name = "tiny-llama-1.1b-chat")]
|
||||||
|
TinyLlama1_1BChat,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
|
@ -124,6 +126,7 @@ fn main() -> Result<()> {
|
||||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||||
|
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||||
});
|
});
|
||||||
println!("loading the model weights from {model_id}");
|
println!("loading the model weights from {model_id}");
|
||||||
let revision = args.revision.unwrap_or("main".to_string());
|
let revision = args.revision.unwrap_or("main".to_string());
|
||||||
|
@ -134,8 +137,12 @@ fn main() -> Result<()> {
|
||||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
let config = config.into_config(args.use_flash_attn);
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
|
||||||
let filenames =
|
let filenames = match args.which {
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
||||||
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
|
}
|
||||||
|
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||||
|
};
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue