Pass directly the buffer ownership. (#949)
This commit is contained in:
parent
e32c89d90c
commit
7edd755756
|
@ -18,8 +18,7 @@ impl Model {
|
||||||
console_error_panic_hook::set_once();
|
console_error_panic_hook::set_once();
|
||||||
console_log!("loading model");
|
console_log!("loading model");
|
||||||
let device = &Device::Cpu;
|
let device = &Device::Cpu;
|
||||||
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
|
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F64, device)?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F64, device);
|
|
||||||
let config: Config = serde_json::from_slice(&config)?;
|
let config: Config = serde_json::from_slice(&config)?;
|
||||||
let tokenizer =
|
let tokenizer =
|
||||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||||
|
|
|
@ -29,8 +29,7 @@ impl ModelConditionalGeneration {
|
||||||
console_error_panic_hook::set_once();
|
console_error_panic_hook::set_once();
|
||||||
console_log!("loading model");
|
console_log!("loading model");
|
||||||
let device = &Device::Cpu;
|
let device = &Device::Cpu;
|
||||||
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
|
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
|
||||||
let mut config: Config = serde_json::from_slice(&config)?;
|
let mut config: Config = serde_json::from_slice(&config)?;
|
||||||
let tokenizer =
|
let tokenizer =
|
||||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||||
|
@ -128,8 +127,7 @@ impl ModelEncoder {
|
||||||
console_error_panic_hook::set_once();
|
console_error_panic_hook::set_once();
|
||||||
console_log!("loading model");
|
console_log!("loading model");
|
||||||
let device = &Device::Cpu;
|
let device = &Device::Cpu;
|
||||||
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
|
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
|
||||||
let mut config: Config = serde_json::from_slice(&config)?;
|
let mut config: Config = serde_json::from_slice(&config)?;
|
||||||
config.use_cache = false;
|
config.use_cache = false;
|
||||||
let tokenizer =
|
let tokenizer =
|
||||||
|
|
|
@ -253,8 +253,7 @@ impl Decoder {
|
||||||
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
||||||
console_log!("loaded mel filters {:?}", mel_filters.shape());
|
console_log!("loaded mel filters {:?}", mel_filters.shape());
|
||||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||||
let weights = safetensors::tensor::SafeTensors::deserialize(&md.weights)?;
|
let vb = VarBuilder::from_buffered_safetensors(md.weights, DTYPE, &device)?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
|
||||||
let config = Config::tiny_en();
|
let config = Config::tiny_en();
|
||||||
let whisper = Whisper::load(&vb, config)?;
|
let whisper = Whisper::load(&vb, config)?;
|
||||||
console_log!("done loading model");
|
console_log!("done loading model");
|
||||||
|
|
|
@ -13,7 +13,7 @@ pub struct Model {
|
||||||
impl Model {
|
impl Model {
|
||||||
#[wasm_bindgen(constructor)]
|
#[wasm_bindgen(constructor)]
|
||||||
pub fn new(data: Vec<u8>, model_size: &str) -> Result<Model, JsError> {
|
pub fn new(data: Vec<u8>, model_size: &str) -> Result<Model, JsError> {
|
||||||
let inner = M::load_(&data, model_size)?;
|
let inner = M::load_(data, model_size)?;
|
||||||
Ok(Self { inner })
|
Ok(Self { inner })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ pub struct ModelPose {
|
||||||
impl ModelPose {
|
impl ModelPose {
|
||||||
#[wasm_bindgen(constructor)]
|
#[wasm_bindgen(constructor)]
|
||||||
pub fn new(data: Vec<u8>, model_size: &str) -> Result<ModelPose, JsError> {
|
pub fn new(data: Vec<u8>, model_size: &str) -> Result<ModelPose, JsError> {
|
||||||
let inner = P::load_(&data, model_size)?;
|
let inner = P::load_(data, model_size)?;
|
||||||
Ok(Self { inner })
|
Ok(Self { inner })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -92,7 +92,7 @@ impl Model {
|
||||||
Ok(bboxes)
|
Ok(bboxes)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> {
|
pub fn load_(weights: Vec<u8>, model_size: &str) -> Result<Self> {
|
||||||
let multiples = match model_size {
|
let multiples = match model_size {
|
||||||
"n" => Multiples::n(),
|
"n" => Multiples::n(),
|
||||||
"s" => Multiples::s(),
|
"s" => Multiples::s(),
|
||||||
|
@ -104,14 +104,13 @@ impl Model {
|
||||||
))?,
|
))?,
|
||||||
};
|
};
|
||||||
let dev = &Device::Cpu;
|
let dev = &Device::Cpu;
|
||||||
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
|
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
|
|
||||||
let model = YoloV8::load(vb, multiples, 80)?;
|
let model = YoloV8::load(vb, multiples, 80)?;
|
||||||
Ok(Self { model })
|
Ok(Self { model })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(md: ModelData) -> Result<Self> {
|
pub fn load(md: ModelData) -> Result<Self> {
|
||||||
Self::load_(&md.weights, &md.model_size.to_string())
|
Self::load_(md.weights, &md.model_size.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,7 +171,7 @@ impl ModelPose {
|
||||||
Ok(bboxes)
|
Ok(bboxes)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> {
|
pub fn load_(weights: Vec<u8>, model_size: &str) -> Result<Self> {
|
||||||
let multiples = match model_size {
|
let multiples = match model_size {
|
||||||
"n" => Multiples::n(),
|
"n" => Multiples::n(),
|
||||||
"s" => Multiples::s(),
|
"s" => Multiples::s(),
|
||||||
|
@ -184,14 +183,13 @@ impl ModelPose {
|
||||||
))?,
|
))?,
|
||||||
};
|
};
|
||||||
let dev = &Device::Cpu;
|
let dev = &Device::Cpu;
|
||||||
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
|
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
|
|
||||||
let model = YoloV8Pose::load(vb, multiples, 1, (17, 3))?;
|
let model = YoloV8Pose::load(vb, multiples, 1, (17, 3))?;
|
||||||
Ok(Self { model })
|
Ok(Self { model })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(md: ModelData) -> Result<Self> {
|
pub fn load(md: ModelData) -> Result<Self> {
|
||||||
Self::load_(&md.weights, &md.model_size.to_string())
|
Self::load_(md.weights, &md.model_size.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue