Do not implement Module for BatchNorm. (#1513)
This commit is contained in:
parent
1fb2dd905c
commit
b0fe5e4453
|
@ -147,7 +147,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
|||
let func = candle_nn::func(move |xs| {
|
||||
let xs = conv.forward(xs)?;
|
||||
let xs = match &bn {
|
||||
Some(bn) => bn.forward(&xs)?,
|
||||
Some(bn) => xs.apply_t(bn, false)?,
|
||||
None => xs,
|
||||
};
|
||||
let xs = if leaky {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
//! running stats.
|
||||
//!
|
||||
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
||||
use candle::{DType, Module, Result, Tensor, Var};
|
||||
use candle::{DType, Result, Tensor, Var};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct BatchNormConfig {
|
||||
|
@ -192,7 +192,7 @@ impl BatchNorm {
|
|||
self.momentum
|
||||
}
|
||||
|
||||
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let num_features = self.running_mean.as_tensor().dim(0)?;
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
|
@ -252,17 +252,7 @@ impl BatchNorm {
|
|||
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
|
||||
if train {
|
||||
self.forward_learning(x)
|
||||
} else {
|
||||
self.forward(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BatchNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let target_shape: Vec<usize> = x
|
||||
.dims()
|
||||
.iter()
|
||||
|
@ -288,6 +278,16 @@ impl Module for BatchNorm {
|
|||
}
|
||||
}
|
||||
|
||||
impl crate::ModuleT for BatchNorm {
|
||||
fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
|
||||
if train {
|
||||
self.forward_train(x)
|
||||
} else {
|
||||
self.forward_eval(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||
num_features: usize,
|
||||
config: C,
|
||||
|
|
|
@ -39,7 +39,7 @@ fn batch_norm() -> Result<()> {
|
|||
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
|
||||
];
|
||||
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
|
||||
let output = bn.forward_learning(&input)?;
|
||||
let output = bn.forward_train(&input)?;
|
||||
assert_eq!(output.dims(), &[2, 5, 3, 4]);
|
||||
let output = output.flatten_all()?;
|
||||
assert_eq!(
|
||||
|
@ -67,7 +67,7 @@ fn batch_norm() -> Result<()> {
|
|||
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
||||
1e-8,
|
||||
)?;
|
||||
let output2 = bn2.forward_learning(&input)?;
|
||||
let output2 = bn2.forward_train(&input)?;
|
||||
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
|
||||
let output2 = output2.flatten_all()?;
|
||||
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
||||
|
|
|
@ -40,8 +40,8 @@ fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result<impl Module>
|
|||
let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
|
||||
let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
|
||||
Ok(candle_nn::func(move |xs| {
|
||||
let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
|
||||
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2)
|
||||
let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
|
||||
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false)
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -64,7 +64,7 @@ fn convmixer(
|
|||
.collect::<Result<Vec<_>>>()?;
|
||||
let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
|
||||
Ok(candle_nn::func(move |xs| {
|
||||
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
|
||||
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
|
||||
for block in blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
|
|
|
@ -169,8 +169,7 @@ impl ConvNormActivation {
|
|||
|
||||
impl Module for ConvNormActivation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.conv2d.forward(xs)?;
|
||||
let xs = self.bn2d.forward(&xs)?;
|
||||
let xs = self.conv2d.forward(xs)?.apply_t(&self.bn2d, false)?;
|
||||
if self.activation {
|
||||
swish(&xs)
|
||||
} else {
|
||||
|
|
|
@ -25,7 +25,7 @@ fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resul
|
|||
if stride != 1 || c_in != c_out {
|
||||
let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
|
||||
let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
|
||||
Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn)))
|
||||
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
|
||||
} else {
|
||||
Ok(Func::new(|xs| Ok(xs.clone())))
|
||||
}
|
||||
|
@ -40,10 +40,10 @@ fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resu
|
|||
Ok(Func::new(move |xs| {
|
||||
let ys = xs
|
||||
.apply(&conv1)?
|
||||
.apply(&bn1)?
|
||||
.apply_t(&bn1, false)?
|
||||
.relu()?
|
||||
.apply(&conv2)?
|
||||
.apply(&bn2)?;
|
||||
.apply_t(&bn2, false)?;
|
||||
(xs.apply(&downsample)? + ys)?.relu()
|
||||
}))
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ fn resnet(
|
|||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&conv1)?
|
||||
.apply(&bn1)?
|
||||
.apply_t(&bn1, false)?
|
||||
.relu()?
|
||||
.pad_with_same(D::Minus1, 1, 1)?
|
||||
.pad_with_same(D::Minus2, 1, 1)?
|
||||
|
@ -149,13 +149,13 @@ fn bottleneck_block(
|
|||
Ok(Func::new(move |xs| {
|
||||
let ys = xs
|
||||
.apply(&conv1)?
|
||||
.apply(&bn1)?
|
||||
.apply_t(&bn1, false)?
|
||||
.relu()?
|
||||
.apply(&conv2)?
|
||||
.apply(&bn2)?
|
||||
.apply_t(&bn2, false)?
|
||||
.relu()?
|
||||
.apply(&conv3)?
|
||||
.apply(&bn3)?;
|
||||
.apply_t(&bn3, false)?;
|
||||
(xs.apply(&downsample)? + ys)?.relu()
|
||||
}))
|
||||
}
|
||||
|
@ -206,7 +206,7 @@ fn bottleneck_resnet(
|
|||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&conv1)?
|
||||
.apply(&bn1)?
|
||||
.apply_t(&bn1, false)?
|
||||
.relu()?
|
||||
.pad_with_same(D::Minus1, 1, 1)?
|
||||
.pad_with_same(D::Minus2, 1, 1)?
|
||||
|
|
|
@ -28,7 +28,7 @@ impl Conv2dBN {
|
|||
impl Module for Conv2dBN {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.c)?.apply(&self.bn)
|
||||
xs.apply(&self.c)?.apply_t(&self.bn, false)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -185,7 +185,7 @@ impl PaellaVQ {
|
|||
xs = xs.apply(&down_block.1)?
|
||||
}
|
||||
xs.apply(&self.down_blocks_conv)?
|
||||
.apply(&self.down_blocks_bn)
|
||||
.apply_t(&self.down_blocks_bn, false)
|
||||
}
|
||||
|
||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
|
|
|
@ -107,8 +107,7 @@ impl ConvBlock {
|
|||
|
||||
impl Module for ConvBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.conv.forward(xs)?;
|
||||
let xs = self.bn.forward(&xs)?;
|
||||
let xs = self.conv.forward(xs)?.apply_t(&self.bn, false)?;
|
||||
candle_nn::ops::silu(&xs)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue