From 30fc5b32fd98223f7544bf3a365eab08aaacb7cc Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Mon, 7 Dec 2020 22:48:25 +0800 Subject: [PATCH] add system test --- .dockerignore | 1 + MANIFEST.in | 2 + extern/cuda/cub/ops/cub_where_op.cc | 2 + extern/cuda/cutt/ops/cutt_transpose_op.cc | 1 + python/jittor/__init__.py | 3 +- python/jittor/compile_extern.py | 6 +- python/jittor/misc.py | 2 +- python/jittor/models/res2net.py | 4 +- python/jittor/test/perf/perf.py | 221 ++++++++++++++++++ python/jittor/test/system/test_all.sh | 6 + .../test/system/test_cuda10.0_ubuntu16.04.sh | 41 ++++ .../test/system/test_cuda10.0_ubuntu18.04.sh | 41 ++++ .../test/system/test_cuda11.1_ubuntu16.04.sh | 41 ++++ .../test/system/test_cuda11.1_ubuntu18.04.sh | 41 ++++ .../test/system/test_cuda11.1_ubuntu20.04.sh | 39 ++++ .../test/system/test_nocuda_ubuntu18.04.sh | 40 ++++ python/jittor/test/test_mkl_conv_op.py | 4 +- src/executor.cc | 6 +- src/executor.h | 4 - src/ops/binary_op.cc | 3 + src/ops/where_op.cc | 2 + src/profiler/simple_profiler.h | 2 + src/utils/cache_compile.cc | 2 +- 23 files changed, 498 insertions(+), 16 deletions(-) create mode 100644 MANIFEST.in create mode 100644 python/jittor/test/perf/perf.py create mode 100644 python/jittor/test/system/test_all.sh create mode 100644 python/jittor/test/system/test_cuda10.0_ubuntu16.04.sh create mode 100644 python/jittor/test/system/test_cuda10.0_ubuntu18.04.sh create mode 100644 python/jittor/test/system/test_cuda11.1_ubuntu16.04.sh create mode 100644 python/jittor/test/system/test_cuda11.1_ubuntu18.04.sh create mode 100644 python/jittor/test/system/test_cuda11.1_ubuntu20.04.sh create mode 100644 python/jittor/test/system/test_nocuda_ubuntu18.04.sh diff --git a/.dockerignore b/.dockerignore index 26a123bd..775478a1 100644 --- a/.dockerignore +++ b/.dockerignore @@ -26,3 +26,4 @@ venv/ python/jittor.egg-info dist/ !doc/source/* +__data__ \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..d416f1cd --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +exclude __data__ +exclude __pycache__ \ No newline at end of file diff --git a/extern/cuda/cub/ops/cub_where_op.cc b/extern/cuda/cub/ops/cub_where_op.cc index deb6b08f..647e66e4 100644 --- a/extern/cuda/cub/ops/cub_where_op.cc +++ b/extern/cuda/cub/ops/cub_where_op.cc @@ -10,6 +10,8 @@ #include "cub_where_op.h" #ifdef JIT_cuda #include "executor.h" +#include +#include "helper_cuda.h" #include #include #include diff --git a/extern/cuda/cutt/ops/cutt_transpose_op.cc b/extern/cuda/cutt/ops/cutt_transpose_op.cc index e9b9582e..f6e14ddc 100644 --- a/extern/cuda/cutt/ops/cutt_transpose_op.cc +++ b/extern/cuda/cutt/ops/cutt_transpose_op.cc @@ -9,6 +9,7 @@ #include "cutt.h" #include "cutt_warper.h" #include "misc/stack_vector.h" +#include "helper_cuda.h" namespace jittor { diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 18fe895d..f2d02734 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -7,7 +7,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.1.3' +__version__ = '1.2.2.0' from . import lock with lock.lock_scope(): ori_int = int @@ -92,6 +92,7 @@ class log_capture_scope(_call_no_record_scope): print(logs) """ def __init__(self, **jt_flags): + jt_flags["use_parallel_op_compiler"] = 0 self.fs = flag_scope(**jt_flags) def __enter__(self): diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index ae07e47a..3f32d783 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -78,9 +78,9 @@ def setup_mkl(): def install_cub(root_folder): - url = "https://github.com/NVIDIA/cub/archive/1.11.0-rc1.tar.gz" - filename = "cub-1.11.0-rc1.tgz" - md5 = "f395687060bed7eaeb5fa8a689276ede" + url = "https://github.com/NVIDIA/cub/archive/1.11.0.tar.gz" + filename = "cub-1.11.0.tgz" + md5 = "97196a885598e40592100e1caaf3d5ea" fullname = os.path.join(root_folder, filename) dirname = os.path.join(root_folder, filename.replace(".tgz","")) diff --git a/python/jittor/misc.py b/python/jittor/misc.py index cdfcdda9..f026b9da 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -33,7 +33,7 @@ def __iter__(x): return result.__iter__() jt.Var.__iter__ = __iter__ -def all(x,dim): +def all(x, dim=[]): return x.all_(dim).bool() jt.Var.all = all diff --git a/python/jittor/models/res2net.py b/python/jittor/models/res2net.py index f10059f1..ce6569b9 100644 --- a/python/jittor/models/res2net.py +++ b/python/jittor/models/res2net.py @@ -175,10 +175,10 @@ class Res2Net(Module): x = self.layer4(x) return x, low_level_feat -def res2net50(output_stride): +def res2net50(output_stride=16): model = Res2Net(Bottle2neck, [3,4,6,3], output_stride) return model -def res2net101(output_stride): +def res2net101(output_stride=16): model = Res2Net(Bottle2neck, [3,4,23,3], output_stride) return model diff --git a/python/jittor/test/perf/perf.py b/python/jittor/test/perf/perf.py new file mode 100644 index 00000000..753b5383 --- /dev/null +++ b/python/jittor/test/perf/perf.py @@ -0,0 +1,221 @@ +import sys, os + +suffix = "" + +import jittor as jt +import time +from pathlib import Path +home_path = str(Path.home()) +perf_path = os.path.join(home_path, ".cache", "jittor_perf") + +def main(): + os.makedirs(perf_path+"/src/jittor", exist_ok=True) + os.makedirs(perf_path+"/src/jittor_utils", exist_ok=True) + os.system(f"cp -rL {jt.flags.jittor_path} {perf_path+'/src/'}") + os.system(f"cp -rL {jt.flags.jittor_path}/../jittor_utils {perf_path+'/src/'}") + use_torch_1_4 = os.environ.get("use_torch_1_4", "0") == "1" + dockerfile_src = r""" +FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install \ + pybind11 \ + numpy \ + tqdm \ + pillow \ + astunparse + +RUN pip3 install torch torchvision +""" + global suffix + if use_torch_1_4: + suffix = "_1_4" + dockerfile_src = dockerfile_src.replace("torch ", "torch==1.4.0 ") + dockerfile_src = dockerfile_src.replace("torchvision", "torchvision==0.5.0") + with open("/tmp/perf_dockerfile", 'w') as f: + f.write(dockerfile_src) + assert os.system("sudo nvidia-smi -lgc 1500") == 0 + assert os.system(f"sudo docker build --tag jittor/jittor-perf{suffix} -f /tmp/perf_dockerfile .") == 0 + # run once for compile source + jt_fps = test_main("jittor", "resnet50", 1) + + logs = "" + # resnext50_32x4d with bs=8 cannot pass this test + #### inference test + for model_name in ["resnet50", "wide_resnet50_2", # "resnext50_32x4d", + "resnet152", "wide_resnet101_2", "resnext101_32x8d", + "alexnet", "vgg11", "squeezenet1_1", "mobilenet_v2", + "densenet121", "densenet169", "densenet201", + "res2net50", "res2net101"]: + for bs in [1, 2, 4, 8, 16, 32, 64, 128]: + jt_fps = test_main("jittor", model_name, bs) + logs += f"jittor-{model_name}-{bs} {jt_fps}\n" + tc_fps = test_main("torch", model_name, bs) + logs += f"torch-{model_name}-{bs} {tc_fps}\n" + logs += f"compare-{model_name}-{bs} {jt_fps/tc_fps}\n" + print(logs) + #### train test + for model_name in ["train_resnet50", "train_resnet101" + ]: + for bs in [1, 2, 4, 8, 16, 32, 64, 128]: + jt_fps = test_main("jittor", model_name, bs) + logs += f"jittor-{model_name}-{bs} {jt_fps}\n" + tc_fps = test_main("torch", model_name, bs) + logs += f"torch-{model_name}-{bs} {tc_fps}\n" + logs += f"compare-{model_name}-{bs} {jt_fps/tc_fps}\n" + print(logs) + with open(f"{perf_path}/jittor-perf{suffix}-latest.txt", "w") as f: + f.write(logs) + from datetime import datetime + with open(f"{perf_path}/jittor-perf{suffix}-{datetime.now()}.txt", "w") as f: + f.write(logs) + +def test_main(name, model_name, bs): + cmd = f"sudo docker run --gpus all --rm -v {perf_path}:/root/.cache/jittor --network host jittor/jittor-perf{suffix} bash -c 'PYTHONPATH=/root/.cache/jittor/src python3.7 /root/.cache/jittor/src/jittor/test/perf/perf.py {name} {model_name} {bs}'" + fps = -1 + try: + print("run cmd:", cmd) + if os.system(cmd) == 0: + with open(f"{perf_path}/{name}-{model_name}-{bs}.txt", 'r') as f: + fps = float(f.read().split()[3]) + except: + pass + return fps + +def time_iter(duration=2, min_iter=5): + start = time.time() + for i in range(10000000): + yield i + end = time.time() + if end-start>duration and i>=min_iter: + return + +def test(name, model_name, bs): + print("hello", name, model_name, bs) + import numpy as np + import time + is_train = False + _model_name = model_name + if model_name.startswith("train_"): + is_train = True + model_name = model_name[6:] + if name == "torch": + import torch + import torchvision.models as tcmodels + from torch import optim + from torch import nn + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + model = tcmodels.__dict__[model_name]() + model = model.cuda() + else: + import jittor as jt + from jittor import optim + from jittor import nn + jt.flags.use_cuda = 1 + jt.cudnn.set_algorithm_cache_size(10000) + import jittor.models as jtmodels + model = jtmodels.__dict__[model_name]() + if (model == "resnet152" or model == "resnet101") and bs == 128 and is_train: + jt.cudnn.set_max_workspace_ratio(0.05) + if is_train: + model.train() + else: + model.eval() + img_size = 224 + if model_name == "inception_v3": + img_size = 300 + test_img = np.random.random((bs, 3, img_size, img_size)).astype("float32") + if is_train: + label = (np.random.random((bs,)) * 1000).astype("int32") + if name == "torch": + test_img = torch.Tensor(test_img).cuda() + if is_train: + label = torch.LongTensor(label).cuda() + opt = optim.SGD(model.parameters(), 0.001) + sync = lambda: torch.cuda.synchronize() + jt = torch + else: + test_img = jt.array(test_img).stop_grad() + if is_train: + label = jt.array(label).stop_grad() + opt = optim.SGD(model.parameters(), 0.001) + sync = lambda: jt.sync_all(True) + + sync() + use_profiler = os.environ.get("use_profiler", "0") == "1" + if hasattr(jt, "nograd"): + ng = jt.no_grad() + ng.__enter__() + def iter(): + x = model(test_img) + if isinstance(x, tuple): + x = x[0] + if is_train: + loss = nn.CrossEntropyLoss()(x, label) + if name == "jittor": + opt.step(loss) + else: + opt.zero_grad() + loss.backward() + opt.step() + else: + x.sync() + sync() + for i in time_iter(): + iter() + sync() + for i in time_iter(): + iter() + sync() + if use_profiler: + if name == "torch": + prof = torch.autograd.profiler.profile(use_cuda=True) + else: + prof = jt.profile_scope() + prof.__enter__() + if name == "jittor": + if hasattr(jt.flags, "use_parallel_op_compiler"): + jt.flags.use_parallel_op_compiler = 0 + start = time.time() + for i in time_iter(10): + iter() + sync() + end = time.time() + if use_profiler: + prof.__exit__(None,None,None) + if name == "torch": + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30)) + total_iter = i+1 + print("duration:", end-start, "FPS:", total_iter*bs/(end-start)) + fpath = f"{home_path}/.cache/jittor/{name}-{_model_name}-{bs}.txt" + with open(fpath, 'w') as f: + f.write(f"duration: {end-start} FPS: {total_iter*bs/(end-start)}") + os.chmod(fpath, 0x666) + +if len(sys.argv) <= 1: + main() +else: + name, model, bs = sys.argv[1:] + bs = int(bs) + test(name, model, bs) \ No newline at end of file diff --git a/python/jittor/test/system/test_all.sh b/python/jittor/test/system/test_all.sh new file mode 100644 index 00000000..b4a2ddec --- /dev/null +++ b/python/jittor/test/system/test_all.sh @@ -0,0 +1,6 @@ +bash python/jittor/test/system/test_cuda10.0_ubuntu16.04.sh +bash python/jittor/test/system/test_cuda10.0_ubuntu18.04.sh +bash python/jittor/test/system/test_cuda11.1_ubuntu16.04.sh +bash python/jittor/test/system/test_cuda11.1_ubuntu18.04.sh +bash python/jittor/test/system/test_cuda11.1_ubuntu20.04.sh +bash python/jittor/test/system/test_nocuda_ubuntu18.04.sh diff --git a/python/jittor/test/system/test_cuda10.0_ubuntu16.04.sh b/python/jittor/test/system/test_cuda10.0_ubuntu16.04.sh new file mode 100644 index 00000000..66aea83f --- /dev/null +++ b/python/jittor/test/system/test_cuda10.0_ubuntu16.04.sh @@ -0,0 +1,41 @@ +cat > /tmp/cuda10.0-ubuntu16.04.dockerfile <<\EOF +FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3.7 -m pip install jittor +RUN python3.7 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor-cuda:10.0-16.04 -f /tmp/cuda10.0-ubuntu16.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor-cuda:10.0-18.04 bash -c \ +"python3.7 -m jittor.test.test_example && \ +python3.7 -m jittor.test.test_resnet && \ +python3.7 -m jittor.test.test_parallel_pass && \ +python3.7 -m jittor.test.test_atomic_tuner && \ +python3.7 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/system/test_cuda10.0_ubuntu18.04.sh b/python/jittor/test/system/test_cuda10.0_ubuntu18.04.sh new file mode 100644 index 00000000..02a6ef42 --- /dev/null +++ b/python/jittor/test/system/test_cuda10.0_ubuntu18.04.sh @@ -0,0 +1,41 @@ +cat > /tmp/cuda10.0-ubuntu18.04.dockerfile <<\EOF +FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3.7 -m pip install jittor +RUN python3.7 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor-cuda:10.0-18.04 -f /tmp/cuda10.0-ubuntu18.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor-cuda:10.0-18.04 bash -c \ +"python3.7 -m jittor.test.test_example && \ +python3.7 -m jittor.test.test_resnet && \ +python3.7 -m jittor.test.test_parallel_pass && \ +python3.7 -m jittor.test.test_atomic_tuner && \ +python3.7 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/system/test_cuda11.1_ubuntu16.04.sh b/python/jittor/test/system/test_cuda11.1_ubuntu16.04.sh new file mode 100644 index 00000000..ad25f725 --- /dev/null +++ b/python/jittor/test/system/test_cuda11.1_ubuntu16.04.sh @@ -0,0 +1,41 @@ +cat > /tmp/cuda11.1-ubuntu16.04.dockerfile <<\EOF +FROM nvidia/cuda:11.1-cudnn8-devel-ubuntu16.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3.7 -m pip install jittor +RUN python3.7 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor-cuda:11.1-16.04 -f /tmp/cuda11.1-ubuntu16.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor-cuda:11.1-16.04 bash -c \ +"python3.7 -m jittor.test.test_example && \ +python3.7 -m jittor.test.test_resnet && \ +python3.7 -m jittor.test.test_parallel_pass && \ +python3.7 -m jittor.test.test_atomic_tuner && \ +python3.7 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/system/test_cuda11.1_ubuntu18.04.sh b/python/jittor/test/system/test_cuda11.1_ubuntu18.04.sh new file mode 100644 index 00000000..9688dc93 --- /dev/null +++ b/python/jittor/test/system/test_cuda11.1_ubuntu18.04.sh @@ -0,0 +1,41 @@ +cat > /tmp/cuda11.1-ubuntu18.04.dockerfile <<\EOF +FROM nvidia/cuda:11.1-cudnn8-devel-ubuntu18.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3.7 -m pip install jittor +RUN python3.7 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor-cuda:11.1-18.04 -f /tmp/cuda11.1-ubuntu18.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor-cuda:11.1-18.04 bash -c \ +"python3.7 -m jittor.test.test_example && \ +python3.7 -m jittor.test.test_resnet && \ +python3.7 -m jittor.test.test_parallel_pass && \ +python3.7 -m jittor.test.test_atomic_tuner && \ +python3.7 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/system/test_cuda11.1_ubuntu20.04.sh b/python/jittor/test/system/test_cuda11.1_ubuntu20.04.sh new file mode 100644 index 00000000..bd776827 --- /dev/null +++ b/python/jittor/test/system/test_cuda11.1_ubuntu20.04.sh @@ -0,0 +1,39 @@ +cat > /tmp/cuda11.1-ubuntu20.04.dockerfile <<\EOF +FROM nvidia/cuda:11.1-devel-ubuntu20.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true +RUN apt install g++ build-essential libomp-dev python3-dev python3-pip wget -y +RUN python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple +WORKDIR /usr/src/ + +RUN wget https://developer.download.nvidia.cn/compute/cuda/repos/ubuntu2004/x86_64/libcudnn8_8.0.5.39-1+cuda11.1_amd64.deb && \ + wget https://developer.download.nvidia.cn/compute/cuda/repos/ubuntu2004/x86_64/libcudnn8-dev_8.0.5.39-1+cuda11.1_amd64.deb && \ + dpkg -i ./libcudnn8_8.0.5.39-1+cuda11.1_amd64.deb ./libcudnn8-dev_8.0.5.39-1+cuda11.1_amd64.deb && \ + rm *.deb +RUN ls + + +RUN pip3 install jittor --timeout 100 && python3 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3 -m pip install jittor +RUN python3 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor-cuda:11.1-20.04 -f /tmp/cuda11.1-ubuntu20.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor-cuda:11.1-20.04 bash -c \ +"python3 -m jittor.test.test_example && \ +python3 -m jittor.test.test_resnet && \ +python3 -m jittor.test.test_parallel_pass && \ +python3 -m jittor.test.test_atomic_tuner && \ +python3 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/system/test_nocuda_ubuntu18.04.sh b/python/jittor/test/system/test_nocuda_ubuntu18.04.sh new file mode 100644 index 00000000..0cbfe42f --- /dev/null +++ b/python/jittor/test/system/test_nocuda_ubuntu18.04.sh @@ -0,0 +1,40 @@ +cat > /tmp/ubuntu18.04.dockerfile <<\EOF +FROM ubuntu:18.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3.7 -m pip install jittor +RUN python3.7 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor:18.04 -f /tmp/ubuntu18.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor:18.04 bash -c \ +"python3.7 -m jittor.test.test_example && \ +python3.7 -m jittor.test.test_parallel_pass && \ +python3.7 -m jittor.test.test_atomic_tuner && \ +python3.7 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/test_mkl_conv_op.py b/python/jittor/test/test_mkl_conv_op.py index cfb8e8dd..26ced31c 100644 --- a/python/jittor/test/test_mkl_conv_op.py +++ b/python/jittor/test/test_mkl_conv_op.py @@ -120,7 +120,7 @@ class TestMklConvOp(unittest.TestCase): with jt.flag_scope( enable_tuner=0, - compile_options={"test_mkl_conv":1} + # compile_options={"test_mkl_conv":1} ): c_jt = conv(a_jt, b_jt, 1, 1) * da gs=jt.grad(c_jt,[a_jt,b_jt]) @@ -166,7 +166,7 @@ class TestMklConvOp(unittest.TestCase): with jt.flag_scope( enable_tuner=0, - compile_options={"test_mkl_conv":1} + # compile_options={"test_mkl_conv":1} ): c_jt = conv_nhwc_hwio(a_jt, b_jt, 1, 1) * da gs=jt.grad(c_jt,[a_jt,b_jt]) diff --git a/src/executor.cc b/src/executor.cc index 29d976c0..978a444e 100644 --- a/src/executor.cc +++ b/src/executor.cc @@ -446,8 +446,10 @@ void Executor::run_sync(vector vars, bool device_sync) { // record trace data if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var==2)) { trace_data.record_execution(op, is_fused_op, jkl); - checkCudaErrors(cudaDeviceSynchronize()); - + #ifdef HAS_CUDA + if (use_cuda) + checkCudaErrors(cudaDeviceSynchronize()); + #endif } LOGvvv << "Finished Op(" >> op->name() << rid >> "/" >> queue.size() >> ") output:" << op->outputs(); diff --git a/src/executor.h b/src/executor.h index 3c2801ac..6f93cd17 100644 --- a/src/executor.h +++ b/src/executor.h @@ -7,10 +7,6 @@ #pragma once #include "common.h" #include "mem/allocator.h" -#ifdef HAS_CUDA -#include -#include "helper_cuda.h" -#endif namespace jittor { diff --git a/src/ops/binary_op.cc b/src/ops/binary_op.cc index 303fc363..ee01de0c 100644 --- a/src/ops/binary_op.cc +++ b/src/ops/binary_op.cc @@ -93,6 +93,9 @@ VarPtr dirty_clone_broadcast(Var* v) { if (op && !v->is_finished() && v->shape.size() > 4 && op->type() == OpType::broadcast) { auto vp = op->duplicate(); if (vp) { + // TODO: loop options should be set to op, rather than var + if (v->loop_options) + vp->loop_options = v->loop_options; return vp; } } diff --git a/src/ops/where_op.cc b/src/ops/where_op.cc index e1605da3..6e7e2e9b 100644 --- a/src/ops/where_op.cc +++ b/src/ops/where_op.cc @@ -10,6 +10,8 @@ #ifdef JIT_cuda #include "executor.h" #include +#include +#include "helper_cuda.h" #endif namespace jittor { diff --git a/src/profiler/simple_profiler.h b/src/profiler/simple_profiler.h index 4f54cb45..7949c01d 100644 --- a/src/profiler/simple_profiler.h +++ b/src/profiler/simple_profiler.h @@ -98,4 +98,6 @@ struct SimpleProfilerGuard { }; +DECLARE_FLAG(int, profiler_enable); + } // jittor \ No newline at end of file diff --git a/src/utils/cache_compile.cc b/src/utils/cache_compile.cc index dc516016..735fedbb 100644 --- a/src/utils/cache_compile.cc +++ b/src/utils/cache_compile.cc @@ -153,7 +153,7 @@ void process(string src, vector& input_names) { while (l