Fix the prompt for mistral when using instruct/interactive mode. (#1013)
This commit is contained in:
parent
328167ec04
commit
f6054e9d60
|
@ -50,6 +50,23 @@ enum Which {
|
|||
Mistral7bInstruct,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn is_mistral(&self) -> bool {
|
||||
match self {
|
||||
Self::L7b
|
||||
| Self::L13b
|
||||
| Self::L70b
|
||||
| Self::L7bChat
|
||||
| Self::L13bChat
|
||||
| Self::L70bChat
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode => false,
|
||||
Self::Mistral7b | Self::Mistral7bInstruct => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
|
@ -114,17 +131,10 @@ impl Args {
|
|||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = match self.which {
|
||||
Which::L7b
|
||||
| Which::L13b
|
||||
| Which::L70b
|
||||
| Which::L7bCode
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode
|
||||
| Which::L7bChat
|
||||
| Which::L13bChat
|
||||
| Which::L70bChat => "hf-internal-testing/llama-tokenizer",
|
||||
Which::Mistral7b | Which::Mistral7bInstruct => "mistralai/Mistral-7B-v0.1",
|
||||
let repo = if self.which.is_mistral() {
|
||||
"mistralai/Mistral-7B-v0.1"
|
||||
} else {
|
||||
"hf-internal-testing/llama-tokenizer"
|
||||
};
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
|
@ -315,7 +325,11 @@ fn main() -> anyhow::Result<()> {
|
|||
prompt.pop();
|
||||
}
|
||||
}
|
||||
prompt
|
||||
if args.which.is_mistral() {
|
||||
format!("[INST] {prompt} [/INST]")
|
||||
} else {
|
||||
prompt
|
||||
}
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
|
@ -351,6 +365,8 @@ fn main() -> anyhow::Result<()> {
|
|||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
|
||||
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
|
@ -369,6 +385,9 @@ fn main() -> anyhow::Result<()> {
|
|||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
|
|
Loading…
Reference in New Issue