Remove some dead-code annotations. (#629)
* Remove some dead-code annotations. * More dead code removal. * One more. * CI fix.
This commit is contained in:
parent
a3f97c143d
commit
72ebb12bca
|
@ -2,19 +2,16 @@
|
|||
// own forward pass (CPU and GPU versions) as well as their backward pass.
|
||||
//
|
||||
// In this example we add the RMS normalization operation and implement it for f32.
|
||||
#![allow(dead_code)]
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[allow(unused)]
|
||||
mod cuda_kernels;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cpu_backend;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, Layout, Result, Shape, Tensor};
|
||||
use candle::{CpuStorage, CustomOp1, Layout, Result, Shape, Tensor};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
|
@ -57,8 +54,9 @@ impl CustomOp1 for LayerNorm {
|
|||
storage: &candle::CudaStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::{cudarc, WrapErr};
|
||||
use cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
use candle::cuda_backend::WrapErr;
|
||||
let (d1, d2) = layout.shape().dims2()?;
|
||||
let d1 = d1 as u32;
|
||||
let d2 = d2 as u32;
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#![allow(dead_code)]
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
|
@ -185,16 +184,6 @@ struct ModelWeights {
|
|||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
struct WeightMap(HashMap<String, QTensor>);
|
||||
impl WeightMap {
|
||||
fn get(&mut self, name: &str) -> Result<QTensor> {
|
||||
match self.0.remove(name) {
|
||||
None => candle::bail!("cannot find tensor with name '{name}'"),
|
||||
Some(tensor) => Ok(tensor),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#![allow(dead_code)]
|
||||
//! Contrastive Language-Image Pre-Training
|
||||
//!
|
||||
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#![allow(dead_code)]
|
||||
//! # Denoising Diffusion Implicit Models
|
||||
//!
|
||||
//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler
|
||||
|
@ -164,17 +163,6 @@ impl DDIMScheduler {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||
timestep - 1
|
||||
} else {
|
||||
timestep
|
||||
};
|
||||
let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
|
||||
let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
|
||||
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
|
||||
}
|
||||
|
||||
pub fn init_noise_sigma(&self) -> f64 {
|
||||
self.init_noise_sigma
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#![allow(dead_code)]
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#![allow(dead_code)]
|
||||
//! ResNet Building Blocks
|
||||
//!
|
||||
//! Some Residual Network blocks used in UNet models.
|
||||
|
|
|
@ -19,34 +19,6 @@ pub struct Config {
|
|||
pub suppress_tokens: Vec<u32>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
#[allow(dead_code)]
|
||||
pub fn tiny_en() -> Self {
|
||||
let suppress_tokens = vec![
|
||||
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93,
|
||||
357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391,
|
||||
1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329,
|
||||
7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306,
|
||||
16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409,
|
||||
34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361,
|
||||
50362,
|
||||
];
|
||||
Self {
|
||||
num_mel_bins: 80,
|
||||
vocab_size: 51864,
|
||||
max_source_positions: 1500,
|
||||
d_model: 384,
|
||||
encoder_attention_heads: 6,
|
||||
encoder_layers: 4,
|
||||
max_target_positions: 448,
|
||||
// n_text_state: 384,
|
||||
decoder_attention_heads: 6,
|
||||
decoder_layers: 4,
|
||||
suppress_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
|
|
|
@ -675,7 +675,6 @@ pub struct YoloV8Pose {
|
|||
head: PoseHead,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl YoloV8Pose {
|
||||
pub fn load(
|
||||
vb: VarBuilder,
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
pub const WITH_TIMER: bool = true;
|
||||
|
||||
struct Timer {
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#![allow(dead_code)]
|
||||
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
|
@ -97,32 +96,6 @@ fn conv1d(
|
|||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
fn conv1d_no_bias(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
|
||||
Ok(Conv1d::new(weight, None, config))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
pr: f64,
|
||||
}
|
||||
|
||||
impl Dropout {
|
||||
fn new(pr: f64) -> Self {
|
||||
Self { pr }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// TODO
|
||||
Ok(x.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
let bias = vb.get(size, "bias")?;
|
||||
|
@ -414,10 +387,4 @@ impl Whisper {
|
|||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
||||
let enc = self.encoder.forward(mel)?;
|
||||
let dec = self.decoder.forward(tokens, &enc)?;
|
||||
Ok(dec)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,9 +33,6 @@ pub const HOP_LENGTH: usize = 160;
|
|||
pub const CHUNK_LENGTH: usize = 30;
|
||||
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||
pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
||||
pub const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2
|
||||
pub const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame
|
||||
pub const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token
|
||||
|
||||
pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
||||
pub const LOGPROB_THRESHOLD: f64 = -1.0;
|
||||
|
@ -46,7 +43,6 @@ pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
|||
pub const SOT_TOKEN: u32 = 50257;
|
||||
pub const EOT_TOKEN: u32 = 50256;
|
||||
pub const NO_SPEECH_TOKEN: u32 = 50361;
|
||||
pub const NO_TIMESTAMP_TOKEN: u32 = 50362;
|
||||
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
||||
pub const SUPPRESS_TOKENS: [u32; 91] = [
|
||||
|
|
|
@ -1,13 +1,9 @@
|
|||
#![allow(dead_code)]
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
|
||||
};
|
||||
use image::DynamicImage;
|
||||
|
||||
const CONFIDENCE_THRESHOLD: f32 = 0.25;
|
||||
const NMS_THRESHOLD: f32 = 0.45;
|
||||
|
||||
// Model architecture from https://github.com/ultralytics/ultralytics/issues/189
|
||||
// https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py
|
||||
|
||||
|
|
Loading…
Reference in New Issue