Do not implement Module for BatchNorm. (#1513)

This commit is contained in:
Laurent Mazare 2024-01-01 10:13:13 +01:00 committed by GitHub
parent 1fb2dd905c
commit b0fe5e4453
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 31 additions and 33 deletions

View File

@ -147,7 +147,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
let func = candle_nn::func(move |xs| { let func = candle_nn::func(move |xs| {
let xs = conv.forward(xs)?; let xs = conv.forward(xs)?;
let xs = match &bn { let xs = match &bn {
Some(bn) => bn.forward(&xs)?, Some(bn) => xs.apply_t(bn, false)?,
None => xs, None => xs,
}; };
let xs = if leaky { let xs = if leaky {

View File

@ -7,7 +7,7 @@
//! running stats. //! running stats.
//! //!
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167 //! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
use candle::{DType, Module, Result, Tensor, Var}; use candle::{DType, Result, Tensor, Var};
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub struct BatchNormConfig { pub struct BatchNormConfig {
@ -192,7 +192,7 @@ impl BatchNorm {
self.momentum self.momentum
} }
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> { pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
let num_features = self.running_mean.as_tensor().dim(0)?; let num_features = self.running_mean.as_tensor().dim(0)?;
let x_dtype = x.dtype(); let x_dtype = x.dtype();
let internal_dtype = match x_dtype { let internal_dtype = match x_dtype {
@ -252,17 +252,7 @@ impl BatchNorm {
x.reshape(x_dims_post_transpose)?.transpose(0, 1) x.reshape(x_dims_post_transpose)?.transpose(0, 1)
} }
pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> { fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
if train {
self.forward_learning(x)
} else {
self.forward(x)
}
}
}
impl Module for BatchNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let target_shape: Vec<usize> = x let target_shape: Vec<usize> = x
.dims() .dims()
.iter() .iter()
@ -288,6 +278,16 @@ impl Module for BatchNorm {
} }
} }
impl crate::ModuleT for BatchNorm {
fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
if train {
self.forward_train(x)
} else {
self.forward_eval(x)
}
}
}
pub fn batch_norm<C: Into<BatchNormConfig>>( pub fn batch_norm<C: Into<BatchNormConfig>>(
num_features: usize, num_features: usize,
config: C, config: C,

View File

@ -39,7 +39,7 @@ fn batch_norm() -> Result<()> {
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205, 1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
]; ];
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?; let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
let output = bn.forward_learning(&input)?; let output = bn.forward_train(&input)?;
assert_eq!(output.dims(), &[2, 5, 3, 4]); assert_eq!(output.dims(), &[2, 5, 3, 4]);
let output = output.flatten_all()?; let output = output.flatten_all()?;
assert_eq!( assert_eq!(
@ -67,7 +67,7 @@ fn batch_norm() -> Result<()> {
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?, Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
1e-8, 1e-8,
)?; )?;
let output2 = bn2.forward_learning(&input)?; let output2 = bn2.forward_train(&input)?;
assert_eq!(output2.dims(), &[2, 5, 3, 4]); assert_eq!(output2.dims(), &[2, 5, 3, 4]);
let output2 = output2.flatten_all()?; let output2 = output2.flatten_all()?;
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?; let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;

View File

@ -40,8 +40,8 @@ fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result<impl Module>
let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?; let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?; let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
Ok(candle_nn::func(move |xs| { Ok(candle_nn::func(move |xs| {
let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?; let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2) (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false)
})) }))
} }
@ -64,7 +64,7 @@ fn convmixer(
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?; let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
Ok(candle_nn::func(move |xs| { Ok(candle_nn::func(move |xs| {
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?; let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
for block in blocks.iter() { for block in blocks.iter() {
xs = xs.apply(block)? xs = xs.apply(block)?
} }

View File

@ -169,8 +169,7 @@ impl ConvNormActivation {
impl Module for ConvNormActivation { impl Module for ConvNormActivation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv2d.forward(xs)?; let xs = self.conv2d.forward(xs)?.apply_t(&self.bn2d, false)?;
let xs = self.bn2d.forward(&xs)?;
if self.activation { if self.activation {
swish(&xs) swish(&xs)
} else { } else {

View File

@ -25,7 +25,7 @@ fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resul
if stride != 1 || c_in != c_out { if stride != 1 || c_in != c_out {
let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?; let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
let bn = batch_norm(c_out, 1e-5, vb.pp(1))?; let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn))) Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
} else { } else {
Ok(Func::new(|xs| Ok(xs.clone()))) Ok(Func::new(|xs| Ok(xs.clone())))
} }
@ -40,10 +40,10 @@ fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resu
Ok(Func::new(move |xs| { Ok(Func::new(move |xs| {
let ys = xs let ys = xs
.apply(&conv1)? .apply(&conv1)?
.apply(&bn1)? .apply_t(&bn1, false)?
.relu()? .relu()?
.apply(&conv2)? .apply(&conv2)?
.apply(&bn2)?; .apply_t(&bn2, false)?;
(xs.apply(&downsample)? + ys)?.relu() (xs.apply(&downsample)? + ys)?.relu()
})) }))
} }
@ -94,7 +94,7 @@ fn resnet(
Ok(Func::new(move |xs| { Ok(Func::new(move |xs| {
let xs = xs let xs = xs
.apply(&conv1)? .apply(&conv1)?
.apply(&bn1)? .apply_t(&bn1, false)?
.relu()? .relu()?
.pad_with_same(D::Minus1, 1, 1)? .pad_with_same(D::Minus1, 1, 1)?
.pad_with_same(D::Minus2, 1, 1)? .pad_with_same(D::Minus2, 1, 1)?
@ -149,13 +149,13 @@ fn bottleneck_block(
Ok(Func::new(move |xs| { Ok(Func::new(move |xs| {
let ys = xs let ys = xs
.apply(&conv1)? .apply(&conv1)?
.apply(&bn1)? .apply_t(&bn1, false)?
.relu()? .relu()?
.apply(&conv2)? .apply(&conv2)?
.apply(&bn2)? .apply_t(&bn2, false)?
.relu()? .relu()?
.apply(&conv3)? .apply(&conv3)?
.apply(&bn3)?; .apply_t(&bn3, false)?;
(xs.apply(&downsample)? + ys)?.relu() (xs.apply(&downsample)? + ys)?.relu()
})) }))
} }
@ -206,7 +206,7 @@ fn bottleneck_resnet(
Ok(Func::new(move |xs| { Ok(Func::new(move |xs| {
let xs = xs let xs = xs
.apply(&conv1)? .apply(&conv1)?
.apply(&bn1)? .apply_t(&bn1, false)?
.relu()? .relu()?
.pad_with_same(D::Minus1, 1, 1)? .pad_with_same(D::Minus1, 1, 1)?
.pad_with_same(D::Minus2, 1, 1)? .pad_with_same(D::Minus2, 1, 1)?

View File

@ -28,7 +28,7 @@ impl Conv2dBN {
impl Module for Conv2dBN { impl Module for Conv2dBN {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
xs.apply(&self.c)?.apply(&self.bn) xs.apply(&self.c)?.apply_t(&self.bn, false)
} }
} }

View File

@ -185,7 +185,7 @@ impl PaellaVQ {
xs = xs.apply(&down_block.1)? xs = xs.apply(&down_block.1)?
} }
xs.apply(&self.down_blocks_conv)? xs.apply(&self.down_blocks_conv)?
.apply(&self.down_blocks_bn) .apply_t(&self.down_blocks_bn, false)
} }
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> { pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {

View File

@ -107,8 +107,7 @@ impl ConvBlock {
impl Module for ConvBlock { impl Module for ConvBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv.forward(xs)?; let xs = self.conv.forward(xs)?.apply_t(&self.bn, false)?;
let xs = self.bn.forward(&xs)?;
candle_nn::ops::silu(&xs) candle_nn::ops::silu(&xs)
} }
} }