Gumbel-Softmax sampling. (#2894)
* Gumbel-Softmax sampling. * Add a sampling test. * Share the gumbel-softmax bits.
This commit is contained in:
parent
a52b76ae82
commit
2653002f29
|
@ -46,7 +46,7 @@ impl TextGeneration {
|
|||
Sampling::ArgMax
|
||||
} else {
|
||||
match (top_k, top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(None, None) => Sampling::GumbelSoftmax { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
|
|
|
@ -31,6 +31,7 @@ pub mod ops;
|
|||
pub mod optim;
|
||||
pub mod rnn;
|
||||
pub mod rotary_emb;
|
||||
pub mod sampling;
|
||||
pub mod sequential;
|
||||
pub mod var_builder;
|
||||
pub mod var_map;
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
use candle::{Result, Tensor};
|
||||
|
||||
/// Sample according to the Gumbel-Softmax distribution.
|
||||
pub fn gumbel_softmax<D: candle::shape::Dim>(
|
||||
logits: &Tensor,
|
||||
temperature: f64,
|
||||
dim: D,
|
||||
) -> Result<Tensor> {
|
||||
if temperature <= 0.0 {
|
||||
logits.argmax(dim)
|
||||
} else if temperature == 1.0 {
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits - minus_g)?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
} else {
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
}
|
||||
}
|
|
@ -13,6 +13,8 @@ pub enum Sampling {
|
|||
TopK { k: usize, temperature: f64 },
|
||||
TopP { p: f64, temperature: f64 },
|
||||
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
||||
// Note that the rng is not used for the Gumbel-Softmax sampling.
|
||||
GumbelSoftmax { temperature: f64 },
|
||||
}
|
||||
|
||||
pub struct LogitsProcessor {
|
||||
|
@ -49,6 +51,11 @@ impl LogitsProcessor {
|
|||
Ok(next_token)
|
||||
}
|
||||
|
||||
fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result<u32> {
|
||||
let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?;
|
||||
sampled.to_vec0::<u32>()
|
||||
}
|
||||
|
||||
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
let next_token = distr.sample(&mut self.rng) as u32;
|
||||
|
@ -127,6 +134,9 @@ impl LogitsProcessor {
|
|||
|
||||
let next_token = match &self.sampling {
|
||||
Sampling::ArgMax => self.sample_argmax(logits)?,
|
||||
Sampling::GumbelSoftmax { temperature } => {
|
||||
self.sample_gumbel_softmax(&logits, *temperature)?
|
||||
}
|
||||
Sampling::All { temperature } => {
|
||||
let prs = prs(*temperature)?;
|
||||
self.sample_multinomial(&prs)?
|
||||
|
|
|
@ -54,3 +54,25 @@ fn sample_with_top_k() -> Result<()> {
|
|||
assert_eq!(token, 2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_gumbel() -> Result<()> {
|
||||
let mut logits_process = LogitsProcessor::from_sampling(
|
||||
42,
|
||||
candle_transformers::generation::Sampling::GumbelSoftmax { temperature: 1.0 },
|
||||
);
|
||||
let logits = Tensor::new(&[-1.0, 0.0, 0.2, 1.0], &Device::Cpu)?;
|
||||
let sm = candle_nn::ops::softmax(&logits, 0)?.to_vec1::<f64>()?;
|
||||
let mut counts = vec![0f64; 4];
|
||||
let samples = 100000;
|
||||
for _ in 0..samples {
|
||||
let token = logits_process.sample(&logits)?;
|
||||
counts[token as usize] += 1f64 / samples as f64;
|
||||
}
|
||||
for i in 0..4 {
|
||||
if (counts[i] - sm[i]).abs() > 0.05 {
|
||||
panic!("pr mismatch {counts:?} {sm:?}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue