mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[ci] remove remaining RDS dependency"
This reverts commit 964d5059587c017b02f1d7d62587722a38683c63. Reverted https://github.com/pytorch/pytorch/pull/79370 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
This commit is contained in:
@ -1,11 +1,28 @@
|
||||
# Owner(s): ["module: ci"]
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
import pathlib
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_LINUX, IS_CI
|
||||
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
|
||||
try:
|
||||
# Just in case PyTorch was not built in 'develop' mode
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
from tools.stats.scribe import rds_write, register_rds_schema
|
||||
except ImportError:
|
||||
register_rds_schema = None
|
||||
rds_write = None
|
||||
|
||||
|
||||
# these tests could eventually be changed to fail if the import/init
|
||||
# time is greater than a certain threshold, but for now we just use them
|
||||
# as a way to track the duration of `import torch`.
|
||||
# as a way to track the duration of `import torch` in our ossci-metrics
|
||||
# S3 bucket (see tools/stats/print_test_stats.py)
|
||||
class TestImportTime(TestCase):
|
||||
def test_time_import_torch(self):
|
||||
TestCase.runWithPytorchAPIUsageStderr("import torch")
|
||||
@ -15,6 +32,43 @@ class TestImportTime(TestCase):
|
||||
"import torch; torch.cuda.device_count()",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not IS_LINUX, "Memory test is only implemented for Linux")
|
||||
@unittest.skipIf(not IS_CI, "Memory test only runs in CI")
|
||||
@unittest.skipIf(rds_write is None, "Cannot import rds_write from tools.stats.scribe")
|
||||
def test_peak_memory(self):
|
||||
def profile(module, name):
|
||||
command = f"import {module}; import resource; print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)"
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-c", command],
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
max_rss = int(result.stdout.decode().strip())
|
||||
|
||||
return {
|
||||
"test_name": name,
|
||||
"peak_memory_bytes": max_rss,
|
||||
}
|
||||
|
||||
data = profile("torch", "pytorch")
|
||||
baseline = profile("sys", "baseline")
|
||||
try:
|
||||
rds_write("import_stats", [data, baseline])
|
||||
except Exception as e:
|
||||
raise unittest.SkipTest(f"Failed to record import_stats: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if register_rds_schema and IS_CI:
|
||||
try:
|
||||
register_rds_schema(
|
||||
"import_stats",
|
||||
{
|
||||
"test_name": "string",
|
||||
"peak_memory_bytes": "int",
|
||||
"time_ms": "int",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to register RDS schema: {e}")
|
||||
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user