Pass directly the buffer ownership. (#949)

This commit is contained in:
Laurent Mazare 2023-09-24 06:34:44 +01:00 committed by GitHub
parent e32c89d90c
commit 7edd755756
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 12 additions and 18 deletions

View File

@ -18,8 +18,7 @@ impl Model {
console_error_panic_hook::set_once(); console_error_panic_hook::set_once();
console_log!("loading model"); console_log!("loading model");
let device = &Device::Cpu; let device = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?; let vb = VarBuilder::from_buffered_safetensors(weights, DType::F64, device)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F64, device);
let config: Config = serde_json::from_slice(&config)?; let config: Config = serde_json::from_slice(&config)?;
let tokenizer = let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;

View File

@ -29,8 +29,7 @@ impl ModelConditionalGeneration {
console_error_panic_hook::set_once(); console_error_panic_hook::set_once();
console_log!("loading model"); console_log!("loading model");
let device = &Device::Cpu; let device = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?; let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let mut config: Config = serde_json::from_slice(&config)?; let mut config: Config = serde_json::from_slice(&config)?;
let tokenizer = let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
@ -128,8 +127,7 @@ impl ModelEncoder {
console_error_panic_hook::set_once(); console_error_panic_hook::set_once();
console_log!("loading model"); console_log!("loading model");
let device = &Device::Cpu; let device = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?; let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let mut config: Config = serde_json::from_slice(&config)?; let mut config: Config = serde_json::from_slice(&config)?;
config.use_cache = false; config.use_cache = false;
let tokenizer = let tokenizer =

View File

@ -253,8 +253,7 @@ impl Decoder {
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?; let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
console_log!("loaded mel filters {:?}", mel_filters.shape()); console_log!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?; let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let weights = safetensors::tensor::SafeTensors::deserialize(&md.weights)?; let vb = VarBuilder::from_buffered_safetensors(md.weights, DTYPE, &device)?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let config = Config::tiny_en(); let config = Config::tiny_en();
let whisper = Whisper::load(&vb, config)?; let whisper = Whisper::load(&vb, config)?;
console_log!("done loading model"); console_log!("done loading model");

View File

@ -13,7 +13,7 @@ pub struct Model {
impl Model { impl Model {
#[wasm_bindgen(constructor)] #[wasm_bindgen(constructor)]
pub fn new(data: Vec<u8>, model_size: &str) -> Result<Model, JsError> { pub fn new(data: Vec<u8>, model_size: &str) -> Result<Model, JsError> {
let inner = M::load_(&data, model_size)?; let inner = M::load_(data, model_size)?;
Ok(Self { inner }) Ok(Self { inner })
} }
@ -46,7 +46,7 @@ pub struct ModelPose {
impl ModelPose { impl ModelPose {
#[wasm_bindgen(constructor)] #[wasm_bindgen(constructor)]
pub fn new(data: Vec<u8>, model_size: &str) -> Result<ModelPose, JsError> { pub fn new(data: Vec<u8>, model_size: &str) -> Result<ModelPose, JsError> {
let inner = P::load_(&data, model_size)?; let inner = P::load_(data, model_size)?;
Ok(Self { inner }) Ok(Self { inner })
} }

View File

@ -92,7 +92,7 @@ impl Model {
Ok(bboxes) Ok(bboxes)
} }
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> { pub fn load_(weights: Vec<u8>, model_size: &str) -> Result<Self> {
let multiples = match model_size { let multiples = match model_size {
"n" => Multiples::n(), "n" => Multiples::n(),
"s" => Multiples::s(), "s" => Multiples::s(),
@ -104,14 +104,13 @@ impl Model {
))?, ))?,
}; };
let dev = &Device::Cpu; let dev = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?; let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
let model = YoloV8::load(vb, multiples, 80)?; let model = YoloV8::load(vb, multiples, 80)?;
Ok(Self { model }) Ok(Self { model })
} }
pub fn load(md: ModelData) -> Result<Self> { pub fn load(md: ModelData) -> Result<Self> {
Self::load_(&md.weights, &md.model_size.to_string()) Self::load_(md.weights, &md.model_size.to_string())
} }
} }
@ -172,7 +171,7 @@ impl ModelPose {
Ok(bboxes) Ok(bboxes)
} }
pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> { pub fn load_(weights: Vec<u8>, model_size: &str) -> Result<Self> {
let multiples = match model_size { let multiples = match model_size {
"n" => Multiples::n(), "n" => Multiples::n(),
"s" => Multiples::s(), "s" => Multiples::s(),
@ -184,14 +183,13 @@ impl ModelPose {
))?, ))?,
}; };
let dev = &Device::Cpu; let dev = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(weights)?; let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, dev)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
let model = YoloV8Pose::load(vb, multiples, 1, (17, 3))?; let model = YoloV8Pose::load(vb, multiples, 1, (17, 3))?;
Ok(Self { model }) Ok(Self { model })
} }
pub fn load(md: ModelData) -> Result<Self> { pub fn load(md: ModelData) -> Result<Self> {
Self::load_(&md.weights, &md.model_size.to_string()) Self::load_(md.weights, &md.model_size.to_string())
} }
} }