diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs index 6702618e..b1dd3127 100644 --- a/candle-examples/examples/yolo-v3/darknet.rs +++ b/candle-examples/examples/yolo-v3/darknet.rs @@ -147,7 +147,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl) let func = candle_nn::func(move |xs| { let xs = conv.forward(xs)?; let xs = match &bn { - Some(bn) => bn.forward(&xs)?, + Some(bn) => xs.apply_t(bn, false)?, None => xs, }; let xs = if leaky { diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 1782e47a..856c2c7a 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -7,7 +7,7 @@ //! running stats. //! //! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167 -use candle::{DType, Module, Result, Tensor, Var}; +use candle::{DType, Result, Tensor, Var}; #[derive(Debug, Clone, Copy, PartialEq)] pub struct BatchNormConfig { @@ -192,7 +192,7 @@ impl BatchNorm { self.momentum } - pub fn forward_learning(&self, x: &Tensor) -> Result { + pub fn forward_train(&self, x: &Tensor) -> Result { let num_features = self.running_mean.as_tensor().dim(0)?; let x_dtype = x.dtype(); let internal_dtype = match x_dtype { @@ -252,17 +252,7 @@ impl BatchNorm { x.reshape(x_dims_post_transpose)?.transpose(0, 1) } - pub fn forward_t(&self, x: &Tensor, train: bool) -> Result { - if train { - self.forward_learning(x) - } else { - self.forward(x) - } - } -} - -impl Module for BatchNorm { - fn forward(&self, x: &Tensor) -> Result { + fn forward_eval(&self, x: &Tensor) -> Result { let target_shape: Vec = x .dims() .iter() @@ -288,6 +278,16 @@ impl Module for BatchNorm { } } +impl crate::ModuleT for BatchNorm { + fn forward_t(&self, x: &Tensor, train: bool) -> Result { + if train { + self.forward_train(x) + } else { + self.forward_eval(x) + } + } +} + pub fn batch_norm>( num_features: usize, config: C, diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs index 73a38545..6fd7361a 100644 --- a/candle-nn/tests/batch_norm.rs +++ b/candle-nn/tests/batch_norm.rs @@ -39,7 +39,7 @@ fn batch_norm() -> Result<()> { 1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205, ]; let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?; - let output = bn.forward_learning(&input)?; + let output = bn.forward_train(&input)?; assert_eq!(output.dims(), &[2, 5, 3, 4]); let output = output.flatten_all()?; assert_eq!( @@ -67,7 +67,7 @@ fn batch_norm() -> Result<()> { Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?, 1e-8, )?; - let output2 = bn2.forward_learning(&input)?; + let output2 = bn2.forward_train(&input)?; assert_eq!(output2.dims(), &[2, 5, 3, 4]); let output2 = output2.flatten_all()?; let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?; diff --git a/candle-transformers/src/models/convmixer.rs b/candle-transformers/src/models/convmixer.rs index 76245f37..f5abfa5d 100644 --- a/candle-transformers/src/models/convmixer.rs +++ b/candle-transformers/src/models/convmixer.rs @@ -40,8 +40,8 @@ fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?; let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?; Ok(candle_nn::func(move |xs| { - let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?; - (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2) + let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?; + (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false) })) } @@ -64,7 +64,7 @@ fn convmixer( .collect::>>()?; let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?; Ok(candle_nn::func(move |xs| { - let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?; + let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?; for block in blocks.iter() { xs = xs.apply(block)? } diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs index ab51c76d..f15c9c79 100644 --- a/candle-transformers/src/models/efficientnet.rs +++ b/candle-transformers/src/models/efficientnet.rs @@ -169,8 +169,7 @@ impl ConvNormActivation { impl Module for ConvNormActivation { fn forward(&self, xs: &Tensor) -> Result { - let xs = self.conv2d.forward(xs)?; - let xs = self.bn2d.forward(&xs)?; + let xs = self.conv2d.forward(xs)?.apply_t(&self.bn2d, false)?; if self.activation { swish(&xs) } else { diff --git a/candle-transformers/src/models/resnet.rs b/candle-transformers/src/models/resnet.rs index f2588e01..30029a0b 100644 --- a/candle-transformers/src/models/resnet.rs +++ b/candle-transformers/src/models/resnet.rs @@ -25,7 +25,7 @@ fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resul if stride != 1 || c_in != c_out { let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?; let bn = batch_norm(c_out, 1e-5, vb.pp(1))?; - Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn))) + Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false))) } else { Ok(Func::new(|xs| Ok(xs.clone()))) } @@ -40,10 +40,10 @@ fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resu Ok(Func::new(move |xs| { let ys = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .apply(&conv2)? - .apply(&bn2)?; + .apply_t(&bn2, false)?; (xs.apply(&downsample)? + ys)?.relu() })) } @@ -94,7 +94,7 @@ fn resnet( Ok(Func::new(move |xs| { let xs = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .pad_with_same(D::Minus1, 1, 1)? .pad_with_same(D::Minus2, 1, 1)? @@ -149,13 +149,13 @@ fn bottleneck_block( Ok(Func::new(move |xs| { let ys = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .apply(&conv2)? - .apply(&bn2)? + .apply_t(&bn2, false)? .relu()? .apply(&conv3)? - .apply(&bn3)?; + .apply_t(&bn3, false)?; (xs.apply(&downsample)? + ys)?.relu() })) } @@ -206,7 +206,7 @@ fn bottleneck_resnet( Ok(Func::new(move |xs| { let xs = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .pad_with_same(D::Minus1, 1, 1)? .pad_with_same(D::Minus2, 1, 1)? diff --git a/candle-transformers/src/models/segment_anything/tiny_vit.rs b/candle-transformers/src/models/segment_anything/tiny_vit.rs index cd2936ab..d1700cc5 100644 --- a/candle-transformers/src/models/segment_anything/tiny_vit.rs +++ b/candle-transformers/src/models/segment_anything/tiny_vit.rs @@ -28,7 +28,7 @@ impl Conv2dBN { impl Module for Conv2dBN { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); - xs.apply(&self.c)?.apply(&self.bn) + xs.apply(&self.c)?.apply_t(&self.bn, false) } } diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs index 4a69cca0..58f795bb 100644 --- a/candle-transformers/src/models/wuerstchen/paella_vq.rs +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -185,7 +185,7 @@ impl PaellaVQ { xs = xs.apply(&down_block.1)? } xs.apply(&self.down_blocks_conv)? - .apply(&self.down_blocks_bn) + .apply_t(&self.down_blocks_bn, false) } pub fn decode(&self, xs: &Tensor) -> Result { diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs index d49cf55f..f1d7ea20 100644 --- a/candle-wasm-examples/yolo/src/model.rs +++ b/candle-wasm-examples/yolo/src/model.rs @@ -107,8 +107,7 @@ impl ConvBlock { impl Module for ConvBlock { fn forward(&self, xs: &Tensor) -> Result { - let xs = self.conv.forward(xs)?; - let xs = self.bn.forward(&xs)?; + let xs = self.conv.forward(xs)?.apply_t(&self.bn, false)?; candle_nn::ops::silu(&xs) } }