Adding matmul

This commit is contained in:
Nicolas Patry 2023-06-21 17:01:32 +02:00
parent ce977b489e
commit 86e4cbbc3d
2 changed files with 21 additions and 8 deletions

13
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,13 @@
repos:
- repo: https://github.com/Narsil/pre-commit-rust
rev: 2eed6366172ef2a5186e8785ec0e67243d7d73d0
hooks:
- id: fmt
name: "Rust (fmt)"
- id: clippy
name: "Rust (clippy)"
args:
[
"--",
"-Dwarnings",
]

View File

@ -178,7 +178,7 @@ impl Tensor {
device: Device,
) -> Result<Self> {
let shape = shape.into();
let storage = device.storage(a);
let storage = device.storage(a)?;
let stride = shape.stride_contiguous();
let is_variable = false;
let tensor_ = Tensor_ {
@ -514,7 +514,7 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Matmul(lhs, rhs) => {
Op::Matmul(_lhs, _rhs) => {
// let (m, k) = lhs.shape;
// let n = rhs.shape.1;
// let strides = (m, n).strides();
@ -539,12 +539,12 @@ impl Tensor {
// rhs.strides,
// );
let lhs_grad = grad.matmul(rhs)?;
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
// let lhs_grad = grad.matmul(rhs)?;
// let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
// *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
// let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
// let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
// *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;