mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144553 Approved by: https://github.com/ezyang ghstack dependencies: #144551
530 lines
19 KiB
Python
530 lines
19 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
import inspect
|
|
import itertools
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
trace_shape_events_log = torch._logging.getArtifactLogger(
|
|
__name__, "trace_shape_events"
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"ShapeEnvEvent",
|
|
"record_shapeenv_event",
|
|
"replay_shape_env_events",
|
|
"FakeTensorMeta",
|
|
"shape_env_check_state_equal",
|
|
"NotEqualError",
|
|
]
|
|
|
|
# [Note: Recording ShapeEnv Events]
|
|
# =================================
|
|
#
|
|
# What is a ShapeEnv event?
|
|
# -------------------------
|
|
# We consider a ShapeEnv event every function call (ShapeEnv method or
|
|
# independent function) that modifies the state of the ShapeEnv instance.
|
|
# Such calls are recorded alongside their positional and keyword arguments,
|
|
# so that it may be replayed over a different ShapeEnv instance.
|
|
#
|
|
# See [Note: ShapeEnv State Equality] for what is considered the state
|
|
# of a ShapeEnv instance.
|
|
#
|
|
# What is it for?
|
|
# ---------------
|
|
# ShapeEnv events recording is used for reconstructing the ShapeEnv in an
|
|
# arbitrary state in time.
|
|
#
|
|
# Being able to arbitrarily replay events like so is useful, mainly for
|
|
# translation validation bisection. i.e. if a ValidationException has been
|
|
# raised, find the earliest point in time where the translation validation
|
|
# fails.
|
|
#
|
|
# Besides that, it also allows us to inspect the given instance and,
|
|
# for example, check the guards that would actually be issued at that point.
|
|
#
|
|
# What kind of arguments can be stored in an event?
|
|
# -------------------------------------------------
|
|
# There's no specific rule for what cannot be used as an argument.
|
|
# That said, pay special attention to the following cases:
|
|
#
|
|
# 1. Tensor inputs: there are some tests that check whether the inputs
|
|
# were garbage collected after execution. These will fail if there's
|
|
# an event that is holding a reference to those inputs.
|
|
#
|
|
# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that
|
|
# will be automatically replaced by the new given ShapeEnv instance.
|
|
#
|
|
# 3. SymTypes arguments: they also hold references to ShapeEnv. So,
|
|
# whenever we see them, we create a new instance, replacing the
|
|
# ShapeEnv reference.
|
|
#
|
|
# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic
|
|
# shapes. That argument must be replaced when replaying the event at
|
|
# ShapeEnvEvent.run, since it has to reference a node from the given
|
|
# instance, and not from the recorded instance.
|
|
|
|
|
|
# Event class for reconstructing ShapeEnv at arbitrary time.
|
|
#
|
|
# Represents a method call that mutates ShapeEnv in a way that affects the
|
|
# issued guards, when ShapeEnv.produce_guards is called.
|
|
@dataclass
|
|
class ShapeEnvEvent:
|
|
# ShapeEnv method.
|
|
f: Callable
|
|
|
|
# Arguments and keyword arguments called with.
|
|
args: Optional[list[Any]] = None
|
|
kwargs: Optional[dict[str, Any]] = None
|
|
|
|
# List of tracked_fakes at the time the method was called.
|
|
tracked_fakes: Optional[list[Any]] = None
|
|
|
|
# Name of the captured event.
|
|
# Used for special handling of particular methods.
|
|
name: Optional[str] = None
|
|
|
|
# Replay itself, but using shape_env as self.
|
|
def run(self, shape_env=None) -> Any:
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
is_symbolic,
|
|
ShapeEnv,
|
|
SymTypes,
|
|
)
|
|
|
|
# Special handling for the constructor event.
|
|
if self.f is ShapeEnv:
|
|
assert shape_env is None and self.args is None and self.kwargs is not None
|
|
return ShapeEnv(**self.kwargs)
|
|
|
|
assert shape_env is not None
|
|
args = list(self.args or [])
|
|
kwargs = dict(self.kwargs or {})
|
|
|
|
# Replace any argument of type ShapeEnv by the given one.
|
|
args, kwargs = pytree.tree_map_only(
|
|
ShapeEnv, lambda _: shape_env, (args, kwargs)
|
|
)
|
|
|
|
# Replace any argument of type SymTypes by a new instance,
|
|
# replacing its ShapeEnv reference.
|
|
args, kwargs = pytree.tree_map_only(
|
|
lambda x: isinstance(x, SymTypes) and is_symbolic(x),
|
|
lambda a: type(a)(a.node.with_shape_env(shape_env)),
|
|
(args, kwargs),
|
|
)
|
|
|
|
# Converts FX nodes using the mapping argument.
|
|
def maybe_convert_node(x: Any) -> Any:
|
|
if not isinstance(x, torch.fx.Node):
|
|
# Don't do anything to x if it's not an FX node.
|
|
return x
|
|
|
|
# If, at some point, we created an FX node, it means that translation validation is on.
|
|
# It also means we are building an FX graph for symbolic shapes at shape_env.graph, and
|
|
# we are tracking node names at shape_env.name_to_node.
|
|
assert hasattr(shape_env, "name_to_node")
|
|
name_to_node = shape_env.name_to_node # type: ignore[attr-defined]
|
|
assert x.name in name_to_node
|
|
return name_to_node[x.name]
|
|
|
|
# Replaces the value of an specific argument by the result of fn.
|
|
def replacearg(index: int, key: str, fn: Callable):
|
|
if index < len(args):
|
|
args[index] = fn(args[index])
|
|
if key in kwargs:
|
|
kwargs[key] = fn(kwargs[key])
|
|
|
|
if self.is_create_fx_call_function():
|
|
# ShapeEnv.create_fx_call_function:
|
|
# "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv.
|
|
# They must be replaced, since a "call_function" FX node with this tuple as argument
|
|
# will be added to the FX graph of the new shape_env.
|
|
replacearg(
|
|
index=2,
|
|
key="args",
|
|
fn=lambda args: tuple(maybe_convert_node(a) for a in args),
|
|
)
|
|
if self.is_evaluate_expr() or self.is_defer_runtime_assert():
|
|
# ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert:
|
|
# "fx_node" parameter is an (optional) FX node that represents the evaluate expression.
|
|
# They must be replaced, since it will be part of a "call_function" FX node for
|
|
# torch._assert, which will be added to the FX graph of the new shape_env.
|
|
replacearg(index=3, key="fx_node", fn=maybe_convert_node)
|
|
|
|
# Actually call the method with the converted arguments.
|
|
return self.f(*args, **kwargs)
|
|
|
|
def __str__(self) -> str:
|
|
name = self.name if self.name is not None else self.f.__name__
|
|
return f"event: {name} ({self.args}, {self.kwargs})"
|
|
|
|
def is_create_fx_call_function(self) -> bool:
|
|
return self.name == "_create_fx_call_function"
|
|
|
|
def is_evaluate_expr(self) -> bool:
|
|
return self.name == "evaluate_expr"
|
|
|
|
def is_defer_runtime_assert(self) -> bool:
|
|
return self.name == "defer_runtime_assert"
|
|
|
|
|
|
NEST = 0
|
|
|
|
|
|
# Extracts a ShapeEnv instance inside args and kwargs.
|
|
# Specifically, it looks for:
|
|
# 1. ShapeEnv arguments
|
|
# 2. SymInt, SymFloat, or SymBool arguments
|
|
# If we find more than one object of any of the above types, we
|
|
# also check that the ShapeEnv instance is the same for all of them.
|
|
def _extract_shape_env_and_assert_equal(args, kwargs):
|
|
from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes
|
|
|
|
def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
|
|
if old is not None:
|
|
assert old is new, "call with different ShapeEnv"
|
|
return new
|
|
|
|
shape_env = None
|
|
for val in itertools.chain(args, kwargs.values()):
|
|
if isinstance(val, ShapeEnv):
|
|
shape_env = assert_equal(shape_env, val)
|
|
if isinstance(val, SymTypes) and is_symbolic(val):
|
|
shape_env = assert_equal(shape_env, val.node.shape_env)
|
|
|
|
return shape_env
|
|
|
|
|
|
# Decorator for recording the given function as a replayable event.
|
|
#
|
|
# This decorator should be used at every function that mutates the state of
|
|
# ShapeEnv in some way that affects the resulting issued guards (i.e. when
|
|
# ShapeEnv.produce_guards is called).
|
|
#
|
|
# save_tracked_fakes: saves a snapshot of the TrackedFake list.
|
|
# This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
|
|
#
|
|
# name: the name of the function being recorded. Normally (and by default) this
|
|
# is taken from the decorated function but can be set if you need to override
|
|
# it.
|
|
#
|
|
# When to save the list of TrackedFake?
|
|
# =====================================
|
|
# We should save the list of TrackedFake whenever the translation validation
|
|
# bisection may actually stop and call the produce_guards method at the moment
|
|
# right after the recorded function was played. In other words, since the
|
|
# bisection bisects through torch._assert calls, we should save in all methods
|
|
# that adds a torch._assert call to the symbolic shapes FX graph.
|
|
#
|
|
# At the moment, there are 2 methods that save the list:
|
|
# - ShapeEnv.evaluate_expr
|
|
# - ShapeEnv.defer_runtime_assert
|
|
def record_shapeenv_event(
|
|
*, save_tracked_fakes: bool = False, name: Optional[str] = None
|
|
) -> Callable:
|
|
def decorator(fn: Callable) -> Callable:
|
|
assert callable(fn)
|
|
args = inspect.getfullargspec(fn).args
|
|
assert args and args[0] == "self", (
|
|
"record_shapeenv_event should only wrap methods on ShapeEnv; refactor your "
|
|
"code so that it calls into a method on ShapeEnv"
|
|
)
|
|
nonlocal name
|
|
if name is None:
|
|
name = fn.__name__
|
|
|
|
@functools.wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
assert isinstance(args[0], ShapeEnv)
|
|
|
|
global NEST
|
|
|
|
trace_shape_events_log.debug(
|
|
"%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs
|
|
)
|
|
NEST += 1
|
|
|
|
def retlog(r):
|
|
trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r)
|
|
return r
|
|
|
|
shape_env = args[0]
|
|
|
|
try:
|
|
if not shape_env.should_record_events or shape_env.is_recording: # type: ignore[has-type]
|
|
# If ShapeEnv is already recording an event, call the wrapped
|
|
# function directly.
|
|
#
|
|
# NB: here, we skip the check of whether all ShapeEnv instances
|
|
# are equal, in favor of a faster dispatch.
|
|
return retlog(fn(*args, **kwargs))
|
|
|
|
# Retrieve an instance of ShapeEnv.
|
|
# Assumption: the collection of args and kwargs may not reference
|
|
# different ShapeEnv instances.
|
|
self = _extract_shape_env_and_assert_equal(args, kwargs)
|
|
|
|
# If we are calling this function without any ShapeEnv instance
|
|
# alive in its arguments, we don't record and call the original.
|
|
if self is None:
|
|
return retlog(fn(*args, **kwargs))
|
|
|
|
# Otherwise, start recording and call the function.
|
|
with self._recording():
|
|
# Take a snapshot of the current tracked_fakes.
|
|
tracked_fakes = (
|
|
self._snapshot_tracked_fakes() if save_tracked_fakes else None
|
|
)
|
|
# Record the event for 'fn'.
|
|
event = ShapeEnvEvent(
|
|
fn,
|
|
list(args),
|
|
kwargs,
|
|
tracked_fakes,
|
|
name=name,
|
|
)
|
|
# Play the event on this ShapeEnv.
|
|
# NB: It's important to put the event first, because running
|
|
# the event can trigger internal events that must be ordered
|
|
# after this event. However, if an exception happens, we do
|
|
# NOT want to have the event in the list, so pop it off from
|
|
# the record if an error happened
|
|
self.events.append(event)
|
|
try:
|
|
return retlog(event.run(self))
|
|
except Exception:
|
|
self.events.pop()
|
|
raise
|
|
|
|
except Exception:
|
|
if not shape_env.should_record_events or shape_env.is_recording:
|
|
# If ShapeEnv is disabled or already recording an event, re-raise the exception without logging.
|
|
raise
|
|
log.error( # noqa: G201
|
|
"failed while running %s(*%s, **%s)",
|
|
name,
|
|
args[1:],
|
|
kwargs,
|
|
exc_info=log.isEnabledFor(logging.INFO),
|
|
)
|
|
raise
|
|
|
|
finally:
|
|
NEST -= 1
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
# Replays the ShapeEnvEvents list.
|
|
# It assumes the first event is the constructor call.
|
|
#
|
|
# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv.
|
|
def replay_shape_env_events(events):
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
constructor_event = events[0]
|
|
assert constructor_event.f == ShapeEnv
|
|
|
|
# Constructs the new ShapeEnv.
|
|
shape_env = constructor_event.run()
|
|
|
|
for event in events[1:]:
|
|
try:
|
|
# Actually replays each event.
|
|
# We need to call create_mapping_fn every time, since the node list might
|
|
# change after each event is replayed.
|
|
event.run(shape_env)
|
|
except Exception:
|
|
log.error("failed when running event: %s", event)
|
|
raise
|
|
|
|
return shape_env
|
|
|
|
|
|
# FakeTensor metadata.
|
|
# This is to be used in place of FakeTensor placeholders when calling
|
|
# ShapeEnv.produce_guards.
|
|
@dataclass
|
|
class FakeTensorMeta:
|
|
tensor_size: tuple[Union[int, torch.SymInt], ...]
|
|
tensor_stride: tuple[Union[int, torch.SymInt], ...]
|
|
tensor_storage_offset: Union[int, torch.SymInt]
|
|
is_nested: bool
|
|
|
|
def size(self) -> tuple[Union[int, torch.SymInt], ...]:
|
|
return self.tensor_size
|
|
|
|
def stride(self) -> tuple[Union[int, torch.SymInt], ...]:
|
|
return self.tensor_stride
|
|
|
|
def storage_offset(self) -> Union[int, torch.SymInt]:
|
|
return self.tensor_storage_offset
|
|
|
|
def dim(self) -> int:
|
|
return len(self.tensor_size)
|
|
|
|
@staticmethod
|
|
def from_fake(fake) -> "FakeTensorMeta":
|
|
return FakeTensorMeta(
|
|
fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested
|
|
)
|
|
|
|
|
|
# [Note: ShapeEnv State Equality]
|
|
# ===============================
|
|
#
|
|
# What is considered ShapeEnv state?
|
|
# ----------------------------------
|
|
# We consider to be the state of a ShapeEnv instance everything that
|
|
# is not in the inline tuple inside remove_nonstate_variables function.
|
|
# That is: the fields within ShapeEnv that modify the flow of execution
|
|
# of the program.
|
|
#
|
|
# So, for example: the replacements field might influence on how an
|
|
# expression is simplified. That, in turn, may result in a guard being
|
|
# statically known (i.e. not added).
|
|
#
|
|
# On the other hand, var_to_stack serves only changes what is printed
|
|
# in the screen, i.e. used only for debugging purposes. Therefore, we
|
|
# should not consider it when comparing states.
|
|
#
|
|
# What to do on NotEqualError?
|
|
# ----------------------------
|
|
# Here are a few possible causes for getting a NotEqualError raised:
|
|
#
|
|
# 1. New field that does not belong in the ShapeEnv state.
|
|
# For example: log field of type ShapeEnvLoggerAdapter. Different
|
|
# ShapeEnv instances will always have different ShapeEnvLoggerAdapter
|
|
# instances, i.e. equality comparison would fail.
|
|
# Solution: add it to the inlined tuple inside remove_nonstate_variables
|
|
# function inside check_equal method.
|
|
#
|
|
# 2. New field that is not directly comparable across instances.
|
|
# For example: guards field of type List[ShapeGuard]. More specifically,
|
|
# the ShapeGuard type holds an expression and a stack information
|
|
# for debugging purposes. When replaying the even on a new ShapeEnv
|
|
# instance, the stack would be different, which would trigger this error.
|
|
# Solution: add a special case to the map_value function inside
|
|
# check_equal function.
|
|
#
|
|
# 3. Mutation of ShapeEnv on some not recorded function.
|
|
# If a mutation of the state of ShapeEnv happens inside a function
|
|
# that is not recorded (or that no caller in the stack is recorded),
|
|
# then, the replayed ShapeEnv won't catch that.
|
|
# Solution: decorate the function with record_shape_env_event.
|
|
|
|
|
|
# Checks whether the state of two ShapeEnv are equal w.r.t. the guards
|
|
# returned by ShapeEnv.produce_guards.
|
|
def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value):
|
|
# Collect and remove variables that don't necessarily represent the state
|
|
# of a ShapeEnv. Note: we copy the dictionary so that we don't modify the
|
|
# instance itself.
|
|
env1_vars = vars(env1).copy()
|
|
env2_vars = vars(env2).copy()
|
|
|
|
for v in non_state_variable_names:
|
|
if v in env1_vars:
|
|
env1_vars.pop(v)
|
|
if v in env2_vars:
|
|
env2_vars.pop(v)
|
|
|
|
# Function for transforming the mismatched values into string.
|
|
# Needed, since dict and set entries order might not be the same every time.
|
|
def value_to_str(value: Any) -> str:
|
|
if isinstance(value, dict):
|
|
return (
|
|
"{"
|
|
+ ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str))
|
|
+ "}"
|
|
)
|
|
if isinstance(value, set):
|
|
return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}"
|
|
return str(value)
|
|
|
|
# Compares env1_vars with env2_vars.
|
|
# Here, we allow the value of each field to be mapped, so that we appropriately
|
|
# compare the two values.
|
|
def compare_vars(
|
|
map_value: Callable[[str, Any], Any],
|
|
) -> list[tuple[str, str, str]]:
|
|
env1_set, env2_set = set(env1_vars), set(env2_vars)
|
|
|
|
# First, compare the set of keys in each vars dictionary.
|
|
if env1_set != env2_set:
|
|
raise NotEqualError(
|
|
"field set mismatch:",
|
|
[
|
|
(
|
|
"found unique fields:",
|
|
str(sorted(env1_set - env2_set)),
|
|
str(sorted(env2_set - env1_set)),
|
|
),
|
|
],
|
|
)
|
|
|
|
# Then, sort the keys, and compare the mapped values of each key.
|
|
sorted_keys = list(env1_set)
|
|
sorted_keys.sort()
|
|
|
|
mapped_dict = [
|
|
(k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k]))
|
|
for k in sorted_keys
|
|
]
|
|
|
|
# Return a list of tuples representing the fields that did not match
|
|
# alongside their respective mapped values.
|
|
return [
|
|
(f"{k}: values don't match.", value_to_str(val1), value_to_str(val2))
|
|
for k, val1, val2 in mapped_dict
|
|
if val1 != val2
|
|
]
|
|
|
|
# Accumulate the mismatching fields.
|
|
errors = compare_vars(map_value)
|
|
|
|
if len(errors) > 0:
|
|
raise NotEqualError("field values don't match:", errors)
|
|
|
|
|
|
class NotEqualError(Exception):
|
|
def __init__(
|
|
self,
|
|
msg: str,
|
|
mismatched: list[tuple[str, str, str]],
|
|
) -> None:
|
|
details = "\n".join(
|
|
[
|
|
"\n".join(
|
|
[
|
|
f"==> {inner_msg}",
|
|
f" > Left: {str1}",
|
|
f" > Right: {str2}",
|
|
]
|
|
)
|
|
for inner_msg, str1, str2 in mismatched
|
|
]
|
|
)
|
|
|
|
super().__init__(
|
|
f"""\
|
|
ShapeEnv not equal: {msg}
|
|
|
|
{details}
|
|
"""
|
|
)
|