forked from maxjhandsome/jittor
fix dataset
This commit is contained in:
parent
adfa5dcdb5
commit
32abf3db0b
|
@ -26,13 +26,9 @@ def collate_batch(batch):
|
|||
elem = batch[0]
|
||||
elem_type = type(elem)
|
||||
if isinstance(elem, jt.Var):
|
||||
if elem.ndim == 1:
|
||||
temp_data = np.stack([data.numpy() for data in batch], 0)
|
||||
temp_data = np.squeeze(temp_data, -1)
|
||||
return jt.array(temp_data)
|
||||
else:
|
||||
temp_data = np.stack([data.numpy() for data in batch], 0)
|
||||
return jt.array(temp_data)
|
||||
# TODO: use jittor
|
||||
temp_data = np.stack([data.data for data in batch], 0)
|
||||
return temp_data
|
||||
if elem_type is np.ndarray:
|
||||
temp_data = np.stack([data for data in batch], 0)
|
||||
return temp_data
|
||||
|
|
Loading…
Reference in New Issue