forked from maxjhandsome/jittor
mpi code polish
This commit is contained in:
parent
3bc1f3cac6
commit
a2736164fb
|
@ -66,7 +66,7 @@ class Optimizer(object):
|
|||
g.assign(g.mpi_all_reduce("mean"))
|
||||
if self.n_step % self.param_sync_iter == 0:
|
||||
for p in params:
|
||||
p.assign(p.mpi_all_reduce("mean"))
|
||||
p.assign(p.mpi_broadcast())
|
||||
self.n_step += 1
|
||||
|
||||
# set up grads in param_groups
|
||||
|
|
|
@ -72,6 +72,10 @@ def run_mpi_test(num_procs, name):
|
|||
class TestMpiEntry(unittest.TestCase):
|
||||
def test_entry(self):
|
||||
run_mpi_test(2, "test_mpi")
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
def test_mpi_resnet_entry(self):
|
||||
run_mpi_test(2, "test_resnet")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -76,7 +76,7 @@ class TestResnet(unittest.TestCase):
|
|||
# print train info
|
||||
global prev
|
||||
pred = np.argmax(output, axis=1)
|
||||
acc = np.sum(target==pred)/self.batch_size
|
||||
acc = np.mean(target==pred)
|
||||
loss_list.append(loss[0])
|
||||
acc_list.append(acc)
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
||||
|
@ -113,10 +113,14 @@ class TestResnet(unittest.TestCase):
|
|||
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
|
||||
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
|
||||
|
||||
assert jt.core.number_of_lived_vars() < 3500
|
||||
if jt.mpi:
|
||||
assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
|
||||
else:
|
||||
assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()
|
||||
|
||||
jt.sync_all(True)
|
||||
assert np.mean(loss_list[-50:])<0.3
|
||||
assert np.mean(acc_list[-50:])>0.8
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -10,12 +10,13 @@ import unittest
|
|||
import os, sys
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor.test.test_mpi import run_mpi_test
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
from jittor.dataset.mnist import MNIST
|
||||
dataloader = MNIST(train=False).set_attrs(batch_size=16)
|
||||
|
||||
def val1():
|
||||
dataloader = MNIST(train=False).set_attrs(batch_size=16)
|
||||
for i, (imgs, labels) in enumerate(dataloader):
|
||||
assert(imgs.shape[0]==8)
|
||||
if i == 5:
|
||||
|
@ -23,6 +24,7 @@ def val1():
|
|||
|
||||
@jt.single_process_scope(rank=0)
|
||||
def val2():
|
||||
dataloader = MNIST(train=False).set_attrs(batch_size=16)
|
||||
for i, (imgs, labels) in enumerate(dataloader):
|
||||
assert(imgs.shape[0]==16)
|
||||
if i == 5:
|
||||
|
@ -34,17 +36,10 @@ class TestSingleProcessScope(unittest.TestCase):
|
|||
val1()
|
||||
val2()
|
||||
|
||||
def run_single_process_scope_test(num_procs, name):
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = f"{mpirun_path} -np {num_procs} {sys.executable} -m jittor.test.{name} -v"
|
||||
print("run cmd:", cmd)
|
||||
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestSingleProcessScopeEntry(unittest.TestCase):
|
||||
def test_entry(self):
|
||||
run_single_process_scope_test(2, "test_single_process_scope")
|
||||
run_mpi_test(2, "test_single_process_scope")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue