mpi code polish

This commit is contained in:
Dun Liang 2020-06-07 13:34:13 +08:00
parent 3bc1f3cac6
commit a2736164fb
5 changed files with 16 additions and 13 deletions

View File

@ -66,7 +66,7 @@ class Optimizer(object):
g.assign(g.mpi_all_reduce("mean"))
if self.n_step % self.param_sync_iter == 0:
for p in params:
p.assign(p.mpi_all_reduce("mean"))
p.assign(p.mpi_broadcast())
self.n_step += 1
# set up grads in param_groups

View File

@ -72,6 +72,10 @@ def run_mpi_test(num_procs, name):
class TestMpiEntry(unittest.TestCase):
def test_entry(self):
run_mpi_test(2, "test_mpi")
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
def test_mpi_resnet_entry(self):
run_mpi_test(2, "test_resnet")
if __name__ == "__main__":
unittest.main()

View File

@ -76,7 +76,7 @@ class TestResnet(unittest.TestCase):
# print train info
global prev
pred = np.argmax(output, axis=1)
acc = np.sum(target==pred)/self.batch_size
acc = np.mean(target==pred)
loss_list.append(loss[0])
acc_list.append(acc)
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
@ -113,10 +113,14 @@ class TestResnet(unittest.TestCase):
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
assert jt.core.number_of_lived_vars() < 3500
if jt.mpi:
assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
else:
assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()
jt.sync_all(True)
assert np.mean(loss_list[-50:])<0.3
assert np.mean(acc_list[-50:])>0.8
if __name__ == "__main__":
unittest.main()

View File

@ -10,12 +10,13 @@ import unittest
import os, sys
import jittor as jt
import numpy as np
from jittor.test.test_mpi import run_mpi_test
mpi = jt.compile_extern.mpi
from jittor.dataset.mnist import MNIST
dataloader = MNIST(train=False).set_attrs(batch_size=16)
def val1():
dataloader = MNIST(train=False).set_attrs(batch_size=16)
for i, (imgs, labels) in enumerate(dataloader):
assert(imgs.shape[0]==8)
if i == 5:
@ -23,6 +24,7 @@ def val1():
@jt.single_process_scope(rank=0)
def val2():
dataloader = MNIST(train=False).set_attrs(batch_size=16)
for i, (imgs, labels) in enumerate(dataloader):
assert(imgs.shape[0]==16)
if i == 5:
@ -34,17 +36,10 @@ class TestSingleProcessScope(unittest.TestCase):
val1()
val2()
def run_single_process_scope_test(num_procs, name):
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = f"{mpirun_path} -np {num_procs} {sys.executable} -m jittor.test.{name} -v"
print("run cmd:", cmd)
assert os.system(cmd)==0, "run cmd failed: "+cmd
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestSingleProcessScopeEntry(unittest.TestCase):
def test_entry(self):
run_single_process_scope_test(2, "test_single_process_scope")
run_mpi_test(2, "test_single_process_scope")
if __name__ == "__main__":
unittest.main()

View File

@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
setuptools.setup(
name='jittor',
version='1.1.4.2',
version='1.1.4.3',
# scripts=[],
author="Jittor Group",
author_email="ran.donglang@gmail.com",