Pass directly the buffer ownership. (#949)
This commit is contained in:
parent
e32c89d90c
commit
7edd755756
|
@ -18,8 +18,7 @@ impl Model {
|
|||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let device = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F64, device);
|
||||
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F64, device)?;
|
||||
let config: Config = serde_json::from_slice(&config)?;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
|
|
|
@ -29,8 +29,7 @@ impl ModelConditionalGeneration {
|
|||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let device = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
|
@ -128,8 +127,7 @@ impl ModelEncoder {
|
|||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let device = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
config.use_cache = false;
|
||||
let tokenizer =
|
||||
|
|
|
@ -253,8 +253,7 @@ impl Decoder {
|
|||
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
||||
console_log!("loaded mel filters {:?}", mel_filters.shape());
|
||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(&md.weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let vb = VarBuilder::from_buffered_safetensors(md.weights, DTYPE, &device)?;
|
||||
let config = Config::tiny_en();
|
||||
let whisper = Whisper::load(&vb, config)?;
|
||||
console_log!("done loading model");
|
||||
|
|
|
@ -13,7 +13,7 @@ pub struct Model {
|
|||
impl Model {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(data: Vec<u8>, model_size: &str) -> Result<Model, JsError> {
|
||||
let inner = M::load_(&data, model_size)?;
|
||||
let inner = M::load_(data, model_size)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
|
@ -46,7 +46,7 @@ pub struct ModelPose {
|
|||
impl ModelPose {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(data: Vec<u8>, model_size: &str) -> Result<ModelPose, JsError> {
|
||||
let inner = P::load_(&data, model_size)?;
|
||||
let inner = P::load_(data, model_size)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ impl Model {
|
|||
Ok(bboxes)
|
||||
}
|
||||
|
||||
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> {
|
||||
pub fn load_(weights: Vec<u8>, model_size: &str) -> Result<Self> {
|
||||
let multiples = match model_size {
|
||||
"n" => Multiples::n(),
|
||||
"s" => Multiples::s(),
|
||||
|
@ -104,14 +104,13 @@ impl Model {
|
|||
))?,
|
||||
};
|
||||
let dev = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
|
||||
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?;
|
||||
let model = YoloV8::load(vb, multiples, 80)?;
|
||||
Ok(Self { model })
|
||||
}
|
||||
|
||||
pub fn load(md: ModelData) -> Result<Self> {
|
||||
Self::load_(&md.weights, &md.model_size.to_string())
|
||||
Self::load_(md.weights, &md.model_size.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -172,7 +171,7 @@ impl ModelPose {
|
|||
Ok(bboxes)
|
||||
}
|
||||
|
||||
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> {
|
||||
pub fn load_(weights: Vec<u8>, model_size: &str) -> Result<Self> {
|
||||
let multiples = match model_size {
|
||||
"n" => Multiples::n(),
|
||||
"s" => Multiples::s(),
|
||||
|
@ -184,14 +183,13 @@ impl ModelPose {
|
|||
))?,
|
||||
};
|
||||
let dev = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
|
||||
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?;
|
||||
let model = YoloV8Pose::load(vb, multiples, 1, (17, 3))?;
|
||||
Ok(Self { model })
|
||||
}
|
||||
|
||||
pub fn load(md: ModelData) -> Result<Self> {
|
||||
Self::load_(&md.weights, &md.model_size.to_string())
|
||||
Self::load_(md.weights, &md.model_size.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue