diff --git a/test/test_import_stats.py b/test/test_import_stats.py index a0ce4707bd1a..bebd291dfa3a 100644 --- a/test/test_import_stats.py +++ b/test/test_import_stats.py @@ -1,28 +1,11 @@ # Owner(s): ["module: ci"] -import subprocess -import sys -import unittest -import pathlib - -from torch.testing._internal.common_utils import TestCase, run_tests, IS_LINUX, IS_CI - - -REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent - -try: - # Just in case PyTorch was not built in 'develop' mode - sys.path.append(str(REPO_ROOT)) - from tools.stats.scribe import rds_write, register_rds_schema -except ImportError: - register_rds_schema = None - rds_write = None +from torch.testing._internal.common_utils import TestCase, run_tests # these tests could eventually be changed to fail if the import/init # time is greater than a certain threshold, but for now we just use them -# as a way to track the duration of `import torch` in our ossci-metrics -# S3 bucket (see tools/stats/print_test_stats.py) +# as a way to track the duration of `import torch`. class TestImportTime(TestCase): def test_time_import_torch(self): TestCase.runWithPytorchAPIUsageStderr("import torch") @@ -32,43 +15,6 @@ class TestImportTime(TestCase): "import torch; torch.cuda.device_count()", ) - @unittest.skipIf(not IS_LINUX, "Memory test is only implemented for Linux") - @unittest.skipIf(not IS_CI, "Memory test only runs in CI") - @unittest.skipIf(rds_write is None, "Cannot import rds_write from tools.stats.scribe") - def test_peak_memory(self): - def profile(module, name): - command = f"import {module}; import resource; print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)" - result = subprocess.run( - [sys.executable, "-c", command], - stdout=subprocess.PIPE, - ) - max_rss = int(result.stdout.decode().strip()) - - return { - "test_name": name, - "peak_memory_bytes": max_rss, - } - - data = profile("torch", "pytorch") - baseline = profile("sys", "baseline") - try: - rds_write("import_stats", [data, baseline]) - except Exception as e: - raise unittest.SkipTest(f"Failed to record import_stats: {e}") - if __name__ == "__main__": - if register_rds_schema and IS_CI: - try: - register_rds_schema( - "import_stats", - { - "test_name": "string", - "peak_memory_bytes": "int", - "time_ms": "int", - }, - ) - except Exception as e: - print(f"Failed to register RDS schema: {e}") - run_tests() diff --git a/tools/stats/scribe.py b/tools/stats/scribe.py index 80da95dd69bd..47a8a819206c 100644 --- a/tools/stats/scribe.py +++ b/tools/stats/scribe.py @@ -2,15 +2,12 @@ import base64 import bz2 import os import json -from typing import Dict, Any, List, Union, Optional +from typing import Any _lambda_client = None -IS_GHA = bool(os.getenv("GITHUB_ACTIONS")) - - def sprint(*args: Any) -> None: print("[scribe]", *args) @@ -62,145 +59,3 @@ def _send_to_scribe_via_http(access_token: str, logs: str) -> str: ) r.raise_for_status() return str(r.text) - - -def invoke_rds(events: List[Dict[str, Any]]) -> Any: - if not IS_GHA: - sprint(f"Not invoking RDS lambda outside GitHub Actions:\n{events}") - return - - return invoke_lambda("rds-proxy", events) - - -def register_rds_schema(table_name: str, schema: Dict[str, str]) -> None: - """ - Register a table in RDS so it can be written to later on with 'rds_write'. - 'schema' should be a mapping of field names -> types, where supported types - are 'int' and 'string'. - - Metadata fields such as pr, ref, branch, workflow_id, and build_environment - will be added automatically. - """ - base = { - "pr": "string", - "ref": "string", - "branch": "string", - "workflow_id": "string", - "build_environment": "string", - } - - event = [{"create_table": {"table_name": table_name, "fields": {**schema, **base}}}] - - invoke_rds(event) - - -def schema_from_sample(data: Dict[str, Any]) -> Dict[str, str]: - """ - Extract a schema compatible with 'register_rds_schema' from data. - """ - schema = {} - for key, value in data.items(): - if isinstance(value, str): - schema[key] = "string" - elif isinstance(value, int): - schema[key] = "int" - elif isinstance(value, float): - schema[key] = "float" - else: - raise RuntimeError(f"Unsupported value type: {key}: {value}") - return schema - - -Query = Dict[str, Any] - - -def rds_query(queries: Union[Query, List[Query]]) -> Any: - """ - Execute a simple read query on RDS. Queries should be of the form below, - where everything except 'table_name' and 'fields' is optional. - - { - "table_name": "my_table", - "fields": ["something", "something_else"], - "where": [ - { - "field": "something", - "value": 10 - } - ], - "group_by": ["something"], - "order_by": ["something"], - "limit": 5, - } - """ - if not isinstance(queries, list): - queries = [queries] - - events = [] - for query in queries: - events.append({"read": {**query}}) - - return invoke_rds(events) - - -def rds_saved_query(query_names: Union[str, List[str]]) -> Any: - """ - Execute a hardcoded RDS query by name. See - https://github.com/pytorch/test-infra/blob/main/aws/lambda/rds-proxy/lambda_function.py#L52 - for available queries or submit a PR there to add a new one. - """ - if not isinstance(query_names, list): - query_names = [query_names] - - events = [] - for name in query_names: - events.append({"read": {"saved_query_name": name}}) - - return invoke_rds(events) - - -def rds_write( - table_name: str, - values_list: List[Dict[str, Any]], - only_on_master: bool = True, - only_on_jobs: Optional[List[str]] = None, -) -> None: - """ - Note: Only works from GitHub Actions CI runners - - Write a set of entries to a particular RDS table. 'table_name' should be - a table registered via 'register_rds_schema' prior to calling rds_write. - 'values_list' should be a list of dictionaries that map field names to - values. - """ - sprint("Writing for", os.getenv("PR_NUMBER")) - is_master = os.getenv("PR_NUMBER", "").strip() == "" - if only_on_master and not is_master: - sprint("Skipping RDS write on PR") - return - - pr = os.getenv("PR_NUMBER", None) - if pr is not None and pr.strip() == "": - pr = None - - build_environment = os.environ.get("BUILD_ENVIRONMENT", "").split()[0] - if only_on_jobs is not None and build_environment not in only_on_jobs: - sprint(f"Skipping write since {build_environment} is not in {only_on_jobs}") - return - - base = { - "pr": pr, - "ref": os.getenv("SHA1"), - "branch": os.getenv("BRANCH"), - "workflow_id": os.getenv("GITHUB_RUN_ID"), - "build_environment": build_environment, - } - - events = [] - for values in values_list: - events.append( - {"write": {"table_name": table_name, "values": {**values, **base}}} - ) - - sprint("Wrote stats for", table_name) - invoke_rds(events)