mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Summary: avoids PGO profile merges Test Plan: test_pgo Differential Revision: D83200714 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163810 Approved by: https://github.com/bobrenjc93
		
			
				
	
	
		
			980 lines
		
	
	
		
			35 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			980 lines
		
	
	
		
			35 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Profile Guided Optimization (PGO) implementation for Dynamo.
 | |
| 
 | |
| This module provides functionality for caching and managing code state profiles
 | |
| that guide optimization decisions in Dynamo. It implements both local and remote
 | |
| caching mechanisms for storing profile information across runs, handles profile
 | |
| merging across distributed ranks, and manages the lifecycle of profile data
 | |
| during compilation. The profiles track dynamic vs static properties of tensors
 | |
| and help Dynamo make better specialization decisions.
 | |
| """
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| import base64
 | |
| import copy
 | |
| import dataclasses
 | |
| import enum
 | |
| import functools
 | |
| import logging
 | |
| import os
 | |
| import pickle
 | |
| import re
 | |
| import zlib
 | |
| from collections import defaultdict
 | |
| from typing import Optional, TYPE_CHECKING, TypeVar, Union
 | |
| from typing_extensions import override, Self
 | |
| 
 | |
| import torch._dynamo.config
 | |
| import torch._utils_internal
 | |
| import torch.compiler.config
 | |
| import torch.distributed as dist
 | |
| from torch._dynamo.utils import (
 | |
|     CompileEventLogger,
 | |
|     dynamo_timed,
 | |
|     set_feature_use,
 | |
|     warn_once,
 | |
| )
 | |
| from torch._environment import is_fbcode
 | |
| from torch._logging._internal import trace_structured_artifact
 | |
| from torch.compiler._cache import (
 | |
|     CacheArtifact,
 | |
|     CacheArtifactFactory,
 | |
|     CacheArtifactManager,
 | |
| )
 | |
| from torch.utils._ordered_set import OrderedSet
 | |
| 
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     import types
 | |
| 
 | |
|     from torch._dynamo.symbolic_convert import InstructionTranslator
 | |
|     from torch._inductor.remote_cache import JsonDataTy, RemoteCache
 | |
| 
 | |
| 
 | |
| class ReservedWorkflowIdUserError(ValueError):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| log = logging.getLogger(__name__)
 | |
| 
 | |
| LOCK_TIMEOUT = 10
 | |
| 
 | |
| # How does in memory representation work?  Concretely, this module is
 | |
| # responsible for holding GLOBAL state representing the state it holds, no
 | |
| # other copies permitted.  So we retire frame_state entirely and store it
 | |
| # here.  This should be reset when Dynamo is reset.  We never GC information
 | |
| # (similar to how the filesystem doesn't get cleaned up except by tmp
 | |
| # cleaner), so the expectation is the information is relatively cheap and we
 | |
| # don't mind leaking it.
 | |
| 
 | |
| 
 | |
| # How exactly did we design the cache key?  Here are some of the questions:
 | |
| #
 | |
| # - JOB_ID: Do we have a unique identifier for the "training run"  (such that
 | |
| #   it stays the same if we're running the same code, and changes if we're
 | |
| #   running something different).
 | |
| #
 | |
| # - RANK: Are we sharing the cache across ranks, or does each rank get
 | |
| #   an individual cache?
 | |
| #
 | |
| # We choose to require job_id for PGO cache.  This is to prevent
 | |
| # situations where unrelated invocations of PyTorch unpredictably cause
 | |
| # changes to each other's behavior.  With a job_id, at least you know there
 | |
| # is some "state" associated with it.  (State dict might be another way to
 | |
| # tell if a run is related or not.)  You can opt-in to YOLO everything
 | |
| # aliases everything by passing a shared job_id for all your invocations.
 | |
| #
 | |
| # We choose to NOT share PGO cache across ranks.  With no RANK_SHARING, there
 | |
| # is never contention between runs, so we can leisurely update a bundle with
 | |
| # information we need.  Because we are grouped by job_id, we can have a single
 | |
| # consolidated bundle for everything (or not; maybe worry about O(n^2) IO if
 | |
| # we updated every compile--let's just instrument this.)  Can even take a
 | |
| # filelock for extra safety (expect no contention); expect 50ns overhead from
 | |
| # uncontended filelock.
 | |
| #
 | |
| # If we did share ranks, everyone is storming to modify the same cache files.
 | |
| # We can do this by having folks atomic write to a CAS-store and then having
 | |
| # readers do on-the-fly merging (this can be implemented in remote using
 | |
| # prefix iteration).  As an optional optimization, one rank can be elected to
 | |
| # handling bundling post facto (ideally, this is done async, after quiescence,
 | |
| # without compiler collective need to wait for everyone to finish writing
 | |
| # their bits.) Not sure how you can avoid a listdir because if some rank shows
 | |
| # up with some new entries we need to pull them in ASAP (unless you want to
 | |
| # delay bundling).
 | |
| #
 | |
| # But compiler collectives fill a similar niche:  compilers chat with each
 | |
| # other so rank 0 has collected everything.  So elect rank 0 only to write the
 | |
| # bundle.  Don't even need CAS-store atomic write; just one rank writing an
 | |
| # updating bundles.  The point is that use compiler collectives to share
 | |
| # profiles across ranks, but use the PGO cache to persist profiles per rank
 | |
| # across attempts.  No need to have one mechanism to do everything.
 | |
| 
 | |
| 
 | |
| @functools.cache
 | |
| def _hash_containing_file(filepath: str) -> str:
 | |
|     # if the file does not exists we consider filepath to be the hash.
 | |
|     if not os.path.exists(filepath):
 | |
|         return filepath
 | |
| 
 | |
|     with open(filepath, "rb") as file:
 | |
|         content = file.read()
 | |
|         crc32_value = zlib.crc32(content)
 | |
|         hash = format(crc32_value & 0xFFFFFFFF, "08x")
 | |
|         return hash
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass(frozen=True)
 | |
| class CodeId:
 | |
|     filename: str
 | |
|     firstlineno: int
 | |
|     name: str
 | |
|     # When a job restart, the code can be copied to a different path than the previous attempt. In that case
 | |
|     # self.filename will have a different value,  we do not want to consider those differences. Instead we
 | |
|     # hash the content of the file and use it as an identifier of the file.
 | |
|     #
 | |
|     # self.filename is kept in the object to give readable information/pointer to the actual file, in a local
 | |
|     # code state it will refer to the first seen file path.
 | |
|     file_hash: str
 | |
| 
 | |
|     # Exclude file name.
 | |
|     def __eq__(self, other: object) -> bool:
 | |
|         if not isinstance(other, CodeId):
 | |
|             return False
 | |
|         return (
 | |
|             self.file_hash == other.file_hash
 | |
|             and self.firstlineno == other.firstlineno
 | |
|             and self.name == other.name
 | |
|         )
 | |
| 
 | |
|     # Ensure if two CodeIds are the same, then they have the same hash by excluding filename.
 | |
|     def __hash__(self) -> int:
 | |
|         return hash((self.file_hash, self.name, self.firstlineno))
 | |
| 
 | |
|     def __str__(self) -> str:
 | |
|         return f"hash({self.file_hash}){self.filename}:{self.firstlineno}:{self.name}"
 | |
| 
 | |
|     @staticmethod
 | |
|     def make(code: types.CodeType) -> CodeId:
 | |
|         return CodeId(
 | |
|             code.co_filename,
 | |
|             code.co_firstlineno,
 | |
|             code.co_name,
 | |
|             _hash_containing_file(code.co_filename),
 | |
|         )
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class CodeState:
 | |
|     automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field(
 | |
|         default_factory=lambda: defaultdict(FrameStateSizeEntry)
 | |
|     )
 | |
| 
 | |
| 
 | |
| _INIT_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None
 | |
| _CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None
 | |
| _LOGGED_DYNAMIC_ALLOWLIST: bool = False
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass(frozen=True)
 | |
| class InferStride:
 | |
|     """
 | |
|     Denotes the quantity stride[dim] * size[dim], which is what the stride would
 | |
|     be for the next physical dimension that results in a contiguous layout.
 | |
| 
 | |
|     For example, given size = [2, 3], stride = [3, 1], we can replace this with
 | |
|     stride = [InferStride(1), 1], because InferStride(1) = stride[1] * size[1] = 1 * 3 = 3
 | |
| 
 | |
|     Indirecting the representation in this way is important for the join operation
 | |
|     on strides as if we join [2, 3][3, 1] and [2, 4][4, 1],
 | |
|     we don't want [2, None][None, 1] which would get eventually symbolized into
 | |
|     [2, s0][s1, 1] (notice that the relationship between s0 and s1 is broken).
 | |
|     If we instead rewrite the expressions as InferStride so we have [2, 3][InferStride(1), 1]
 | |
|     and [2, 4][InferStride(1), 1] we now join to [2, None][InferStride(1), 1] will
 | |
|     result in [2, s0][s0, 1], as desired.
 | |
|     """
 | |
| 
 | |
|     dim: int
 | |
| 
 | |
| 
 | |
| _T = TypeVar("_T")
 | |
| 
 | |
| 
 | |
| class AutoUnset(enum.Enum):
 | |
|     """
 | |
|     The identity element of our semilattice, a generic "don't know" element that
 | |
|     is always subsumed when we get more information.
 | |
|     """
 | |
| 
 | |
|     token = 0
 | |
| 
 | |
| 
 | |
| auto_unset = AutoUnset.token
 | |
| 
 | |
| 
 | |
| class AutoDynamic(enum.Enum):
 | |
|     """
 | |
|     The top element of our (bounded) semilattice, whenever you merge this with
 | |
|     any other element you always get it again
 | |
|     """
 | |
| 
 | |
|     token = 0
 | |
| 
 | |
| 
 | |
| auto_dynamic = AutoDynamic.token
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class FrameStateSizeEntry:
 | |
|     scalar: Union[int, AutoDynamic, AutoUnset] = dataclasses.field(default=auto_unset)
 | |
|     # NB: We don't have cases where we have a known dimensionality but
 | |
|     # we know NOTHING about the individual sizes
 | |
|     size: Union[AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic], ...]] = (
 | |
|         dataclasses.field(default=auto_unset)
 | |
|     )
 | |
|     stride: Union[
 | |
|         AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic, InferStride], ...]
 | |
|     ] = dataclasses.field(default=auto_unset)
 | |
| 
 | |
|     def render(self) -> str:
 | |
|         # Special cases
 | |
|         def render_single(s: Union[int, AutoDynamic, AutoUnset, InferStride]) -> str:
 | |
|             if s is auto_dynamic:
 | |
|                 return "?"
 | |
|             elif s is auto_unset:
 | |
|                 # This basically shouldn't happen, this is for debugging
 | |
|                 return "auto unset"
 | |
|             elif isinstance(s, InferStride):
 | |
|                 return f"S({s.dim})"
 | |
|             else:
 | |
|                 return str(s)
 | |
| 
 | |
|         def render_tuple(ss: tuple[Union[int, AutoDynamic, InferStride], ...]) -> str:
 | |
|             return "[" + ", ".join(render_single(s) for s in ss) + "]"
 | |
| 
 | |
|         # Common cases
 | |
|         if self.size is auto_dynamic and self.stride is auto_dynamic:
 | |
|             if self.scalar is auto_dynamic:
 | |
|                 return "fully dynamic scalar or tensor"
 | |
|             else:
 | |
|                 return f"scalar {self.scalar}"
 | |
|         elif self.scalar is auto_dynamic:
 | |
|             if isinstance(self.size, tuple) and isinstance(self.stride, tuple):
 | |
|                 return f"tensor size={render_tuple(self.size)} stride={render_tuple(self.stride)}"
 | |
| 
 | |
|         # Fallback
 | |
|         return "unusual {repr(self)}"
 | |
| 
 | |
|     def __post_init__(self) -> None:
 | |
|         assert not isinstance(self.scalar, torch.SymInt), self.scalar
 | |
|         if isinstance(self.size, tuple):
 | |
|             for s in self.size:
 | |
|                 assert not isinstance(s, torch.SymInt), s
 | |
|         if isinstance(self.stride, tuple):
 | |
|             for s1 in self.stride:
 | |
|                 assert not isinstance(s1, torch.SymInt), s1
 | |
| 
 | |
|     def is_size_dynamic(self, dim: int) -> bool:
 | |
|         if self.size is auto_dynamic:
 | |
|             return True
 | |
|         if self.size is auto_unset:
 | |
|             return False
 | |
|         return self.size[dim] is auto_dynamic
 | |
| 
 | |
|     def is_stride_dynamic(self, dim: int) -> bool:
 | |
|         # At the moment, dynamic strides is a bit buggy.  Good test case
 | |
|         # here is `PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py
 | |
|         # TestAutograd.test_gradcheck_jacobian_mismatch`
 | |
|         #
 | |
|         # This if statement preserves historical behavior, which is that we
 | |
|         # ONLY make strides dynamic if the size is exactly static everywhere.
 | |
|         # We could potentially relax this but in general we should be very
 | |
|         # careful about when to infer dynamic strides.
 | |
|         #
 | |
|         # Actually, the existing algorithm is already somewhat problematic.
 | |
|         # Suppose a tensor that is sometimes:
 | |
|         # f32[2, 3, 5][15, 5, 1] and other times
 | |
|         # f32[2, 3, 5][5, 10, 1] (specifically, dim 0 and 1 are physically transposed).
 | |
|         # If we infer strides should be (DYNAMIC, DYNAMIC, 1).  But this is
 | |
|         # silly: we really should have just guarded on dim order.
 | |
|         if not (
 | |
|             isinstance(self.size, tuple) and all(type(s) is int for s in self.size)
 | |
|         ):
 | |
|             return False
 | |
|         if self.stride is auto_dynamic:
 | |
|             return True
 | |
|         if self.stride is auto_unset:
 | |
|             return False
 | |
|         return self.stride[dim] is auto_dynamic
 | |
| 
 | |
|     @staticmethod
 | |
|     def _munge_symint(xs: tuple[int, ...]) -> tuple[Union[AutoDynamic, int], ...]:
 | |
|         return tuple(auto_dynamic if isinstance(x, torch.SymInt) else x for x in xs)
 | |
| 
 | |
|     @classmethod
 | |
|     def make_scalar(cls, x: int) -> FrameStateSizeEntry:
 | |
|         return FrameStateSizeEntry(scalar=x, size=auto_dynamic, stride=auto_dynamic)
 | |
| 
 | |
|     @classmethod
 | |
|     def make_tensor(
 | |
|         cls, size: tuple[int, ...], stride: tuple[int, ...]
 | |
|     ) -> FrameStateSizeEntry:
 | |
|         return FrameStateSizeEntry(
 | |
|             scalar=auto_dynamic,
 | |
|             size=cls._munge_symint(size),
 | |
|             stride=cls._munge_symint(stride),
 | |
|         )
 | |
| 
 | |
|     @classmethod
 | |
|     def make_size(cls, size: tuple[int, ...]) -> FrameStateSizeEntry:
 | |
|         return FrameStateSizeEntry(
 | |
|             scalar=auto_unset,
 | |
|             size=cls._munge_symint(size),
 | |
|             stride=auto_unset,
 | |
|         )
 | |
| 
 | |
|     @staticmethod
 | |
|     def _merge_atom(x: _T, y: _T) -> Union[AutoDynamic, _T]:
 | |
|         if x is auto_unset:
 | |
|             return y
 | |
|         if y is auto_unset:
 | |
|             return x
 | |
|         if x is auto_dynamic or y is auto_dynamic or x != y:
 | |
|             return auto_dynamic
 | |
|         return x
 | |
| 
 | |
|     @classmethod
 | |
|     def _merge_atom_tup(
 | |
|         cls,
 | |
|         xs: Union[AutoDynamic, AutoUnset, tuple[_T, ...]],
 | |
|         ys: Union[AutoDynamic, AutoUnset, tuple[_T, ...]],
 | |
|     ) -> Union[AutoDynamic, AutoUnset, tuple[Union[AutoDynamic, _T], ...]]:
 | |
|         if xs is auto_unset:
 | |
|             return ys
 | |
|         if ys is auto_unset:
 | |
|             return xs
 | |
|         if xs is auto_dynamic or ys is auto_dynamic:
 | |
|             return auto_dynamic
 | |
|         if len(xs) != len(ys):
 | |
|             return auto_dynamic
 | |
|         return tuple(cls._merge_atom(x, y) for x, y in zip(xs, ys))
 | |
| 
 | |
|     def __ior__(self, other: Self) -> Self:
 | |
|         self.scalar = self._merge_atom(self.scalar, other.scalar)
 | |
|         self.size = self._merge_atom_tup(self.size, other.size)
 | |
|         self.stride = self._merge_atom_tup(self.stride, other.stride)
 | |
|         return self
 | |
| 
 | |
| 
 | |
| def update_automatic_dynamic(
 | |
|     tx: InstructionTranslator,
 | |
|     name: str,
 | |
|     entry: FrameStateSizeEntry,
 | |
|     *,
 | |
|     is_unspecialized_nn_module: bool = False,
 | |
| ) -> FrameStateSizeEntry:
 | |
|     code_id = CodeId.make(tx.f_code)
 | |
|     frame_state = get_code_state()[code_id]
 | |
|     if torch._dynamo.config.automatic_dynamic_shapes:
 | |
|         is_update = name in frame_state.automatic_dynamic
 | |
|         mut_entry = frame_state.automatic_dynamic[name]
 | |
|         old_entry = copy.copy(mut_entry)
 | |
|         mut_entry |= entry
 | |
| 
 | |
|         # Do some logs (damn, I spend more code logging than I do actually doing
 | |
|         # the updates lol)
 | |
|         if is_update and old_entry.scalar != mut_entry.scalar:
 | |
|             log.debug(
 | |
|                 "automatic dynamic int %s val %s != %s",
 | |
|                 name,
 | |
|                 entry.scalar,
 | |
|                 old_entry.scalar,
 | |
|             )
 | |
|             CompileEventLogger.instant(
 | |
|                 "automatic_dynamic",
 | |
|                 {
 | |
|                     "name": name,
 | |
|                     "dim_changed": "scalar",
 | |
|                     "reason": "scalar change",
 | |
|                     "cached": str(old_entry.scalar),
 | |
|                     "new": str(entry.scalar),
 | |
|                 },
 | |
|             )
 | |
|             if is_unspecialized_nn_module:
 | |
|                 log.info(
 | |
|                     "%s is converted to a symbolic integer. It is an attribute of a "
 | |
|                     "user defined nn module class. If you wish to keep it static, you can "
 | |
|                     "mark the nn module class as `torch._dynamo.mark_static`.",
 | |
|                     name,
 | |
|                 )
 | |
| 
 | |
|         def log_tup(
 | |
|             tup_name: str, short_reason: str, long_reason: str, i: Optional[int] = None
 | |
|         ) -> None:
 | |
|             entry_tup = (
 | |
|                 getattr(entry, tup_name) if i is None else getattr(entry, tup_name)[i]
 | |
|             )
 | |
|             old_entry_tup = (
 | |
|                 getattr(old_entry, tup_name)
 | |
|                 if i is None
 | |
|                 else getattr(old_entry, tup_name)[i]
 | |
|             )
 | |
|             log.debug(
 | |
|                 "automatic dynamic %s %s %s %s != %s",
 | |
|                 tup_name,
 | |
|                 name,
 | |
|                 short_reason,
 | |
|                 # NB: We used to only report len(...) here for dim mismatch
 | |
|                 entry_tup,
 | |
|                 old_entry_tup,
 | |
|             )
 | |
|             CompileEventLogger.instant(
 | |
|                 "automatic_dynamic",
 | |
|                 {
 | |
|                     "name": name,
 | |
|                     "dim_changed": "all" if i is None else i,
 | |
|                     "reason": long_reason,
 | |
|                     "cached": str(old_entry_tup),
 | |
|                     "new": str(entry_tup),
 | |
|                 },
 | |
|             )
 | |
| 
 | |
|         if is_update and old_entry.size != mut_entry.size:
 | |
|             if isinstance(old_entry.size, tuple) and isinstance(entry.size, tuple):
 | |
|                 if len(old_entry.size) != len(entry.size):
 | |
|                     log_tup("size", "dim", "dimensionality change")
 | |
|                 else:
 | |
|                     for i in range(len(entry.size)):
 | |
|                         if old_entry.size[i] != entry.size[i]:
 | |
|                             log_tup("size", f"size({i})", "size change", i)
 | |
|             else:
 | |
|                 log_tup("size", "other", "other")
 | |
| 
 | |
|         if is_update and old_entry.stride != mut_entry.stride:
 | |
|             if isinstance(old_entry.stride, tuple) and isinstance(entry.stride, tuple):
 | |
|                 if len(old_entry.stride) != len(entry.stride):
 | |
|                     log_tup("stride", "dim", "dimensionality change")
 | |
|                 else:
 | |
|                     for i in range(len(entry.stride)):
 | |
|                         if old_entry.stride[i] != entry.stride[i]:
 | |
|                             log_tup("stride", f"stride({i})", "stride change", i)
 | |
|             else:
 | |
|                 log_tup("stride", "other", "other")
 | |
|     else:
 | |
|         old_entry = frame_state.automatic_dynamic[name]
 | |
|         log.debug(
 | |
|             "automatic dynamic is off, overwriting int %s val %s -> %s",
 | |
|             name,
 | |
|             old_entry.scalar,
 | |
|             entry.scalar,
 | |
|         )
 | |
|         frame_state.automatic_dynamic[name] = entry
 | |
|         mut_entry = entry
 | |
| 
 | |
|     return mut_entry
 | |
| 
 | |
| 
 | |
| def process_automatic_dynamic(
 | |
|     tx: InstructionTranslator,
 | |
|     name: str,
 | |
|     entry: FrameStateSizeEntry,
 | |
|     *,
 | |
|     is_unspecialized_nn_module: bool = False,
 | |
| ) -> FrameStateSizeEntry:
 | |
|     if (st := tx.distributed_state) is None:
 | |
|         return update_automatic_dynamic(
 | |
|             tx,
 | |
|             name,
 | |
|             entry,
 | |
|             is_unspecialized_nn_module=is_unspecialized_nn_module,
 | |
|         )
 | |
|     elif st.all_states is None:
 | |
|         # Preflight, always pretend as if it's static.  The point here
 | |
|         # is we want to get through the preflight quickly, and static
 | |
|         # will run faster.  The preexisting frame state will get
 | |
|         # applied anyway after we do compiler collectives.
 | |
|         # TODO: I'm not sure if we should just bong the entire pgo
 | |
|         # state here, it kind of depends if we're going to have other
 | |
|         # things that talk in compiler collective.  Also, the PGO
 | |
|         # state, if we've already inferred something is automatic
 | |
|         # dynamic, will have lost the actual input sizes, which might
 | |
|         # be useful for debugging purposes (e.g., observing 0/1
 | |
|         # specialization).  Bonging the entire PGO state here would
 | |
|         # let us delete this logic here; the compiler collective
 | |
|         # would just directly update_automatic_dynamic
 | |
|         st.local_state.automatic_dynamic[name] = entry
 | |
|         return entry
 | |
|     else:
 | |
|         # Apply the updates.  NB: all_states includes the local state
 | |
|         # too.
 | |
|         res = None
 | |
|         for sub_state in st.all_states:
 | |
|             if name in sub_state.automatic_dynamic:
 | |
|                 res = update_automatic_dynamic(
 | |
|                     tx,
 | |
|                     name,
 | |
|                     sub_state.automatic_dynamic[name],
 | |
|                     is_unspecialized_nn_module=is_unspecialized_nn_module,
 | |
|                 )
 | |
|         assert res is not None
 | |
|         return res
 | |
| 
 | |
| 
 | |
| def format_cache_key(key: str) -> str:
 | |
|     # NB: We always use global rank for keys, even though they are overkill
 | |
|     # for local only cache
 | |
|     rank = None
 | |
|     if dist.is_available() and dist.is_initialized():
 | |
|         rank = dist.get_rank()
 | |
| 
 | |
|     tag = torch.compiler.config.cache_key_tag
 | |
|     return f"{key}:{rank}:{tag}"
 | |
| 
 | |
| 
 | |
| def get_cache_key() -> Optional[str]:
 | |
|     # TODO: info versions of these logs that log only once
 | |
|     if torch.compiler.config.force_disable_caches:
 | |
|         warn_once(
 | |
|             "dynamo_pgo force disabled by torch.compiler.config.force_disable_caches"
 | |
|         )
 | |
|         return None
 | |
| 
 | |
|     # NB: We namespace the cache keys so that only user-specified job id
 | |
|     # can alias with each other.
 | |
|     if (r := torch.compiler.config.job_id) is not None:
 | |
|         if r.startswith("mast:"):
 | |
|             raise ReservedWorkflowIdUserError(
 | |
|                 "torch.compiler.config.job_id with prefix 'mast:' is reserved for "
 | |
|                 "automatically generated job id associated with a specific MAST job "
 | |
|                 "name and version."
 | |
|             )
 | |
|         return format_cache_key(r)
 | |
| 
 | |
|     if (name_version := torch._utils_internal.get_mast_job_name_version()) is not None:
 | |
|         mast_job_name, mast_job_version = name_version
 | |
|         return format_cache_key(f"mast:{mast_job_name}:{mast_job_version}")
 | |
| 
 | |
|     return None
 | |
| 
 | |
| 
 | |
| def get_extra_cache_key(sticky_key: str) -> Optional[str]:
 | |
|     if torch.compiler.config.force_disable_caches:
 | |
|         warn_once(
 | |
|             "dynamo_pgo force disabled by torch.compiler.config.force_disable_caches"
 | |
|         )
 | |
|         return None
 | |
| 
 | |
|     return format_cache_key(sticky_key)
 | |
| 
 | |
| 
 | |
| # This solely controls local PGO
 | |
| def code_state_path(cache_key: str) -> Optional[str]:
 | |
|     if not torch._dynamo.config.automatic_dynamic_local_pgo:
 | |
|         log.debug("automatic_dynamic_local_pgo not enabled")
 | |
|         return None
 | |
| 
 | |
|     from torch._inductor.runtime.runtime_utils import cache_dir
 | |
| 
 | |
|     code_state_key = re.sub(r'[<>:"/\\|?*]', "_", f"code_state_{cache_key}.pkl")
 | |
|     return os.path.join(cache_dir(), "dynamo", code_state_key)
 | |
| 
 | |
| 
 | |
| def should_use_remote_dynamo_pgo_cache() -> bool:
 | |
|     if torch.compiler.config.force_disable_caches:
 | |
|         return False
 | |
| 
 | |
|     if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None:
 | |
|         return r
 | |
| 
 | |
|     if not is_fbcode():
 | |
|         return False
 | |
| 
 | |
|     if torch._utils_internal.is_fb_unit_test():
 | |
|         return False
 | |
| 
 | |
|     try:
 | |
|         from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
 | |
|     except ModuleNotFoundError:
 | |
|         return False
 | |
| 
 | |
|     return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
 | |
|         "pytorch/remote_cache:dynamo_pgo_version"
 | |
|     )
 | |
| 
 | |
| 
 | |
| def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
 | |
|     from torch._inductor.remote_cache import create_cache
 | |
| 
 | |
|     if not should_use_remote_dynamo_pgo_cache():
 | |
|         return None
 | |
| 
 | |
|     return create_cache(
 | |
|         "dynamo-pgo",
 | |
|         is_fbcode(),
 | |
|         "FbRemoteDynamoPGOCache",
 | |
|         "RemoteDynamoPGOCache",
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _collect_dynamic_sources(code_state: CodeState) -> OrderedSet[str]:
 | |
|     dynamic_sources: OrderedSet[str] = OrderedSet()
 | |
|     for src, fs in code_state.automatic_dynamic.items():
 | |
|         dynamic = False
 | |
|         if isinstance(fs.size, tuple):
 | |
|             dynamic = auto_dynamic in fs.size  # type: ignore[operator]
 | |
|         elif fs.scalar == auto_dynamic:
 | |
|             dynamic = True
 | |
|         if dynamic:
 | |
|             dynamic_sources.add(src)
 | |
|     return dynamic_sources
 | |
| 
 | |
| 
 | |
| def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None:
 | |
|     global _LOGGED_DYNAMIC_ALLOWLIST
 | |
|     code_id = CodeId.make(f_code)
 | |
|     frame_state = get_code_state()[code_id]
 | |
|     frame_whitelist = ",".join(_collect_dynamic_sources(frame_state))
 | |
|     if frame_whitelist:
 | |
|         with dynamo_timed(name := "pgo.dynamic_whitelist", log_pt2_compile_event=True):
 | |
|             CompileEventLogger.pt2_compile(
 | |
|                 name, recompile_dynamic_whitelist=frame_whitelist
 | |
|             )
 | |
|         if not _LOGGED_DYNAMIC_ALLOWLIST:
 | |
|             torch._utils_internal.add_mlhub_insight(
 | |
|                 category="dynamic_shapes_analysis",
 | |
|                 insight="Dynamic shape recompilation detected",
 | |
|                 insight_description="PGO detected a recompilation due to dynamic shapes. \
 | |
|                 Please follow the instruction from the action link to reduce \
 | |
|                 recompilation overhead.",
 | |
|             )
 | |
|             # add mlhub insight only once per rank
 | |
|             _LOGGED_DYNAMIC_ALLOWLIST = True
 | |
| 
 | |
| 
 | |
| def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str:
 | |
|     code_state_str = "\n".join(
 | |
|         f"{k}:\n"
 | |
|         + "\n".join(
 | |
|             f"  {src}: {fs.render()}" for src, fs in v.automatic_dynamic.items()
 | |
|         )
 | |
|         for k, v in cs.items()
 | |
|     )
 | |
|     dynamic_sources: OrderedSet[str] = OrderedSet()
 | |
|     for state in cs.values():
 | |
|         dynamic_sources.update(_collect_dynamic_sources(state))
 | |
|     if dynamic_sources:
 | |
|         code_state_str += (
 | |
|             "\n\nPGO detected a recompilation due to dynamic shapes. "
 | |
|             "To reduce shape recompilations by compiling dynamically to start, "
 | |
|             f'set environment variable TORCH_COMPILE_DYNAMIC_SOURCES="{",".join(dynamic_sources)}"'
 | |
|         )
 | |
|     return code_state_str
 | |
| 
 | |
| 
 | |
| @CacheArtifactFactory.register
 | |
| class PGOCacheArtifact(CacheArtifact):
 | |
|     @override
 | |
|     def populate_cache(self) -> None:
 | |
|         meta = write_local_impl(
 | |
|             self._rewrite_cache_key_for_mega_cache(self.key), self.content
 | |
|         )
 | |
|         assert meta is not None
 | |
| 
 | |
|     @override
 | |
|     @staticmethod
 | |
|     def type() -> str:
 | |
|         return "pgo"
 | |
| 
 | |
|     @staticmethod
 | |
|     def _rewrite_cache_key_for_mega_cache(original_key: str) -> str:
 | |
|         """
 | |
|         The PGO cache artifact key for a MAST job contains the job name and the version.
 | |
|         When we want to use the cache artifact on a different MAST job, we need to
 | |
|         update the key to use the new MAST job's name and version.
 | |
|         """
 | |
|         if not original_key.startswith("mast:"):
 | |
|             # if original_key is overridden, then dont change it
 | |
|             return original_key
 | |
|         if (new_key := get_cache_key()) is not None:
 | |
|             return new_key
 | |
|         return original_key
 | |
| 
 | |
| 
 | |
| def hit(key: str, ty: str) -> defaultdict[CodeId, CodeState]:
 | |
|     global _INIT_CODE_STATE
 | |
|     assert isinstance(_CODE_STATE, defaultdict)
 | |
|     log.info("get_code_state %s hit %s, %d entries", key, ty, len(_CODE_STATE))
 | |
|     trace_structured_artifact(
 | |
|         f"get_{ty}_code_state",
 | |
|         "string",
 | |
|         lambda: render_code_state(_CODE_STATE),  # type: ignore[arg-type]
 | |
|     )
 | |
|     set_feature_use("pgo", True)
 | |
|     _INIT_CODE_STATE = copy.deepcopy(_CODE_STATE)
 | |
|     return _CODE_STATE
 | |
| 
 | |
| 
 | |
| def get_local_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeState]]:
 | |
|     global _CODE_STATE
 | |
|     path = code_state_path(cache_key)
 | |
|     if path is not None and os.path.exists(path):
 | |
|         with dynamo_timed(
 | |
|             name := "pgo.get_local_code_state", log_pt2_compile_event=True
 | |
|         ):
 | |
|             CompileEventLogger.pt2_compile(name, cache_key=cache_key)
 | |
|             # Read lock not necessary as we always write atomically write to
 | |
|             # the actual location
 | |
|             with open(path, "rb") as f:
 | |
|                 try:
 | |
|                     content = f.read()
 | |
|                     _CODE_STATE = pickle.loads(content)
 | |
|                     CompileEventLogger.pt2_compile(name, cache_size_bytes=f.tell())
 | |
|                 except Exception:
 | |
|                     log.warning(
 | |
|                         "get_code_state failed while reading %s", path, exc_info=True
 | |
|                     )
 | |
|                 else:
 | |
|                     CacheArtifactManager.record_artifact(
 | |
|                         PGOCacheArtifact.type(), cache_key, content
 | |
|                     )
 | |
|                     return hit(path, "local")
 | |
|     return None
 | |
| 
 | |
| 
 | |
| def lookup_remote_cache_entry(
 | |
|     remote_cache: RemoteCache[JsonDataTy],
 | |
|     cache_key: str,
 | |
|     event_name: Optional[str] = None,
 | |
| ) -> Optional[defaultdict[CodeId, CodeState]]:
 | |
|     code_state = None
 | |
|     try:
 | |
|         cache_data = remote_cache.get(cache_key)
 | |
|     except Exception:
 | |
|         log.warning("get_code_state failed remote read on %s", cache_key, exc_info=True)
 | |
|     else:
 | |
|         if cache_data is not None:
 | |
|             try:
 | |
|                 assert isinstance(cache_data, dict)
 | |
|                 data = cache_data["data"]
 | |
|                 assert isinstance(data, str)
 | |
|                 payload = base64.b64decode(data)
 | |
|                 if event_name is not None:
 | |
|                     CompileEventLogger.pt2_compile(
 | |
|                         event_name, cache_size_bytes=len(payload)
 | |
|                     )
 | |
|                 code_state = pickle.loads(payload)
 | |
|             except Exception:
 | |
|                 log.warning(
 | |
|                     "get_code_state failed parsing remote result on %s",
 | |
|                     cache_key,
 | |
|                     exc_info=True,
 | |
|                 )
 | |
|             else:
 | |
|                 CacheArtifactManager.record_artifact(
 | |
|                     PGOCacheArtifact.type(), cache_key, payload
 | |
|                 )
 | |
|         else:
 | |
|             log.info("get_code_state remote miss on %s", cache_key)
 | |
|     return code_state
 | |
| 
 | |
| 
 | |
| def get_remote_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeState]]:
 | |
|     global _CODE_STATE
 | |
|     remote_cache = get_remote_cache()
 | |
|     if remote_cache is not None:
 | |
|         with dynamo_timed(
 | |
|             name := "pgo.get_remote_code_state",
 | |
|             log_pt2_compile_event=True,
 | |
|             dynamo_compile_column_us="pgo_get_remote_code_state_time_us",
 | |
|         ):
 | |
|             CompileEventLogger.pt2_compile(name, cache_key=cache_key)
 | |
|             code_state = lookup_remote_cache_entry(remote_cache, cache_key, name)
 | |
|             if code_state is not None:
 | |
|                 _CODE_STATE = code_state
 | |
|                 return hit(cache_key, "remote")
 | |
|     return None
 | |
| 
 | |
| 
 | |
| def get_extra_remote_code_state(cache_key: str) -> None:
 | |
|     """
 | |
|     Reads an additional PGO profile from the given cache key, and merges it with the default PGO profile.
 | |
|     """
 | |
|     global _CODE_STATE
 | |
|     assert _CODE_STATE is not None
 | |
| 
 | |
|     remote_cache = get_remote_cache()
 | |
|     if remote_cache is not None:
 | |
|         with dynamo_timed(
 | |
|             name := "pgo.get_extra_remote_code_state",
 | |
|             log_pt2_compile_event=True,
 | |
|             dynamo_compile_column_us="pgo_get_remote_code_state_time_us",
 | |
|         ):
 | |
|             CompileEventLogger.pt2_compile(name, cache_key=cache_key)
 | |
|             code_state = lookup_remote_cache_entry(remote_cache, cache_key)
 | |
|             log.info(
 | |
|                 "get_extra_code_state %s hit, %d entries",
 | |
|                 cache_key,
 | |
|                 len(code_state) if code_state is not None else 0,
 | |
|             )
 | |
|             if code_state is not None:
 | |
|                 assert not _CODE_STATE
 | |
|                 _CODE_STATE = code_state
 | |
|                 # log to tlparse
 | |
|                 trace_structured_artifact(
 | |
|                     "get_extra_remote_code_state",
 | |
|                     "string",
 | |
|                     lambda: render_code_state(code_state),
 | |
|                 )
 | |
| 
 | |
| 
 | |
| def get_code_state() -> defaultdict[CodeId, CodeState]:
 | |
|     global _CODE_STATE, _INIT_CODE_STATE
 | |
|     if _CODE_STATE is not None:
 | |
|         return _CODE_STATE
 | |
| 
 | |
|     # Initialize it (even if we don't look up profile)
 | |
|     _CODE_STATE = defaultdict(CodeState)
 | |
| 
 | |
|     cache_key = get_cache_key()
 | |
|     if cache_key is None:
 | |
|         return _CODE_STATE
 | |
| 
 | |
|     # Attempt local
 | |
|     local_code_state = get_local_code_state(cache_key)
 | |
| 
 | |
|     # Attempt remote
 | |
|     if local_code_state is None:
 | |
|         get_remote_code_state(cache_key)
 | |
| 
 | |
|     # Attempt additional remote if neither local/default remote succeeded
 | |
|     if (
 | |
|         not _CODE_STATE
 | |
|         and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
 | |
|     ):
 | |
|         extra_read_key = get_extra_cache_key(sticky_read)
 | |
|         if extra_read_key is not None:
 | |
|             get_extra_remote_code_state(extra_read_key)
 | |
| 
 | |
|     log.info("get_code_state using default")
 | |
| 
 | |
|     assert _CODE_STATE is not None
 | |
|     return _CODE_STATE
 | |
| 
 | |
| 
 | |
| def put_code_state() -> None:
 | |
|     if _CODE_STATE is None:
 | |
|         log.info("put_code_state: never initialized, will not write")
 | |
|         return
 | |
| 
 | |
|     if _CODE_STATE == _INIT_CODE_STATE:
 | |
|         log.info("put_code_state: no change, skipping")
 | |
|         return
 | |
| 
 | |
|     cache_key = get_cache_key()
 | |
|     if cache_key is None:
 | |
|         log.info("put_code_state: no cache key, skipping")
 | |
|         return
 | |
| 
 | |
|     put_local_code_state(cache_key)
 | |
|     put_remote_code_state(cache_key)
 | |
|     if (sticky_write := torch.compiler.config.pgo_extra_write_key) is not None:
 | |
|         extra_write_key = get_extra_cache_key(sticky_write)
 | |
|         if extra_write_key is not None:
 | |
|             put_remote_code_state(extra_write_key)
 | |
| 
 | |
| 
 | |
| def write_local_impl(cache_key: str, pickled_code: bytes) -> Optional[tuple[str, int]]:
 | |
|     path = code_state_path(cache_key)
 | |
| 
 | |
|     if path is None:
 | |
|         return None
 | |
| 
 | |
|     # If the user isn't misusing our API, we should have exclusive access to
 | |
|     # this directory.  But it's not too hard
 | |
| 
 | |
|     tmp_path = path + ".tmp"
 | |
|     lock_path = path + ".lock"
 | |
|     # We /mostly/ don't need the lock but the tmp file could be clobbered
 | |
|     # TODO: use a safe tempfile create to eliminate lock
 | |
|     from torch.utils._filelock import FileLock
 | |
| 
 | |
|     os.makedirs(os.path.dirname(path), exist_ok=True)
 | |
| 
 | |
|     with FileLock(lock_path, timeout=LOCK_TIMEOUT):
 | |
|         with open(tmp_path, "wb") as f:
 | |
|             f.write(pickled_code)
 | |
|             size = f.tell()
 | |
|         os.replace(tmp_path, path)
 | |
|     return path, size
 | |
| 
 | |
| 
 | |
| def put_local_code_state(cache_key: str) -> None:
 | |
|     with dynamo_timed(name := "pgo.put_local_code_state", log_pt2_compile_event=True):
 | |
|         CompileEventLogger.pt2_compile(name, cache_key=cache_key)
 | |
|         assert _CODE_STATE is not None
 | |
| 
 | |
|         pickled_code = pickle.dumps(_CODE_STATE)
 | |
| 
 | |
|         CacheArtifactManager.record_artifact(
 | |
|             PGOCacheArtifact.type(), cache_key, pickled_code
 | |
|         )
 | |
| 
 | |
|         meta = write_local_impl(cache_key, pickled_code)
 | |
|         if meta is None:
 | |
|             log.info("put_code_state: local cache disabled")
 | |
|             return
 | |
|         path, size = meta
 | |
| 
 | |
|         CompileEventLogger.pt2_compile(name, cache_size_bytes=size)
 | |
|         log.info("put_code_state: wrote local %s, %d entries", path, len(_CODE_STATE))
 | |
|         trace_structured_artifact(
 | |
|             "put_local_code_state",
 | |
|             "string",
 | |
|             lambda: render_code_state(_CODE_STATE),
 | |
|         )
 | |
| 
 | |
| 
 | |
| def put_remote_code_state(cache_key: str, extra_code_state: bool = False) -> None:
 | |
|     event_name = (
 | |
|         "put_remote_code_state"
 | |
|         if not extra_code_state
 | |
|         else "put_extra_remote_code_state"
 | |
|     )
 | |
|     with dynamo_timed(
 | |
|         name := f"pgo.{event_name}",
 | |
|         log_pt2_compile_event=True,
 | |
|         dynamo_compile_column_us="pgo_put_remote_code_state_time_us",
 | |
|     ):
 | |
|         CompileEventLogger.pt2_compile(name, cache_key=cache_key)
 | |
|         assert _CODE_STATE is not None
 | |
| 
 | |
|         remote_cache = get_remote_cache()
 | |
| 
 | |
|         if remote_cache is None:
 | |
|             log.info("%s: remote cache disabled", event_name)
 | |
|             return
 | |
| 
 | |
|         content = pickle.dumps(_CODE_STATE)
 | |
|         CompileEventLogger.pt2_compile(name, cache_size_bytes=len(content))
 | |
|         cache_data: JsonDataTy = {
 | |
|             "data": base64.b64encode(content).decode("ascii"),
 | |
|         }
 | |
|         remote_cache.put(cache_key, cache_data)
 | |
|         log.info(
 | |
|             "%s: wrote remote %s, %d entries", event_name, cache_key, len(_CODE_STATE)
 | |
|         )
 | |
|         # TODO: don't log this multiple times
 | |
|         trace_structured_artifact(
 | |
|             event_name,
 | |
|             "string",
 | |
|             lambda: render_code_state(_CODE_STATE),
 | |
|         )
 | |
| 
 | |
| 
 | |
| # NB: this does NOT reset the cached code state on disk
 | |
| def reset_code_state() -> None:
 | |
|     global _CODE_STATE, _INIT_CODE_STATE, _LOGGED_DYNAMIC_ALLOWLIST
 | |
|     _CODE_STATE = None
 | |
|     _INIT_CODE_STATE = None
 | |
|     _LOGGED_DYNAMIC_ALLOWLIST = False
 |