dataset statistic

This commit is contained in:
Dun Liang 2020-06-16 23:14:09 +08:00
parent 0c50dbe0d9
commit d9b76de87c
3 changed files with 84 additions and 10 deletions

View File

@ -15,11 +15,13 @@ from jittor.dataset.utils import get_random_list, get_order_list, collate_batch
from collections.abc import Sequence, Mapping
import pathlib
from PIL import Image
from jittor_utils import ring_buffer
from jittor_utils.ring_buffer import RingBuffer
import multiprocessing as mp
import signal
from jittor_utils import LOG
import jittor as jt
import time
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
mp_log_v = os.environ.get("mp_log_v", 0)
@ -29,7 +31,8 @@ class Worker:
def __init__(self, target, args, buffer_size):
buffer = mp.Array('c', buffer_size, lock=False)
self.buffer = RingBuffer(buffer)
self.p = mp.Process(target=target, args=args+(self.buffer,))
self.status = mp.Array('f', 4, lock=False)
self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
self.p.daemon = True
self.p.start()
@ -64,7 +67,8 @@ class Dataset(object):
shuffle = False,
drop_last = False,
num_workers = 0,
buffer_size = 512*1024*1024):
buffer_size = 512*1024*1024,
stop_grad = True):
super().__init__()
self.total_len = None
self.batch_size = batch_size
@ -72,6 +76,7 @@ class Dataset(object):
self.drop_last = drop_last
self.num_workers = num_workers
self.buffer_size = buffer_size
self.stop_grad = stop_grad
def __getitem__(self, index):
raise NotImplementedError
@ -90,6 +95,16 @@ class Dataset(object):
Example::
dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)
Attrs:
* batch_size(int): batch size, default 16.
* totol_len(int): totol lenght.
* shuffle(bool): shuffle at each epoch, default False.
* drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.
* num_workers: number of workers for loading data
* buffer_size: buffer size for each worker in bytes, default(512MB).
* stop_grad: stop grad for data, default(True).
'''
for k,v in kw.items():
assert hasattr(self, k), k
@ -100,16 +115,17 @@ class Dataset(object):
'''
Change batch data to jittor array, such as np.ndarray, int, and float.
'''
to_jt = lambda x: jt.array(x).stop_grad() \
if self.stop_grad else jt.array(x)
if isinstance(batch, np.ndarray):
return jt.array(batch)
return to_jt(batch)
assert isinstance(batch, Sequence)
new_batch = []
for a in batch:
if isinstance(a, np.ndarray) or \
isinstance(a, int) or \
isinstance(a, float):
new_batch.append(jt.array(a))
new_batch.append(to_jt(a))
else:
new_batch.append(a)
return new_batch
@ -133,11 +149,14 @@ class Dataset(object):
for w in self.workers:
w.p.terminate()
def _worker_main(self, worker_id, buffer):
def _worker_main(self, worker_id, buffer, status):
import time
try:
gid_obj = self.gid.get_obj()
gid_lock = self.gid.get_lock()
start = time.time()
while True:
# get id
with gid_lock:
while gid_obj.value >= self.batch_len:
self.num_idle.value += 1
@ -148,19 +167,51 @@ class Dataset(object):
self.idmap[cid] = worker_id
gid_obj.value += 1
self.gidc.notify()
now = time.time()
other_time = now - start
start = now
# load and transform data
batch = []
if mp_log_v:
print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size))
for i in range(cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)):
batch.append(self[self.index_list[i]])
batch = self.collate_batch(batch)
now = time.time()
data_time = now - start
start = now
# send data to main process
if mp_log_v:
print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer)
buffer.send(batch)
now = time.time()
send_time = now - start
start = now
status[0], status[1], status[2], status[3] = \
other_time, data_time, send_time, \
other_time + data_time + send_time
except:
os.kill(os.getppid(), signal.SIGINT)
raise
def display_worker_status(self):
if not hasattr(self, "workers"):
return
msg = [""]
msg.append(f"progress:{self.last_id}/{self.batch_len}")
msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}")
msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}")
msg.append(f"recv_raw_call: {ring_buffer.recv_raw_call}")
msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id-9):self.last_id+1]}")
msg.append(f"ID\twait(s)\tload(s)\tsend(s)")
for i in range(self.num_workers):
w = self.workers[i]
s = w.status
msg.append(f"#{i}\t{s[0]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer.allocator}")
LOG.i('\n'.join(msg))
def _stop_all_workers(self):
# wait until all workers idle
if self.num_idle.value < self.num_workers:
@ -269,6 +320,8 @@ class Dataset(object):
with gid_lock:
gid_obj.value = 0
self.gidc.notify_all()
start = time.time()
self.batch_time = 0
for i in range(self.batch_len):
# try not get lock first
if gid_obj.value <= i:
@ -277,15 +330,32 @@ class Dataset(object):
if mp_log_v:
print("wait")
self.gidc.wait()
now = time.time()
self.wait_time = now - start
start = now
self.last_id = i
worker_id = self.idmap[i]
w = self.workers[worker_id]
if mp_log_v:
print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
batch = w.buffer.recv()
now = time.time()
self.recv_time = now - start
start = now
if mp_log_v:
print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
batch = self.to_jittor(batch)
now = time.time()
self.to_jittor_time = now - start
start = now
yield batch
now = time.time()
self.batch_time = now - start
start = now
else:
batch_data = []
for idx in index_list:

View File

@ -12,6 +12,8 @@ import random
import pickle
import ctypes
recv_raw_call = 0.0
class RingBufferAllocator:
def __init__(self, size):
self.size = size
@ -28,9 +30,9 @@ class RingBufferAllocator:
if is_full:
cap = 0
else:
cap = int((r - l) / self.size * 100)
if cap<=0: cap += 100
return f"Buffer(free={int(cap)}%)"
cap = (r - l) / self.size
if cap<=0: cap += 1
return f"Buffer(free={cap*100:.3f}% l={l} r={r} size={self.size})"
def alloc_with_lock(self, size):
with self.lock:
@ -231,6 +233,8 @@ class RingBuffer:
assert window.nbytes == data.nbytes
def recv_raw(self, nbytes, shape, dtype):
global recv_raw_call
recv_raw_call += 1
with self.allocator.lock:
location = self.allocator.free(nbytes)
while location is None:

View File

@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
setuptools.setup(
name='jittor',
version='1.1.4.7',
version='1.1.4.8',
# scripts=[],
author="Jittor Group",
author_email="ran.donglang@gmail.com",