diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index c2ed0e25..46f474bb 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -33,6 +33,8 @@ enum Which { V2, #[value(name = "solar-10.7b")] Solar10_7B, + #[value(name = "tiny-llama-1.1b-chat")] + TinyLlama1_1BChat, } #[derive(Parser, Debug)] @@ -124,6 +126,7 @@ fn main() -> Result<()> { Which::V1 => "Narsil/amall-7b".to_string(), Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), + Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), }); println!("loading the model weights from {model_id}"); let revision = args.revision.unwrap_or("main".to_string()); @@ -134,8 +137,12 @@ fn main() -> Result<()> { let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; let config = config.into_config(args.use_flash_attn); - let filenames = - candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?; + let filenames = match args.which { + Which::V1 | Which::V2 | Which::Solar10_7B => { + candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? + } + Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], + }; println!("building the model"); let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;