forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_import_stats.py
74 lines (60 loc) · 2.41 KB
/
test_import_stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Owner(s): ["module: ci"]
import subprocess
import sys
import unittest
import pathlib
from torch.testing._internal.common_utils import TestCase, run_tests, IS_LINUX, IS_IN_CI
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
try:
# Just in case PyTorch was not built in 'develop' mode
sys.path.append(str(REPO_ROOT))
from tools.stats.scribe import rds_write, register_rds_schema
except ImportError:
register_rds_schema = None
rds_write = None
# these tests could eventually be changed to fail if the import/init
# time is greater than a certain threshold, but for now we just use them
# as a way to track the duration of `import torch` in our ossci-metrics
# S3 bucket (see tools/stats/print_test_stats.py)
class TestImportTime(TestCase):
def test_time_import_torch(self):
TestCase.runWithPytorchAPIUsageStderr("import torch")
def test_time_cuda_device_count(self):
TestCase.runWithPytorchAPIUsageStderr(
"import torch; torch.cuda.device_count()",
)
@unittest.skipIf(not IS_LINUX, "Memory test is only implemented for Linux")
@unittest.skipIf(not IS_IN_CI, "Memory test only runs in CI")
@unittest.skipIf(rds_write is None, "Cannot import rds_write from tools.stats.scribe")
def test_peak_memory(self):
def profile(module, name):
command = f"import {module}; import resource; print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)"
result = subprocess.run(
[sys.executable, "-c", command],
stdout=subprocess.PIPE,
)
max_rss = int(result.stdout.decode().strip())
return {
"test_name": name,
"peak_memory_bytes": max_rss,
}
data = profile("torch", "pytorch")
baseline = profile("sys", "baseline")
try:
rds_write("import_stats", [data, baseline])
except Exception as e:
raise unittest.SkipTest(f"Failed to record import_stats: {e}")
if __name__ == "__main__":
if register_rds_schema and IS_IN_CI:
try:
register_rds_schema(
"import_stats",
{
"test_name": "string",
"peak_memory_bytes": "int",
"time_ms": "int",
},
)
except Exception as e:
print(f"Failed to register RDS schema: {e}")
run_tests()