update to 1.2.2.20

This commit is contained in:
Dun Liang 2021-01-22 19:07:27 +08:00
parent eec9a44038
commit 4297ba6fde
13 changed files with 100 additions and 45 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.19'
__version__ = '1.2.2.20'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -852,6 +852,11 @@ with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core()
python_path = sys.executable
# something python do not return the correct sys executable
# this will happend when multiple python version installed
ex_python_path = python_path + '.' + str(sys.version_info.minor)
if os.path.isfile(ex_python_path):
python_path = ex_python_path
py3_config_path = jit_utils.py3_config_path
nvcc_path = env_or_try_find('nvcc_path', '/usr/local/cuda/bin/nvcc')

View File

@ -162,6 +162,12 @@ class Dataset(object):
def _worker_main(self, worker_id, buffer, status):
import jittor_utils
jittor_utils.cc.init_subprocess()
jt.jt_init_subprocess()
# parallel_op_compiler still problematic,
# it is not work on ubuntu 16.04. but worked on ubuntu 20.04
# it seems like the static value of parallel compiler
# is not correctly init.
jt.flags.use_parallel_op_compiler = 0
import time
try:
gid_obj = self.gid.get_obj()
@ -293,6 +299,8 @@ Example::
w.buffer.clear()
def _init_workers(self):
jt.clean()
jt.gc()
self.index_list = mp.Array('i', self.real_len, lock=False)
workers = []
# batch id to worker id

View File

@ -18,12 +18,15 @@ if has_cupy:
import jittor as jt
import os
import ctypes
cupy_device = cp.cuda.Device(jt.mpi.local_rank())
device_num = 0
if jt.mpi:
device_num = jt.mpi.local_rank()
cupy_device = cp.cuda.Device(device_num)
cupy_device.__enter__()
def cvt(a):
a_pointer, read_only_flag = a.__array_interface__['data']
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a, jt.mpi.local_rank()),0)
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a, device_num),0)
a = cp.ndarray(a.shape,a.dtype,aptr)
return a

View File

@ -16,8 +16,6 @@ from torch.autograd import Variable
class TestCumprod(unittest.TestCase):
def test_cumprod_cpu(self):
jt.flags.use_cuda = 0
for i in range(1,6):
for j in range(i):
x = np.random.rand(*((10,)*i))
@ -31,21 +29,10 @@ class TestCumprod(unittest.TestCase):
assert np.allclose(y_jt.numpy(), y_tc.data)
assert np.allclose(g_jt.numpy(), g_tc.data)
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_cumprod_gpu(self):
jt.flags.use_cuda = 1
for i in range(1,6):
for j in range(i):
x = np.random.rand(*((10,)*i))
x_jt = jt.array(x)
y_jt = jt.cumprod(x_jt, j).sqr()
g_jt = jt.grad(y_jt.sum(), x_jt)
x_tc = Variable(torch.from_numpy(x), requires_grad=True)
y_tc = torch.cumprod(x_tc, j)**2
y_tc.sum().backward()
g_tc = x_tc.grad
assert np.allclose(y_jt.numpy(), y_tc.data)
assert np.allclose(g_jt.numpy(), g_tc.data)
self.test_cumprod_cpu()
if __name__ == "__main__":
unittest.main()

View File

@ -8,7 +8,7 @@
# ***************************************************************
import unittest
import jittor as jt
from jittor.dataset.dataset import ImageFolder
from jittor.dataset.dataset import ImageFolder, Dataset
import jittor.transform as transform
import jittor as jt
@ -76,5 +76,32 @@ class TestDataset(unittest.TestCase):
assert isinstance(batch[1], np.ndarray)
class TestDataset2(unittest.TestCase):
def test_dataset_use_jittor(self):
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=10240)
def __getitem__(self, k):
x = jt.array(k)
y = x
for i in range(10):
for j in range(i+2):
y = y + j - j
y.stop_fuse()
return x, y
dataset = YourDataset().set_attrs(batch_size=256, shuffle=True, num_workers=4)
for x, y in dataset:
# dataset.display_worker_status()
pass
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_dataset_use_jittor_cuda(self):
self.test_dataset_use_jittor()
if __name__ == "__main__":
unittest.main()

View File

@ -21,7 +21,7 @@ def test_ring_buffer():
print("test send recv", type(data))
buffer.push(data)
recv = buffer.pop()
if isinstance(data, np.ndarray):
if isinstance(data, (np.ndarray, jt.Var)):
assert (recv == data).all()
else:
assert data == recv
@ -63,6 +63,8 @@ def test_ring_buffer():
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
test_send_recv(test_ring_buffer)
test_send_recv(jt.array(np.random.rand(10,10)))
expect_error(lambda: test_send_recv(np.random.rand(10,1000)))

View File

@ -19,6 +19,7 @@ try:
import torch
import torch.nn as tnn
import torchvision
from torch.autograd import Variable
except:
torch = None
tnn = None
@ -26,7 +27,7 @@ except:
skip_this_test = True
# TODO: more test
# @unittest.skipIf(skip_this_test, "No Torch found")
@unittest.skipIf(skip_this_test, "No Torch found")
class TestSearchSorted(unittest.TestCase):
def test_origin(self):
sorted = jt.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]])
@ -46,6 +47,28 @@ class TestSearchSorted(unittest.TestCase):
def test_cuda(self):
self.test_origin()
def test_searchsorted_cpu(self):
for i in range(1,3):
s = np.sort(np.random.rand(*((10,)*i)),-1)
v = np.random.rand(*((10,)*i))
s_jt = jt.array(s)
v_jt = jt.array(v)
s_tc = torch.from_numpy(s)
v_tc = torch.from_numpy(v)
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
y_tc = torch.searchsorted(s_tc, v_tc, right=True)
assert np.allclose(y_jt.numpy(), y_tc.data)
y_jt = jt.searchsorted(s_jt, v_jt, right=False)
y_tc = torch.searchsorted(s_tc, v_tc, right=False)
assert np.allclose(y_jt.numpy(), y_tc.data)
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_searchsorted_gpu(self):
self.test_searchsorted_cpu()
if __name__ == "__main__":

View File

@ -16,8 +16,6 @@ from torch.autograd import Variable
class TestSearchsorted(unittest.TestCase):
def test_searchsorted_cpu(self):
jt.flags.use_cuda = 0
for i in range(1,3):
s = np.sort(np.random.rand(*((10,)*i)),-1)
v = np.random.rand(*((10,)*i))
@ -26,30 +24,17 @@ class TestSearchsorted(unittest.TestCase):
s_tc = torch.from_numpy(s)
v_tc = torch.from_numpy(v)
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
y_tc = torch.searchsorted(s_tc, v_tc, right=True)
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
assert np.allclose(y_jt.numpy(), y_tc.data)
y_jt = jt.searchsorted(s_jt, v_jt, right=False)
y_tc = torch.searchsorted(s_tc, v_tc, right=False)
assert np.allclose(y_jt.numpy(), y_tc.data)
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_searchsorted_gpu(self):
jt.flags.use_cuda = 1
for i in range(1,3):
s = np.sort(np.random.rand(*((10,)*i)),-1)
v = np.random.rand(*((10,)*i))
s_jt = jt.array(s)
v_jt = jt.array(v)
s_tc = torch.from_numpy(s)
v_tc = torch.from_numpy(v)
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
y_tc = torch.searchsorted(s_tc, v_tc, right=True)
assert np.allclose(y_jt.numpy(), y_tc.data)
y_jt = jt.searchsorted(s_jt, v_jt, right=False)
y_tc = torch.searchsorted(s_tc, v_tc, right=False)
assert np.allclose(y_jt.numpy(), y_tc.data)
self.test_searchsorted_cpu()
if __name__ == "__main__":
unittest.main()

View File

@ -3,10 +3,12 @@ For other OS, use Jittor may be risky.
If you insist on installing, please set the environment variable : export FORCE_INSTALL=1
We strongly recommended docker installation:
# CPU only
# CPU only(Linux)
>>> docker run -it --network host jittor/jittor
# CPU and CUDA
# CPU and CUDA(Linux)
>>> docker run -it --network host jittor/jittor-cuda
# CPU only(Mac and Windows)
>>> docker run -it -p 8888:8888 jittor/jittor
Reference:
1. Windows/Mac/Linux install Jittor in Docker: https://cg.cs.tsinghua.edu.cn/jittor/tutorial/2020-5-15-00-00-docker/

View File

@ -7,6 +7,7 @@
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include "helper_cuda.h"
#include "misc/cuda_flags.h"
#endif
#include <random>
@ -14,6 +15,7 @@
#include "ops/op_register.h"
#include "var.h"
#include "op.h"
#include "executor.h"
namespace jittor {
@ -78,4 +80,11 @@ void add_set_seed_callback(set_seed_callback callback) {
std::default_random_engine* get_random_engine() { return eng.get(); }
void jt_init_subprocess() {
#ifdef HAS_CUDA
use_cuda = 0;
exe.last_is_cuda = false;
#endif
}
}

View File

@ -25,4 +25,7 @@ std::default_random_engine* get_random_engine();
// @pyjt(cleanup)
void cleanup();
// @pyjt(jt_init_subprocess)
void jt_init_subprocess();
} // jittor

View File

@ -158,7 +158,8 @@ void send_log(std::ostringstream&& out) {
mwsr_list_log::push(move(out));
} else {
std::lock_guard<std::mutex> lk(sync_log_m);
std::cerr << "[SYNC]" << out.str();
// std::cerr << "[SYNC]";
std::cerr << out.str();
std::cerr.flush();
}
}
@ -256,7 +257,7 @@ void stream_hash(uint64_t& hash, char c) {
DEFINE_FLAG(int, log_silent, 0, "The log will be completely silent.");
DEFINE_FLAG(int, log_v, 0, "Verbose level of logging");
DEFINE_FLAG(int, log_sync, 0, "Set log printed synchronously.");
DEFINE_FLAG(int, log_sync, 1, "Set log printed synchronously.");
DEFINE_FLAG_WITH_SETTER(string, log_vprefix, "",
"Verbose level of logging prefix\n"
"example: log_vprefix='op=1,node=2,executor.cc:38$=1000'");