Implement top_p / nucleus sampling (#819)

* Implement top_p / nucleus sampling

* Update changelog

* rustfmt

* Add tests

* Fix clippy warning

* Fix another clippy error
This commit is contained in:
Juarez Bochi 2023-09-12 09:10:16 -07:00 committed by GitHub
parent 42da17694a
commit 805bf9ffa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 199 additions and 43 deletions

View File

@ -4,6 +4,8 @@ This documents the main changes to the `candle` crate.
## v0.2.2 - Unreleased ## v0.2.2 - Unreleased
### Added ### Added
- Support for `top_p` sampling
[819](https://github.com/huggingface/candle/pull/819).
### Modified ### Modified

View File

@ -28,9 +28,10 @@ impl TextGeneration {
tokenizer: Tokenizer, tokenizer: Tokenizer,
seed: u64, seed: u64,
temp: Option<f64>, temp: Option<f64>,
top_p: Option<f64>,
device: &Device, device: &Device,
) -> Self { ) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp); let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self { Self {
model, model,
tokenizer, tokenizer,
@ -94,6 +95,10 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -149,7 +154,14 @@ fn main() -> Result<()> {
let model = GPTBigCode::load(vb, config)?; let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device); let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?; pipeline.run(&args.prompt, args.sample_len)?;
Ok(()) Ok(())
} }

View File

@ -25,17 +25,25 @@ struct TextGeneration {
repeat_last_n: usize, repeat_last_n: usize,
} }
struct GenerationOptions {
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration { impl TextGeneration {
fn new( fn new(
model: Falcon, model: Falcon,
tokenizer: Tokenizer, tokenizer: Tokenizer,
generation_options: GenerationOptions,
seed: u64, seed: u64,
temp: Option<f64>,
device: &Device, device: &Device,
repeat_penalty: f32,
repeat_last_n: usize,
) -> Self { ) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp); let logits_processor =
LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
let repeat_penalty = generation_options.repeat_penalty;
let repeat_last_n = generation_options.repeat_last_n;
Self { Self {
model, model,
tokenizer, tokenizer,
@ -118,6 +126,10 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -185,15 +197,14 @@ fn main() -> Result<()> {
let model = Falcon::load(vb, config)?; let model = Falcon::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new( let generation_options = GenerationOptions {
model, temp: args.temperature,
tokenizer, top_p: args.top_p,
args.seed, repeat_penalty: args.repeat_penalty,
args.temperature, repeat_last_n: args.repeat_last_n,
&device, };
args.repeat_penalty, let mut pipeline =
args.repeat_last_n, TextGeneration::new(model, tokenizer, generation_options, args.seed, &device);
);
pipeline.run(&args.prompt, args.sample_len)?; pipeline.run(&args.prompt, args.sample_len)?;
Ok(()) Ok(())
} }

View File

@ -42,6 +42,10 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -193,7 +197,7 @@ fn main() -> Result<()> {
println!("starting the inference loop"); println!("starting the inference loop");
print!("{prompt}"); print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
let mut index_pos = 0; let mut index_pos = 0;
let mut token_generated = 0; let mut token_generated = 0;

View File

@ -27,6 +27,10 @@ struct InferenceCmd {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
#[arg(long, default_value = "")] #[arg(long, default_value = "")]
prompt: String, prompt: String,
@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
None => { None => {
let cmd = InferenceCmd { let cmd = InferenceCmd {
temperature: None, temperature: None,
top_p: None,
prompt: "".to_string(), prompt: "".to_string(),
config: None, config: None,
model_id: "karpathy/tinyllamas".to_string(), model_id: "karpathy/tinyllamas".to_string(),
@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Llama::load(vb, &cache, config)?; let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop"); println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature); let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
let mut index_pos = 0; let mut index_pos = 0;
print!("{}", args.prompt); print!("{}", args.prompt);

View File

@ -71,6 +71,10 @@ struct Args {
#[arg(long, default_value_t = 0.8)] #[arg(long, default_value_t = 0.8)]
temperature: f64, temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -310,7 +314,7 @@ fn main() -> anyhow::Result<()> {
prompt_tokens prompt_tokens
}; };
let mut all_tokens = vec![]; let mut all_tokens = vec![];
let mut logits_processor = LogitsProcessor::new(args.seed, temperature); let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let start_prompt_processing = std::time::Instant::now(); let start_prompt_processing = std::time::Instant::now();
let mut next_token = { let mut next_token = {

View File

@ -4,32 +4,76 @@ use rand::{distributions::Distribution, SeedableRng};
pub struct LogitsProcessor { pub struct LogitsProcessor {
rng: rand::rngs::StdRng, rng: rand::rngs::StdRng,
temperature: Option<f64>, temperature: Option<f64>,
top_p: Option<f64>,
} }
impl LogitsProcessor { impl LogitsProcessor {
pub fn new(seed: u64, temperature: Option<f64>) -> Self { pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
Self { Self {
rng: rand::rngs::StdRng::seed_from_u64(seed), rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature, temperature,
top_p,
} }
} }
fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
let logits_v: Vec<f32> = logits.to_vec1()?;
let next_token = logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap();
Ok(next_token)
}
fn sample_multi(&mut self, prs: &Vec<f32>) -> Result<u32> {
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
let next_token = distr.sample(&mut self.rng) as u32;
Ok(next_token)
}
fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability top_p. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
// Sort by descending probability.
argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap());
// Clamp smaller probabilities to zero.
let mut cumsum = 0.;
for index in &argsort_indices {
if cumsum >= top_p {
prs[*index] = 0.0;
} else {
cumsum += prs[*index];
}
}
// Sample with clamped probabilities.
let next_token = self.sample_multi(prs)?;
Ok(next_token)
}
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?; let logits = logits.to_dtype(DType::F32)?;
let temperature = self.temperature.unwrap_or(0.); let temperature = self.temperature.unwrap_or(0.);
let next_token = if temperature > 0. { let top_p = self.top_p.unwrap_or(1.);
let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?; let next_token = if temperature == 0. {
let prs: Vec<f32> = prs.to_vec1()?; self.sample_argmax(logits)?
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
distr.sample(&mut self.rng) as u32
} else { } else {
let logits_v: Vec<f32> = logits.to_vec1()?; let logits = &(&logits / temperature)?;
logits_v let prs = candle_nn::ops::softmax(logits, D::Minus1)?;
.iter() let mut prs: Vec<f32> = prs.to_vec1()?;
.enumerate() if top_p <= 0.0 || top_p >= 1.0 {
.max_by(|(_, u), (_, v)| u.total_cmp(v)) // simply sample from the predicted probability distribution
.map(|(i, _)| i as u32) self.sample_multi(&prs)?
.unwrap() } else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
self.sample_topp(&mut prs, top_p as f32)?
}
}; };
Ok(next_token) Ok(next_token)
} }

View File

@ -0,0 +1,29 @@
use candle::{Device, Result, Tensor};
use candle_transformers::generation::LogitsProcessor;
#[test]
fn sample_with_zero_temperature() -> Result<()> {
let mut logits_process = LogitsProcessor::new(1337, None, None);
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
let token = logits_process.sample(&logits)?;
assert_eq!(token, 3);
Ok(())
}
#[test]
fn sample_with_temperature() -> Result<()> {
let mut logits_process = LogitsProcessor::new(42, Some(0.9), None);
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
let token = logits_process.sample(&logits)?;
assert_eq!(token, 0);
Ok(())
}
#[test]
fn sample_with_top_p() -> Result<()> {
let mut logits_process = LogitsProcessor::new(42, Some(1.0), Some(0.5));
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
let token = logits_process.sample(&logits)?;
assert_eq!(token, 2);
Ok(())
}

View File

@ -56,6 +56,7 @@
const weightsURL = `${MODELS_BASE_URL}/${model.url}`; const weightsURL = `${MODELS_BASE_URL}/${model.url}`;
const prompt = getValue("prompt"); const prompt = getValue("prompt");
const temperature = getValue("temperature"); const temperature = getValue("temperature");
const topP = getValue("top-p");
const repeatPenalty = getValue("repeat_penalty"); const repeatPenalty = getValue("repeat_penalty");
const seed = getValue("seed"); const seed = getValue("seed");
const maxSeqLen = getValue("max-seq"); const maxSeqLen = getValue("max-seq");
@ -99,6 +100,7 @@
tokenizerURL: "tokenizer.json", tokenizerURL: "tokenizer.json",
prompt, prompt,
temp: temperature, temp: temperature,
top_p: topP,
repeatPenalty, repeatPenalty,
seed: BigInt(seed), seed: BigInt(seed),
maxSeqLen, maxSeqLen,
@ -251,7 +253,7 @@
<input <input
type="range" type="range"
id="max-seq" id="max-seq"
name="temperature" name="max-seq"
min="1" min="1"
max="256" max="256"
step="1" step="1"
@ -279,6 +281,22 @@
> >
0.50</output 0.50</output
> >
<label class="text-sm font-medium" for="top-p">Top-p</label>
<input
type="range"
id="top-p"
name="top-p"
min="0"
max="1"
step="0.01"
value="1.00"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>
1.00</output
>
<label class="text-sm font-medium" for="repeat_penalty" <label class="text-sm font-medium" for="repeat_penalty"
>Repeat Penalty</label >Repeat Penalty</label

View File

@ -46,6 +46,7 @@ pub struct App {
status: String, status: String,
loaded: bool, loaded: bool,
temperature: std::rc::Rc<std::cell::RefCell<f64>>, temperature: std::rc::Rc<std::cell::RefCell<f64>>,
top_p: std::rc::Rc<std::cell::RefCell<f64>>,
prompt: std::rc::Rc<std::cell::RefCell<String>>, prompt: std::rc::Rc<std::cell::RefCell<String>>,
generated: String, generated: String,
n_tokens: usize, n_tokens: usize,
@ -81,6 +82,7 @@ impl Component for App {
status, status,
n_tokens: 0, n_tokens: 0,
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)), temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
top_p: std::rc::Rc::new(std::cell::RefCell::new(1.0)),
prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())), prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())),
generated: String::new(), generated: String::new(),
current_decode: None, current_decode: None,
@ -122,10 +124,11 @@ impl Component for App {
self.n_tokens = 0; self.n_tokens = 0;
self.generated.clear(); self.generated.clear();
let temp = *self.temperature.borrow(); let temp = *self.temperature.borrow();
let top_p = *self.top_p.borrow();
let prompt = self.prompt.borrow().clone(); let prompt = self.prompt.borrow().clone();
console_log!("temp: {}, prompt: {}", temp, prompt); console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt);
ctx.link() ctx.link()
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt))) .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt)))
} }
true true
} }
@ -177,13 +180,21 @@ impl Component for App {
fn view(&self, ctx: &Context<Self>) -> Html { fn view(&self, ctx: &Context<Self>) -> Html {
use yew::TargetCast; use yew::TargetCast;
let temperature = self.temperature.clone(); let temperature = self.temperature.clone();
let oninput = ctx.link().callback(move |e: yew::InputEvent| { let oninput_temperature = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into(); let input: web_sys::HtmlInputElement = e.target_unchecked_into();
if let Ok(temp) = f64::from_str(&input.value()) { if let Ok(temp) = f64::from_str(&input.value()) {
*temperature.borrow_mut() = temp *temperature.borrow_mut() = temp
} }
Msg::Refresh Msg::Refresh
}); });
let top_p = self.top_p.clone();
let oninput_top_p = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
if let Ok(top_p_input) = f64::from_str(&input.value()) {
*top_p.borrow_mut() = top_p_input
}
Msg::Refresh
});
let prompt = self.prompt.clone(); let prompt = self.prompt.clone();
let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| { let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into(); let input: web_sys::HtmlInputElement = e.target_unchecked_into();
@ -201,9 +212,13 @@ impl Component for App {
</p> </p>
</div> </div>
{"temperature \u{00a0} "} {"temperature \u{00a0} "}
<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/> <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} oninput={oninput_temperature} id="temp"/>
{format!(" \u{00a0} {}", self.temperature.borrow())} {format!(" \u{00a0} {}", self.temperature.borrow())}
<br/ > <br/ >
{"top_p \u{00a0} "}
<input type="range" min="0." max="1.0" step="0.05" value={self.top_p.borrow().to_string()} oninput={oninput_top_p} id="top_p"/>
{format!(" \u{00a0} {}", self.top_p.borrow())}
<br/ >
{"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/> {"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/>
<br/ > <br/ >
{ {

View File

@ -47,7 +47,7 @@ impl Model {
tokenizer, tokenizer,
model: weights, model: weights,
}); });
let logits_processor = LogitsProcessor::new(299792458, None); let logits_processor = LogitsProcessor::new(299792458, None, None);
match model { match model {
Ok(inner) => Ok(Self { Ok(inner) => Ok(Self {
inner, inner,
@ -69,6 +69,7 @@ impl Model {
&mut self, &mut self,
prompt: String, prompt: String,
temp: f64, temp: f64,
top_p: f64,
repeat_penalty: f32, repeat_penalty: f32,
seed: u64, seed: u64,
) -> Result<String, JsError> { ) -> Result<String, JsError> {
@ -80,7 +81,12 @@ impl Model {
} }
} }
let temp = if temp <= 0. { None } else { Some(temp) }; let temp = if temp <= 0. { None } else { Some(temp) };
self.logits_processor = LogitsProcessor::new(seed, temp); let top_p = if top_p <= 0. || top_p >= 1. {
None
} else {
Some(top_p)
};
self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
self.repeat_penalty = repeat_penalty; self.repeat_penalty = repeat_penalty;
self.tokens.clear(); self.tokens.clear();
let tokens = self let tokens = self

View File

@ -62,12 +62,18 @@ impl Model {
link: &WorkerLink<Worker>, link: &WorkerLink<Worker>,
id: HandlerId, id: HandlerId,
temp: f64, temp: f64,
top_p: f64,
prompt: String, prompt: String,
) -> Result<()> { ) -> Result<()> {
let dev = Device::Cpu; let dev = Device::Cpu;
let temp = if temp <= 0. { None } else { Some(temp) }; let temp = if temp <= 0. { None } else { Some(temp) };
console_log!("{temp:?} {prompt}"); let top_p = if top_p <= 0. || top_p >= 1.0 {
let mut logits_processor = LogitsProcessor::new(299792458, temp); None
} else {
Some(top_p)
};
console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}");
let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p);
let mut index_pos = 0; let mut index_pos = 0;
let mut tokens = self let mut tokens = self
.tokenizer .tokenizer
@ -268,7 +274,7 @@ pub struct Worker {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub enum WorkerInput { pub enum WorkerInput {
ModelData(ModelData), ModelData(ModelData),
Run(f64, String), Run(f64, f64, String),
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker {
} }
Err(err) => Err(format!("model creation error {err:?}")), Err(err) => Err(format!("model creation error {err:?}")),
}, },
WorkerInput::Run(temp, prompt) => match &mut self.model { WorkerInput::Run(temp, top_p, prompt) => match &mut self.model {
None => Err("model has not been set yet".to_string()), None => Err("model has not been set yet".to_string()),
Some(model) => { Some(model) => {
{ {
@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker {
} }
} }
let result = model let result = model
.run(&self.link, id, temp, prompt) .run(&self.link, id, temp, top_p, prompt)
.map_err(|e| e.to_string()); .map_err(|e| e.to_string());
Ok(WorkerOutput::GenerationDone(result)) Ok(WorkerOutput::GenerationDone(result))
} }