Compare commits

...

1 Commits

Author SHA1 Message Date
4553a1e19c Testing metaconda
Summary: Testing metaconda
2025-09-30 10:21:16 -07:00

View File

@ -1,9 +1,13 @@
# mypy: allow-untyped-defs
import functools
import getpass
import json
import logging
import os
import socket
import sys
import tempfile
import time
import typing_extensions
from typing import Any, Callable, Optional, TypeVar
from typing_extensions import ParamSpec
@ -28,6 +32,12 @@ if os.environ.get("TORCH_COMPILE_STROBELIGHT", False):
log.info("Strobelight profiler is enabled via environment variable")
StrobelightCompileTimeProfiler.enable()
try:
from observability.structured_logging import Sample, ScubaData
except ImportError:
ScubaData = Sample = None
# this arbitrary-looking assortment of functionality is provided here
# to have a central place for overridable behavior. The motivating
# use is the FB build environment, where this source file is replaced
@ -121,8 +131,8 @@ def add_mlhub_insight(category: str, insight: str, insight_description: str):
pass
def log_compilation_event(metrics):
log.info("%s", metrics)
# def log_compilation_event(metrics):
# log.info("%s", metrics)
def upload_graph(graph):
@ -368,3 +378,105 @@ def find_compile_subproc_binary() -> Optional[str]:
Allows overriding the binary used for subprocesses
"""
return None
class CompilationEventScubaLogger:
"""
Meta only: Facilitates logging compilation events to scuba in environments
where we have observability.structured_logging available.
"""
SCUBA_TABLE = "dynamo_compile"
scuba_data: Optional[ScubaData] = ScubaData(SCUBA_TABLE) if ScubaData else None
@staticmethod
@functools.lru_cache(maxsize=1)
def get_runtime_env_kwargs():
"""
Helper to extract the various common fields from the environment.
"""
def get_as_int(name: str) -> Optional[int]:
try:
return int(os.environ.get(name))
except:
return None
tw_cluster = os.environ.get("TW_JOB_CLUSTER", None)
tw_user = os.environ.get("TW_JOB_USER", None)
tw_job = os.environ.get("TW_JOB_NAME", None)
tw_task = os.environ.get("TW_TASK_ID", None)
tw_task_handle = None
if tw_cluster or tw_user or tw_job or tw_task:
tw_task_handle = f"{tw_cluster}/{tw_user}/{tw_job}/{tw_task}"
try:
metadata = os.environ.get("AI_TRAINING_METADATA")
entitlement = json.loads(metadata)["entitlement"]
except:
entitlement = None
return {
"job_owner_unixname": os.environ.get("MAST_JOB_OWNER_UNIXNAME", None),
"job_attempt": get_as_int("MAST_HPC_JOB_ATTEMPT_INDEX"),
"global_rank": get_as_int("ROLE_RANK"),
"local_rank": get_as_int("LOCAL_RANK"),
"entitlement": entitlement,
"mast_user": os.environ.get("MAST_JOB_OWNER_UNIXNAME", None),
"mast_job_name": os.environ.get("MAST_HPC_JOB_NAME", None),
"mast_job_version": os.environ.get("MAST_HPC_JOB_VERSION", None),
"mast_job_attempt": os.environ.get("MAST_HPC_JOB_ATTEMPT_INDEX", None),
"mast_job_type": os.environ.get("MAST_JOB_TYPE", None),
"mast_cluster": os.environ.get("MAST_CLUSTER", None),
"model_type_name": os.environ.get("MODEL_TYPE_NAME", None),
"service_id": os.environ.get("FB_SERVICE_ID", None),
"tw_task_handle": tw_task_handle,
"tenant_priority": os.environ.get("TENANT_PRIORITY", None),
"username": getpass.getuser(),
"hostname": socket.gethostname(),
}
@classmethod
def sample_from_dict(cls, data: dict[str, Any]) -> Sample:
"""
Create a sample from a dict of compilation metrics.
"""
sample = Sample()
sample.add_int("time", int(time.time()))
for k, v in data.items():
if isinstance(v, int):
sample.add_int(k, v)
elif isinstance(v, str):
sample.add_normal(k, v)
elif isinstance(v, float):
sample.add_double(k, v)
elif isinstance(v, list):
sample.add_normvector(k, [str(e) for e in v])
elif isinstance(v, set):
sample.add_tags(k, {str(e) for e in v})
return sample
@classmethod
def log(cls, metrics: dict[str, Any]) -> None:
"""
Log compilation metrics to scuba.
"""
log.error("XYXY: starting compilation event")
assert cls.scuba_data
try:
data = {**cls.get_runtime_env_kwargs(), **metrics}
sample = cls.sample_from_dict(data)
cls.scuba_data.add_sample(sample)
log.error("XYXY: finished compilation event")
except Exception as e:
log.error("XYXY Failed to log compilation event: %s", e)
def log_compilation_event(metrics):
if ScubaData:
CompilationEventScubaLogger.log(vars(metrics))
else:
log.info("%s", metrics)