Add the upblocks. (#853)

This commit is contained in:
Laurent Mazare 2023-09-14 23:24:56 +02:00 committed by GitHub
parent 91ec546feb
commit 130fe5a087
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 5 deletions

View File

@ -110,7 +110,7 @@ impl ToUsize2 for (usize, usize) {
}
// A simple trait defining a module with forward method using a single argument.
pub trait Module: std::fmt::Debug {
pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
@ -119,3 +119,9 @@ impl Module for quantized::QMatMul {
self.forward(xs)
}
}
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)
}
}

View File

@ -44,6 +44,10 @@ pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
(xs.neg()?.exp()? + 1.0)?.recip()
}
pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
xs.relu()?.minimum(&(xs * negative_slope)?)
}
pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
// This implementation is inefficient as it stores the full mask for the backward pass.
// Instead we could just store the seed and have a specialized kernel that would both

View File

@ -161,8 +161,57 @@ impl WDiffNeXt {
down_blocks.push(down_block)
}
// TODO: populate.
let up_blocks = Vec::with_capacity(C_HIDDEN.len());
let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
let vb = vb.pp("up_blocks").pp(i);
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
let mut layer_i = 0;
for j in 0..BLOCKS[i] {
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 {
c_hidden + c_skip
} else {
c_skip
};
let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?;
layer_i += 1;
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
layer_i += 1;
let attn_block = if j == 0 {
None
} else {
let attn_block =
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
layer_i += 1;
Some(attn_block)
};
let sub_block = SubBlock {
res_block,
ts_block,
attn_block,
};
sub_blocks.push(sub_block)
}
let (layer_norm, conv, start_layer_i) = if i > 0 {
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
layer_i += 1;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
..Default::default()
};
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?;
layer_i += 1;
(Some(layer_norm), Some(conv), 2)
} else {
(None, None, 0)
};
let up_block = UpBlock {
layer_norm,
conv,
sub_blocks,
};
up_blocks.push(up_block)
}
let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?;
let clf_conv = candle_nn::conv2d(

View File

@ -85,10 +85,9 @@ impl WPrior {
pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
let x_in = xs;
let mut xs = xs.apply(&self.projection)?;
// TODO: leaky relu
let c_embed = c
.apply(&self.cond_mapper_lin1)?
.relu()?
.apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))?
.apply(&self.cond_mapper_lin2)?;
let r_embed = self.gen_r_embedding(r)?;
for block in self.blocks.iter() {