mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This word appears often in class descriptions and is not consistently spelled. Update comments and some function names to use the correct spelling consistently. Facilitates searching the codebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155944 Approved by: https://github.com/Skylion007
764 lines
27 KiB
Python
764 lines
27 KiB
Python
# mypy: ignore-errors
|
|
|
|
"""TorchDynamo support for __torch_function__ tensor subclasses.
|
|
|
|
This module implements support for tensor subclasses with __torch_function__ overrides.
|
|
A tensor subclass instance is represented as a TensorWithTFOverrideVariable, which handles
|
|
dispatching __torch_function__ on attribute accesses, method calls, and torch API calls.
|
|
|
|
Unsupported features:
|
|
- Triggering __torch_function__ on tensor subclass non-tensor custom attributes
|
|
- Graph breaking on mutating guardable tensor properties within a __torch_function__ context
|
|
(can cause excessive recompiles in certain cases)
|
|
- Matching exact eager behavior of ignoring __torch_function__ objects in non-tensor
|
|
argument positions of Torch API calls
|
|
|
|
Supported features:
|
|
- Static method implementations of __torch_function__ on custom objects (triggers on torch
|
|
API calls with the object as any argument)
|
|
- Triggering __torch_function__ on torch API calls with tensor subclass arguments
|
|
- __torch_function__ calls on base tensor attribute access and method calls for tensor
|
|
subclass instances
|
|
- Matches dispatch ordering behavior of eager __torch_function__ with subclass/object
|
|
arguments in any position
|
|
|
|
See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w
|
|
for more information on the design.
|
|
"""
|
|
|
|
import collections
|
|
import contextlib
|
|
import functools
|
|
import inspect
|
|
import operator
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch._C
|
|
import torch.utils._pytree as pytree
|
|
from torch._guards import Source
|
|
from torch.overrides import (
|
|
_get_overloaded_args,
|
|
BaseTorchFunctionMode,
|
|
get_default_nowrap_functions,
|
|
TorchFunctionMode,
|
|
)
|
|
from torch.utils._device import DeviceContext
|
|
|
|
from .. import graph_break_hints
|
|
from ..exc import unimplemented_v2
|
|
from ..guards import GuardBuilder, install_guard
|
|
from ..polyfills import NoEnterTorchFunctionMode
|
|
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
|
|
from ..utils import (
|
|
class_has_getattribute,
|
|
clear_torch_function_mode_stack,
|
|
get_safe_global_name,
|
|
has_torch_function,
|
|
is_tensor_base_attr_getter,
|
|
set_torch_function_mode_stack,
|
|
)
|
|
from .base import VariableTracker
|
|
from .constant import ConstantVariable
|
|
from .ctx_manager import GenericContextWrappingVariable
|
|
from .functions import UserMethodVariable
|
|
from .lazy import LazyVariableTracker
|
|
from .lists import TupleVariable
|
|
from .tensor import TensorSubclassVariable, TensorVariable
|
|
from .user_defined import UserDefinedObjectVariable
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._dynamo.codegen import PyCodegen
|
|
from torch._dynamo.symbolic_convert import InstructionTranslator
|
|
|
|
|
|
bin_ops = [
|
|
operator.pow,
|
|
operator.mul,
|
|
operator.matmul,
|
|
operator.floordiv,
|
|
operator.truediv,
|
|
operator.mod,
|
|
operator.add,
|
|
operator.lt,
|
|
operator.gt,
|
|
operator.ge,
|
|
operator.le,
|
|
operator.ne,
|
|
operator.eq,
|
|
operator.sub,
|
|
operator.ipow,
|
|
operator.imul,
|
|
operator.imatmul,
|
|
operator.ifloordiv,
|
|
operator.itruediv,
|
|
operator.imod,
|
|
operator.iadd,
|
|
operator.isub,
|
|
]
|
|
|
|
bin_int_ops = [
|
|
operator.and_,
|
|
operator.or_,
|
|
operator.xor,
|
|
operator.iand,
|
|
operator.ixor,
|
|
operator.ior,
|
|
]
|
|
|
|
un_int_ops = [operator.invert]
|
|
|
|
tensor_and_int_ops = [
|
|
operator.lshift,
|
|
operator.rshift,
|
|
operator.ilshift,
|
|
operator.irshift,
|
|
operator.getitem,
|
|
]
|
|
|
|
un_ops = [
|
|
operator.abs,
|
|
operator.pos,
|
|
operator.neg,
|
|
operator.not_, # Note: this has a local scalar dense call
|
|
operator.length_hint,
|
|
]
|
|
|
|
BUILTIN_TO_TENSOR_FN_MAP = {}
|
|
|
|
# These functions represent the r* versions of the above ops
|
|
# Basically, if __add__(1, Tensor) is called, it is translated
|
|
# to __radd__(Tensor, 1).
|
|
# In the builtin var, we check if there is a tensor in the first args position,
|
|
# if not, we swap the args and use the r* version of the op.
|
|
BUILTIN_TO_TENSOR_RFN_MAP = {}
|
|
|
|
|
|
def populate_builtin_to_tensor_fn_map():
|
|
global BUILTIN_TO_TENSOR_FN_MAP
|
|
|
|
most_recent_func = None
|
|
|
|
class GetMethodMode(BaseTorchFunctionMode):
|
|
"""
|
|
Mode to extract the correct methods from torch function invocations
|
|
(Used to get the correct torch.Tensor methods from builtins)
|
|
"""
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
nonlocal most_recent_func
|
|
most_recent_func = func
|
|
return func(*args, **kwargs)
|
|
|
|
inp0 = torch.ones(1)
|
|
inp1 = torch.ones(1)
|
|
inp0_int = torch.ones(1, dtype=torch.int32)
|
|
inp1_int = torch.ones(1, dtype=torch.int32)
|
|
with GetMethodMode():
|
|
setups_and_oplists = [
|
|
(lambda o: o(inp0), un_ops),
|
|
(lambda o: o(inp0_int), un_int_ops),
|
|
(lambda o: o(inp0, inp1), bin_ops),
|
|
(lambda o: o(inp0_int, inp1_int), bin_int_ops),
|
|
(lambda o: o(inp0_int, 0), tensor_and_int_ops),
|
|
]
|
|
for setup_fn, op_list in setups_and_oplists:
|
|
for op in op_list:
|
|
setup_fn(op)
|
|
assert most_recent_func is not None
|
|
BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func
|
|
|
|
# gather the reverse functions
|
|
rsetups_and_oplists = [
|
|
(
|
|
lambda o: o(1, inp1),
|
|
bin_ops,
|
|
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
|
|
(lambda o: o(1, inp1_int), bin_int_ops),
|
|
(lambda o: o(0, inp0_int), tensor_and_int_ops),
|
|
]
|
|
|
|
rskips = {operator.matmul, operator.imatmul, operator.getitem}
|
|
for setup_fn, op_list in rsetups_and_oplists:
|
|
for op in op_list:
|
|
if op in rskips:
|
|
continue
|
|
setup_fn(op)
|
|
assert most_recent_func is not None
|
|
if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]:
|
|
BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func
|
|
|
|
|
|
populate_builtin_to_tensor_fn_map()
|
|
|
|
banned_attrs = [
|
|
fn.__self__.__name__
|
|
for fn in get_default_nowrap_functions()
|
|
if is_tensor_base_attr_getter(fn)
|
|
]
|
|
|
|
|
|
@functools.cache
|
|
def get_prev_stack_var_name():
|
|
from ..bytecode_transformation import unique_id
|
|
|
|
return unique_id("___prev_torch_function_mode_stack")
|
|
|
|
|
|
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
|
|
class TorchFunctionModeStackStateManager:
|
|
def __init__(self):
|
|
self.stack = []
|
|
|
|
def __enter__(self):
|
|
self.stack = torch.overrides._get_current_function_mode_stack()
|
|
clear_torch_function_mode_stack()
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
set_torch_function_mode_stack(self.stack)
|
|
self.stack = []
|
|
|
|
@contextlib.contextmanager
|
|
def temp_restore_stack(self):
|
|
prev = torch.overrides._get_current_function_mode_stack()
|
|
set_torch_function_mode_stack(self.stack)
|
|
try:
|
|
yield
|
|
finally:
|
|
set_torch_function_mode_stack(prev)
|
|
|
|
|
|
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
|
|
|
|
|
|
class SymbolicTorchFunctionState:
|
|
def __init__(self, py_stack):
|
|
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
|
|
# There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
|
|
# These are their definitions:
|
|
# 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered
|
|
# (if either are entered, this will be False)
|
|
# 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR
|
|
# torch._C.DisableTorchFunction has been entered
|
|
# To disambiguate these and keep myself sane I added a C API to check whether all torch function
|
|
# concepts (modes and subclasses) are enabled.
|
|
# This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate
|
|
# the stack length from the enablement state of torch function modes.
|
|
# This is important because now if a mode is pushed while dynamo is tracing, we know whether
|
|
# or not torch function modes are enabled and whether we should trace it.
|
|
self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled()
|
|
|
|
# This differs from the C API of the same name
|
|
# this will only be false iff we have entered torch._C.DisableTorchFunction
|
|
# and does not take into account the mode stack length, while the C API bundles these
|
|
# two concepts
|
|
self.torch_function_mode_enabled = (
|
|
not torch._C._is_torch_function_all_disabled()
|
|
)
|
|
|
|
self.cur_mode = None
|
|
|
|
TorchFunctionModeStackVariable.reset()
|
|
|
|
self.mode_stack: collections.deque[TorchFunctionModeVariable] = (
|
|
collections.deque()
|
|
)
|
|
|
|
for i, val in enumerate(py_stack):
|
|
self.mode_stack.append(
|
|
LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i))
|
|
)
|
|
|
|
def in_torch_function_mode(self):
|
|
return len(self.mode_stack) > 0
|
|
|
|
def pop_torch_function_mode(self):
|
|
return self.mode_stack.pop()
|
|
|
|
def push_torch_function_mode(self, mode_var):
|
|
self.mode_stack.append(mode_var)
|
|
|
|
def call_torch_function_mode(self, tx, fn, types, args, kwargs):
|
|
with self._pop_mode_for_inlining() as cur_mode:
|
|
return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
|
|
|
|
@contextlib.contextmanager
|
|
def _pop_mode_for_inlining(self):
|
|
old_mode = self.cur_mode
|
|
self.cur_mode = self.pop_torch_function_mode()
|
|
try:
|
|
yield self.cur_mode
|
|
finally:
|
|
mode = self.cur_mode
|
|
self.cur_mode = old_mode
|
|
self.push_torch_function_mode(mode)
|
|
|
|
|
|
class TorchFunctionModeStackVariable(VariableTracker):
|
|
"""Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation"""
|
|
|
|
# singleton value representing the global torch function mode stack
|
|
# singleton (it exists in C++)
|
|
stack_value_singleton = object()
|
|
|
|
# offset is used to track if we have inserted/removed a
|
|
# device context which is always placed at the bottom of the stack
|
|
# if a device context is inserted, the graph will run this mutation
|
|
# so when we want to reconstruct any other modes on the stack
|
|
# their indices should be shifted right by 1 (+1)
|
|
# Conversely, if there was a device context on the stack, and the graph
|
|
# mutates the stack to remove that context (set default device to None)
|
|
# each of the indices of other modes should be shifted left by 1 (-1)
|
|
offset = 0
|
|
|
|
def __init__(self, source, symbolic_stack):
|
|
self.source = source
|
|
self.symbolic_stack = symbolic_stack
|
|
|
|
@classmethod
|
|
def reset(cls):
|
|
cls.offset = 0
|
|
|
|
@classmethod
|
|
def register_mutation(cls, tx: "InstructionTranslator"):
|
|
if cls.stack_value_singleton not in tx.output.side_effects:
|
|
var = cls(
|
|
source=Source(),
|
|
symbolic_stack=tx.symbolic_torch_function_state.mode_stack,
|
|
)
|
|
tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
|
|
tx.output.side_effects.mutation(var)
|
|
|
|
@classmethod
|
|
def register_device_context_insertion(cls, tx: "InstructionTranslator"):
|
|
stack = tx.symbolic_torch_function_state.mode_stack
|
|
if stack and cls.is_device_context(stack[0]):
|
|
return
|
|
else:
|
|
cls.offset += 1
|
|
stack.insert(
|
|
0,
|
|
TorchFunctionModeVariable(
|
|
None, source=TorchFunctionModeStackSource(-cls.offset)
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def clear_default_device(cls, tx: "InstructionTranslator"):
|
|
stack = tx.symbolic_torch_function_state.mode_stack
|
|
if stack and cls.is_device_context(stack[0]):
|
|
stack.popleft()
|
|
cls.offset -= 1
|
|
|
|
@staticmethod
|
|
def is_device_context(var):
|
|
return isinstance(var.value, DeviceContext) or var.value is None
|
|
|
|
@classmethod
|
|
def get_mode_index(cls, ind):
|
|
return ind + cls.offset
|
|
|
|
|
|
class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
|
@staticmethod
|
|
def is_supported_torch_function_mode(ty):
|
|
# Supported in this sense means we can support graph breaks under the
|
|
# context.
|
|
# We are able to trace custom modes but if there are graph breaks under them
|
|
# and they have a custom __enter__/__exit__ we don't handle this for the
|
|
# same reason we don't handle generic context managers: there may be side effects
|
|
# that are now affected by executing the funtion across two frames instead of one
|
|
# Today we support the enter/exit of the default TorchFunctionMode as well as
|
|
# DeviceContext (which is used for set_default_device)
|
|
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
|
|
not class_has_getattribute(ty)
|
|
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
|
|
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
|
|
)
|
|
|
|
def __init__(self, value, source=None, **kwargs):
|
|
if value is not None:
|
|
super().__init__(value, **kwargs)
|
|
self.value = value
|
|
self.cm_obj = value # needed for BC with calling enter from CM code
|
|
self.source = source
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"):
|
|
# This shouldn't be called unless we have a source
|
|
assert self.source
|
|
self.source.reconstruct(codegen)
|
|
|
|
def module_name(self):
|
|
return self.value.__module__
|
|
|
|
def fn_name(self):
|
|
return type(self.value).__name__
|
|
|
|
def python_type(self):
|
|
return type(self.value)
|
|
|
|
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
|
|
return call_torch_function(
|
|
tx,
|
|
get_torch_function_fn(tx, self),
|
|
fn,
|
|
types,
|
|
args,
|
|
kwargs,
|
|
)
|
|
|
|
def enter(self, tx):
|
|
from .torch import TorchInGraphFunctionVariable
|
|
|
|
if isinstance(self.value, NoEnterTorchFunctionMode):
|
|
return ConstantVariable.create(None)
|
|
|
|
TorchInGraphFunctionVariable(
|
|
torch._C._push_on_torch_function_stack
|
|
).call_function(tx, [self], {})
|
|
return ConstantVariable.create(None)
|
|
|
|
def exit(self, tx: "InstructionTranslator", *args):
|
|
from .torch import TorchInGraphFunctionVariable
|
|
|
|
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
|
|
tx, [], {}
|
|
)
|
|
return ConstantVariable.create(None)
|
|
|
|
def reconstruct_type(self, codegen: "PyCodegen"):
|
|
ty = NoEnterTorchFunctionMode
|
|
codegen(
|
|
AttrSource(
|
|
codegen.tx.import_source(ty.__module__),
|
|
ty.__name__,
|
|
)
|
|
)
|
|
|
|
def supports_graph_breaks(self):
|
|
return True
|
|
|
|
def exit_on_graph_break(self):
|
|
return False
|
|
|
|
|
|
def _get_all_args(args, kwargs):
|
|
return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs))
|
|
|
|
|
|
def _flatten_vts(vts):
|
|
from collections import deque
|
|
|
|
from .dicts import ConstDictVariable
|
|
from .lists import ListVariable
|
|
|
|
vts = deque(vts)
|
|
output = []
|
|
|
|
while vts:
|
|
vt = vts.pop()
|
|
|
|
if not vt.is_realized() and vt.peek_type() in (dict, list, tuple):
|
|
vt.realize()
|
|
|
|
if vt.is_realized():
|
|
if isinstance(vt, ListVariable):
|
|
vts.extend(vt.items)
|
|
elif isinstance(vt, ConstDictVariable):
|
|
vts.extend(vt.items.values())
|
|
|
|
output.append(vt)
|
|
|
|
return output
|
|
|
|
|
|
def _get_subclass_type(var):
|
|
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
|
|
return var.python_type()
|
|
|
|
|
|
def _get_subclass_type_var(tx: "InstructionTranslator", var):
|
|
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
|
|
if isinstance(var, TensorWithTFOverrideVariable):
|
|
return var.class_type_var(tx)
|
|
elif isinstance(var, UserDefinedObjectVariable):
|
|
source = var.source and TypeSource(var.source)
|
|
return VariableTracker.build(tx, var.python_type(), source)
|
|
|
|
|
|
def _is_attr_overridden(tx: "InstructionTranslator", var, name):
|
|
import torch
|
|
|
|
overridden = False
|
|
try:
|
|
attr_val = inspect.getattr_static(var.python_type(), name)
|
|
overridden |= attr_val != getattr(torch.Tensor, name)
|
|
except AttributeError:
|
|
pass
|
|
|
|
return overridden
|
|
|
|
|
|
def call_torch_function(tx, torch_function_var, fn, types, args, kwargs):
|
|
# This emulates calling __torch_function__, which has a signature
|
|
# def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
#
|
|
# Also notice the `cls` is not explicitly passed in the reference
|
|
# implementations:
|
|
# 1. https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/csrc/utils/python_arg_parser.cpp#L368-L374 # noqa: B950
|
|
# 2. https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/overrides.py#L1741-L1743
|
|
tf_args = [
|
|
fn,
|
|
types,
|
|
VariableTracker.build(tx, tuple(args)),
|
|
VariableTracker.build(tx, kwargs),
|
|
]
|
|
return torch_function_var.call_function(tx, tf_args, {})
|
|
|
|
|
|
def get_torch_function_fn(tx: "InstructionTranslator", vt):
|
|
# The underlying function could be a classmethod, staticmethod, regular
|
|
# function or a function with C-implementation. It doesn't matter as long as
|
|
# they satisfy the calling convention in `call_torch_function`.
|
|
from .builtin import BuiltinVariable
|
|
|
|
args = [vt, ConstantVariable("__torch_function__")]
|
|
func_vt = BuiltinVariable(getattr).call_function(tx, args, {})
|
|
return func_vt
|
|
|
|
|
|
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
|
|
has_overridden_args = any(
|
|
has_torch_function(arg) for arg in _get_all_args(args, kwargs)
|
|
)
|
|
tf_state = tx.symbolic_torch_function_state
|
|
return (has_overridden_args and tf_state.torch_function_subclass_enabled) or (
|
|
tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode()
|
|
)
|
|
|
|
|
|
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
|
|
"""Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args"""
|
|
|
|
all_args = _get_all_args(args, kwargs)
|
|
overloaded_args = _get_overloaded_args(
|
|
[arg for arg in all_args if has_torch_function(arg)],
|
|
_get_subclass_type,
|
|
)
|
|
|
|
types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args])
|
|
|
|
if tx.symbolic_torch_function_state.in_torch_function_mode():
|
|
res = tx.symbolic_torch_function_state.call_torch_function_mode(
|
|
tx, fn, types, args, kwargs
|
|
)
|
|
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
|
|
return res
|
|
|
|
for arg in overloaded_args:
|
|
res = arg.call_torch_function(
|
|
tx,
|
|
fn,
|
|
types,
|
|
args,
|
|
kwargs,
|
|
)
|
|
|
|
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
|
|
return res
|
|
|
|
unimplemented_v2(
|
|
gb_type="All __torch_function__ overrides returned NotImplemented due to TypeError from user code",
|
|
context=f"{fn=}, {args=}, {kwargs=}",
|
|
explanation=f"All __torch_function__ overrides for for function {fn} returned NotImplemented",
|
|
hints=[
|
|
*graph_break_hints.USER_ERROR,
|
|
],
|
|
)
|
|
|
|
|
|
class TensorWithTFOverrideVariable(TensorVariable):
|
|
"""
|
|
Represents a tensor subclass instance with a __torch_function__ override.
|
|
"""
|
|
|
|
@classmethod
|
|
def from_tensor_var(cls, tx, tensor_var, class_type, cls_source):
|
|
# [Note: __torch_function__] coerce `tensor_var` into a
|
|
# TensorWithTFOverrideVariable. In eager, this is just a type change.
|
|
import torch
|
|
|
|
# This simulates shallow-copying the tensor object.
|
|
kwargs = dict(tensor_var.__dict__)
|
|
input_tensor_type = kwargs.pop("class_type")
|
|
assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), (
|
|
f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var"
|
|
)
|
|
var = cls(class_type=class_type, **kwargs)
|
|
var.install_global(tx)
|
|
return var
|
|
|
|
def install_global(self, tx):
|
|
# stash the subclass type to rewrap an output tensor if needed
|
|
# this is needed because the actual type needs to be available
|
|
# each time the compiled artifact is run and outputs a wrapped tensor.
|
|
if self.global_mangled_class_name(tx) not in tx.output.global_scope:
|
|
# Safe because global_mangled_class_name figures it out
|
|
tx.output.install_global_unsafe(
|
|
self.global_mangled_class_name(tx), self.class_type
|
|
)
|
|
|
|
def python_type(self):
|
|
return self.class_type
|
|
|
|
def class_type_var(self, tx):
|
|
return TensorSubclassVariable(
|
|
self.class_type, source=GlobalSource(self.global_mangled_class_name(tx))
|
|
)
|
|
|
|
def global_mangled_class_name(self, tx):
|
|
return get_safe_global_name(
|
|
tx, f"__subclass_{self.class_type.__name__}", self.class_type
|
|
)
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
|
# [Note: __torch_function__] We currently only support attributes that are defined on
|
|
# base tensors, custom attribute accesses will graph break.
|
|
import torch
|
|
|
|
# I think only `_base` is breaking because we aren't modelling view
|
|
# relationship perfectly in some scenarios.
|
|
if name in banned_attrs:
|
|
unimplemented_v2(
|
|
gb_type="Unsupported tensor subclass attribute access",
|
|
context=f"{name}",
|
|
explanation="`torch.compile` currently can't trace this",
|
|
hints=[
|
|
f"Avoid accessing {name} of tensor subclass in torch.compile region",
|
|
*graph_break_hints.SUPPORTABLE,
|
|
],
|
|
)
|
|
|
|
# Handle non-overridden attributes inherited from `torch.Tensor`.
|
|
attr_is_overridden = _is_attr_overridden(tx, self, name)
|
|
if (
|
|
hasattr(torch.Tensor, name)
|
|
and not attr_is_overridden
|
|
and not inspect.ismethoddescriptor(getattr(torch.Tensor, name))
|
|
):
|
|
args, kwargs = [self], {}
|
|
if can_dispatch_torch_function(tx, args, kwargs):
|
|
if self.source:
|
|
install_guard(
|
|
AttrSource(
|
|
AttrSource(self.source, "__class__"), name
|
|
).make_guard(GuardBuilder.FUNCTION_MATCH)
|
|
)
|
|
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
|
|
|
|
return self.call_torch_function(
|
|
tx,
|
|
get_fn,
|
|
TupleVariable([self.class_type_var(tx)]),
|
|
args,
|
|
kwargs,
|
|
)
|
|
else:
|
|
# `TensorVariable.var_getattr` doesn't handle user-defined
|
|
# function/attribute well, so we explicitly handle them here.
|
|
#
|
|
# TODO move this logic into `TensorVariable`, or try to merge it
|
|
# with similar logic in `UserDefinedObjectVariable`.
|
|
try:
|
|
attr = inspect.getattr_static(self.class_type, name)
|
|
except AttributeError:
|
|
pass
|
|
else:
|
|
import types
|
|
|
|
cls_source = GlobalSource(self.global_mangled_class_name(tx))
|
|
attr_source = AttrSource(cls_source, name)
|
|
if isinstance(attr, types.FunctionType):
|
|
install_guard(attr_source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
|
return UserMethodVariable(attr, self)
|
|
|
|
elif isinstance(attr, property):
|
|
getter_source = AttrSource(attr_source, "fget")
|
|
getter = attr.fget
|
|
getter_var = UserMethodVariable(getter, self, source=getter_source)
|
|
return getter_var.call_function(tx, [], {})
|
|
|
|
elif isinstance(attr, classmethod):
|
|
return UserMethodVariable(
|
|
attr.__func__, self.class_type_var(tx), source=attr_source
|
|
)
|
|
|
|
elif attr_is_overridden:
|
|
unimplemented_v2(
|
|
gb_type="Unsupported tensor subclass overridden attribute access",
|
|
context=f"{name}",
|
|
explanation="`torch.compile` only support tracing certain types of overridden tensor subclass attributes",
|
|
hints=[
|
|
f"Avoid accessing {name} of tensor subclass in torch.compile region",
|
|
f"Renaming attribute `{name}` of type {self.class_type}",
|
|
*graph_break_hints.SUPPORTABLE,
|
|
],
|
|
)
|
|
|
|
return super().var_getattr(tx, name)
|
|
|
|
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
|
|
# NOTE this assumes `__torch_function__` isn't modified during tracing.
|
|
if not hasattr(self, "torch_function_fn"):
|
|
self.torch_function_fn = get_torch_function_fn(tx, self)
|
|
|
|
return call_torch_function(
|
|
tx,
|
|
self.torch_function_fn,
|
|
fn,
|
|
types,
|
|
args,
|
|
kwargs,
|
|
)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "list[VariableTracker]",
|
|
kwargs: "dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
# This code block implements inlining the __torch_function__ override
|
|
# of `call_method`.
|
|
tf_args = [self] + args
|
|
if can_dispatch_torch_function(tx, tf_args, kwargs):
|
|
import torch
|
|
|
|
if _is_attr_overridden(tx, self, name):
|
|
unimplemented_v2(
|
|
gb_type="Tensor subclass overridden method call",
|
|
context=f"{name}",
|
|
explanation="`torch.compile` currently can't trace this",
|
|
hints=[
|
|
f"Avoid calling {name} of tensor subclass in torch.compile region",
|
|
f"Renaming method `{name}` of type {self.class_type}",
|
|
*graph_break_hints.SUPPORTABLE,
|
|
],
|
|
)
|
|
|
|
# [Note: __torch_function__] Currently we only support methods that are defined on tensor
|
|
# we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality
|
|
# We've established with the above check that the method is not overridden, so we guard that the method is the same
|
|
# as the impl defined on tensor and retrieve it
|
|
if self.source:
|
|
source = AttrSource(AttrSource(self.source, "__class__"), name)
|
|
value = inspect.getattr_static(self.python_type(), name)
|
|
else:
|
|
source = None
|
|
value = getattr(torch.Tensor, name)
|
|
func_var = VariableTracker.build(tx, value, source)
|
|
return dispatch_torch_function(tx, func_var, tf_args, kwargs)
|
|
else:
|
|
return super().call_method(tx, name, args, kwargs)
|