Make the metal sdpa tests deterministic. (#2750)
This commit is contained in:
parent
da02b59516
commit
ab9019425a
|
@ -26,6 +26,7 @@ candle-metal-kernels = { workspace = true, optional = true }
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
|
rand_distr = { workspace = true }
|
||||||
criterion = { workspace = true }
|
criterion = { workspace = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
|
|
@ -1,86 +1,84 @@
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
mod metal_sdpa_tests {
|
mod metal_sdpa_tests {
|
||||||
#[test]
|
use candle::{DType, Device, Result, Shape, Tensor};
|
||||||
fn sdpa_full() -> candle::Result<()> {
|
use rand::SeedableRng;
|
||||||
use candle::{DType, Device, Tensor};
|
use rand_distr::Distribution;
|
||||||
|
use std::ops::{Div, Mul};
|
||||||
|
|
||||||
|
fn randn<S: Into<Shape>>(
|
||||||
|
rng: &mut rand::rngs::StdRng,
|
||||||
|
shape: S,
|
||||||
|
dev: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let shape = shape.into();
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
|
||||||
|
let vs: Vec<f32> = (0..elem_count).map(|_| normal.sample(rng)).collect();
|
||||||
|
Tensor::from_vec(vs, &shape, dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sdpa_full() -> Result<()> {
|
||||||
// Force seqlen = 100
|
// Force seqlen = 100
|
||||||
const BS: usize = 4;
|
const BS: usize = 4;
|
||||||
const R: usize = 4;
|
const R: usize = 4;
|
||||||
const L: usize = 4;
|
const L: usize = 4;
|
||||||
const DK: usize = 64;
|
const DK: usize = 64;
|
||||||
const H: usize = 3;
|
const H: usize = 3;
|
||||||
|
|
||||||
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
||||||
|
|
||||||
let device = Device::new_metal(0)?;
|
let device = Device::new_metal(0)?;
|
||||||
|
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
|
||||||
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
|
let q = randn(&mut rng, (BS, H, R, DK), &device)?;
|
||||||
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
let k = randn(&mut rng, (BS, H, L, DK), &device)?;
|
||||||
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
let v = randn(&mut rng, (BS, H, L, DK), &device)?;
|
||||||
|
|
||||||
let ground_truth = {
|
let ground_truth = {
|
||||||
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
||||||
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
|
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
|
||||||
.to_dtype(q.dtype())?;
|
.to_dtype(q.dtype())?;
|
||||||
att.matmul(&v.clone())?
|
att.matmul(&v.clone())?
|
||||||
};
|
};
|
||||||
|
|
||||||
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
|
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
|
||||||
|
|
||||||
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
||||||
|
|
||||||
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
||||||
.sum_all()?
|
.sum_all()?
|
||||||
.to_scalar()?;
|
.to_scalar()?;
|
||||||
|
assert!(error <= 0.0004, "{}", error);
|
||||||
assert!(error <= 0.0005, "{}", error);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sdpa_vector() -> candle::Result<()> {
|
fn sdpa_vector() -> Result<()> {
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
|
|
||||||
// Allow vectorized, seqlen = 1
|
// Allow vectorized, seqlen = 1
|
||||||
const BS: usize = 4;
|
const BS: usize = 4;
|
||||||
const R: usize = 1;
|
const R: usize = 1;
|
||||||
const L: usize = 1;
|
const L: usize = 1;
|
||||||
const DK: usize = 64;
|
const DK: usize = 64;
|
||||||
const H: usize = 3;
|
const H: usize = 3;
|
||||||
|
|
||||||
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
||||||
|
|
||||||
let device = Device::new_metal(0)?;
|
let device = Device::new_metal(0)?;
|
||||||
|
let mut rng = rand::rngs::StdRng::seed_from_u64(4242);
|
||||||
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
|
let q = randn(&mut rng, (BS, H, R, DK), &device)?;
|
||||||
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
let k = randn(&mut rng, (BS, H, L, DK), &device)?;
|
||||||
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
let v = randn(&mut rng, (BS, H, L, DK), &device)?;
|
||||||
|
|
||||||
let ground_truth = {
|
let ground_truth = {
|
||||||
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
||||||
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
|
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
|
||||||
.to_dtype(q.dtype())?;
|
.to_dtype(q.dtype())?;
|
||||||
att.matmul(&v.clone())?
|
att.matmul(&v.clone())?
|
||||||
};
|
};
|
||||||
|
|
||||||
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
|
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
|
||||||
|
|
||||||
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
||||||
|
|
||||||
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
||||||
.sum_all()?
|
.sum_all()?
|
||||||
.to_scalar()?;
|
.to_scalar()?;
|
||||||
|
assert!(error <= 0.000, "{}", error);
|
||||||
assert!(error <= 0.0001, "{}", error);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sdpa_full_softcapping() -> candle::Result<()> {
|
fn sdpa_full_softcapping() -> Result<()> {
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use std::ops::{Div, Mul};
|
|
||||||
|
|
||||||
// Allow vectorized, seqlen = 1
|
// Allow vectorized, seqlen = 1
|
||||||
const BS: usize = 4;
|
const BS: usize = 4;
|
||||||
const R: usize = 4;
|
const R: usize = 4;
|
||||||
|
@ -88,14 +86,13 @@ mod metal_sdpa_tests {
|
||||||
const DK: usize = 64;
|
const DK: usize = 64;
|
||||||
const H: usize = 3;
|
const H: usize = 3;
|
||||||
const SOFTCAP: f64 = 50.;
|
const SOFTCAP: f64 = 50.;
|
||||||
|
|
||||||
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
||||||
|
|
||||||
let device = Device::new_metal(0)?;
|
let device = Device::new_metal(0)?;
|
||||||
|
let mut rng = rand::rngs::StdRng::seed_from_u64(424242);
|
||||||
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
|
let q = randn(&mut rng, (BS, H, R, DK), &device)?;
|
||||||
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
let k = randn(&mut rng, (BS, H, L, DK), &device)?;
|
||||||
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
let v = randn(&mut rng, (BS, H, L, DK), &device)?;
|
||||||
|
|
||||||
let ground_truth = {
|
let ground_truth = {
|
||||||
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
||||||
let att = candle_nn::ops::softmax_last_dim(
|
let att = candle_nn::ops::softmax_last_dim(
|
||||||
|
@ -107,25 +104,17 @@ mod metal_sdpa_tests {
|
||||||
.to_dtype(q.dtype())?;
|
.to_dtype(q.dtype())?;
|
||||||
att.matmul(&v.clone())?
|
att.matmul(&v.clone())?
|
||||||
};
|
};
|
||||||
|
|
||||||
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
|
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
|
||||||
|
|
||||||
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
||||||
|
|
||||||
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
||||||
.sum_all()?
|
.sum_all()?
|
||||||
.to_scalar()?;
|
.to_scalar()?;
|
||||||
|
|
||||||
assert!(error <= 0.0005, "{}", error);
|
assert!(error <= 0.0005, "{}", error);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sdpa_vector_softcapping() -> candle::Result<()> {
|
fn sdpa_vector_softcapping() -> Result<()> {
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use std::ops::{Div, Mul};
|
|
||||||
|
|
||||||
// Allow vectorized, seqlen = 1
|
// Allow vectorized, seqlen = 1
|
||||||
const BS: usize = 4;
|
const BS: usize = 4;
|
||||||
const R: usize = 1;
|
const R: usize = 1;
|
||||||
|
@ -133,14 +122,13 @@ mod metal_sdpa_tests {
|
||||||
const DK: usize = 64;
|
const DK: usize = 64;
|
||||||
const H: usize = 3;
|
const H: usize = 3;
|
||||||
const SOFTCAP: f64 = 50.;
|
const SOFTCAP: f64 = 50.;
|
||||||
|
|
||||||
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
||||||
|
|
||||||
let device = Device::new_metal(0)?;
|
let device = Device::new_metal(0)?;
|
||||||
|
let mut rng = rand::rngs::StdRng::seed_from_u64(42424242);
|
||||||
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
|
let q = randn(&mut rng, (BS, H, R, DK), &device)?;
|
||||||
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
let k = randn(&mut rng, (BS, H, L, DK), &device)?;
|
||||||
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
let v = randn(&mut rng, (BS, H, L, DK), &device)?;
|
||||||
|
|
||||||
let ground_truth = {
|
let ground_truth = {
|
||||||
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
||||||
let att = candle_nn::ops::softmax_last_dim(
|
let att = candle_nn::ops::softmax_last_dim(
|
||||||
|
@ -152,55 +140,42 @@ mod metal_sdpa_tests {
|
||||||
.to_dtype(q.dtype())?;
|
.to_dtype(q.dtype())?;
|
||||||
att.matmul(&v.clone())?
|
att.matmul(&v.clone())?
|
||||||
};
|
};
|
||||||
|
|
||||||
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
|
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
|
||||||
|
|
||||||
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
||||||
|
|
||||||
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
||||||
.sum_all()?
|
.sum_all()?
|
||||||
.to_scalar()?;
|
.to_scalar()?;
|
||||||
|
|
||||||
assert!(error <= 0.0001, "{}", error);
|
assert!(error <= 0.0001, "{}", error);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sdpa_vector_cross() -> candle::Result<()> {
|
fn sdpa_vector_cross() -> Result<()> {
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
|
|
||||||
// Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1
|
// Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1
|
||||||
const BS: usize = 4;
|
const BS: usize = 4;
|
||||||
const R: usize = 1;
|
const R: usize = 1;
|
||||||
const L: usize = 24;
|
const L: usize = 24;
|
||||||
const DK: usize = 64;
|
const DK: usize = 64;
|
||||||
const H: usize = 3;
|
const H: usize = 3;
|
||||||
|
|
||||||
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
||||||
|
|
||||||
let device = Device::new_metal(0)?;
|
let device = Device::new_metal(0)?;
|
||||||
|
let mut rng = rand::rngs::StdRng::seed_from_u64(4242424242);
|
||||||
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
|
let q = randn(&mut rng, (BS, H, R, DK), &device)?;
|
||||||
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
let k = randn(&mut rng, (BS, H, L, DK), &device)?;
|
||||||
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
let v = randn(&mut rng, (BS, H, L, DK), &device)?;
|
||||||
|
|
||||||
let ground_truth = {
|
let ground_truth = {
|
||||||
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
||||||
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
|
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
|
||||||
.to_dtype(q.dtype())?;
|
.to_dtype(q.dtype())?;
|
||||||
att.matmul(&v.clone())?
|
att.matmul(&v.clone())?
|
||||||
};
|
};
|
||||||
|
|
||||||
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
|
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
|
||||||
|
|
||||||
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
||||||
|
|
||||||
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
||||||
.sum_all()?
|
.sum_all()?
|
||||||
.to_scalar()?;
|
.to_scalar()?;
|
||||||
|
|
||||||
assert!(error <= 0.0013, "{}", error);
|
assert!(error <= 0.0013, "{}", error);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue