mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Revert "[ci] remove remaining RDS dependency""
This reverts commit 21e32d5a0b3b8c4f93aae00761a6befd6ab47765.
This commit is contained in:
@ -1,28 +1,11 @@
|
||||
# Owner(s): ["module: ci"]
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
import pathlib
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_LINUX, IS_CI
|
||||
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
|
||||
try:
|
||||
# Just in case PyTorch was not built in 'develop' mode
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
from tools.stats.scribe import rds_write, register_rds_schema
|
||||
except ImportError:
|
||||
register_rds_schema = None
|
||||
rds_write = None
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
|
||||
# these tests could eventually be changed to fail if the import/init
|
||||
# time is greater than a certain threshold, but for now we just use them
|
||||
# as a way to track the duration of `import torch` in our ossci-metrics
|
||||
# S3 bucket (see tools/stats/print_test_stats.py)
|
||||
# as a way to track the duration of `import torch`.
|
||||
class TestImportTime(TestCase):
|
||||
def test_time_import_torch(self):
|
||||
TestCase.runWithPytorchAPIUsageStderr("import torch")
|
||||
@ -32,43 +15,6 @@ class TestImportTime(TestCase):
|
||||
"import torch; torch.cuda.device_count()",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not IS_LINUX, "Memory test is only implemented for Linux")
|
||||
@unittest.skipIf(not IS_CI, "Memory test only runs in CI")
|
||||
@unittest.skipIf(rds_write is None, "Cannot import rds_write from tools.stats.scribe")
|
||||
def test_peak_memory(self):
|
||||
def profile(module, name):
|
||||
command = f"import {module}; import resource; print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)"
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-c", command],
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
max_rss = int(result.stdout.decode().strip())
|
||||
|
||||
return {
|
||||
"test_name": name,
|
||||
"peak_memory_bytes": max_rss,
|
||||
}
|
||||
|
||||
data = profile("torch", "pytorch")
|
||||
baseline = profile("sys", "baseline")
|
||||
try:
|
||||
rds_write("import_stats", [data, baseline])
|
||||
except Exception as e:
|
||||
raise unittest.SkipTest(f"Failed to record import_stats: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if register_rds_schema and IS_CI:
|
||||
try:
|
||||
register_rds_schema(
|
||||
"import_stats",
|
||||
{
|
||||
"test_name": "string",
|
||||
"peak_memory_bytes": "int",
|
||||
"time_ms": "int",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to register RDS schema: {e}")
|
||||
|
||||
run_tests()
|
||||
|
@ -2,15 +2,12 @@ import base64
|
||||
import bz2
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Any, List, Union, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
_lambda_client = None
|
||||
|
||||
|
||||
IS_GHA = bool(os.getenv("GITHUB_ACTIONS"))
|
||||
|
||||
|
||||
def sprint(*args: Any) -> None:
|
||||
print("[scribe]", *args)
|
||||
|
||||
@ -62,145 +59,3 @@ def _send_to_scribe_via_http(access_token: str, logs: str) -> str:
|
||||
)
|
||||
r.raise_for_status()
|
||||
return str(r.text)
|
||||
|
||||
|
||||
def invoke_rds(events: List[Dict[str, Any]]) -> Any:
|
||||
if not IS_GHA:
|
||||
sprint(f"Not invoking RDS lambda outside GitHub Actions:\n{events}")
|
||||
return
|
||||
|
||||
return invoke_lambda("rds-proxy", events)
|
||||
|
||||
|
||||
def register_rds_schema(table_name: str, schema: Dict[str, str]) -> None:
|
||||
"""
|
||||
Register a table in RDS so it can be written to later on with 'rds_write'.
|
||||
'schema' should be a mapping of field names -> types, where supported types
|
||||
are 'int' and 'string'.
|
||||
|
||||
Metadata fields such as pr, ref, branch, workflow_id, and build_environment
|
||||
will be added automatically.
|
||||
"""
|
||||
base = {
|
||||
"pr": "string",
|
||||
"ref": "string",
|
||||
"branch": "string",
|
||||
"workflow_id": "string",
|
||||
"build_environment": "string",
|
||||
}
|
||||
|
||||
event = [{"create_table": {"table_name": table_name, "fields": {**schema, **base}}}]
|
||||
|
||||
invoke_rds(event)
|
||||
|
||||
|
||||
def schema_from_sample(data: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""
|
||||
Extract a schema compatible with 'register_rds_schema' from data.
|
||||
"""
|
||||
schema = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
schema[key] = "string"
|
||||
elif isinstance(value, int):
|
||||
schema[key] = "int"
|
||||
elif isinstance(value, float):
|
||||
schema[key] = "float"
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported value type: {key}: {value}")
|
||||
return schema
|
||||
|
||||
|
||||
Query = Dict[str, Any]
|
||||
|
||||
|
||||
def rds_query(queries: Union[Query, List[Query]]) -> Any:
|
||||
"""
|
||||
Execute a simple read query on RDS. Queries should be of the form below,
|
||||
where everything except 'table_name' and 'fields' is optional.
|
||||
|
||||
{
|
||||
"table_name": "my_table",
|
||||
"fields": ["something", "something_else"],
|
||||
"where": [
|
||||
{
|
||||
"field": "something",
|
||||
"value": 10
|
||||
}
|
||||
],
|
||||
"group_by": ["something"],
|
||||
"order_by": ["something"],
|
||||
"limit": 5,
|
||||
}
|
||||
"""
|
||||
if not isinstance(queries, list):
|
||||
queries = [queries]
|
||||
|
||||
events = []
|
||||
for query in queries:
|
||||
events.append({"read": {**query}})
|
||||
|
||||
return invoke_rds(events)
|
||||
|
||||
|
||||
def rds_saved_query(query_names: Union[str, List[str]]) -> Any:
|
||||
"""
|
||||
Execute a hardcoded RDS query by name. See
|
||||
https://github.com/pytorch/test-infra/blob/main/aws/lambda/rds-proxy/lambda_function.py#L52
|
||||
for available queries or submit a PR there to add a new one.
|
||||
"""
|
||||
if not isinstance(query_names, list):
|
||||
query_names = [query_names]
|
||||
|
||||
events = []
|
||||
for name in query_names:
|
||||
events.append({"read": {"saved_query_name": name}})
|
||||
|
||||
return invoke_rds(events)
|
||||
|
||||
|
||||
def rds_write(
|
||||
table_name: str,
|
||||
values_list: List[Dict[str, Any]],
|
||||
only_on_master: bool = True,
|
||||
only_on_jobs: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Note: Only works from GitHub Actions CI runners
|
||||
|
||||
Write a set of entries to a particular RDS table. 'table_name' should be
|
||||
a table registered via 'register_rds_schema' prior to calling rds_write.
|
||||
'values_list' should be a list of dictionaries that map field names to
|
||||
values.
|
||||
"""
|
||||
sprint("Writing for", os.getenv("PR_NUMBER"))
|
||||
is_master = os.getenv("PR_NUMBER", "").strip() == ""
|
||||
if only_on_master and not is_master:
|
||||
sprint("Skipping RDS write on PR")
|
||||
return
|
||||
|
||||
pr = os.getenv("PR_NUMBER", None)
|
||||
if pr is not None and pr.strip() == "":
|
||||
pr = None
|
||||
|
||||
build_environment = os.environ.get("BUILD_ENVIRONMENT", "").split()[0]
|
||||
if only_on_jobs is not None and build_environment not in only_on_jobs:
|
||||
sprint(f"Skipping write since {build_environment} is not in {only_on_jobs}")
|
||||
return
|
||||
|
||||
base = {
|
||||
"pr": pr,
|
||||
"ref": os.getenv("SHA1"),
|
||||
"branch": os.getenv("BRANCH"),
|
||||
"workflow_id": os.getenv("GITHUB_RUN_ID"),
|
||||
"build_environment": build_environment,
|
||||
}
|
||||
|
||||
events = []
|
||||
for values in values_list:
|
||||
events.append(
|
||||
{"write": {"table_name": table_name, "values": {**values, **base}}}
|
||||
)
|
||||
|
||||
sprint("Wrote stats for", table_name)
|
||||
invoke_rds(events)
|
||||
|
Reference in New Issue
Block a user