Fix the prompt for mistral when using instruct/interactive mode. (#1013)

This commit is contained in:
Laurent Mazare 2023-10-01 06:44:30 +01:00 committed by GitHub
parent 328167ec04
commit f6054e9d60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 12 deletions

View File

@ -50,6 +50,23 @@ enum Which {
Mistral7bInstruct,
}
impl Which {
fn is_mistral(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => false,
Self::Mistral7b | Self::Mistral7bInstruct => true,
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -114,17 +131,10 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = match self.which {
Which::L7b
| Which::L13b
| Which::L70b
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode
| Which::L7bChat
| Which::L13bChat
| Which::L70bChat => "hf-internal-testing/llama-tokenizer",
Which::Mistral7b | Which::Mistral7bInstruct => "mistralai/Mistral-7B-v0.1",
let repo = if self.which.is_mistral() {
"mistralai/Mistral-7B-v0.1"
} else {
"hf-internal-testing/llama-tokenizer"
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
@ -315,7 +325,11 @@ fn main() -> anyhow::Result<()> {
prompt.pop();
}
}
prompt
if args.which.is_mistral() {
format!("[INST] {prompt} [/INST]")
} else {
prompt
}
}
};
print!("{}", &prompt_str);
@ -351,6 +365,8 @@ fn main() -> anyhow::Result<()> {
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
@ -369,6 +385,9 @@ fn main() -> anyhow::Result<()> {
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
if next_token == eos_token {
break;
};
}
let dt = start_post_prompt.elapsed();
println!(