Metal: support unary abs (#1503)
* Metal: support unary abs * cargo fmt
This commit is contained in:
parent
87d7f81b43
commit
0a245e6fa4
|
@ -665,6 +665,7 @@ impl BackendStorage for MetalStorage {
|
||||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||||
|
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||||
|
@ -680,6 +681,7 @@ impl BackendStorage for MetalStorage {
|
||||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||||
|
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||||
("uround", DType::F16) => contiguous::round::HALF,
|
("uround", DType::F16) => contiguous::round::HALF,
|
||||||
|
@ -712,6 +714,7 @@ impl BackendStorage for MetalStorage {
|
||||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||||
|
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||||
("uround", DType::F32) => strided::round::FLOAT,
|
("uround", DType::F32) => strided::round::FLOAT,
|
||||||
|
@ -725,6 +728,7 @@ impl BackendStorage for MetalStorage {
|
||||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||||
("uerf", DType::F16) => strided::erf::HALF,
|
("uerf", DType::F16) => strided::erf::HALF,
|
||||||
|
("uabs", DType::F16) => strided::abs::HALF,
|
||||||
("uceil", DType::F16) => strided::ceil::HALF,
|
("uceil", DType::F16) => strided::ceil::HALF,
|
||||||
("ufloor", DType::F16) => strided::floor::HALF,
|
("ufloor", DType::F16) => strided::floor::HALF,
|
||||||
("uround", DType::F16) => strided::round::HALF,
|
("uround", DType::F16) => strided::round::HALF,
|
||||||
|
|
|
@ -173,7 +173,10 @@ macro_rules! ops{
|
||||||
}
|
}
|
||||||
|
|
||||||
pub mod unary {
|
pub mod unary {
|
||||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh, recip);
|
ops!(
|
||||||
|
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
|
||||||
|
recip
|
||||||
|
);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
|
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
|
||||||
|
|
|
@ -102,6 +102,7 @@ UNARY_OP(neg)
|
||||||
UNARY_OP(exp)
|
UNARY_OP(exp)
|
||||||
UNARY_OP(log)
|
UNARY_OP(log)
|
||||||
UNARY_OP(gelu)
|
UNARY_OP(gelu)
|
||||||
|
UNARY_OP(abs)
|
||||||
UNARY_OP(ceil)
|
UNARY_OP(ceil)
|
||||||
UNARY_OP(floor)
|
UNARY_OP(floor)
|
||||||
UNARY_OP(round)
|
UNARY_OP(round)
|
||||||
|
|
Loading…
Reference in New Issue