Fix lints for clippy 1.75. (#1494)

This commit is contained in:
Laurent Mazare 2023-12-28 20:26:20 +01:00 committed by GitHub
parent cd889c0f8a
commit 1e442d4bb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 38 additions and 40 deletions

View File

@ -478,23 +478,6 @@ extract_dims!(
(usize, usize, usize, usize, usize) (usize, usize, usize, usize, usize)
); );
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stride() {
let shape = Shape::from(());
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
let shape = Shape::from(42);
assert_eq!(shape.stride_contiguous(), [1]);
let shape = Shape::from((42, 1337));
assert_eq!(shape.stride_contiguous(), [1337, 1]);
let shape = Shape::from((299, 792, 458));
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}
pub trait ShapeWithOneHole { pub trait ShapeWithOneHole {
fn into_shape(self, el_count: usize) -> Result<Shape>; fn into_shape(self, el_count: usize) -> Result<Shape>;
} }
@ -627,3 +610,20 @@ impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
Ok((d1, d2, d3, d4, d).into()) Ok((d1, d2, d3, d4, d).into())
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stride() {
let shape = Shape::from(());
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
let shape = Shape::from(42);
assert_eq!(shape.stride_contiguous(), [1]);
let shape = Shape::from((42, 1337));
assert_eq!(shape.stride_contiguous(), [1337, 1]);
let shape = Shape::from((299, 792, 458));
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}

View File

@ -321,7 +321,7 @@ impl MusicgenDecoder {
let positions = self.embed_positions.forward(&input)?.to_device(dev)?; let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
let mut xs = inputs_embeds.broadcast_add(&positions)?; let mut xs = inputs_embeds.broadcast_add(&positions)?;
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?; let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() { for decoder_layer in self.layers.iter_mut() {
xs = decoder_layer.forward(&xs, &attention_mask, None)?; xs = decoder_layer.forward(&xs, &attention_mask, None)?;
} }
let xs = self.layer_norm.forward(&xs)?; let xs = self.layer_norm.forward(&xs)?;

View File

@ -184,7 +184,7 @@ impl Sam {
let labels = Tensor::from_vec(labels, (1, n_points), img_embeddings.device())?; let labels = Tensor::from_vec(labels, (1, n_points), img_embeddings.device())?;
Some((points, labels)) Some((points, labels))
}; };
let points = points.as_ref().map(|(x, y)| (x, y)); let points = points.as_ref().map(|xy| (&xy.0, &xy.1));
let (sparse_prompt_embeddings, dense_prompt_embeddings) = let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder.forward(points, None, None)?; self.prompt_encoder.forward(points, None, None)?;
self.mask_decoder.forward( self.mask_decoder.forward(

View File

@ -34,8 +34,8 @@ pub enum Msg {
Run, Run,
UpdateStatus(String), UpdateStatus(String),
SetModel(ModelData), SetModel(ModelData),
WorkerInMsg(WorkerInput), WorkerIn(WorkerInput),
WorkerOutMsg(Result<WorkerOutput, String>), WorkerOut(Result<WorkerOutput, String>),
} }
pub struct CurrentDecode { pub struct CurrentDecode {
@ -75,7 +75,7 @@ impl Component for App {
let status = "loading weights".to_string(); let status = "loading weights".to_string();
let cb = { let cb = {
let link = ctx.link().clone(); let link = ctx.link().clone();
move |e| link.send_message(Self::Message::WorkerOutMsg(e)) move |e| link.send_message(Self::Message::WorkerOut(e))
}; };
let worker = Worker::bridge(std::rc::Rc::new(cb)); let worker = Worker::bridge(std::rc::Rc::new(cb));
Self { Self {
@ -128,11 +128,11 @@ impl Component for App {
let prompt = self.prompt.borrow().clone(); let prompt = self.prompt.borrow().clone();
console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt); console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt);
ctx.link() ctx.link()
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt))) .send_message(Msg::WorkerIn(WorkerInput::Run(temp, top_p, prompt)))
} }
true true
} }
Msg::WorkerOutMsg(output) => { Msg::WorkerOut(output) => {
match output { match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(), Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::GenerationDone(Err(err))) => { Ok(WorkerOutput::GenerationDone(Err(err))) => {
@ -165,7 +165,7 @@ impl Component for App {
} }
true true
} }
Msg::WorkerInMsg(inp) => { Msg::WorkerIn(inp) => {
self.worker.send(inp); self.worker.send(inp);
true true
} }

View File

@ -42,8 +42,8 @@ pub enum Msg {
Run(usize), Run(usize),
UpdateStatus(String), UpdateStatus(String),
SetDecoder(ModelData), SetDecoder(ModelData),
WorkerInMsg(WorkerInput), WorkerIn(WorkerInput),
WorkerOutMsg(Result<WorkerOutput, String>), WorkerOut(Result<WorkerOutput, String>),
} }
pub struct CurrentDecode { pub struct CurrentDecode {
@ -116,7 +116,7 @@ impl Component for App {
let status = "loading weights".to_string(); let status = "loading weights".to_string();
let cb = { let cb = {
let link = ctx.link().clone(); let link = ctx.link().clone();
move |e| link.send_message(Self::Message::WorkerOutMsg(e)) move |e| link.send_message(Self::Message::WorkerOut(e))
}; };
let worker = Worker::bridge(std::rc::Rc::new(cb)); let worker = Worker::bridge(std::rc::Rc::new(cb));
Self { Self {
@ -165,18 +165,16 @@ impl Component for App {
Err(err) => { Err(err) => {
let output = Err(format!("decoding error: {err:?}")); let output = Err(format!("decoding error: {err:?}"));
// Mimic a worker output to so as to release current_decode // Mimic a worker output to so as to release current_decode
Msg::WorkerOutMsg(output) Msg::WorkerOut(output)
}
Ok(wav_bytes) => {
Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
} }
Ok(wav_bytes) => Msg::WorkerIn(WorkerInput::DecodeTask { wav_bytes }),
} }
}) })
} }
// //
true true
} }
Msg::WorkerOutMsg(output) => { Msg::WorkerOut(output) => {
let dt = self.current_decode.as_ref().and_then(|current_decode| { let dt = self.current_decode.as_ref().and_then(|current_decode| {
current_decode.start_time.and_then(|start_time| { current_decode.start_time.and_then(|start_time| {
performance_now().map(|stop_time| stop_time - start_time) performance_now().map(|stop_time| stop_time - start_time)
@ -198,7 +196,7 @@ impl Component for App {
} }
true true
} }
Msg::WorkerInMsg(inp) => { Msg::WorkerIn(inp) => {
self.worker.send(inp); self.worker.send(inp);
true true
} }

View File

@ -33,8 +33,8 @@ pub enum Msg {
Run, Run,
UpdateStatus(String), UpdateStatus(String),
SetModel(ModelData), SetModel(ModelData),
WorkerInMsg(WorkerInput), WorkerIn(WorkerInput),
WorkerOutMsg(Result<WorkerOutput, String>), WorkerOut(Result<WorkerOutput, String>),
} }
pub struct CurrentDecode { pub struct CurrentDecode {
@ -117,7 +117,7 @@ impl Component for App {
let status = "loading weights".to_string(); let status = "loading weights".to_string();
let cb = { let cb = {
let link = ctx.link().clone(); let link = ctx.link().clone();
move |e| link.send_message(Self::Message::WorkerOutMsg(e)) move |e| link.send_message(Self::Message::WorkerOut(e))
}; };
let worker = Worker::bridge(std::rc::Rc::new(cb)); let worker = Worker::bridge(std::rc::Rc::new(cb));
Self { Self {
@ -166,7 +166,7 @@ impl Component for App {
let status = format!("{err:?}"); let status = format!("{err:?}");
Msg::UpdateStatus(status) Msg::UpdateStatus(status)
} }
Ok(image_data) => Msg::WorkerInMsg(WorkerInput::RunData(RunData { Ok(image_data) => Msg::WorkerIn(WorkerInput::RunData(RunData {
image_data, image_data,
conf_threshold: 0.5, conf_threshold: 0.5,
iou_threshold: 0.5, iou_threshold: 0.5,
@ -176,7 +176,7 @@ impl Component for App {
} }
true true
} }
Msg::WorkerOutMsg(output) => { Msg::WorkerOut(output) => {
match output { match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(), Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::ProcessingDone(Err(err))) => { Ok(WorkerOutput::ProcessingDone(Err(err))) => {
@ -218,7 +218,7 @@ impl Component for App {
} }
true true
} }
Msg::WorkerInMsg(inp) => { Msg::WorkerIn(inp) => {
self.worker.send(inp); self.worker.send(inp);
true true
} }