load pth support

This commit is contained in:
Dun Liang 2020-05-18 13:19:05 +08:00
parent 3274fade1f
commit e5265505df
2 changed files with 21 additions and 4 deletions

View File

@ -580,6 +580,7 @@ class Module:
return ", ".join(ss)
def load_parameters(self, params):
n_failed = 0
for key in params.keys():
v = self
key_ = key.split('.')
@ -598,16 +599,21 @@ class Module:
end = 1
break
if end ==1:
# print(f'init {key} fail ...')
n_failed += 1
LOG.w(f'load parameter {key} failed ...')
pass
else:
# print(f'init {key} success ...')
LOG.v(f'load parameter {key} success ...')
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
v.assign(array(params[key]))
elif isinstance(params[key], Var):
v.assign(params[key])
else:
v.assign(array(params[key].cpu( ).detach().numpy()))
# assume is pytorch tensor
v.assign(array(params[key].cpu().detach().numpy()))
if n_failed:
LOG.w(f"load total {len(params)} params, {n_failed} failed")
def save(self, path):
params = self.parameters()
params_dict = {}
@ -617,6 +623,14 @@ class Module:
pickle.dump(params_dict, f, pickle.HIGHEST_PROTOCOL)
def load(self, path):
if path.endswith(".pth"):
try:
dirty_fix_pytorch_runtime_error()
import torch
except:
raise RuntimeError("pytorch need to be installed when load pth format.")
self.load_parameters(torch.load(path, map_location=torch.device('cpu')))
return
with open(path, 'rb') as f:
self.load_parameters(pickle.load(f))

View File

@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
setuptools.setup(
name='jittor',
version='1.1.3',
version='1.1.3.1',
# scripts=[],
author="Jittor Group",
author_email="ran.donglang@gmail.com",
@ -44,3 +44,6 @@ setuptools.setup(
"astunparse",
],
)
# python3.7 setup.py sdist
# python3.7 -m twine upload dist/*