forked from maxjhandsome/jittor
load pth support
This commit is contained in:
parent
3274fade1f
commit
e5265505df
|
@ -580,6 +580,7 @@ class Module:
|
|||
return ", ".join(ss)
|
||||
|
||||
def load_parameters(self, params):
|
||||
n_failed = 0
|
||||
for key in params.keys():
|
||||
v = self
|
||||
key_ = key.split('.')
|
||||
|
@ -598,16 +599,21 @@ class Module:
|
|||
end = 1
|
||||
break
|
||||
if end ==1:
|
||||
# print(f'init {key} fail ...')
|
||||
n_failed += 1
|
||||
LOG.w(f'load parameter {key} failed ...')
|
||||
pass
|
||||
else:
|
||||
# print(f'init {key} success ...')
|
||||
LOG.v(f'load parameter {key} success ...')
|
||||
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
|
||||
v.assign(array(params[key]))
|
||||
elif isinstance(params[key], Var):
|
||||
v.assign(params[key])
|
||||
else:
|
||||
v.assign(array(params[key].cpu( ).detach().numpy()))
|
||||
# assume is pytorch tensor
|
||||
v.assign(array(params[key].cpu().detach().numpy()))
|
||||
if n_failed:
|
||||
LOG.w(f"load total {len(params)} params, {n_failed} failed")
|
||||
|
||||
def save(self, path):
|
||||
params = self.parameters()
|
||||
params_dict = {}
|
||||
|
@ -617,6 +623,14 @@ class Module:
|
|||
pickle.dump(params_dict, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def load(self, path):
|
||||
if path.endswith(".pth"):
|
||||
try:
|
||||
dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
raise RuntimeError("pytorch need to be installed when load pth format.")
|
||||
self.load_parameters(torch.load(path, map_location=torch.device('cpu')))
|
||||
return
|
||||
with open(path, 'rb') as f:
|
||||
self.load_parameters(pickle.load(f))
|
||||
|
||||
|
|
5
setup.py
5
setup.py
|
@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
|
|||
|
||||
setuptools.setup(
|
||||
name='jittor',
|
||||
version='1.1.3',
|
||||
version='1.1.3.1',
|
||||
# scripts=[],
|
||||
author="Jittor Group",
|
||||
author_email="ran.donglang@gmail.com",
|
||||
|
@ -44,3 +44,6 @@ setuptools.setup(
|
|||
"astunparse",
|
||||
],
|
||||
)
|
||||
|
||||
# python3.7 setup.py sdist
|
||||
# python3.7 -m twine upload dist/*
|
||||
|
|
Loading…
Reference in New Issue