forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
allcompare_test.py
87 lines (63 loc) · 2.2 KB
/
allcompare_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python3
from hypothesis import given, settings
import hypothesis.strategies as st
from multiprocessing import Process
import numpy as np
import tempfile
import shutil
import caffe2.python.hypothesis_test_util as hu
op_engine = 'GLOO'
class TemporaryDirectory:
def __enter__(self):
self.tmpdir = tempfile.mkdtemp()
return self.tmpdir
def __exit__(self, type, value, traceback):
shutil.rmtree(self.tmpdir)
def allcompare_process(filestore_dir, process_id, data, num_procs):
from caffe2.python import core, data_parallel_model, workspace, dyndep
from caffe2.python.model_helper import ModelHelper
from caffe2.proto import caffe2_pb2
dyndep.InitOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
workspace.RunOperatorOnce(
core.CreateOperator(
"FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir
)
)
rendezvous = dict(
kv_handler="store_handler",
shard_id=process_id,
num_shards=num_procs,
engine=op_engine,
exit_nets=None
)
model = ModelHelper()
model._rendezvous = rendezvous
workspace.FeedBlob("test_data", data)
data_parallel_model._RunComparison(
model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0)
)
class TestAllCompare(hu.HypothesisTestCase):
@given(
d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8)
)
@settings(deadline=10000)
def test_allcompare(self, d, n, num_procs):
dims = []
for _ in range(d):
dims.append(np.random.randint(1, high=n))
test_data = np.random.ranf(size=tuple(dims)).astype(np.float32)
with TemporaryDirectory() as tempdir:
processes = []
for idx in range(num_procs):
process = Process(
target=allcompare_process,
args=(tempdir, idx, test_data, num_procs)
)
processes.append(process)
process.start()
while len(processes) > 0:
process = processes.pop()
process.join()
if __name__ == "__main__":
import unittest
unittest.main()