Simd support (#448)

* Import the simd intrinsics in candle-core.

* simd version of reduce-sum.

* Bugfix.

* Fix some clippy lints.
This commit is contained in:
Laurent Mazare 2023-08-15 09:50:38 +01:00 committed by GitHub
parent 90374097dc
commit 495e0b7580
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 487 additions and 14 deletions

View File

@ -31,9 +31,8 @@ clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.14", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.6", package = "candle-gemm" }
ggblas = "0.1.2"
hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }

View File

@ -15,7 +15,6 @@ byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.1.0", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
ggblas = { workspace = true }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
libc = { workspace = true, optional = true }

148
candle-core/src/cpu/avx.rs Normal file
View File

@ -0,0 +1,148 @@
use super::{Cpu, CpuF16};
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use half::f16;
pub struct CurrentCpu {}
const STEP: usize = 32;
const EPR: usize = 8;
const ARR: usize = STEP / EPR;
impl Cpu<ARR> for CurrentCpu {
type Unit = __m256;
type Array = [__m256; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
_mm256_setzero_ps()
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn from_f32(v: f32) -> Self::Unit {
_mm256_set1_ps(v)
}
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
_mm256_loadu_ps(mem_addr)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
_mm256_add_ps(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
_mm256_add_ps(_mm256_mul_ps(b, c), a)
}
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
_mm256_storeu_ps(mem_addr, a);
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..ARR / 2 {
x[2 * i] = _mm256_add_ps(x[2 * i], x[2 * i + 1]);
}
for i in 0..ARR / 4 {
x[4 * i] = _mm256_add_ps(x[4 * i], x[4 * i + 2]);
}
#[allow(clippy::reversed_empty_ranges)]
for i in 0..ARR / 8 {
x[8 * i] = _mm256_add_ps(x[8 * i], x[8 * i + 4]);
}
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
let t1 = _mm_hadd_ps(t0, t0);
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}
pub struct CurrentCpuF16 {}
impl CpuF16<ARR> for CurrentCpuF16 {
type Unit = __m256;
type Array = [__m256; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
_mm256_setzero_ps()
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn from_f32(v: f32) -> Self::Unit {
_mm256_set1_ps(v)
}
#[cfg(target_feature = "f16c")]
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
}
#[cfg(not(target_feature = "f16c"))]
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = (*mem_addr.add(i)).to_f32();
}
_mm_loadu_ps(tmp.as_ptr())
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
_mm256_add_ps(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
_mm256_add_ps(_mm256_mul_ps(b, c), a)
}
#[cfg(target_feature = "f16c")]
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
}
#[cfg(not(target_feature = "f16c"))]
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
let mut tmp = [0.0f32; 8];
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
for i in 0..8 {
*mem_addr.add(i) = f16::from_f32(tmp[i]);
}
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
let mut offset = ARR >> 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
let t1 = _mm_hadd_ps(t0, t0);
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}

View File

@ -31,19 +31,26 @@ pub trait VecDot: num_traits::NumAssign + Copy {
impl VecDot for f32 {
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
ggblas::ggml::vec_dot_f32(lhs, rhs, res, len)
super::vec_dot_f32(lhs, rhs, res, len)
}
// TODO: enable the following once the updated ggblas is available.
// #[inline(always)]
// unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
// ggblas::ggml::vec_reduce_sum(xs, res, len)
// }
#[inline(always)]
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
super::vec_sum(xs, res, len)
}
}
impl VecDot for half::f16 {
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
*res = half::f16::from_f32(res_f32);
}
}
impl VecDot for f64 {}
impl VecDot for half::bf16 {}
impl VecDot for half::f16 {}
impl VecDot for u8 {}
impl VecDot for u32 {}

179
candle-core/src/cpu/mod.rs Normal file
View File

@ -0,0 +1,179 @@
pub mod kernels;
trait Cpu<const ARR: usize> {
type Unit;
type Array;
const STEP: usize;
const EPR: usize;
fn n() -> usize;
unsafe fn zero() -> Self::Unit;
unsafe fn zero_array() -> Self::Array;
unsafe fn load(mem_addr: *const f32) -> Self::Unit;
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
}
trait CpuF16<const ARR: usize> {
type Unit;
type Array;
const STEP: usize;
const EPR: usize;
fn n() -> usize;
unsafe fn zero() -> Self::Unit;
unsafe fn zero_array() -> Self::Array;
unsafe fn load(mem_addr: *const f16) -> Self::Unit;
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
}
use half::f16;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub mod avx;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub use avx::{CurrentCpu, CurrentCpuF16};
#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
pub mod simd128;
#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
pub use simd128::CurrentCpu;
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(target_feature = "neon")]
pub mod neon;
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(target_feature = "neon")]
pub use neon::CurrentCpu;
#[cfg(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
let np = k & !(CurrentCpu::STEP - 1);
let mut sum = CurrentCpu::zero_array();
let mut ax = CurrentCpu::zero_array();
let mut ay = CurrentCpu::zero_array();
for i in (0..np).step_by(CurrentCpu::STEP) {
for j in 0..CurrentCpu::n() {
ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR));
ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR));
sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]);
}
}
CurrentCpu::vec_reduce(sum, c);
// leftovers
for i in np..k {
*c += *a_row.add(i) * (*b_row.add(i));
}
}
#[cfg(not(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
)))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
// leftovers
for i in 0..k {
*c += *a_row.add(i) * (*b_row.add(i));
}
}
#[cfg(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
))]
#[inline(always)]
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
let np = k & !(CurrentCpu::STEP - 1);
let mut sum = CurrentCpu::zero_array();
let mut x = CurrentCpu::zero_array();
for i in (0..np).step_by(CurrentCpu::STEP) {
for j in 0..CurrentCpu::n() {
x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR));
sum[j] = CurrentCpu::vec_add(sum[j], x[j]);
}
}
CurrentCpu::vec_reduce(sum, b);
// leftovers
for i in np..k {
*b += *row.add(i)
}
}
#[cfg(not(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
)))]
#[inline(always)]
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
*b = 0f32;
for i in 0..k {
*b += *row.add(i)
}
}
#[cfg(target_feature = "avx")]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
let mut sumf = 0.0f32;
let np = k & !(CurrentCpuF16::STEP - 1);
let mut sum = CurrentCpuF16::zero_array();
let mut ax = CurrentCpuF16::zero_array();
let mut ay = CurrentCpuF16::zero_array();
for i in (0..np).step_by(CurrentCpuF16::STEP) {
for j in 0..CurrentCpuF16::n() {
ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR));
ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR));
sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]);
}
}
CurrentCpuF16::vec_reduce(sum, &mut sumf);
// leftovers
for i in np..k {
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sumf;
}
#[cfg(not(target_feature = "avx"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
// leftovers
let mut sum = 0.0;
for i in 0..k {
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sum;
}

View File

@ -0,0 +1,77 @@
use super::Cpu;
#[cfg(target_arch = "arm")]
use core::arch::arm::*;
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
pub struct CurrentCpu {}
const STEP: usize = 16;
const EPR: usize = 4;
const ARR: usize = STEP / EPR;
impl CurrentCpu {
#[cfg(target_arch = "aarch64")]
unsafe fn reduce_one(x: float32x4_t) -> f32 {
vaddvq_f32(x)
}
#[cfg(target_arch = "arm")]
unsafe fn reduce_one(x: float32x4_t) -> f32 {
vgetq_lane_f32(x, 0) + vgetq_lane_f32(x, 1) + vgetq_lane_f32(x, 2) + vgetq_lane_f32(x, 3)
}
}
impl Cpu<ARR> for CurrentCpu {
type Unit = float32x4_t;
type Array = [float32x4_t; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
vdupq_n_f32(0.0)
}
unsafe fn from_f32(x: f32) -> Self::Unit {
vdupq_n_f32(x)
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
vld1q_f32(mem_addr)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
vaddq_f32(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
vfmaq_f32(a, b, c)
}
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
vst1q_f32(mem_addr, a);
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..ARR / 2 {
x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]);
}
for i in 0..ARR / 4 {
x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]);
}
for i in 0..ARR / 8 {
x[8 * i] = vaddq_f32(x[8 * i], x[8 * i + 4]);
}
*y = Self::reduce_one(x[0]);
}
}

View File

@ -0,0 +1,64 @@
use super::Cpu;
use core::arch::wasm32::*;
pub struct CurrentCpu {}
const STEP: usize = 16;
const EPR: usize = 4;
const ARR: usize = STEP / EPR;
impl Cpu<ARR> for CurrentCpu {
type Unit = v128;
type Array = [v128; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
f32x4_splat(0.0)
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn from_f32(v: f32) -> Self::Unit {
f32x4_splat(v)
}
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
v128_load(mem_addr as *mut v128)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
f32x4_add(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
f32x4_add(f32x4_mul(b, c), a)
}
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
v128_store(mem_addr as *mut v128, a);
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..ARR / 2 {
x[2 * i] = f32x4_add(x[2 * i], x[2 * i + 1]);
}
for i in 0..ARR / 4 {
x[4 * i] = f32x4_add(x[4 * i], x[4 * i + 2]);
}
for i in 0..ARR / 8 {
x[8 * i] = f32x4_add(x[8 * i], x[8 * i + 4]);
}
*y = f32x4_extract_lane::<0>(x[0])
+ f32x4_extract_lane::<1>(x[0])
+ f32x4_extract_lane::<2>(x[0])
+ f32x4_extract_lane::<3>(x[0]);
}
}

View File

@ -1051,7 +1051,7 @@ impl<'a> Map2 for Conv1D<'a> {
let num_threads = crate::utils::get_num_threads();
for offset in 0..p.k_size {
crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
@ -1123,7 +1123,7 @@ impl<'a> Map2 for Conv2D<'a> {
for offset_h in 0..p.k_h {
for offset_w in 0..p.k_w {
crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * out_w * out_h;
let k_cont = (0..p.c_in)
.map(|c_in_idx| {

View File

@ -62,7 +62,7 @@ pub trait WithDType:
+ 'static
+ Send
+ Sync
+ crate::cpu_kernels::VecDot
+ crate::cpu::kernels::VecDot
{
const DTYPE: DType;

View File

@ -39,8 +39,8 @@ pub mod backend;
pub mod backprop;
mod conv;
mod convert;
pub mod cpu;
pub mod cpu_backend;
pub mod cpu_kernels;
#[cfg(feature = "cuda")]
pub mod cuda_backend;
#[cfg(feature = "cudnn")]