add system test

This commit is contained in:
Dun Liang 2020-12-07 22:48:25 +08:00
parent 6e761f89da
commit 30fc5b32fd
23 changed files with 498 additions and 16 deletions

View File

@ -26,3 +26,4 @@ venv/
python/jittor.egg-info python/jittor.egg-info
dist/ dist/
!doc/source/* !doc/source/*
__data__

2
MANIFEST.in Normal file
View File

@ -0,0 +1,2 @@
exclude __data__
exclude __pycache__

View File

@ -10,6 +10,8 @@
#include "cub_where_op.h" #include "cub_where_op.h"
#ifdef JIT_cuda #ifdef JIT_cuda
#include "executor.h" #include "executor.h"
#include <cuda_runtime.h>
#include "helper_cuda.h"
#include <assert.h> #include <assert.h>
#include <executor.h> #include <executor.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>

View File

@ -9,6 +9,7 @@
#include "cutt.h" #include "cutt.h"
#include "cutt_warper.h" #include "cutt_warper.h"
#include "misc/stack_vector.h" #include "misc/stack_vector.h"
#include "helper_cuda.h"
namespace jittor { namespace jittor {

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in # This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.2.1.3' __version__ = '1.2.2.0'
from . import lock from . import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int
@ -92,6 +92,7 @@ class log_capture_scope(_call_no_record_scope):
print(logs) print(logs)
""" """
def __init__(self, **jt_flags): def __init__(self, **jt_flags):
jt_flags["use_parallel_op_compiler"] = 0
self.fs = flag_scope(**jt_flags) self.fs = flag_scope(**jt_flags)
def __enter__(self): def __enter__(self):

View File

@ -78,9 +78,9 @@ def setup_mkl():
def install_cub(root_folder): def install_cub(root_folder):
url = "https://github.com/NVIDIA/cub/archive/1.11.0-rc1.tar.gz" url = "https://github.com/NVIDIA/cub/archive/1.11.0.tar.gz"
filename = "cub-1.11.0-rc1.tgz" filename = "cub-1.11.0.tgz"
md5 = "f395687060bed7eaeb5fa8a689276ede" md5 = "97196a885598e40592100e1caaf3d5ea"
fullname = os.path.join(root_folder, filename) fullname = os.path.join(root_folder, filename)
dirname = os.path.join(root_folder, filename.replace(".tgz","")) dirname = os.path.join(root_folder, filename.replace(".tgz",""))

View File

@ -33,7 +33,7 @@ def __iter__(x):
return result.__iter__() return result.__iter__()
jt.Var.__iter__ = __iter__ jt.Var.__iter__ = __iter__
def all(x,dim): def all(x, dim=[]):
return x.all_(dim).bool() return x.all_(dim).bool()
jt.Var.all = all jt.Var.all = all

View File

@ -175,10 +175,10 @@ class Res2Net(Module):
x = self.layer4(x) x = self.layer4(x)
return x, low_level_feat return x, low_level_feat
def res2net50(output_stride): def res2net50(output_stride=16):
model = Res2Net(Bottle2neck, [3,4,6,3], output_stride) model = Res2Net(Bottle2neck, [3,4,6,3], output_stride)
return model return model
def res2net101(output_stride): def res2net101(output_stride=16):
model = Res2Net(Bottle2neck, [3,4,23,3], output_stride) model = Res2Net(Bottle2neck, [3,4,23,3], output_stride)
return model return model

View File

@ -0,0 +1,221 @@
import sys, os
suffix = ""
import jittor as jt
import time
from pathlib import Path
home_path = str(Path.home())
perf_path = os.path.join(home_path, ".cache", "jittor_perf")
def main():
os.makedirs(perf_path+"/src/jittor", exist_ok=True)
os.makedirs(perf_path+"/src/jittor_utils", exist_ok=True)
os.system(f"cp -rL {jt.flags.jittor_path} {perf_path+'/src/'}")
os.system(f"cp -rL {jt.flags.jittor_path}/../jittor_utils {perf_path+'/src/'}")
use_torch_1_4 = os.environ.get("use_torch_1_4", "0") == "1"
dockerfile_src = r"""
FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04
RUN echo \
"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list
# RUN rm -rf /var/lib/apt/lists/*
RUN apt update || true
RUN apt install wget \
python3.7 python3.7-dev \
g++ build-essential -y
WORKDIR /usr/src
RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \
&& wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7
# change tsinghua mirror
RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3 install \
pybind11 \
numpy \
tqdm \
pillow \
astunparse
RUN pip3 install torch torchvision
"""
global suffix
if use_torch_1_4:
suffix = "_1_4"
dockerfile_src = dockerfile_src.replace("torch ", "torch==1.4.0 ")
dockerfile_src = dockerfile_src.replace("torchvision", "torchvision==0.5.0")
with open("/tmp/perf_dockerfile", 'w') as f:
f.write(dockerfile_src)
assert os.system("sudo nvidia-smi -lgc 1500") == 0
assert os.system(f"sudo docker build --tag jittor/jittor-perf{suffix} -f /tmp/perf_dockerfile .") == 0
# run once for compile source
jt_fps = test_main("jittor", "resnet50", 1)
logs = ""
# resnext50_32x4d with bs=8 cannot pass this test
#### inference test
for model_name in ["resnet50", "wide_resnet50_2", # "resnext50_32x4d",
"resnet152", "wide_resnet101_2", "resnext101_32x8d",
"alexnet", "vgg11", "squeezenet1_1", "mobilenet_v2",
"densenet121", "densenet169", "densenet201",
"res2net50", "res2net101"]:
for bs in [1, 2, 4, 8, 16, 32, 64, 128]:
jt_fps = test_main("jittor", model_name, bs)
logs += f"jittor-{model_name}-{bs} {jt_fps}\n"
tc_fps = test_main("torch", model_name, bs)
logs += f"torch-{model_name}-{bs} {tc_fps}\n"
logs += f"compare-{model_name}-{bs} {jt_fps/tc_fps}\n"
print(logs)
#### train test
for model_name in ["train_resnet50", "train_resnet101"
]:
for bs in [1, 2, 4, 8, 16, 32, 64, 128]:
jt_fps = test_main("jittor", model_name, bs)
logs += f"jittor-{model_name}-{bs} {jt_fps}\n"
tc_fps = test_main("torch", model_name, bs)
logs += f"torch-{model_name}-{bs} {tc_fps}\n"
logs += f"compare-{model_name}-{bs} {jt_fps/tc_fps}\n"
print(logs)
with open(f"{perf_path}/jittor-perf{suffix}-latest.txt", "w") as f:
f.write(logs)
from datetime import datetime
with open(f"{perf_path}/jittor-perf{suffix}-{datetime.now()}.txt", "w") as f:
f.write(logs)
def test_main(name, model_name, bs):
cmd = f"sudo docker run --gpus all --rm -v {perf_path}:/root/.cache/jittor --network host jittor/jittor-perf{suffix} bash -c 'PYTHONPATH=/root/.cache/jittor/src python3.7 /root/.cache/jittor/src/jittor/test/perf/perf.py {name} {model_name} {bs}'"
fps = -1
try:
print("run cmd:", cmd)
if os.system(cmd) == 0:
with open(f"{perf_path}/{name}-{model_name}-{bs}.txt", 'r') as f:
fps = float(f.read().split()[3])
except:
pass
return fps
def time_iter(duration=2, min_iter=5):
start = time.time()
for i in range(10000000):
yield i
end = time.time()
if end-start>duration and i>=min_iter:
return
def test(name, model_name, bs):
print("hello", name, model_name, bs)
import numpy as np
import time
is_train = False
_model_name = model_name
if model_name.startswith("train_"):
is_train = True
model_name = model_name[6:]
if name == "torch":
import torch
import torchvision.models as tcmodels
from torch import optim
from torch import nn
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
model = tcmodels.__dict__[model_name]()
model = model.cuda()
else:
import jittor as jt
from jittor import optim
from jittor import nn
jt.flags.use_cuda = 1
jt.cudnn.set_algorithm_cache_size(10000)
import jittor.models as jtmodels
model = jtmodels.__dict__[model_name]()
if (model == "resnet152" or model == "resnet101") and bs == 128 and is_train:
jt.cudnn.set_max_workspace_ratio(0.05)
if is_train:
model.train()
else:
model.eval()
img_size = 224
if model_name == "inception_v3":
img_size = 300
test_img = np.random.random((bs, 3, img_size, img_size)).astype("float32")
if is_train:
label = (np.random.random((bs,)) * 1000).astype("int32")
if name == "torch":
test_img = torch.Tensor(test_img).cuda()
if is_train:
label = torch.LongTensor(label).cuda()
opt = optim.SGD(model.parameters(), 0.001)
sync = lambda: torch.cuda.synchronize()
jt = torch
else:
test_img = jt.array(test_img).stop_grad()
if is_train:
label = jt.array(label).stop_grad()
opt = optim.SGD(model.parameters(), 0.001)
sync = lambda: jt.sync_all(True)
sync()
use_profiler = os.environ.get("use_profiler", "0") == "1"
if hasattr(jt, "nograd"):
ng = jt.no_grad()
ng.__enter__()
def iter():
x = model(test_img)
if isinstance(x, tuple):
x = x[0]
if is_train:
loss = nn.CrossEntropyLoss()(x, label)
if name == "jittor":
opt.step(loss)
else:
opt.zero_grad()
loss.backward()
opt.step()
else:
x.sync()
sync()
for i in time_iter():
iter()
sync()
for i in time_iter():
iter()
sync()
if use_profiler:
if name == "torch":
prof = torch.autograd.profiler.profile(use_cuda=True)
else:
prof = jt.profile_scope()
prof.__enter__()
if name == "jittor":
if hasattr(jt.flags, "use_parallel_op_compiler"):
jt.flags.use_parallel_op_compiler = 0
start = time.time()
for i in time_iter(10):
iter()
sync()
end = time.time()
if use_profiler:
prof.__exit__(None,None,None)
if name == "torch":
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
total_iter = i+1
print("duration:", end-start, "FPS:", total_iter*bs/(end-start))
fpath = f"{home_path}/.cache/jittor/{name}-{_model_name}-{bs}.txt"
with open(fpath, 'w') as f:
f.write(f"duration: {end-start} FPS: {total_iter*bs/(end-start)}")
os.chmod(fpath, 0x666)
if len(sys.argv) <= 1:
main()
else:
name, model, bs = sys.argv[1:]
bs = int(bs)
test(name, model, bs)

View File

@ -0,0 +1,6 @@
bash python/jittor/test/system/test_cuda10.0_ubuntu16.04.sh
bash python/jittor/test/system/test_cuda10.0_ubuntu18.04.sh
bash python/jittor/test/system/test_cuda11.1_ubuntu16.04.sh
bash python/jittor/test/system/test_cuda11.1_ubuntu18.04.sh
bash python/jittor/test/system/test_cuda11.1_ubuntu20.04.sh
bash python/jittor/test/system/test_nocuda_ubuntu18.04.sh

View File

@ -0,0 +1,41 @@
cat > /tmp/cuda10.0-ubuntu16.04.dockerfile <<\EOF
FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
RUN apt update && apt install ca-certificates -y
RUN echo \
"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list
# RUN rm -rf /var/lib/apt/lists/*
RUN apt update || true
RUN apt install wget \
python3.7 python3.7-dev \
g++ build-essential -y
WORKDIR /usr/src
RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \
&& wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7
# change tsinghua mirror
RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3.7 -m pip install jittor
RUN python3.7 -m jittor.test.test_core
EOF
sudo docker build --tag jittor/jittor-cuda:10.0-16.04 -f /tmp/cuda10.0-ubuntu16.04.dockerfile .
sudo docker run --gpus all --rm jittor/jittor-cuda:10.0-18.04 bash -c \
"python3.7 -m jittor.test.test_example && \
python3.7 -m jittor.test.test_resnet && \
python3.7 -m jittor.test.test_parallel_pass && \
python3.7 -m jittor.test.test_atomic_tuner && \
python3.7 -m jittor.test.test_where_op"

View File

@ -0,0 +1,41 @@
cat > /tmp/cuda10.0-ubuntu18.04.dockerfile <<\EOF
FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04
RUN apt update && apt install ca-certificates -y
RUN echo \
"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list
# RUN rm -rf /var/lib/apt/lists/*
RUN apt update || true
RUN apt install wget \
python3.7 python3.7-dev \
g++ build-essential -y
WORKDIR /usr/src
RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \
&& wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7
# change tsinghua mirror
RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3.7 -m pip install jittor
RUN python3.7 -m jittor.test.test_core
EOF
sudo docker build --tag jittor/jittor-cuda:10.0-18.04 -f /tmp/cuda10.0-ubuntu18.04.dockerfile .
sudo docker run --gpus all --rm jittor/jittor-cuda:10.0-18.04 bash -c \
"python3.7 -m jittor.test.test_example && \
python3.7 -m jittor.test.test_resnet && \
python3.7 -m jittor.test.test_parallel_pass && \
python3.7 -m jittor.test.test_atomic_tuner && \
python3.7 -m jittor.test.test_where_op"

View File

@ -0,0 +1,41 @@
cat > /tmp/cuda11.1-ubuntu16.04.dockerfile <<\EOF
FROM nvidia/cuda:11.1-cudnn8-devel-ubuntu16.04
RUN apt update && apt install ca-certificates -y
RUN echo \
"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list
# RUN rm -rf /var/lib/apt/lists/*
RUN apt update || true
RUN apt install wget \
python3.7 python3.7-dev \
g++ build-essential -y
WORKDIR /usr/src
RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \
&& wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7
# change tsinghua mirror
RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3.7 -m pip install jittor
RUN python3.7 -m jittor.test.test_core
EOF
sudo docker build --tag jittor/jittor-cuda:11.1-16.04 -f /tmp/cuda11.1-ubuntu16.04.dockerfile .
sudo docker run --gpus all --rm jittor/jittor-cuda:11.1-16.04 bash -c \
"python3.7 -m jittor.test.test_example && \
python3.7 -m jittor.test.test_resnet && \
python3.7 -m jittor.test.test_parallel_pass && \
python3.7 -m jittor.test.test_atomic_tuner && \
python3.7 -m jittor.test.test_where_op"

View File

@ -0,0 +1,41 @@
cat > /tmp/cuda11.1-ubuntu18.04.dockerfile <<\EOF
FROM nvidia/cuda:11.1-cudnn8-devel-ubuntu18.04
RUN apt update && apt install ca-certificates -y
RUN echo \
"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list
# RUN rm -rf /var/lib/apt/lists/*
RUN apt update || true
RUN apt install wget \
python3.7 python3.7-dev \
g++ build-essential -y
WORKDIR /usr/src
RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \
&& wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7
# change tsinghua mirror
RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3.7 -m pip install jittor
RUN python3.7 -m jittor.test.test_core
EOF
sudo docker build --tag jittor/jittor-cuda:11.1-18.04 -f /tmp/cuda11.1-ubuntu18.04.dockerfile .
sudo docker run --gpus all --rm jittor/jittor-cuda:11.1-18.04 bash -c \
"python3.7 -m jittor.test.test_example && \
python3.7 -m jittor.test.test_resnet && \
python3.7 -m jittor.test.test_parallel_pass && \
python3.7 -m jittor.test.test_atomic_tuner && \
python3.7 -m jittor.test.test_where_op"

View File

@ -0,0 +1,39 @@
cat > /tmp/cuda11.1-ubuntu20.04.dockerfile <<\EOF
FROM nvidia/cuda:11.1-devel-ubuntu20.04
RUN apt update && apt install ca-certificates -y
RUN echo \
"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse" > /etc/apt/sources.list
# RUN rm -rf /var/lib/apt/lists/*
RUN apt update || true
RUN apt install g++ build-essential libomp-dev python3-dev python3-pip wget -y
RUN python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
WORKDIR /usr/src/
RUN wget https://developer.download.nvidia.cn/compute/cuda/repos/ubuntu2004/x86_64/libcudnn8_8.0.5.39-1+cuda11.1_amd64.deb && \
wget https://developer.download.nvidia.cn/compute/cuda/repos/ubuntu2004/x86_64/libcudnn8-dev_8.0.5.39-1+cuda11.1_amd64.deb && \
dpkg -i ./libcudnn8_8.0.5.39-1+cuda11.1_amd64.deb ./libcudnn8-dev_8.0.5.39-1+cuda11.1_amd64.deb && \
rm *.deb
RUN ls
RUN pip3 install jittor --timeout 100 && python3 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3 -m pip install jittor
RUN python3 -m jittor.test.test_core
EOF
sudo docker build --tag jittor/jittor-cuda:11.1-20.04 -f /tmp/cuda11.1-ubuntu20.04.dockerfile .
sudo docker run --gpus all --rm jittor/jittor-cuda:11.1-20.04 bash -c \
"python3 -m jittor.test.test_example && \
python3 -m jittor.test.test_resnet && \
python3 -m jittor.test.test_parallel_pass && \
python3 -m jittor.test.test_atomic_tuner && \
python3 -m jittor.test.test_where_op"

View File

@ -0,0 +1,40 @@
cat > /tmp/ubuntu18.04.dockerfile <<\EOF
FROM ubuntu:18.04
RUN apt update && apt install ca-certificates -y
RUN echo \
"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\
deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list
# RUN rm -rf /var/lib/apt/lists/*
RUN apt update
RUN apt install wget \
python3.7 python3.7-dev \
g++ build-essential -y
WORKDIR /usr/src
RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \
&& wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7
# change tsinghua mirror
RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3.7 -m pip install jittor
RUN python3.7 -m jittor.test.test_core
EOF
sudo docker build --tag jittor/jittor:18.04 -f /tmp/ubuntu18.04.dockerfile .
sudo docker run --gpus all --rm jittor/jittor:18.04 bash -c \
"python3.7 -m jittor.test.test_example && \
python3.7 -m jittor.test.test_parallel_pass && \
python3.7 -m jittor.test.test_atomic_tuner && \
python3.7 -m jittor.test.test_where_op"

View File

@ -120,7 +120,7 @@ class TestMklConvOp(unittest.TestCase):
with jt.flag_scope( with jt.flag_scope(
enable_tuner=0, enable_tuner=0,
compile_options={"test_mkl_conv":1} # compile_options={"test_mkl_conv":1}
): ):
c_jt = conv(a_jt, b_jt, 1, 1) * da c_jt = conv(a_jt, b_jt, 1, 1) * da
gs=jt.grad(c_jt,[a_jt,b_jt]) gs=jt.grad(c_jt,[a_jt,b_jt])
@ -166,7 +166,7 @@ class TestMklConvOp(unittest.TestCase):
with jt.flag_scope( with jt.flag_scope(
enable_tuner=0, enable_tuner=0,
compile_options={"test_mkl_conv":1} # compile_options={"test_mkl_conv":1}
): ):
c_jt = conv_nhwc_hwio(a_jt, b_jt, 1, 1) * da c_jt = conv_nhwc_hwio(a_jt, b_jt, 1, 1) * da
gs=jt.grad(c_jt,[a_jt,b_jt]) gs=jt.grad(c_jt,[a_jt,b_jt])

View File

@ -446,8 +446,10 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
// record trace data // record trace data
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var==2)) { if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var==2)) {
trace_data.record_execution(op, is_fused_op, jkl); trace_data.record_execution(op, is_fused_op, jkl);
checkCudaErrors(cudaDeviceSynchronize()); #ifdef HAS_CUDA
if (use_cuda)
checkCudaErrors(cudaDeviceSynchronize());
#endif
} }
LOGvvv << "Finished Op(" >> op->name() << rid >> LOGvvv << "Finished Op(" >> op->name() << rid >>
"/" >> queue.size() >> ") output:" << op->outputs(); "/" >> queue.size() >> ") output:" << op->outputs();

View File

@ -7,10 +7,6 @@
#pragma once #pragma once
#include "common.h" #include "common.h"
#include "mem/allocator.h" #include "mem/allocator.h"
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include "helper_cuda.h"
#endif
namespace jittor { namespace jittor {

View File

@ -93,6 +93,9 @@ VarPtr dirty_clone_broadcast(Var* v) {
if (op && !v->is_finished() && v->shape.size() > 4 && op->type() == OpType::broadcast) { if (op && !v->is_finished() && v->shape.size() > 4 && op->type() == OpType::broadcast) {
auto vp = op->duplicate(); auto vp = op->duplicate();
if (vp) { if (vp) {
// TODO: loop options should be set to op, rather than var
if (v->loop_options)
vp->loop_options = v->loop_options;
return vp; return vp;
} }
} }

View File

@ -10,6 +10,8 @@
#ifdef JIT_cuda #ifdef JIT_cuda
#include "executor.h" #include "executor.h"
#include <assert.h> #include <assert.h>
#include <cuda_runtime.h>
#include "helper_cuda.h"
#endif #endif
namespace jittor { namespace jittor {

View File

@ -98,4 +98,6 @@ struct SimpleProfilerGuard {
}; };
DECLARE_FLAG(int, profiler_enable);
} // jittor } // jittor

View File

@ -153,7 +153,7 @@ void process(string src, vector<string>& input_names) {
while (l<src.size() && (src[l] != ' ' && src[l] != '\n')) l++; while (l<src.size() && (src[l] != ' ' && src[l] != '\n')) l++;
if (src[k] == '"' && src[l-1] == '"' && j-i==8 && src.substr(i,j-i) == "#include") { if (src[k] == '"' && src[l-1] == '"' && j-i==8 && src.substr(i,j-i) == "#include") {
auto inc = src.substr(k+1, l-k-2); auto inc = src.substr(k+1, l-k-2);
if (inc != "test.h") { if (inc != "test.h" && inc != "helper_cuda.h") {
LOGvvvv << "Found include" << inc; LOGvvvv << "Found include" << inc;
input_names.push_back(inc); input_names.push_back(inc);
} }