fix dataset

This commit is contained in:
Dun Liang 2020-06-18 21:42:04 +08:00
parent adfa5dcdb5
commit 32abf3db0b
2 changed files with 4 additions and 8 deletions

View File

@ -26,13 +26,9 @@ def collate_batch(batch):
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, jt.Var):
if elem.ndim == 1:
temp_data = np.stack([data.numpy() for data in batch], 0)
temp_data = np.squeeze(temp_data, -1)
return jt.array(temp_data)
else:
temp_data = np.stack([data.numpy() for data in batch], 0)
return jt.array(temp_data)
# TODO: use jittor
temp_data = np.stack([data.data for data in batch], 0)
return temp_data
if elem_type is np.ndarray:
temp_data = np.stack([data for data in batch], 0)
return temp_data

View File

@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
setuptools.setup(
name='jittor',
version='1.1.4.8',
version='1.1.4.9',
# scripts=[],
author="Jittor Group",
author_email="ran.donglang@gmail.com",