Use rayon directly rather than constraining the number of threads. (#749)

This commit is contained in:
Laurent Mazare 2023-09-05 21:26:15 +02:00 committed by GitHub
parent 6a40decc76
commit a4f40f3dc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 8 deletions

View File

@ -2,6 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
use rayon::prelude::*;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error.
@ -1052,10 +1053,8 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
let num_threads = crate::utils::get_num_threads();
for offset in 0..p.k_size {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
@ -1123,11 +1122,9 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
let num_threads = crate::utils::get_num_threads();
for offset_h in 0..p.k_h {
for offset_w in 0..p.k_w {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let dst_idx = dst_c_idx * out_w * out_h;
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
@ -1216,11 +1213,10 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
}
}
}
let num_threads = crate::utils::get_num_threads();
for k_y in 0..p.k_h {
for k_x in 0..p.k_w {
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]