Fix VarBuilder::from_slice_safetensors (#2180)
Also implement SimpleBackend for SliceSafetensors Signed-off-by: Harry Stern <harry@harrystern.net>
This commit is contained in:
parent
21f82a5155
commit
13c64f6828
|
@ -422,6 +422,32 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'a> SimpleBackend for candle::safetensors::SliceSafetensors<'a> {
|
||||||
|
fn get(
|
||||||
|
&self,
|
||||||
|
s: Shape,
|
||||||
|
name: &str,
|
||||||
|
_: crate::Init,
|
||||||
|
dtype: DType,
|
||||||
|
dev: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
|
||||||
|
if tensor.shape() != &s {
|
||||||
|
Err(candle::Error::UnexpectedShape {
|
||||||
|
msg: format!("shape mismatch for {name}"),
|
||||||
|
expected: s,
|
||||||
|
got: tensor.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
Ok(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool {
|
||||||
|
self.get(name).is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a> VarBuilder<'a> {
|
impl<'a> VarBuilder<'a> {
|
||||||
/// Initializes a `VarBuilder` using a custom backend.
|
/// Initializes a `VarBuilder` using a custom backend.
|
||||||
///
|
///
|
||||||
|
@ -481,15 +507,15 @@ impl<'a> VarBuilder<'a> {
|
||||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
/// Initializes a `VarBuilder` from a binary buffer in the safetensor format.
|
||||||
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
||||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
/// Initializes a `VarBuilder` from a binary slice in the safetensor format.
|
||||||
pub fn from_slice_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
let tensors = candle::safetensors::SliceSafetensors::new(data)?;
|
||||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue