mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This work rewrites vmap support in torch.compile by inlining most of the frames into the existing FX graph. It also unlocks to PyTorch to support features that were previously missing, such as keyword args. Fixes: https://github.com/pytorch/pytorch/issues/114306 Pull Request resolved: https://github.com/pytorch/pytorch/pull/116050 Approved by: https://github.com/zou3519
668 lines
27 KiB
Python
668 lines
27 KiB
Python
import inspect
|
|
import logging
|
|
|
|
import math
|
|
import re
|
|
from typing import Dict, List
|
|
|
|
from torch._streambase import _StreamBase
|
|
from ..guards import install_guard
|
|
|
|
try:
|
|
import numpy as np
|
|
except ModuleNotFoundError:
|
|
np = None
|
|
|
|
import torch._C
|
|
import torch._refs
|
|
import torch.fx
|
|
import torch.nn
|
|
import torch.onnx.operators
|
|
from torch._logging import warning_once
|
|
|
|
from .. import config, polyfill, variables
|
|
from ..device_interface import get_registered_device_interfaces
|
|
from ..exc import unimplemented
|
|
from ..guards import GuardBuilder
|
|
from ..utils import (
|
|
check_constant_args,
|
|
check_unspec_python_args,
|
|
guard_if_dyn,
|
|
has_torch_function,
|
|
product,
|
|
proxy_args_kwargs,
|
|
)
|
|
from .base import VariableTracker
|
|
from .ctx_manager import (
|
|
AutocastModeVariable,
|
|
NullContextVariable,
|
|
TorchFunctionDisableVariable,
|
|
)
|
|
from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable
|
|
from .higher_order_ops import TorchHigherOrderOperatorVariable
|
|
from .lists import ListVariable, TupleVariable
|
|
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
REWRITE_OPS_TO_TENSOR_SIZE_METHOD = [
|
|
torch.onnx.operators.shape_as_tensor,
|
|
torch._shape_as_tensor,
|
|
]
|
|
|
|
constant_fold_functions = [
|
|
torch._assert,
|
|
torch._utils._get_device_index,
|
|
torch.cuda.is_available,
|
|
torch.distributed.is_available,
|
|
torch.get_autocast_gpu_dtype,
|
|
torch.get_default_dtype,
|
|
torch.is_autocast_cache_enabled,
|
|
torch.is_autocast_cpu_enabled,
|
|
torch.is_autocast_enabled,
|
|
torch.is_complex,
|
|
torch.is_floating_point,
|
|
torch.nn.functional._Reduction.get_enum,
|
|
torch.promote_types,
|
|
torch._C._get_privateuse1_backend_name,
|
|
]
|
|
|
|
|
|
if torch.distributed.is_available():
|
|
constant_fold_functions.extend(
|
|
[
|
|
torch.distributed.is_initialized,
|
|
torch.distributed.get_rank,
|
|
torch.distributed.get_world_size,
|
|
]
|
|
)
|
|
|
|
|
|
tracing_state_functions = {
|
|
torch.jit.is_scripting: False,
|
|
torch.jit.is_tracing: False,
|
|
torch._C._get_tracing_state: None,
|
|
torch.fx._symbolic_trace.is_fx_tracing: False,
|
|
torch.onnx.is_in_onnx_export: False,
|
|
torch._dynamo.external_utils.is_compiling: True,
|
|
torch._utils.is_compiling: True,
|
|
}
|
|
|
|
|
|
class BaseTorchVariable(VariableTracker):
|
|
"""common base for all torch.* functions, classes, modules and other things"""
|
|
|
|
@classmethod
|
|
def create_with_source(cls, value, source):
|
|
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
|
return cls(
|
|
value,
|
|
source=source,
|
|
)
|
|
|
|
def __init__(self, value, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.value = value
|
|
|
|
def reconstruct(self, codegen):
|
|
try:
|
|
name = f"{self.value.__module__}.{self.value.__name__}"
|
|
except Exception:
|
|
name = f"torch_obj_{id(self.value)}"
|
|
unique_var_name = "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name)
|
|
return codegen.setup_globally_cached(unique_var_name, self.value, False)
|
|
|
|
def as_proxy(self):
|
|
return self.value
|
|
|
|
def python_type(self):
|
|
return type(self.value)
|
|
|
|
def as_python_constant(self):
|
|
return self.value
|
|
|
|
def call_hasattr(self, tx, name):
|
|
result = hasattr(self.value, name)
|
|
return variables.ConstantVariable.create(result)
|
|
|
|
def can_constant_fold_through(self):
|
|
if self.value in constant_fold_functions:
|
|
return True
|
|
return getattr(self.value, "__module__", None) == "math"
|
|
|
|
|
|
class TorchCtxManagerClassVariable(BaseTorchVariable):
|
|
"""Points to a context manager class in torch.* that dynamo has implementations"""
|
|
|
|
def __repr__(self):
|
|
return f"TorchCtxManagerClassVariable({self.value})"
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
from . import (
|
|
GradModeVariable,
|
|
InferenceModeVariable,
|
|
StreamVariable,
|
|
VmapIncrementNestingCtxManagerVariable,
|
|
)
|
|
|
|
if self.value is torch.no_grad:
|
|
if len(args) == 1 and isinstance(
|
|
args[0], variables.functions.BaseUserFunctionVariable
|
|
):
|
|
ctx = GradModeVariable.create(tx, False)
|
|
return ctx.call_function(tx, args, kwargs)
|
|
else:
|
|
return GradModeVariable.create(tx, False)
|
|
elif self.value is torch.enable_grad:
|
|
if len(args) == 1 and isinstance(
|
|
args[0], variables.functions.BaseUserFunctionVariable
|
|
):
|
|
ctx = GradModeVariable.create(tx, True)
|
|
return ctx.call_function(tx, args, kwargs)
|
|
return GradModeVariable.create(tx, True)
|
|
elif self.value is torch.set_grad_enabled and len(args) == 1:
|
|
return GradModeVariable.create(
|
|
tx, args[0].as_python_constant(), initialized=True
|
|
)
|
|
elif self.value is torch.inference_mode:
|
|
return InferenceModeVariable.create(tx, args[0].as_python_constant())
|
|
elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase):
|
|
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
|
|
|
|
return wrap_fx_proxy_cls(
|
|
StreamVariable,
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
self.value,
|
|
(),
|
|
{},
|
|
),
|
|
)
|
|
elif self.value in (
|
|
torch.amp.autocast_mode.autocast,
|
|
torch.cuda.amp.autocast,
|
|
torch.cpu.amp.autocast,
|
|
):
|
|
return AutocastModeVariable.create(self.value, args, kwargs)
|
|
elif self.value in (
|
|
torch.profiler.profile,
|
|
torch.profiler.record_function,
|
|
torch.autograd.profiler.profile,
|
|
torch.autograd.profiler.record_function,
|
|
):
|
|
warning_once(log, "Profiler function %s will be ignored", self.value)
|
|
return NullContextVariable()
|
|
elif self.value is torch._C.DisableTorchFunctionSubclass:
|
|
assert not (args or kwargs)
|
|
return TorchFunctionDisableVariable.create(tx)
|
|
elif self.value is torch._functorch.vmap.vmap_increment_nesting:
|
|
assert len(args) == 2
|
|
return VmapIncrementNestingCtxManagerVariable.create(
|
|
tx,
|
|
[guard_if_dyn(x) for x in args],
|
|
)
|
|
|
|
|
|
class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|
"""Points to a torch function/method that should be put in FX graph"""
|
|
|
|
def __repr__(self):
|
|
return f"TorchInGraphFunctionVariable({self.value})"
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
from . import (
|
|
ConstantVariable,
|
|
DeterministicAlgorithmsVariable,
|
|
DisabledSavedTensorsHooksVariable,
|
|
GradModeVariable,
|
|
SDPAParamsVariable,
|
|
StreamContextVariable,
|
|
SymNodeVariable,
|
|
TensorVariable,
|
|
UserDefinedObjectVariable,
|
|
)
|
|
|
|
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
|
|
|
|
constant_args = check_constant_args(args, kwargs)
|
|
unspec_python_args = check_unspec_python_args(args, kwargs)
|
|
|
|
if self.can_constant_fold_through() and (constant_args or unspec_python_args):
|
|
# constant fold
|
|
return ConstantVariable.create(
|
|
self.as_python_constant()(
|
|
*[x.as_python_constant() for x in args],
|
|
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
|
),
|
|
)
|
|
elif self.value in tracing_state_functions:
|
|
assert not args and not kwargs
|
|
# See: https://github.com/pytorch/pytorch/issues/110765
|
|
if self.value in (
|
|
torch._utils.is_compiling,
|
|
torch._dynamo.external_utils.is_compiling,
|
|
):
|
|
tx.mark_inconsistent_side_effects()
|
|
return ConstantVariable.create(tracing_state_functions[self.value])
|
|
elif self.value in (torch._functorch.eager_transforms.grad_impl,):
|
|
return TorchHigherOrderOperatorVariable.make(
|
|
self.value,
|
|
source=self.source,
|
|
).call_function(tx, args, kwargs)
|
|
elif self.value is torch.overrides.get_default_nowrap_functions:
|
|
# [Note: __torch_function__] we return empty here because we restrict
|
|
# the set of functions that we trace __torch_function__ on to
|
|
# functions outside of the actual set. Implementing this properly will require implementing
|
|
# some variable types to track and compare tensor getset descriptors
|
|
from .builder import SourcelessBuilder
|
|
|
|
return SourcelessBuilder()(
|
|
tx, torch.overrides.get_default_nowrap_functions()
|
|
)
|
|
elif self.value == math.radians and not (constant_args or unspec_python_args):
|
|
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0
|
|
from .builder import SourcelessBuilder
|
|
|
|
return tx.inline_user_function_return(
|
|
SourcelessBuilder()(tx, polyfill.radians), args, kwargs
|
|
)
|
|
elif self.value in (torch.is_tensor, torch.overrides.is_tensor_like):
|
|
assert len(args) == 1
|
|
if isinstance(args[0], TensorVariable) or (
|
|
self.value is torch.overrides.is_tensor_like
|
|
and isinstance(args[0], UserDefinedObjectVariable)
|
|
and hasattr(args[0].value, "__torch_function__")
|
|
):
|
|
return ConstantVariable.create(True)
|
|
else:
|
|
return ConstantVariable.create(False)
|
|
elif self.value in (
|
|
torch.is_floating_point,
|
|
torch.is_complex,
|
|
):
|
|
input_arg = None
|
|
if args:
|
|
input_arg = args[0]
|
|
else:
|
|
assert "input" in kwargs
|
|
input_arg = kwargs["input"]
|
|
if isinstance(input_arg, TensorVariable) and input_arg.dtype is not None:
|
|
if self.value is torch.is_floating_point:
|
|
return ConstantVariable.create(input_arg.dtype.is_floating_point)
|
|
elif self.value is torch.is_complex:
|
|
return ConstantVariable.create(input_arg.dtype.is_complex)
|
|
else:
|
|
raise AssertionError(f"calling {self.value}")
|
|
elif (
|
|
self.value is torch.numel
|
|
and isinstance(args[0], TensorVariable)
|
|
and args[0].size is not None
|
|
):
|
|
return ConstantVariable.create(product(args[0].size))
|
|
elif self.value in REWRITE_OPS_TO_TENSOR_SIZE_METHOD:
|
|
assert len(args) == 1
|
|
assert isinstance(args[0], TensorVariable)
|
|
return args[0].call_method(tx, "size", [], {})
|
|
elif self.value in (
|
|
torch.nn.modules.utils._single,
|
|
torch.nn.modules.utils._pair,
|
|
torch.nn.modules.utils._triple,
|
|
torch.nn.modules.utils._quadruple,
|
|
torch.nn.modules.utils._ntuple,
|
|
):
|
|
return self._call_ntuple(tx, args, kwargs)
|
|
elif self.value is torch.is_grad_enabled:
|
|
assert not (args or kwargs)
|
|
install_guard(GradModeVariable._guards_singleton)
|
|
return ConstantVariable.create(torch.is_grad_enabled())
|
|
elif self.value is torch.use_deterministic_algorithms and len(args) == 1:
|
|
return DeterministicAlgorithmsVariable.create(
|
|
tx, args[0].as_python_constant()
|
|
)
|
|
elif self.value is torch.are_deterministic_algorithms_enabled:
|
|
assert not (args or kwargs)
|
|
install_guard(DeterministicAlgorithmsVariable._guards_singleton)
|
|
return ConstantVariable.create(torch.are_deterministic_algorithms_enabled())
|
|
elif self.value is torch.autograd.graph.disable_saved_tensors_hooks:
|
|
assert len(args) == 1
|
|
return DisabledSavedTensorsHooksVariable.create(
|
|
tx, args[0].as_python_constant()
|
|
)
|
|
elif self.value is torch._C._is_torch_function_enabled:
|
|
assert not (args or kwargs)
|
|
install_guard(TorchFunctionDisableVariable._guards_singleton)
|
|
return ConstantVariable.create(tx.output.torch_function_enabled)
|
|
elif self.value in (
|
|
torch.overrides.has_torch_function,
|
|
torch.overrides.has_torch_function_variadic,
|
|
torch.overrides.has_torch_function_unary,
|
|
):
|
|
assert not kwargs
|
|
return ConstantVariable.create(
|
|
any(has_torch_function(a) for a in args),
|
|
)
|
|
elif any(
|
|
self.value is method
|
|
for method in [
|
|
device_interface.stream
|
|
for _, device_interface in get_registered_device_interfaces()
|
|
]
|
|
):
|
|
assert len(args) == 1
|
|
return StreamContextVariable.create(tx, args[0])
|
|
elif self.value is torch.from_numpy:
|
|
if not config.trace_numpy:
|
|
unimplemented("torch.from_numpy. config.trace_numpy is False")
|
|
if not np:
|
|
unimplemented("torch.from_numpy. NumPy is not available")
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=TensorVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_function",
|
|
torch.as_tensor,
|
|
*proxy_args_kwargs(args, {}),
|
|
),
|
|
example_value=None,
|
|
)
|
|
elif can_dispatch_torch_function(tx, args, kwargs):
|
|
return dispatch_torch_function(tx, self, args, kwargs)
|
|
elif self.value is torch.jit.annotate:
|
|
assert len(args) == 2
|
|
return args[1]
|
|
elif self.value is torch.backends.cudnn.is_acceptable:
|
|
# is_acceptable(tensor) returns true if
|
|
# (a) tensor dtype/device are supported by cudnn
|
|
# (b) cudnn is available
|
|
# (c) some initialization has completed
|
|
# technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version)
|
|
assert (
|
|
len(args) == 1 or "tensor" in kwargs
|
|
), "Expect 1 input to cudnn.is_acceptable"
|
|
tensor_variable = args[0] if len(args) > 0 else kwargs["tensor"]
|
|
assert isinstance(
|
|
tensor_variable, TensorVariable
|
|
), "Expect input to cudnn.is_acceptable to be a tensor"
|
|
tensor_inp = torch.tensor(
|
|
0, dtype=tensor_variable.dtype, device=tensor_variable.device
|
|
)
|
|
return ConstantVariable.create(
|
|
torch.backends.cudnn.is_acceptable(tensor_inp)
|
|
)
|
|
elif (
|
|
self.value == torch.numel
|
|
and len(args) == 1
|
|
and isinstance(args[0], TensorVariable)
|
|
and len(kwargs) == 0
|
|
):
|
|
# TODO(voz): This is rewritten as a call_method because
|
|
# torch.numel(x) w/ sym shapes raises a RuntimeError and x.numel() does not
|
|
return wrap_fx_proxy(
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method",
|
|
"numel",
|
|
*proxy_args_kwargs(args, kwargs),
|
|
),
|
|
)
|
|
# TODO: These special cases shouldn't be necessary; we should
|
|
# generically support torch.ops that return int
|
|
elif (
|
|
self.value in (torch.ops.aten.sym_size, torch.ops.aten.sym_size.int)
|
|
and len(args) == 2
|
|
and len(kwargs) == 0
|
|
and isinstance(args[0], TensorVariable)
|
|
):
|
|
# we see this when retracing already traced code
|
|
return args[0].call_method(tx, "size", [args[1]], {})
|
|
elif (
|
|
self.value in (torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int)
|
|
and len(args) == 2
|
|
and len(kwargs) == 0
|
|
and isinstance(args[0], TensorVariable)
|
|
):
|
|
return args[0].call_method(tx, "stride", [args[1]], {})
|
|
elif (
|
|
self.value == torch.addcdiv
|
|
and len(args) == 3
|
|
and "value" in kwargs
|
|
and len(kwargs) == 1
|
|
):
|
|
# decompose addcdiv into constituent ops, prevents a graph break due to converting
|
|
# value to a scalar
|
|
result = TorchInGraphFunctionVariable(torch.div).call_function(
|
|
tx, args[1:], {}
|
|
)
|
|
result = TorchInGraphFunctionVariable(torch.mul).call_function(
|
|
tx, [result, kwargs["value"]], {}
|
|
)
|
|
return TorchInGraphFunctionVariable(torch.add).call_function(
|
|
tx, [args[0], result], {}
|
|
)
|
|
elif (
|
|
self.value is torch._assert
|
|
and len(args) >= 1
|
|
and (
|
|
(args[0].is_python_constant() and args[0].as_python_constant())
|
|
or (
|
|
isinstance(args[0], variables.SymNodeVariable)
|
|
and args[0].evaluate_expr()
|
|
)
|
|
)
|
|
):
|
|
return ConstantVariable(None)
|
|
elif SDPAParamsVariable.is_sdpa_params(self.value):
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_function",
|
|
torch._C._SDPAParams,
|
|
*proxy_args_kwargs(args, kwargs),
|
|
),
|
|
param_vars=args,
|
|
)
|
|
elif is_constant_pg_functions(self.value):
|
|
# becuase the input is a "ProcessGroupVariable", we'll be guarding on its
|
|
# ID_MATCH based on how it was constructed.
|
|
|
|
# We desugar it at trace-time into ranks by directly calling util
|
|
# bake the result into the trace
|
|
assert len(args) == 1, "Expected one arg (pg)"
|
|
assert isinstance(args[0], ProcessGroupVariable)
|
|
|
|
invocation_result = self.value(args[0].as_python_constant())
|
|
# Note - while we *could* cook up sources around invocations, like a FunctionSource
|
|
# the space of invoking functions in the middle of the guard chain is very iffy. As such,
|
|
# guard propagation via options is the best we can do.
|
|
from .builder import SourcelessBuilder
|
|
|
|
return SourcelessBuilder()(tx, invocation_result)
|
|
elif is_from_local(self.value):
|
|
# rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
|
|
# and rewrite args to have only proxyable args, then insert call_function
|
|
args_as_value = [x.as_python_constant() for x in args[1:]]
|
|
kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}
|
|
|
|
def fn_with_prim_types(x):
|
|
return self.value(x, *args_as_value, **kwargs_as_value)
|
|
|
|
# attach the same function name for better debugging
|
|
fn_with_prim_types.__name__ = "prim " + self.value.__name__
|
|
|
|
return wrap_fx_proxy(
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_function",
|
|
fn_with_prim_types,
|
|
*proxy_args_kwargs([args[0]], {}),
|
|
),
|
|
)
|
|
elif (
|
|
self.value is torch.nested.nested_tensor
|
|
and kwargs.get("layout", torch.strided) == torch.strided
|
|
):
|
|
raise unimplemented("torch.compile does not support strided NestedTensor")
|
|
elif self.value is torch.nn.functional.one_hot and (
|
|
len(args) + len(kwargs) == 1
|
|
or (
|
|
len(args) == 2
|
|
and args[1].is_python_constant()
|
|
and args[1].as_python_constant() == -1
|
|
)
|
|
):
|
|
raise unimplemented(
|
|
"torch.nn.functional.one_hot with data-dependent output shape"
|
|
)
|
|
else:
|
|
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
|
|
all_ints_or_floats = all(
|
|
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
|
|
for x in args
|
|
)
|
|
bin_ops = {"add", "sub", "mul", "div", "sqrt"}
|
|
if (
|
|
getattr(self.value, "__module__", "") == "torch"
|
|
and self.value.__name__ in bin_ops
|
|
and any_symints_or_symfloats
|
|
and all_ints_or_floats
|
|
):
|
|
msg = f"""\
|
|
Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
|
|
To support this behavior, we need to allow const-propping tensors that store symint data.
|
|
For now, dynamo will explicitly graph break when it encounters user code with this behavior.
|
|
"""
|
|
log.warning(msg)
|
|
raise unimplemented(msg)
|
|
|
|
# TODO(voz): Replace w/ dynamic shape rewrite table.
|
|
# Ideally, we would be able to do this at ctor time, but alas we need a combination
|
|
# of value + args to determine this.
|
|
fn_ = self.value
|
|
if any(isinstance(x, SymNodeVariable) for x in args):
|
|
torch_sym_op = f"_sym_{self.value.__name__}"
|
|
if getattr(self.value, "__module__", None) == "math" and hasattr(
|
|
torch, torch_sym_op
|
|
):
|
|
fn_ = getattr(torch, torch_sym_op)
|
|
|
|
if fn_ is torch.tensor:
|
|
|
|
def check_any_unspec(x):
|
|
# NB: This includes UnspecializedPythonVariable
|
|
if isinstance(x, (TensorVariable, SymNodeVariable)):
|
|
return True
|
|
elif isinstance(x, (ListVariable, TupleVariable)):
|
|
return any(check_any_unspec(y) for y in x.items)
|
|
# TODO: there maybe other recursive structures you need to
|
|
# check
|
|
else:
|
|
return False
|
|
|
|
data_arg = None
|
|
if args:
|
|
data_arg = args[0]
|
|
elif "data" in kwargs:
|
|
data_arg = kwargs["data"]
|
|
|
|
# NB: OK to pass torch.tensor(tensor), this will trace fine
|
|
if not isinstance(data_arg, TensorVariable) and check_any_unspec(
|
|
data_arg
|
|
):
|
|
# This is slower and less canonical, so only use it if we
|
|
# have to
|
|
fn_ = torch._refs.tensor
|
|
|
|
tensor_variable = wrap_fx_proxy(
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_function",
|
|
fn_,
|
|
*proxy_args_kwargs(args, kwargs),
|
|
),
|
|
)
|
|
|
|
if (
|
|
isinstance(tensor_variable, TensorVariable)
|
|
and "requires_grad" in kwargs
|
|
and kwargs["requires_grad"].as_python_constant()
|
|
):
|
|
unimplemented(
|
|
"""factory functions that return tensors that require grad are not supported.
|
|
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
|
|
)
|
|
|
|
if "out" in kwargs and not (
|
|
isinstance(kwargs["out"], variables.ConstantVariable)
|
|
and kwargs["out"].as_python_constant() is None
|
|
):
|
|
# out variants of torch operators like torch.sort and
|
|
# torch.sigmoid mutate the tensors in the out field. Track such
|
|
# tensors and rewrite the symbolic locals.
|
|
if isinstance(tensor_variable, TupleVariable):
|
|
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
|
|
output_tensor_names = [
|
|
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
|
|
]
|
|
for idx, name in enumerate(output_tensor_names):
|
|
if name in tx.symbolic_locals:
|
|
tx.symbolic_locals[name] = tensor_variable.items[idx]
|
|
elif isinstance(tensor_variable, TensorVariable):
|
|
assert isinstance(kwargs["out"], TensorVariable)
|
|
if (
|
|
kwargs["out"].source
|
|
and kwargs["out"] in tx.output.graphargs
|
|
and kwargs["out"].size != tensor_variable.size
|
|
):
|
|
# It's hard to get out variants with resizing on graph inputs work
|
|
# properly across dynamo/aot/inductor, just fall back.
|
|
unimplemented("out variants with resizing on graph inputs")
|
|
assert "example_value" in kwargs["out"].proxy.node.meta
|
|
if not torch._prims_common.is_contiguous(
|
|
kwargs["out"].proxy.node.meta["example_value"]
|
|
):
|
|
# It's difficult to handle strides correctly in functionalization
|
|
# when calling an out= op with a non-contiguous out argument
|
|
unimplemented(
|
|
"out= op was called where output tensor was non-contiguous"
|
|
)
|
|
name = tx.find_symbolic_locals_name(kwargs["out"])
|
|
if name in tx.symbolic_locals:
|
|
tx.symbolic_locals[name] = tensor_variable
|
|
else:
|
|
unimplemented(f"out variant of {type(kwargs['out'])}")
|
|
|
|
return tensor_variable
|
|
|
|
def _call_ntuple(self, tx, args, kwargs):
|
|
"""inline behavior of torch.nn.modules.utils._ntuple"""
|
|
if self.value is torch.nn.modules.utils._ntuple:
|
|
count = args[0].as_python_constant()
|
|
else:
|
|
count = self.value.__closure__[0].cell_contents
|
|
assert isinstance(count, int)
|
|
assert not kwargs
|
|
|
|
def handle_ntuple(value):
|
|
if value.has_unpack_var_sequence(tx):
|
|
return variables.TupleVariable(
|
|
list(value.unpack_var_sequence(tx)),
|
|
)
|
|
elif value.is_python_constant():
|
|
# constant prop through it
|
|
return variables.ConstantVariable.create(
|
|
torch.nn.modules.utils._ntuple(count)(value.as_python_constant()),
|
|
)
|
|
else:
|
|
unimplemented(f"torch.nn.modules.utils._ntuple({value})")
|
|
|
|
if self.value is torch.nn.modules.utils._ntuple:
|
|
return variables.LambdaVariable(handle_ntuple)
|
|
else:
|
|
return handle_ntuple(args[0])
|