Files
pytorch/torch/fx/experimental/recording.py
2025-01-20 18:11:54 +00:00

514 lines
19 KiB
Python

# mypy: allow-untyped-defs
import functools
import inspect
import itertools
import logging
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import torch
import torch.utils._pytree as pytree
log = logging.getLogger(__name__)
trace_shape_events_log = torch._logging.getArtifactLogger(
__name__, "trace_shape_events"
)
__all__ = [
"ShapeEnvEvent",
"record_shapeenv_event",
"replay_shape_env_events",
"FakeTensorMeta",
"shape_env_check_state_equal",
"NotEqualError",
]
# [Note: Recording ShapeEnv Events]
# =================================
#
# What is a ShapeEnv event?
# -------------------------
# We consider a ShapeEnv event every function call (ShapeEnv method or
# independent function) that modifies the state of the ShapeEnv instance.
# Such calls are recorded alongside their positional and keyword arguments,
# so that it may be replayed over a different ShapeEnv instance.
#
# See [Note: ShapeEnv State Equality] for what is considered the state
# of a ShapeEnv instance.
#
# What is it for?
# ---------------
# ShapeEnv events recording is used for reconstructing the ShapeEnv in an
# arbitrary state in time.
#
# Being able to arbitrarily replay events like so is useful, mainly for
# translation validation bisection. i.e. if a ValidationException has been
# raised, find the earliest point in time where the translation validation
# fails.
#
# Besides that, it also allows us to inspect the given instance and,
# for example, check the guards that would actually be issued at that point.
#
# What kind of arguments can be stored in an event?
# -------------------------------------------------
# There's no specific rule for what cannot be used as an argument.
# That said, pay special attention to the following cases:
#
# 1. Tensor inputs: there are some tests that check whether the inputs
# were garbage collected after execution. These will fail if there's
# an event that is holding a reference to those inputs.
#
# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that
# will be automatically replaced by the new given ShapeEnv instance.
#
# 3. SymTypes arguments: they also hold references to ShapeEnv. So,
# whenever we see them, we create a new instance, replacing the
# ShapeEnv reference.
#
# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic
# shapes. That argument must be replaced when replaying the event at
# ShapeEnvEvent.run, since it has to reference a node from the given
# instance, and not from the recorded instance.
# Event class for reconstructing ShapeEnv at arbitrary time.
#
# Represents a method call that mutates ShapeEnv in a way that affects the
# issued guards, when ShapeEnv.produce_guards is called.
@dataclass
class ShapeEnvEvent:
# ShapeEnv method.
f: Callable
# Arguments and keyword arguments called with.
args: Optional[list[Any]] = None
kwargs: Optional[dict[str, Any]] = None
# List of tracked_fakes at the time the method was called.
tracked_fakes: Optional[list[Any]] = None
# Name of the captured event.
# Used for special handling of particular methods.
name: Optional[str] = None
# Replay itself, but using shape_env as self.
def run(self, shape_env=None) -> Any:
from torch.fx.experimental.symbolic_shapes import (
is_symbolic,
ShapeEnv,
SymTypes,
)
# Special handling for the constructor event.
if self.f is ShapeEnv:
assert shape_env is None and self.args is None and self.kwargs is not None
return ShapeEnv(**self.kwargs)
assert shape_env is not None
args = list(self.args or [])
kwargs = dict(self.kwargs or {})
# Replace any argument of type ShapeEnv by the given one.
args, kwargs = pytree.tree_map_only(
ShapeEnv, lambda _: shape_env, (args, kwargs)
)
# Replace any argument of type SymTypes by a new instance,
# replacing its ShapeEnv reference.
args, kwargs = pytree.tree_map_only(
lambda x: isinstance(x, SymTypes) and is_symbolic(x),
lambda a: type(a)(a.node.with_shape_env(shape_env)),
(args, kwargs),
)
# Converts FX nodes using the mapping argument.
def maybe_convert_node(x: Any) -> Any:
if not isinstance(x, torch.fx.Node):
# Don't do anything to x if it's not an FX node.
return x
# If, at some point, we created an FX node, it means that translation validation is on.
# It also means we are building an FX graph for symbolic shapes at shape_env.graph, and
# we are tracking node names at shape_env.name_to_node.
assert hasattr(shape_env, "name_to_node")
name_to_node = shape_env.name_to_node # type: ignore[attr-defined]
assert x.name in name_to_node
return name_to_node[x.name]
# Replaces the value of an specific argument by the result of fn.
def replacearg(index: int, key: str, fn: Callable):
if index < len(args):
args[index] = fn(args[index])
if key in kwargs:
kwargs[key] = fn(kwargs[key])
if self.is_create_fx_call_function():
# ShapeEnv.create_fx_call_function:
# "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv.
# They must be replaced, since a "call_function" FX node with this tuple as argument
# will be added to the FX graph of the new shape_env.
replacearg(
index=2,
key="args",
fn=lambda args: tuple(maybe_convert_node(a) for a in args),
)
if self.is_evaluate_expr() or self.is_defer_runtime_assert():
# ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert:
# "fx_node" parameter is an (optional) FX node that represents the evaluate expression.
# They must be replaced, since it will be part of a "call_function" FX node for
# torch._assert, which will be added to the FX graph of the new shape_env.
replacearg(index=3, key="fx_node", fn=maybe_convert_node)
# Actually call the method with the converted arguments.
return self.f(*args, **kwargs)
def __str__(self) -> str:
name = self.name if self.name is not None else self.f.__name__
return f"event: {name} ({self.args}, {self.kwargs})"
def is_create_fx_call_function(self) -> bool:
return self.name == "_create_fx_call_function"
def is_evaluate_expr(self) -> bool:
return self.name == "evaluate_expr"
def is_defer_runtime_assert(self) -> bool:
return self.name == "defer_runtime_assert"
NEST = 0
# Extracts a ShapeEnv instance inside args and kwargs.
# Specifically, it looks for:
# 1. ShapeEnv arguments
# 2. SymInt, SymFloat, or SymBool arguments
# If we find more than one object of any of the above types, we
# also check that the ShapeEnv instance is the same for all of them.
def _extract_shape_env_and_assert_equal(args, kwargs):
from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes
def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
if old is not None:
assert old is new, "call with different ShapeEnv"
return new
shape_env = None
for val in itertools.chain(args, kwargs.values()):
if isinstance(val, ShapeEnv):
shape_env = assert_equal(shape_env, val)
if isinstance(val, SymTypes) and is_symbolic(val):
shape_env = assert_equal(shape_env, val.node.shape_env)
return shape_env
# Decorator for recording the given function as a replayable event.
#
# This decorator should be used at every function that mutates the state of
# ShapeEnv in some way that affects the resulting issued guards (i.e. when
# ShapeEnv.produce_guards is called).
#
# save_tracked_fakes: saves a snapshot of the TrackedFake list.
# This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
#
# When to save the list of TrackedFake?
# =====================================
# We should save the list of TrackedFake whenever the translation validation
# bisection may actually stop and call the produce_guards method at the moment
# right after the recorded function was played. In other words, since the
# bisection bisects through torch._assert calls, we should save in all methods
# that adds a torch._assert call to the symbolic shapes FX graph.
#
# At the moment, there are 2 methods that save the list:
# - ShapeEnv.evaluate_expr
# - ShapeEnv.defer_runtime_assert
def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
def decorator(fn: Callable) -> Callable:
assert callable(fn)
args = inspect.getfullargspec(fn).args
assert args and args[0] == "self", (
"record_shapeenv_event should only wrap methods on ShapeEnv; refactor your "
"code so that it calls into a method on ShapeEnv"
)
name = fn.__name__
@functools.wraps(fn)
def wrapper(*args, **kwargs):
from torch.fx.experimental.symbolic_shapes import ShapeEnv
assert isinstance(args[0], ShapeEnv)
global NEST
trace_shape_events_log.debug(
"%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs
)
NEST += 1
def retlog(r):
trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r)
return r
try:
shape_env = args[0]
if not shape_env.should_record_events or shape_env.is_recording: # type: ignore[has-type]
# If ShapeEnv is already recording an event, call the wrapped
# function directly.
#
# NB: here, we skip the check of whether all ShapeEnv instances
# are equal, in favor of a faster dispatch.
return retlog(fn(*args, **kwargs))
# Retrieve an instance of ShapeEnv.
# Assumption: the collection of args and kwargs may not reference
# different ShapeEnv instances.
self = _extract_shape_env_and_assert_equal(args, kwargs)
# If we are calling this function without any ShapeEnv instance
# alive in its arguments, we don't record and call the original.
if self is None:
return retlog(fn(*args, **kwargs))
# Otherwise, start recording and call the function.
with self._recording():
# Take a snapshot of the current tracked_fakes.
tracked_fakes = (
self._snapshot_tracked_fakes() if save_tracked_fakes else None
)
# Record the event for 'fn'.
event = ShapeEnvEvent(
fn, list(args), kwargs, tracked_fakes, name=fn.__name__
)
# Play the event on this ShapeEnv.
# NB: It's important to put the event first, because running
# the event can trigger internal events that must be ordered
# after this event. However, if an exception happens, we do
# NOT want to have the event in the list, so pop it off from
# the record if an error happened
self.events.append(event)
try:
return retlog(event.run(self))
except Exception:
self.events.pop()
raise
except Exception:
log.error( # noqa: G201
"failed while running %s(*%s, **%s)",
name,
args[1:],
kwargs,
exc_info=log.isEnabledFor(logging.INFO),
)
raise
finally:
NEST -= 1
return wrapper
return decorator
# Replays the ShapeEnvEvents list.
# It assumes the first event is the constructor call.
#
# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv.
def replay_shape_env_events(events):
from torch.fx.experimental.symbolic_shapes import ShapeEnv
constructor_event = events[0]
assert constructor_event.f == ShapeEnv
# Constructs the new ShapeEnv.
shape_env = constructor_event.run()
for event in events[1:]:
try:
# Actually replays each event.
# We need to call create_mapping_fn every time, since the node list might
# change after each event is replayed.
event.run(shape_env)
except Exception:
log.error("failed when running event: %s", event)
raise
return shape_env
# FakeTensor metadata.
# This is to be used in place of FakeTensor placeholders when calling
# ShapeEnv.produce_guards.
@dataclass
class FakeTensorMeta:
tensor_size: tuple[Union[int, torch.SymInt], ...]
tensor_stride: tuple[Union[int, torch.SymInt], ...]
tensor_storage_offset: Union[int, torch.SymInt]
is_nested: bool
def size(self) -> tuple[Union[int, torch.SymInt], ...]:
return self.tensor_size
def stride(self) -> tuple[Union[int, torch.SymInt], ...]:
return self.tensor_stride
def storage_offset(self) -> Union[int, torch.SymInt]:
return self.tensor_storage_offset
def dim(self) -> int:
return len(self.tensor_size)
@staticmethod
def from_fake(fake) -> "FakeTensorMeta":
return FakeTensorMeta(
fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested
)
# [Note: ShapeEnv State Equality]
# ===============================
#
# What is considered ShapeEnv state?
# ----------------------------------
# We consider to be the state of a ShapeEnv instance everything that
# is not in the inline tuple inside remove_nonstate_variables function.
# That is: the fields within ShapeEnv that modify the flow of execution
# of the program.
#
# So, for example: the replacements field might influence on how an
# expression is simplified. That, in turn, may result in a guard being
# statically known (i.e. not added).
#
# On the other hand, var_to_stack serves only changes what is printed
# in the screen, i.e. used only for debugging purposes. Therefore, we
# should not consider it when comparing states.
#
# What to do on NotEqualError?
# ----------------------------
# Here are a few possible causes for getting a NotEqualError raised:
#
# 1. New field that does not belong in the ShapeEnv state.
# For example: log field of type ShapeEnvLoggerAdapter. Different
# ShapeEnv instances will always have different ShapeEnvLoggerAdapter
# instances, i.e. equality comparison would fail.
# Solution: add it to the inlined tuple inside remove_nonstate_variables
# function inside check_equal method.
#
# 2. New field that is not directly comparable across instances.
# For example: guards field of type List[ShapeGuard]. More specifically,
# the ShapeGuard type holds an expression and a stack information
# for debugging purposes. When replaying the even on a new ShapeEnv
# instance, the stack would be different, which would trigger this error.
# Solution: add a special case to the map_value function inside
# check_equal function.
#
# 3. Mutation of ShapeEnv on some not recorded function.
# If a mutation of the state of ShapeEnv happens inside a function
# that is not recorded (or that no caller in the stack is recorded),
# then, the replayed ShapeEnv won't catch that.
# Solution: decorate the function with record_shape_env_event.
# Checks whether the state of two ShapeEnv are equal w.r.t. the guards
# returned by ShapeEnv.produce_guards.
def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value):
# Collect and remove variables that don't necessarily represent the state
# of a ShapeEnv. Note: we copy the dictionary so that we don't modify the
# instance itself.
env1_vars = vars(env1).copy()
env2_vars = vars(env2).copy()
for v in non_state_variable_names:
if v in env1_vars:
env1_vars.pop(v)
if v in env2_vars:
env2_vars.pop(v)
# Function for transforming the mismatched values into string.
# Needed, since dict and set entries order might not be the same every time.
def value_to_str(value: Any) -> str:
if isinstance(value, dict):
return (
"{"
+ ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str))
+ "}"
)
if isinstance(value, set):
return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}"
return str(value)
# Compares env1_vars with env2_vars.
# Here, we allow the value of each field to be mapped, so that we appropriately
# compare the two values.
def compare_vars(
map_value: Callable[[str, Any], Any]
) -> list[tuple[str, str, str]]:
env1_set, env2_set = set(env1_vars), set(env2_vars)
# First, compare the set of keys in each vars dictionary.
if env1_set != env2_set:
raise NotEqualError(
"field set mismatch:",
[
(
"found unique fields:",
str(sorted(env1_set - env2_set)),
str(sorted(env2_set - env1_set)),
),
],
)
# Then, sort the keys, and compare the mapped values of each key.
sorted_keys = list(env1_set)
sorted_keys.sort()
mapped_dict = [
(k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k]))
for k in sorted_keys
]
# Return a list of tuples representing the fields that did not match
# alongside their respective mapped values.
return [
(f"{k}: values don't match.", value_to_str(val1), value_to_str(val2))
for k, val1, val2 in mapped_dict
if val1 != val2
]
# Accumulate the mismatching fields.
errors = compare_vars(map_value)
if len(errors) > 0:
raise NotEqualError("field values don't match:", errors)
class NotEqualError(Exception):
def __init__(
self,
msg: str,
mismatched: list[tuple[str, str, str]],
) -> None:
details = "\n".join(
[
"\n".join(
[
f"==> {inner_msg}",
f" > Left: {str1}",
f" > Right: {str2}",
]
)
for inner_msg, str1, str2 in mismatched
]
)
super().__init__(
f"""\
ShapeEnv not equal: {msg}
{details}
"""
)