forked from maxjhandsome/jittor
dataset statistic
This commit is contained in:
parent
0c50dbe0d9
commit
d9b76de87c
|
@ -15,11 +15,13 @@ from jittor.dataset.utils import get_random_list, get_order_list, collate_batch
|
|||
from collections.abc import Sequence, Mapping
|
||||
import pathlib
|
||||
from PIL import Image
|
||||
from jittor_utils import ring_buffer
|
||||
from jittor_utils.ring_buffer import RingBuffer
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
from jittor_utils import LOG
|
||||
import jittor as jt
|
||||
import time
|
||||
|
||||
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
|
||||
mp_log_v = os.environ.get("mp_log_v", 0)
|
||||
|
@ -29,7 +31,8 @@ class Worker:
|
|||
def __init__(self, target, args, buffer_size):
|
||||
buffer = mp.Array('c', buffer_size, lock=False)
|
||||
self.buffer = RingBuffer(buffer)
|
||||
self.p = mp.Process(target=target, args=args+(self.buffer,))
|
||||
self.status = mp.Array('f', 4, lock=False)
|
||||
self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
|
||||
|
@ -64,7 +67,8 @@ class Dataset(object):
|
|||
shuffle = False,
|
||||
drop_last = False,
|
||||
num_workers = 0,
|
||||
buffer_size = 512*1024*1024):
|
||||
buffer_size = 512*1024*1024,
|
||||
stop_grad = True):
|
||||
super().__init__()
|
||||
self.total_len = None
|
||||
self.batch_size = batch_size
|
||||
|
@ -72,6 +76,7 @@ class Dataset(object):
|
|||
self.drop_last = drop_last
|
||||
self.num_workers = num_workers
|
||||
self.buffer_size = buffer_size
|
||||
self.stop_grad = stop_grad
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
@ -90,6 +95,16 @@ class Dataset(object):
|
|||
Example::
|
||||
|
||||
dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)
|
||||
|
||||
Attrs:
|
||||
|
||||
* batch_size(int): batch size, default 16.
|
||||
* totol_len(int): totol lenght.
|
||||
* shuffle(bool): shuffle at each epoch, default False.
|
||||
* drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.
|
||||
* num_workers: number of workers for loading data
|
||||
* buffer_size: buffer size for each worker in bytes, default(512MB).
|
||||
* stop_grad: stop grad for data, default(True).
|
||||
'''
|
||||
for k,v in kw.items():
|
||||
assert hasattr(self, k), k
|
||||
|
@ -100,16 +115,17 @@ class Dataset(object):
|
|||
'''
|
||||
Change batch data to jittor array, such as np.ndarray, int, and float.
|
||||
'''
|
||||
|
||||
to_jt = lambda x: jt.array(x).stop_grad() \
|
||||
if self.stop_grad else jt.array(x)
|
||||
if isinstance(batch, np.ndarray):
|
||||
return jt.array(batch)
|
||||
return to_jt(batch)
|
||||
assert isinstance(batch, Sequence)
|
||||
new_batch = []
|
||||
for a in batch:
|
||||
if isinstance(a, np.ndarray) or \
|
||||
isinstance(a, int) or \
|
||||
isinstance(a, float):
|
||||
new_batch.append(jt.array(a))
|
||||
new_batch.append(to_jt(a))
|
||||
else:
|
||||
new_batch.append(a)
|
||||
return new_batch
|
||||
|
@ -133,11 +149,14 @@ class Dataset(object):
|
|||
for w in self.workers:
|
||||
w.p.terminate()
|
||||
|
||||
def _worker_main(self, worker_id, buffer):
|
||||
def _worker_main(self, worker_id, buffer, status):
|
||||
import time
|
||||
try:
|
||||
gid_obj = self.gid.get_obj()
|
||||
gid_lock = self.gid.get_lock()
|
||||
start = time.time()
|
||||
while True:
|
||||
# get id
|
||||
with gid_lock:
|
||||
while gid_obj.value >= self.batch_len:
|
||||
self.num_idle.value += 1
|
||||
|
@ -148,19 +167,51 @@ class Dataset(object):
|
|||
self.idmap[cid] = worker_id
|
||||
gid_obj.value += 1
|
||||
self.gidc.notify()
|
||||
now = time.time()
|
||||
other_time = now - start
|
||||
start = now
|
||||
|
||||
# load and transform data
|
||||
batch = []
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size))
|
||||
for i in range(cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)):
|
||||
batch.append(self[self.index_list[i]])
|
||||
batch = self.collate_batch(batch)
|
||||
now = time.time()
|
||||
data_time = now - start
|
||||
start = now
|
||||
|
||||
# send data to main process
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer)
|
||||
buffer.send(batch)
|
||||
now = time.time()
|
||||
send_time = now - start
|
||||
start = now
|
||||
status[0], status[1], status[2], status[3] = \
|
||||
other_time, data_time, send_time, \
|
||||
other_time + data_time + send_time
|
||||
except:
|
||||
os.kill(os.getppid(), signal.SIGINT)
|
||||
raise
|
||||
|
||||
def display_worker_status(self):
|
||||
if not hasattr(self, "workers"):
|
||||
return
|
||||
msg = [""]
|
||||
msg.append(f"progress:{self.last_id}/{self.batch_len}")
|
||||
msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}")
|
||||
msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}")
|
||||
msg.append(f"recv_raw_call: {ring_buffer.recv_raw_call}")
|
||||
msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id-9):self.last_id+1]}")
|
||||
msg.append(f"ID\twait(s)\tload(s)\tsend(s)")
|
||||
for i in range(self.num_workers):
|
||||
w = self.workers[i]
|
||||
s = w.status
|
||||
msg.append(f"#{i}\t{s[0]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer.allocator}")
|
||||
LOG.i('\n'.join(msg))
|
||||
|
||||
def _stop_all_workers(self):
|
||||
# wait until all workers idle
|
||||
if self.num_idle.value < self.num_workers:
|
||||
|
@ -269,6 +320,8 @@ class Dataset(object):
|
|||
with gid_lock:
|
||||
gid_obj.value = 0
|
||||
self.gidc.notify_all()
|
||||
start = time.time()
|
||||
self.batch_time = 0
|
||||
for i in range(self.batch_len):
|
||||
# try not get lock first
|
||||
if gid_obj.value <= i:
|
||||
|
@ -277,15 +330,32 @@ class Dataset(object):
|
|||
if mp_log_v:
|
||||
print("wait")
|
||||
self.gidc.wait()
|
||||
now = time.time()
|
||||
self.wait_time = now - start
|
||||
start = now
|
||||
|
||||
self.last_id = i
|
||||
worker_id = self.idmap[i]
|
||||
w = self.workers[worker_id]
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
|
||||
batch = w.buffer.recv()
|
||||
now = time.time()
|
||||
self.recv_time = now - start
|
||||
start = now
|
||||
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
|
||||
batch = self.to_jittor(batch)
|
||||
now = time.time()
|
||||
self.to_jittor_time = now - start
|
||||
start = now
|
||||
|
||||
yield batch
|
||||
|
||||
now = time.time()
|
||||
self.batch_time = now - start
|
||||
start = now
|
||||
else:
|
||||
batch_data = []
|
||||
for idx in index_list:
|
||||
|
|
|
@ -12,6 +12,8 @@ import random
|
|||
import pickle
|
||||
import ctypes
|
||||
|
||||
recv_raw_call = 0.0
|
||||
|
||||
class RingBufferAllocator:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
@ -28,9 +30,9 @@ class RingBufferAllocator:
|
|||
if is_full:
|
||||
cap = 0
|
||||
else:
|
||||
cap = int((r - l) / self.size * 100)
|
||||
if cap<=0: cap += 100
|
||||
return f"Buffer(free={int(cap)}%)"
|
||||
cap = (r - l) / self.size
|
||||
if cap<=0: cap += 1
|
||||
return f"Buffer(free={cap*100:.3f}% l={l} r={r} size={self.size})"
|
||||
|
||||
def alloc_with_lock(self, size):
|
||||
with self.lock:
|
||||
|
@ -231,6 +233,8 @@ class RingBuffer:
|
|||
assert window.nbytes == data.nbytes
|
||||
|
||||
def recv_raw(self, nbytes, shape, dtype):
|
||||
global recv_raw_call
|
||||
recv_raw_call += 1
|
||||
with self.allocator.lock:
|
||||
location = self.allocator.free(nbytes)
|
||||
while location is None:
|
||||
|
|
Loading…
Reference in New Issue