Do not implement Module for BatchNorm. (#1513)

This commit is contained in:
Laurent Mazare 2024-01-01 10:13:13 +01:00 committed by GitHub
parent 1fb2dd905c
commit b0fe5e4453
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 31 additions and 33 deletions

View File

@ -147,7 +147,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
let func = candle_nn::func(move |xs| {
let xs = conv.forward(xs)?;
let xs = match &bn {
Some(bn) => bn.forward(&xs)?,
Some(bn) => xs.apply_t(bn, false)?,
None => xs,
};
let xs = if leaky {

View File

@ -7,7 +7,7 @@
//! running stats.
//!
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
use candle::{DType, Module, Result, Tensor, Var};
use candle::{DType, Result, Tensor, Var};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BatchNormConfig {
@ -192,7 +192,7 @@ impl BatchNorm {
self.momentum
}
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
let num_features = self.running_mean.as_tensor().dim(0)?;
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
@ -252,17 +252,7 @@ impl BatchNorm {
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
}
pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
if train {
self.forward_learning(x)
} else {
self.forward(x)
}
}
}
impl Module for BatchNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
let target_shape: Vec<usize> = x
.dims()
.iter()
@ -288,6 +278,16 @@ impl Module for BatchNorm {
}
}
impl crate::ModuleT for BatchNorm {
fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
if train {
self.forward_train(x)
} else {
self.forward_eval(x)
}
}
}
pub fn batch_norm<C: Into<BatchNormConfig>>(
num_features: usize,
config: C,

View File

@ -39,7 +39,7 @@ fn batch_norm() -> Result<()> {
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
];
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
let output = bn.forward_learning(&input)?;
let output = bn.forward_train(&input)?;
assert_eq!(output.dims(), &[2, 5, 3, 4]);
let output = output.flatten_all()?;
assert_eq!(
@ -67,7 +67,7 @@ fn batch_norm() -> Result<()> {
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
1e-8,
)?;
let output2 = bn2.forward_learning(&input)?;
let output2 = bn2.forward_train(&input)?;
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
let output2 = output2.flatten_all()?;
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;

View File

@ -40,8 +40,8 @@ fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result<impl Module>
let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
Ok(candle_nn::func(move |xs| {
let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2)
let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false)
}))
}
@ -64,7 +64,7 @@ fn convmixer(
.collect::<Result<Vec<_>>>()?;
let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
Ok(candle_nn::func(move |xs| {
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
for block in blocks.iter() {
xs = xs.apply(block)?
}

View File

@ -169,8 +169,7 @@ impl ConvNormActivation {
impl Module for ConvNormActivation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv2d.forward(xs)?;
let xs = self.bn2d.forward(&xs)?;
let xs = self.conv2d.forward(xs)?.apply_t(&self.bn2d, false)?;
if self.activation {
swish(&xs)
} else {

View File

@ -25,7 +25,7 @@ fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resul
if stride != 1 || c_in != c_out {
let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn)))
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
} else {
Ok(Func::new(|xs| Ok(xs.clone())))
}
@ -40,10 +40,10 @@ fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resu
Ok(Func::new(move |xs| {
let ys = xs
.apply(&conv1)?
.apply(&bn1)?
.apply_t(&bn1, false)?
.relu()?
.apply(&conv2)?
.apply(&bn2)?;
.apply_t(&bn2, false)?;
(xs.apply(&downsample)? + ys)?.relu()
}))
}
@ -94,7 +94,7 @@ fn resnet(
Ok(Func::new(move |xs| {
let xs = xs
.apply(&conv1)?
.apply(&bn1)?
.apply_t(&bn1, false)?
.relu()?
.pad_with_same(D::Minus1, 1, 1)?
.pad_with_same(D::Minus2, 1, 1)?
@ -149,13 +149,13 @@ fn bottleneck_block(
Ok(Func::new(move |xs| {
let ys = xs
.apply(&conv1)?
.apply(&bn1)?
.apply_t(&bn1, false)?
.relu()?
.apply(&conv2)?
.apply(&bn2)?
.apply_t(&bn2, false)?
.relu()?
.apply(&conv3)?
.apply(&bn3)?;
.apply_t(&bn3, false)?;
(xs.apply(&downsample)? + ys)?.relu()
}))
}
@ -206,7 +206,7 @@ fn bottleneck_resnet(
Ok(Func::new(move |xs| {
let xs = xs
.apply(&conv1)?
.apply(&bn1)?
.apply_t(&bn1, false)?
.relu()?
.pad_with_same(D::Minus1, 1, 1)?
.pad_with_same(D::Minus2, 1, 1)?

View File

@ -28,7 +28,7 @@ impl Conv2dBN {
impl Module for Conv2dBN {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.c)?.apply(&self.bn)
xs.apply(&self.c)?.apply_t(&self.bn, false)
}
}

View File

@ -185,7 +185,7 @@ impl PaellaVQ {
xs = xs.apply(&down_block.1)?
}
xs.apply(&self.down_blocks_conv)?
.apply(&self.down_blocks_bn)
.apply_t(&self.down_blocks_bn, false)
}
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {

View File

@ -107,8 +107,7 @@ impl ConvBlock {
impl Module for ConvBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv.forward(xs)?;
let xs = self.bn.forward(&xs)?;
let xs = self.conv.forward(xs)?.apply_t(&self.bn, false)?;
candle_nn::ops::silu(&xs)
}
}