Add a python function to save as safetensors. (#740)
This commit is contained in:
parent
ab0d9fbdd1
commit
000487c36f
|
@ -1,5 +1,4 @@
|
||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
// TODO: Handle negative dimension indexes.
|
|
||||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::{IntoPyDict, PyTuple};
|
use pyo3::types::{IntoPyDict, PyTuple};
|
||||||
|
@ -714,6 +713,18 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
Ok(res.into_py_dict(py).to_object(py))
|
Ok(res.into_py_dict(py).to_object(py))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn save_safetensors(
|
||||||
|
path: &str,
|
||||||
|
tensors: std::collections::HashMap<String, PyTensor>,
|
||||||
|
) -> PyResult<()> {
|
||||||
|
let tensors = tensors
|
||||||
|
.into_iter()
|
||||||
|
.map(|(s, t)| (s, t.0))
|
||||||
|
.collect::<std::collections::HashMap<_, _>>();
|
||||||
|
::candle::safetensors::save(&tensors, path).map_err(wrap_err)
|
||||||
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
|
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
|
||||||
let mut file = std::fs::File::open(path)?;
|
let mut file = std::fs::File::open(path)?;
|
||||||
|
@ -867,6 +878,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_function(wrap_pyfunction!(rand, m)?)?;
|
m.add_function(wrap_pyfunction!(rand, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(randn, m)?)?;
|
m.add_function(wrap_pyfunction!(randn, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(tensor, m)?)?;
|
m.add_function(wrap_pyfunction!(tensor, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(stack, m)?)?;
|
m.add_function(wrap_pyfunction!(stack, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(zeros, m)?)?;
|
m.add_function(wrap_pyfunction!(zeros, m)?)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
Loading…
Reference in New Issue