Do not implement Module for BatchNorm. (#1513)
This commit is contained in:
parent
1fb2dd905c
commit
b0fe5e4453
|
@ -147,7 +147,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
||||||
let func = candle_nn::func(move |xs| {
|
let func = candle_nn::func(move |xs| {
|
||||||
let xs = conv.forward(xs)?;
|
let xs = conv.forward(xs)?;
|
||||||
let xs = match &bn {
|
let xs = match &bn {
|
||||||
Some(bn) => bn.forward(&xs)?,
|
Some(bn) => xs.apply_t(bn, false)?,
|
||||||
None => xs,
|
None => xs,
|
||||||
};
|
};
|
||||||
let xs = if leaky {
|
let xs = if leaky {
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
//! running stats.
|
//! running stats.
|
||||||
//!
|
//!
|
||||||
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
||||||
use candle::{DType, Module, Result, Tensor, Var};
|
use candle::{DType, Result, Tensor, Var};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
pub struct BatchNormConfig {
|
pub struct BatchNormConfig {
|
||||||
|
@ -192,7 +192,7 @@ impl BatchNorm {
|
||||||
self.momentum
|
self.momentum
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let num_features = self.running_mean.as_tensor().dim(0)?;
|
let num_features = self.running_mean.as_tensor().dim(0)?;
|
||||||
let x_dtype = x.dtype();
|
let x_dtype = x.dtype();
|
||||||
let internal_dtype = match x_dtype {
|
let internal_dtype = match x_dtype {
|
||||||
|
@ -252,17 +252,7 @@ impl BatchNorm {
|
||||||
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
|
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
|
fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
if train {
|
|
||||||
self.forward_learning(x)
|
|
||||||
} else {
|
|
||||||
self.forward(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for BatchNorm {
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let target_shape: Vec<usize> = x
|
let target_shape: Vec<usize> = x
|
||||||
.dims()
|
.dims()
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -288,6 +278,16 @@ impl Module for BatchNorm {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl crate::ModuleT for BatchNorm {
|
||||||
|
fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
|
||||||
|
if train {
|
||||||
|
self.forward_train(x)
|
||||||
|
} else {
|
||||||
|
self.forward_eval(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn batch_norm<C: Into<BatchNormConfig>>(
|
pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||||
num_features: usize,
|
num_features: usize,
|
||||||
config: C,
|
config: C,
|
||||||
|
|
|
@ -39,7 +39,7 @@ fn batch_norm() -> Result<()> {
|
||||||
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
|
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
|
||||||
];
|
];
|
||||||
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
|
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
|
||||||
let output = bn.forward_learning(&input)?;
|
let output = bn.forward_train(&input)?;
|
||||||
assert_eq!(output.dims(), &[2, 5, 3, 4]);
|
assert_eq!(output.dims(), &[2, 5, 3, 4]);
|
||||||
let output = output.flatten_all()?;
|
let output = output.flatten_all()?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -67,7 +67,7 @@ fn batch_norm() -> Result<()> {
|
||||||
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
||||||
1e-8,
|
1e-8,
|
||||||
)?;
|
)?;
|
||||||
let output2 = bn2.forward_learning(&input)?;
|
let output2 = bn2.forward_train(&input)?;
|
||||||
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
|
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
|
||||||
let output2 = output2.flatten_all()?;
|
let output2 = output2.flatten_all()?;
|
||||||
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
||||||
|
|
|
@ -40,8 +40,8 @@ fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result<impl Module>
|
||||||
let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
|
let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
|
||||||
let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
|
let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
|
||||||
Ok(candle_nn::func(move |xs| {
|
Ok(candle_nn::func(move |xs| {
|
||||||
let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
|
let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
|
||||||
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2)
|
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ fn convmixer(
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
|
let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
|
||||||
Ok(candle_nn::func(move |xs| {
|
Ok(candle_nn::func(move |xs| {
|
||||||
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
|
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
|
||||||
for block in blocks.iter() {
|
for block in blocks.iter() {
|
||||||
xs = xs.apply(block)?
|
xs = xs.apply(block)?
|
||||||
}
|
}
|
||||||
|
|
|
@ -169,8 +169,7 @@ impl ConvNormActivation {
|
||||||
|
|
||||||
impl Module for ConvNormActivation {
|
impl Module for ConvNormActivation {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let xs = self.conv2d.forward(xs)?;
|
let xs = self.conv2d.forward(xs)?.apply_t(&self.bn2d, false)?;
|
||||||
let xs = self.bn2d.forward(&xs)?;
|
|
||||||
if self.activation {
|
if self.activation {
|
||||||
swish(&xs)
|
swish(&xs)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -25,7 +25,7 @@ fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resul
|
||||||
if stride != 1 || c_in != c_out {
|
if stride != 1 || c_in != c_out {
|
||||||
let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
|
let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
|
||||||
let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
|
let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
|
||||||
Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn)))
|
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
|
||||||
} else {
|
} else {
|
||||||
Ok(Func::new(|xs| Ok(xs.clone())))
|
Ok(Func::new(|xs| Ok(xs.clone())))
|
||||||
}
|
}
|
||||||
|
@ -40,10 +40,10 @@ fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resu
|
||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let ys = xs
|
let ys = xs
|
||||||
.apply(&conv1)?
|
.apply(&conv1)?
|
||||||
.apply(&bn1)?
|
.apply_t(&bn1, false)?
|
||||||
.relu()?
|
.relu()?
|
||||||
.apply(&conv2)?
|
.apply(&conv2)?
|
||||||
.apply(&bn2)?;
|
.apply_t(&bn2, false)?;
|
||||||
(xs.apply(&downsample)? + ys)?.relu()
|
(xs.apply(&downsample)? + ys)?.relu()
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
@ -94,7 +94,7 @@ fn resnet(
|
||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let xs = xs
|
let xs = xs
|
||||||
.apply(&conv1)?
|
.apply(&conv1)?
|
||||||
.apply(&bn1)?
|
.apply_t(&bn1, false)?
|
||||||
.relu()?
|
.relu()?
|
||||||
.pad_with_same(D::Minus1, 1, 1)?
|
.pad_with_same(D::Minus1, 1, 1)?
|
||||||
.pad_with_same(D::Minus2, 1, 1)?
|
.pad_with_same(D::Minus2, 1, 1)?
|
||||||
|
@ -149,13 +149,13 @@ fn bottleneck_block(
|
||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let ys = xs
|
let ys = xs
|
||||||
.apply(&conv1)?
|
.apply(&conv1)?
|
||||||
.apply(&bn1)?
|
.apply_t(&bn1, false)?
|
||||||
.relu()?
|
.relu()?
|
||||||
.apply(&conv2)?
|
.apply(&conv2)?
|
||||||
.apply(&bn2)?
|
.apply_t(&bn2, false)?
|
||||||
.relu()?
|
.relu()?
|
||||||
.apply(&conv3)?
|
.apply(&conv3)?
|
||||||
.apply(&bn3)?;
|
.apply_t(&bn3, false)?;
|
||||||
(xs.apply(&downsample)? + ys)?.relu()
|
(xs.apply(&downsample)? + ys)?.relu()
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
@ -206,7 +206,7 @@ fn bottleneck_resnet(
|
||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let xs = xs
|
let xs = xs
|
||||||
.apply(&conv1)?
|
.apply(&conv1)?
|
||||||
.apply(&bn1)?
|
.apply_t(&bn1, false)?
|
||||||
.relu()?
|
.relu()?
|
||||||
.pad_with_same(D::Minus1, 1, 1)?
|
.pad_with_same(D::Minus1, 1, 1)?
|
||||||
.pad_with_same(D::Minus2, 1, 1)?
|
.pad_with_same(D::Minus2, 1, 1)?
|
||||||
|
|
|
@ -28,7 +28,7 @@ impl Conv2dBN {
|
||||||
impl Module for Conv2dBN {
|
impl Module for Conv2dBN {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
xs.apply(&self.c)?.apply(&self.bn)
|
xs.apply(&self.c)?.apply_t(&self.bn, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -185,7 +185,7 @@ impl PaellaVQ {
|
||||||
xs = xs.apply(&down_block.1)?
|
xs = xs.apply(&down_block.1)?
|
||||||
}
|
}
|
||||||
xs.apply(&self.down_blocks_conv)?
|
xs.apply(&self.down_blocks_conv)?
|
||||||
.apply(&self.down_blocks_bn)
|
.apply_t(&self.down_blocks_bn, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
|
|
@ -107,8 +107,7 @@ impl ConvBlock {
|
||||||
|
|
||||||
impl Module for ConvBlock {
|
impl Module for ConvBlock {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let xs = self.conv.forward(xs)?;
|
let xs = self.conv.forward(xs)?.apply_t(&self.bn, false)?;
|
||||||
let xs = self.bn.forward(&xs)?;
|
|
||||||
candle_nn::ops::silu(&xs)
|
candle_nn::ops::silu(&xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue