Adding matmul
This commit is contained in:
parent
ce977b489e
commit
86e4cbbc3d
|
@ -0,0 +1,13 @@
|
|||
repos:
|
||||
- repo: https://github.com/Narsil/pre-commit-rust
|
||||
rev: 2eed6366172ef2a5186e8785ec0e67243d7d73d0
|
||||
hooks:
|
||||
- id: fmt
|
||||
name: "Rust (fmt)"
|
||||
- id: clippy
|
||||
name: "Rust (clippy)"
|
||||
args:
|
||||
[
|
||||
"--",
|
||||
"-Dwarnings",
|
||||
]
|
|
@ -178,7 +178,7 @@ impl Tensor {
|
|||
device: Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let storage = device.storage(a);
|
||||
let storage = device.storage(a)?;
|
||||
let stride = shape.stride_contiguous();
|
||||
let is_variable = false;
|
||||
let tensor_ = Tensor_ {
|
||||
|
@ -514,7 +514,7 @@ impl Tensor {
|
|||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Matmul(lhs, rhs) => {
|
||||
Op::Matmul(_lhs, _rhs) => {
|
||||
// let (m, k) = lhs.shape;
|
||||
// let n = rhs.shape.1;
|
||||
// let strides = (m, n).strides();
|
||||
|
@ -539,12 +539,12 @@ impl Tensor {
|
|||
// rhs.strides,
|
||||
// );
|
||||
|
||||
let lhs_grad = grad.matmul(rhs)?;
|
||||
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
// let lhs_grad = grad.matmul(rhs)?;
|
||||
// let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
||||
// *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
// let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
||||
// let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||
// *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Affine { arg, mul, .. } => {
|
||||
let arg_grad = grad.affine(*mul, 0.)?;
|
||||
|
|
Loading…
Reference in New Issue