Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error
This commit is contained in:
parent
42da17694a
commit
805bf9ffa7
|
@ -4,6 +4,8 @@ This documents the main changes to the `candle` crate.
|
|||
## v0.2.2 - Unreleased
|
||||
|
||||
### Added
|
||||
- Support for `top_p` sampling
|
||||
[819](https://github.com/huggingface/candle/pull/819).
|
||||
|
||||
### Modified
|
||||
|
||||
|
|
|
@ -28,9 +28,10 @@ impl TextGeneration {
|
|||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp);
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
|
@ -94,6 +95,10 @@ struct Args {
|
|||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
@ -149,7 +154,14 @@ fn main() -> Result<()> {
|
|||
let model = GPTBigCode::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -25,17 +25,25 @@ struct TextGeneration {
|
|||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
struct GenerationOptions {
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
fn new(
|
||||
model: Falcon,
|
||||
tokenizer: Tokenizer,
|
||||
generation_options: GenerationOptions,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
device: &Device,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp);
|
||||
let logits_processor =
|
||||
LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
|
||||
let repeat_penalty = generation_options.repeat_penalty;
|
||||
let repeat_last_n = generation_options.repeat_last_n;
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
|
@ -118,6 +126,10 @@ struct Args {
|
|||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
@ -185,15 +197,14 @@ fn main() -> Result<()> {
|
|||
let model = Falcon::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
&device,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
);
|
||||
let generation_options = GenerationOptions {
|
||||
temp: args.temperature,
|
||||
top_p: args.top_p,
|
||||
repeat_penalty: args.repeat_penalty,
|
||||
repeat_last_n: args.repeat_last_n,
|
||||
};
|
||||
let mut pipeline =
|
||||
TextGeneration::new(model, tokenizer, generation_options, args.seed, &device);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -42,6 +42,10 @@ struct Args {
|
|||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
@ -193,7 +197,7 @@ fn main() -> Result<()> {
|
|||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
|
|
|
@ -27,6 +27,10 @@ struct InferenceCmd {
|
|||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
prompt: String,
|
||||
|
||||
|
@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
|
|||
None => {
|
||||
let cmd = InferenceCmd {
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
prompt: "".to_string(),
|
||||
config: None,
|
||||
model_id: "karpathy/tinyllamas".to_string(),
|
||||
|
@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
||||
let mut index_pos = 0;
|
||||
|
||||
print!("{}", args.prompt);
|
||||
|
|
|
@ -71,6 +71,10 @@ struct Args {
|
|||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
@ -310,7 +314,7 @@ fn main() -> anyhow::Result<()> {
|
|||
prompt_tokens
|
||||
};
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
|
|
|
@ -4,32 +4,76 @@ use rand::{distributions::Distribution, SeedableRng};
|
|||
pub struct LogitsProcessor {
|
||||
rng: rand::rngs::StdRng,
|
||||
temperature: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
}
|
||||
|
||||
impl LogitsProcessor {
|
||||
pub fn new(seed: u64, temperature: Option<f64>) -> Self {
|
||||
pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
|
||||
Self {
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
temperature,
|
||||
top_p,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
let next_token = logits_v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap();
|
||||
Ok(next_token)
|
||||
}
|
||||
|
||||
fn sample_multi(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
let next_token = distr.sample(&mut self.rng) as u32;
|
||||
Ok(next_token)
|
||||
}
|
||||
|
||||
fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
|
||||
// top-p sampling (or "nucleus sampling") samples from the smallest set of
|
||||
// tokens that exceed probability top_p. This way we never sample tokens that
|
||||
// have very low probabilities and are less likely to go "off the rails".
|
||||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||
|
||||
// Sort by descending probability.
|
||||
argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap());
|
||||
|
||||
// Clamp smaller probabilities to zero.
|
||||
let mut cumsum = 0.;
|
||||
for index in &argsort_indices {
|
||||
if cumsum >= top_p {
|
||||
prs[*index] = 0.0;
|
||||
} else {
|
||||
cumsum += prs[*index];
|
||||
}
|
||||
}
|
||||
|
||||
// Sample with clamped probabilities.
|
||||
let next_token = self.sample_multi(prs)?;
|
||||
Ok(next_token)
|
||||
}
|
||||
|
||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let temperature = self.temperature.unwrap_or(0.);
|
||||
let next_token = if temperature > 0. {
|
||||
let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?;
|
||||
let prs: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
let top_p = self.top_p.unwrap_or(1.);
|
||||
let next_token = if temperature == 0. {
|
||||
self.sample_argmax(logits)?
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
logits_v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
let logits = &(&logits / temperature)?;
|
||||
let prs = candle_nn::ops::softmax(logits, D::Minus1)?;
|
||||
let mut prs: Vec<f32> = prs.to_vec1()?;
|
||||
if top_p <= 0.0 || top_p >= 1.0 {
|
||||
// simply sample from the predicted probability distribution
|
||||
self.sample_multi(&prs)?
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
self.sample_topp(&mut prs, top_p as f32)?
|
||||
}
|
||||
};
|
||||
Ok(next_token)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
use candle::{Device, Result, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
#[test]
|
||||
fn sample_with_zero_temperature() -> Result<()> {
|
||||
let mut logits_process = LogitsProcessor::new(1337, None, None);
|
||||
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
|
||||
let token = logits_process.sample(&logits)?;
|
||||
assert_eq!(token, 3);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_with_temperature() -> Result<()> {
|
||||
let mut logits_process = LogitsProcessor::new(42, Some(0.9), None);
|
||||
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
|
||||
let token = logits_process.sample(&logits)?;
|
||||
assert_eq!(token, 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_with_top_p() -> Result<()> {
|
||||
let mut logits_process = LogitsProcessor::new(42, Some(1.0), Some(0.5));
|
||||
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
|
||||
let token = logits_process.sample(&logits)?;
|
||||
assert_eq!(token, 2);
|
||||
Ok(())
|
||||
}
|
|
@ -56,6 +56,7 @@
|
|||
const weightsURL = `${MODELS_BASE_URL}/${model.url}`;
|
||||
const prompt = getValue("prompt");
|
||||
const temperature = getValue("temperature");
|
||||
const topP = getValue("top-p");
|
||||
const repeatPenalty = getValue("repeat_penalty");
|
||||
const seed = getValue("seed");
|
||||
const maxSeqLen = getValue("max-seq");
|
||||
|
@ -99,6 +100,7 @@
|
|||
tokenizerURL: "tokenizer.json",
|
||||
prompt,
|
||||
temp: temperature,
|
||||
top_p: topP,
|
||||
repeatPenalty,
|
||||
seed: BigInt(seed),
|
||||
maxSeqLen,
|
||||
|
@ -251,7 +253,7 @@
|
|||
<input
|
||||
type="range"
|
||||
id="max-seq"
|
||||
name="temperature"
|
||||
name="max-seq"
|
||||
min="1"
|
||||
max="256"
|
||||
step="1"
|
||||
|
@ -279,6 +281,22 @@
|
|||
>
|
||||
0.50</output
|
||||
>
|
||||
<label class="text-sm font-medium" for="top-p">Top-p</label>
|
||||
<input
|
||||
type="range"
|
||||
id="top-p"
|
||||
name="top-p"
|
||||
min="0"
|
||||
max="1"
|
||||
step="0.01"
|
||||
value="1.00"
|
||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
|
||||
/>
|
||||
<output
|
||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
|
||||
>
|
||||
1.00</output
|
||||
>
|
||||
|
||||
<label class="text-sm font-medium" for="repeat_penalty"
|
||||
>Repeat Penalty</label
|
||||
|
|
|
@ -46,6 +46,7 @@ pub struct App {
|
|||
status: String,
|
||||
loaded: bool,
|
||||
temperature: std::rc::Rc<std::cell::RefCell<f64>>,
|
||||
top_p: std::rc::Rc<std::cell::RefCell<f64>>,
|
||||
prompt: std::rc::Rc<std::cell::RefCell<String>>,
|
||||
generated: String,
|
||||
n_tokens: usize,
|
||||
|
@ -81,6 +82,7 @@ impl Component for App {
|
|||
status,
|
||||
n_tokens: 0,
|
||||
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
|
||||
top_p: std::rc::Rc::new(std::cell::RefCell::new(1.0)),
|
||||
prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())),
|
||||
generated: String::new(),
|
||||
current_decode: None,
|
||||
|
@ -122,10 +124,11 @@ impl Component for App {
|
|||
self.n_tokens = 0;
|
||||
self.generated.clear();
|
||||
let temp = *self.temperature.borrow();
|
||||
let top_p = *self.top_p.borrow();
|
||||
let prompt = self.prompt.borrow().clone();
|
||||
console_log!("temp: {}, prompt: {}", temp, prompt);
|
||||
console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt);
|
||||
ctx.link()
|
||||
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt)))
|
||||
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt)))
|
||||
}
|
||||
true
|
||||
}
|
||||
|
@ -177,13 +180,21 @@ impl Component for App {
|
|||
fn view(&self, ctx: &Context<Self>) -> Html {
|
||||
use yew::TargetCast;
|
||||
let temperature = self.temperature.clone();
|
||||
let oninput = ctx.link().callback(move |e: yew::InputEvent| {
|
||||
let oninput_temperature = ctx.link().callback(move |e: yew::InputEvent| {
|
||||
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
|
||||
if let Ok(temp) = f64::from_str(&input.value()) {
|
||||
*temperature.borrow_mut() = temp
|
||||
}
|
||||
Msg::Refresh
|
||||
});
|
||||
let top_p = self.top_p.clone();
|
||||
let oninput_top_p = ctx.link().callback(move |e: yew::InputEvent| {
|
||||
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
|
||||
if let Ok(top_p_input) = f64::from_str(&input.value()) {
|
||||
*top_p.borrow_mut() = top_p_input
|
||||
}
|
||||
Msg::Refresh
|
||||
});
|
||||
let prompt = self.prompt.clone();
|
||||
let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| {
|
||||
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
|
||||
|
@ -201,9 +212,13 @@ impl Component for App {
|
|||
</p>
|
||||
</div>
|
||||
{"temperature \u{00a0} "}
|
||||
<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
|
||||
<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} oninput={oninput_temperature} id="temp"/>
|
||||
{format!(" \u{00a0} {}", self.temperature.borrow())}
|
||||
<br/ >
|
||||
{"top_p \u{00a0} "}
|
||||
<input type="range" min="0." max="1.0" step="0.05" value={self.top_p.borrow().to_string()} oninput={oninput_top_p} id="top_p"/>
|
||||
{format!(" \u{00a0} {}", self.top_p.borrow())}
|
||||
<br/ >
|
||||
{"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/>
|
||||
<br/ >
|
||||
{
|
||||
|
|
|
@ -47,7 +47,7 @@ impl Model {
|
|||
tokenizer,
|
||||
model: weights,
|
||||
});
|
||||
let logits_processor = LogitsProcessor::new(299792458, None);
|
||||
let logits_processor = LogitsProcessor::new(299792458, None, None);
|
||||
match model {
|
||||
Ok(inner) => Ok(Self {
|
||||
inner,
|
||||
|
@ -69,6 +69,7 @@ impl Model {
|
|||
&mut self,
|
||||
prompt: String,
|
||||
temp: f64,
|
||||
top_p: f64,
|
||||
repeat_penalty: f32,
|
||||
seed: u64,
|
||||
) -> Result<String, JsError> {
|
||||
|
@ -80,7 +81,12 @@ impl Model {
|
|||
}
|
||||
}
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
self.logits_processor = LogitsProcessor::new(seed, temp);
|
||||
let top_p = if top_p <= 0. || top_p >= 1. {
|
||||
None
|
||||
} else {
|
||||
Some(top_p)
|
||||
};
|
||||
self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
self.repeat_penalty = repeat_penalty;
|
||||
self.tokens.clear();
|
||||
let tokens = self
|
||||
|
|
|
@ -62,12 +62,18 @@ impl Model {
|
|||
link: &WorkerLink<Worker>,
|
||||
id: HandlerId,
|
||||
temp: f64,
|
||||
top_p: f64,
|
||||
prompt: String,
|
||||
) -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
console_log!("{temp:?} {prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temp);
|
||||
let top_p = if top_p <= 0. || top_p >= 1.0 {
|
||||
None
|
||||
} else {
|
||||
Some(top_p)
|
||||
};
|
||||
console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p);
|
||||
let mut index_pos = 0;
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
|
@ -268,7 +274,7 @@ pub struct Worker {
|
|||
#[derive(Serialize, Deserialize)]
|
||||
pub enum WorkerInput {
|
||||
ModelData(ModelData),
|
||||
Run(f64, String),
|
||||
Run(f64, f64, String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
|
@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker {
|
|||
}
|
||||
Err(err) => Err(format!("model creation error {err:?}")),
|
||||
},
|
||||
WorkerInput::Run(temp, prompt) => match &mut self.model {
|
||||
WorkerInput::Run(temp, top_p, prompt) => match &mut self.model {
|
||||
None => Err("model has not been set yet".to_string()),
|
||||
Some(model) => {
|
||||
{
|
||||
|
@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker {
|
|||
}
|
||||
}
|
||||
let result = model
|
||||
.run(&self.link, id, temp, prompt)
|
||||
.run(&self.link, id, temp, top_p, prompt)
|
||||
.map_err(|e| e.to_string());
|
||||
Ok(WorkerOutput::GenerationDone(result))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue