mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[hoo] Add with_effects to handle side effectful ops (#120296)
Proposal: https://docs.google.com/document/d/179QyhicGzTXJ5jvTAoAosP_Nzgf3PpgZwU_E3VV9PlM/edit#heading=h.bnm38nu3yfno Implementation discussion: https://docs.google.com/document/d/179QyhicGzTXJ5jvTAoAosP_Nzgf3PpgZwU_E3VV9PlM/edit#heading=h.bj61609o1buq Result with print: ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %with_effects : [num_users=1] = call_function[target=torch._higher_order_ops.effects.with_effects](args = (%arg0_1, aten.print.default, moo), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, %arg1_1), kwargs = {}) return [getitem, add] ``` Follow ups: * Add handling to auto_functionalize * Add support for tokens on the export side * Add support for tokens on the inductor side Pull Request resolved: https://github.com/pytorch/pytorch/pull/120296 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
29976519a1
commit
a7e93c341f
208
test/higher_order_ops/test_with_effects.py
Normal file
208
test/higher_order_ops/test_with_effects.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
# Owner(s): ["module: functorch"]
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch._dynamo
|
||||||
|
import torch._inductor
|
||||||
|
import torch._inductor.decomposition
|
||||||
|
from torch._functorch.aot_autograd import aot_export_module
|
||||||
|
from torch._higher_order_ops.effects import with_effects
|
||||||
|
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||||
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
from torch.testing import FileCheck
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
find_library_location,
|
||||||
|
IS_FBCODE,
|
||||||
|
IS_MACOS,
|
||||||
|
IS_SANDCASTLE,
|
||||||
|
IS_WINDOWS,
|
||||||
|
run_tests,
|
||||||
|
skipIfTorchDynamo,
|
||||||
|
TestCase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||||
|
class TestWithEffects(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
if IS_MACOS:
|
||||||
|
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||||
|
elif IS_SANDCASTLE or IS_FBCODE:
|
||||||
|
torch.ops.load_library(
|
||||||
|
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
||||||
|
)
|
||||||
|
elif IS_WINDOWS:
|
||||||
|
lib_file_path = find_library_location("torchbind_test.dll")
|
||||||
|
torch.ops.load_library(str(lib_file_path))
|
||||||
|
else:
|
||||||
|
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||||
|
torch.ops.load_library(str(lib_file_path))
|
||||||
|
|
||||||
|
def test_print(self):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
res = x + x
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
return (res,)
|
||||||
|
|
||||||
|
inputs = (torch.randn(3),)
|
||||||
|
|
||||||
|
# Without functionalization, print should just appear in the graph directly
|
||||||
|
gm = make_fx(M())(*inputs)
|
||||||
|
FileCheck().check_count("torch.ops.aten._print.default", 2, exactly=True).run(
|
||||||
|
gm.code
|
||||||
|
)
|
||||||
|
|
||||||
|
# With functionalization, it should appear wrapped with with_effects()
|
||||||
|
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(gm.code).strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, arg0_1, arg1_1):
|
||||||
|
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
|
||||||
|
getitem = with_effects[0]; with_effects = None
|
||||||
|
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
|
||||||
|
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
|
||||||
|
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
||||||
|
return (getitem_2, add)""",
|
||||||
|
)
|
||||||
|
self.assertEqual(len(gs.input_tokens), 1)
|
||||||
|
self.assertEqual(len(gs.output_tokens), 1)
|
||||||
|
|
||||||
|
@unittest.expectedFailure # Will enable this once we enable tokens in export
|
||||||
|
def test_torchbind_custom_op(self):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return (x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x),)
|
||||||
|
|
||||||
|
with enable_torchbind_tracing():
|
||||||
|
gm, gs = aot_export_module(M(), (torch.ones(2, 3),), trace_joint=False)
|
||||||
|
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(gm.code).strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, arg0_1):
|
||||||
|
_tensor_constant0 = self._tensor_constant0
|
||||||
|
takes_foo = torch.ops._TorchScriptTesting.takes_foo.default(_tensor_constant0, arg0_1); _tensor_constant0 = None
|
||||||
|
add = torch.ops.aten.add.Tensor(arg0_1, takes_foo); arg0_1 = takes_foo = None
|
||||||
|
return (add,)""", # noqa: B950
|
||||||
|
)
|
||||||
|
self.assertEqual(len(gs.input_tokens), 1)
|
||||||
|
self.assertEqual(len(gs.output_tokens), 1)
|
||||||
|
|
||||||
|
def test_print_with_buffer_mutations(self):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("buf", torch.ones(3))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
res = x + x
|
||||||
|
self.buf.add_(res)
|
||||||
|
res = self.buf + x
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
return (res,)
|
||||||
|
|
||||||
|
inputs = (torch.randn(3),)
|
||||||
|
|
||||||
|
# With functionalization, it should appear wrapped with with_effects()
|
||||||
|
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(gm.code).strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||||
|
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
|
||||||
|
getitem = with_effects[0]; with_effects = None
|
||||||
|
add = torch.ops.aten.add.Tensor(arg2_1, arg2_1)
|
||||||
|
add_1 = torch.ops.aten.add.Tensor(arg1_1, add); arg1_1 = add = None
|
||||||
|
add_2 = torch.ops.aten.add.Tensor(add_1, arg2_1); arg2_1 = None
|
||||||
|
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
|
||||||
|
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
||||||
|
return (getitem_2, add_1, add_2)""",
|
||||||
|
)
|
||||||
|
self.assertEqual(len(gs.input_tokens), 1)
|
||||||
|
self.assertEqual(len(gs.output_tokens), 1)
|
||||||
|
self.assertEqual(len(gs.buffers_to_mutate), 1)
|
||||||
|
|
||||||
|
def test_print_with_input_mutations(self):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
res = x + x
|
||||||
|
x.add_(res)
|
||||||
|
res = x + x
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
return (res,)
|
||||||
|
|
||||||
|
inputs = (torch.randn(3),)
|
||||||
|
|
||||||
|
# With functionalization, it should appear wrapped with with_effects()
|
||||||
|
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
|
||||||
|
self.assertEqual(len(gs.input_tokens), 1)
|
||||||
|
self.assertEqual(len(gs.output_tokens), 1)
|
||||||
|
self.assertEqual(len(gs.user_inputs_to_mutate), 1)
|
||||||
|
|
||||||
|
def test_alias_op(self):
|
||||||
|
def f(token, x):
|
||||||
|
token, out = with_effects(token, torch.ops.aten.absolute_.default, x)
|
||||||
|
return token, out
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
AssertionError, r"Ops with aliasing is not supported"
|
||||||
|
):
|
||||||
|
make_fx(f)(torch.tensor([]), torch.tensor(4))
|
||||||
|
|
||||||
|
def test_compile_aot_eager(self):
|
||||||
|
def f(x):
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
res = x + x
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
return res
|
||||||
|
|
||||||
|
inputs = (torch.randn(2, 3),)
|
||||||
|
|
||||||
|
res = torch.compile(f, backend="aot_eager")(*inputs)
|
||||||
|
self.assertTrue(torch.allclose(res, f(*inputs)))
|
||||||
|
|
||||||
|
@skipIfTorchDynamo(
|
||||||
|
"We're testing if the test works with inductor, which it currently"
|
||||||
|
"doesn't, so we expectedFailure-d the test, but the Dynamo tests"
|
||||||
|
"override the backend, causing an unexpected success"
|
||||||
|
)
|
||||||
|
@unittest.expectedFailure # NYI: AssertionError: with_effects is not an OpOverload
|
||||||
|
def test_compile_inductor(self):
|
||||||
|
def f(x):
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
res = x + x
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
return res
|
||||||
|
|
||||||
|
inputs = (torch.randn(2, 3),)
|
||||||
|
|
||||||
|
res = torch.compile(f, backend="inductor")(*inputs)
|
||||||
|
self.assertTrue(torch.allclose(res, f(*inputs)))
|
||||||
|
|
||||||
|
def test_compile_aot_eager_requires_grad(self):
|
||||||
|
def f(x):
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
res = x + x
|
||||||
|
torch.ops.aten._print("moo")
|
||||||
|
return res
|
||||||
|
|
||||||
|
inputs = (torch.randn(2, 3, requires_grad=True),)
|
||||||
|
|
||||||
|
res = torch.compile(f, backend="aot_eager")(*inputs)
|
||||||
|
self.assertTrue(torch.allclose(res, f(*inputs)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
@ -105,7 +105,8 @@ def run_functionalized_fw_and_collect_metadata(
|
|||||||
|
|
||||||
# It doesn't matter if we run this under predispatch or not because it is
|
# It doesn't matter if we run this under predispatch or not because it is
|
||||||
# only for figuring out metadata
|
# only for figuring out metadata
|
||||||
with disable_above, FunctionalTensorMode():
|
mode = FunctionalTensorMode(_allow_token_discovery=True)
|
||||||
|
with disable_above, mode:
|
||||||
# precondition: The passed in function already handles unflattening inputs + flattening outputs
|
# precondition: The passed in function already handles unflattening inputs + flattening outputs
|
||||||
flat_f_args = pytree.tree_map(_to_fun, flat_args)
|
flat_f_args = pytree.tree_map(_to_fun, flat_args)
|
||||||
flat_f_outs = f(*flat_f_args)
|
flat_f_outs = f(*flat_f_args)
|
||||||
@ -618,6 +619,7 @@ from a multi-output view call"
|
|||||||
subclass_tangent_meta=create_subclass_meta(traced_tangents),
|
subclass_tangent_meta=create_subclass_meta(traced_tangents),
|
||||||
is_train=is_train,
|
is_train=is_train,
|
||||||
grad_enabled_mutation=grad_enabled_mutation,
|
grad_enabled_mutation=grad_enabled_mutation,
|
||||||
|
tokens=mode._tokens,
|
||||||
)
|
)
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
@ -374,9 +374,10 @@ def create_graph_signature(
|
|||||||
graph_output_names = _graph_output_names(fx_g)
|
graph_output_names = _graph_output_names(fx_g)
|
||||||
|
|
||||||
num_params_buffers = len(param_names) + len(buffer_names)
|
num_params_buffers = len(param_names) + len(buffer_names)
|
||||||
|
num_tokens = len(fw_metadata.tokens)
|
||||||
# We have enough restrictions on the graph (no de-duping, synthetic bases, etc),
|
# We have enough restrictions on the graph (no de-duping, synthetic bases, etc),
|
||||||
# Such that # graph inps = # user inps + # params + # buffers
|
# Such that # graph inps = # user inps + # params + # buffers
|
||||||
num_user_args = len(graph_input_names) - num_params_buffers
|
num_user_args = len(graph_input_names) - num_params_buffers - num_tokens
|
||||||
|
|
||||||
if trace_joint:
|
if trace_joint:
|
||||||
assert num_user_fw_outs is not None
|
assert num_user_fw_outs is not None
|
||||||
@ -411,7 +412,9 @@ def create_graph_signature(
|
|||||||
else:
|
else:
|
||||||
backward_signature = None
|
backward_signature = None
|
||||||
num_user_fw_outs = (
|
num_user_fw_outs = (
|
||||||
len(graph_output_names) - fw_metadata.num_mutated_inp_runtime_indices
|
len(graph_output_names)
|
||||||
|
- fw_metadata.num_mutated_inp_runtime_indices
|
||||||
|
- num_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
return GraphSignature.from_tracing_metadata(
|
return GraphSignature.from_tracing_metadata(
|
||||||
|
@ -222,6 +222,9 @@ def aot_dispatch_autograd(
|
|||||||
+ inner_meta.num_outputs
|
+ inner_meta.num_outputs
|
||||||
+ inner_meta.num_intermediate_bases
|
+ inner_meta.num_intermediate_bases
|
||||||
+ inner_meta.num_outputs_rng_offset
|
+ inner_meta.num_outputs_rng_offset
|
||||||
|
+ len(
|
||||||
|
fw_metadata.tokens
|
||||||
|
) # See Note [Side-Effectful Tokens in AOTAutograd]
|
||||||
)
|
)
|
||||||
fw_module, bw_module = aot_config.partition_fn(
|
fw_module, bw_module = aot_config.partition_fn(
|
||||||
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
|
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
|
||||||
@ -493,7 +496,7 @@ def aot_dispatch_autograd(
|
|||||||
args = (*args, seed, offset)
|
args = (*args, seed, offset)
|
||||||
# There is a pretty complicated calling convention around what the compiled fw returns.
|
# There is a pretty complicated calling convention around what the compiled fw returns.
|
||||||
# The full list of outputs and their relative order is:
|
# The full list of outputs and their relative order is:
|
||||||
# (*mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
|
# (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
|
||||||
# - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version
|
# - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version
|
||||||
# of the original view, and not the synthetic base
|
# of the original view, and not the synthetic base
|
||||||
|
|
||||||
@ -514,6 +517,7 @@ def aot_dispatch_autograd(
|
|||||||
num_mutated_runtime_inps = (
|
num_mutated_runtime_inps = (
|
||||||
CompiledFunction.metadata.num_mutated_inp_runtime_indices
|
CompiledFunction.metadata.num_mutated_inp_runtime_indices
|
||||||
)
|
)
|
||||||
|
num_tokens = len(CompiledFunction.metadata.tokens)
|
||||||
num_forward_returns = CompiledFunction.metadata.num_forward_returns
|
num_forward_returns = CompiledFunction.metadata.num_forward_returns
|
||||||
num_forward = CompiledFunction.metadata.num_forward
|
num_forward = CompiledFunction.metadata.num_forward
|
||||||
|
|
||||||
@ -538,7 +542,7 @@ def aot_dispatch_autograd(
|
|||||||
), str([type(x) for x in symint_outs])
|
), str([type(x) for x in symint_outs])
|
||||||
ctx.symints = symint_outs
|
ctx.symints = symint_outs
|
||||||
|
|
||||||
raw_returns = fw_outs[0:num_forward_returns]
|
raw_returns = fw_outs[0 : num_forward_returns + num_tokens]
|
||||||
|
|
||||||
# Wrap all autograd.Function.forward() outputs that are aliases
|
# Wrap all autograd.Function.forward() outputs that are aliases
|
||||||
# so that autograd.Function doesn't treat them as tensors
|
# so that autograd.Function doesn't treat them as tensors
|
||||||
|
@ -69,10 +69,15 @@ def create_runtime_wrapper(
|
|||||||
keep_input_mutations: bool,
|
keep_input_mutations: bool,
|
||||||
disable_amp: bool,
|
disable_amp: bool,
|
||||||
):
|
):
|
||||||
|
num_tokens = len(runtime_metadata.tokens)
|
||||||
|
|
||||||
if not hasattr(compiled_fn, "_boxed_call"):
|
if not hasattr(compiled_fn, "_boxed_call"):
|
||||||
compiled_fn = make_boxed_func(compiled_fn)
|
compiled_fn = make_boxed_func(compiled_fn)
|
||||||
|
|
||||||
def runtime_wrapper(*args):
|
def runtime_wrapper(*args):
|
||||||
|
# Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
|
||||||
|
args = (*[torch.tensor([])] * num_tokens, *args)
|
||||||
|
|
||||||
if trace_joint:
|
if trace_joint:
|
||||||
args_ = list(args)
|
args_ = list(args)
|
||||||
# See Note [Detaching inputs that never need gradients]
|
# See Note [Detaching inputs that never need gradients]
|
||||||
@ -120,8 +125,12 @@ def create_runtime_wrapper(
|
|||||||
== num_mutated_runtime_inps
|
== num_mutated_runtime_inps
|
||||||
+ runtime_metadata.num_outputs
|
+ runtime_metadata.num_outputs
|
||||||
+ num_intermediate_bases
|
+ num_intermediate_bases
|
||||||
|
+ num_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Toss out the effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
|
||||||
|
all_outs = all_outs[num_tokens:]
|
||||||
|
|
||||||
# Step 3: After running the compiled fw, apply updates to mutated inputs
|
# Step 3: After running the compiled fw, apply updates to mutated inputs
|
||||||
num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices
|
num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices
|
||||||
if num_mutations_to_apply > 0:
|
if num_mutations_to_apply > 0:
|
||||||
|
@ -5,7 +5,7 @@ input/output types, metadata, config, function signatures etc.
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import functools
|
import functools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, NewType, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, NewType, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
@ -267,6 +267,11 @@ class ViewAndMutationMeta:
|
|||||||
# raised
|
# raised
|
||||||
deterministic: Optional[bool] = None
|
deterministic: Optional[bool] = None
|
||||||
|
|
||||||
|
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
|
||||||
|
# side-effectful operators, FunctionalTensorMode will populate this
|
||||||
|
# dictionary telling us how many tokens we will need during tracing.
|
||||||
|
tokens: Dict[Any, torch.Tensor] = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# pre-compute the indices of the inputs that are mutated.
|
# pre-compute the indices of the inputs that are mutated.
|
||||||
# When keep_input_mutations is set, we don't need to worry about our epilogue
|
# When keep_input_mutations is set, we don't need to worry about our epilogue
|
||||||
@ -549,6 +554,9 @@ class GraphSignature:
|
|||||||
|
|
||||||
backward_signature: Optional[BackwardSignature]
|
backward_signature: Optional[BackwardSignature]
|
||||||
|
|
||||||
|
input_tokens: List[GraphInputName]
|
||||||
|
output_tokens: List[GraphOutputName]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tracing_metadata(
|
def from_tracing_metadata(
|
||||||
cls,
|
cls,
|
||||||
@ -569,35 +577,54 @@ class GraphSignature:
|
|||||||
graph_outputs = graph_output_names
|
graph_outputs = graph_output_names
|
||||||
parameters = list(named_parameters)
|
parameters = list(named_parameters)
|
||||||
buffers = list(named_buffers)
|
buffers = list(named_buffers)
|
||||||
|
num_tokens = len(view_mutation_metadata.tokens)
|
||||||
|
|
||||||
# Calling convention assumptions:
|
# Calling convention assumptions:
|
||||||
# (1) graph inputs = (params, buffers, user_inputs)
|
# (1) graph inputs = (input_tokens, params, buffers, user_inputs)
|
||||||
# (2) graph outputs = (mutated_inputs, user_outs, param_gradients)
|
# (2) graph outputs = (output_tokens, mutated_inputs, user_outs, param_gradients)
|
||||||
# (If we are capturing an inference graph, this convention is identical
|
# (If we are capturing an inference graph, this convention is identical
|
||||||
# except that param_gradients is empty)
|
# except that param_gradients is empty)
|
||||||
user_inputs = graph_inputs[len(parameters) + len(buffers) :]
|
# See Note [Side-Effectful Tokens in AOTAutograd] for information on tokens
|
||||||
assert num_user_inputs == len(user_inputs)
|
|
||||||
assert len(graph_inputs) == (len(parameters) + len(buffers) + len(user_inputs))
|
|
||||||
|
|
||||||
inputs_to_parameters = dict(zip(graph_inputs[: len(parameters)], parameters))
|
# Address input calling conventions:
|
||||||
|
start, stop = 0, num_tokens
|
||||||
|
input_tokens = graph_inputs[start:stop]
|
||||||
|
|
||||||
|
start, stop = stop, stop + len(parameters)
|
||||||
|
inputs_to_parameters = dict(zip(graph_inputs[start:stop], parameters))
|
||||||
|
|
||||||
|
start, stop = stop, stop + len(buffers)
|
||||||
inputs_to_buffers = dict(
|
inputs_to_buffers = dict(
|
||||||
zip(
|
zip(
|
||||||
graph_inputs[len(parameters) : len(parameters) + len(buffers)],
|
graph_inputs[start:stop],
|
||||||
buffers,
|
buffers,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
names = [*parameters, *buffers, *user_inputs]
|
start, stop = stop, stop + num_user_inputs
|
||||||
|
user_inputs = graph_inputs[start:stop]
|
||||||
|
|
||||||
|
# We should've gone through all the inputs now
|
||||||
|
assert len(graph_inputs) - stop == 0
|
||||||
|
|
||||||
|
# Address output calling conventions:
|
||||||
|
start, stop = 0, num_tokens
|
||||||
|
output_tokens = graph_outputs[start:stop]
|
||||||
|
|
||||||
|
names = [*input_tokens, *parameters, *buffers, *user_inputs]
|
||||||
mutations = []
|
mutations = []
|
||||||
for idx, input_info in enumerate(view_mutation_metadata.input_info):
|
for idx, input_info in enumerate(view_mutation_metadata.input_info):
|
||||||
if input_info.mutates_data:
|
if input_info.mutates_data:
|
||||||
# Only buffers can be mutated, not parameters
|
# Only buffers can be mutated, not parameters
|
||||||
assert idx >= len(parameters)
|
assert idx >= len(parameters)
|
||||||
mutations.append(names[idx])
|
mutations.append(names[idx + num_tokens])
|
||||||
|
|
||||||
assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices
|
assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices
|
||||||
|
|
||||||
start, stop = 0, view_mutation_metadata.num_mutated_inp_runtime_indices
|
start, stop = (
|
||||||
|
stop,
|
||||||
|
stop + view_mutation_metadata.num_mutated_inp_runtime_indices,
|
||||||
|
)
|
||||||
outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations))
|
outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations))
|
||||||
|
|
||||||
user_inputs_to_mutate = {}
|
user_inputs_to_mutate = {}
|
||||||
@ -631,6 +658,8 @@ class GraphSignature:
|
|||||||
in_spec=in_spec,
|
in_spec=in_spec,
|
||||||
out_spec=out_spec,
|
out_spec=out_spec,
|
||||||
backward_signature=backward_signature,
|
backward_signature=backward_signature,
|
||||||
|
input_tokens=input_tokens, # type: ignore[arg-type]
|
||||||
|
output_tokens=output_tokens, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ from torch import Tensor
|
|||||||
from torch._decomp.decompositions_for_rng import PhiloxStateTracker
|
from torch._decomp.decompositions_for_rng import PhiloxStateTracker
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
from torch._prims_common import CUDARngStateHelper
|
from torch._prims_common import CUDARngStateHelper
|
||||||
from torch._subclasses.functional_tensor import FunctionalTensorMode
|
|
||||||
from torch.fx.experimental.symbolic_shapes import definitely_false, sym_eq
|
from torch.fx.experimental.symbolic_shapes import definitely_false, sym_eq
|
||||||
from torch.nn.utils import stateless
|
from torch.nn.utils import stateless
|
||||||
|
|
||||||
@ -350,12 +349,43 @@ def create_functionalized_fn(
|
|||||||
disable_above = torch._C._ExcludeDispatchKeyGuard(
|
disable_above = torch._C._ExcludeDispatchKeyGuard(
|
||||||
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
|
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
|
||||||
)
|
)
|
||||||
with disable_above, FunctionalTensorMode(aot_config.pre_dispatch):
|
|
||||||
|
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||||
|
if trace_joint:
|
||||||
|
assert (
|
||||||
|
isinstance(args, tuple)
|
||||||
|
and len(args) == 2
|
||||||
|
and isinstance(args[0], (list, tuple))
|
||||||
|
)
|
||||||
|
tokens = args[0][: len(meta.tokens)]
|
||||||
|
actual_args = args[0][len(meta.tokens) :]
|
||||||
|
args = (actual_args, args[1])
|
||||||
|
else:
|
||||||
|
tokens = args[: len(meta.tokens)]
|
||||||
|
args = args[len(meta.tokens) :]
|
||||||
|
assert all(token.numel() == 0 for token in tokens)
|
||||||
|
|
||||||
|
with disable_above:
|
||||||
# Wrap inputs into functional wrappers
|
# Wrap inputs into functional wrappers
|
||||||
f_args = pytree.tree_map(to_fun, args)
|
f_args = pytree.tree_map(to_fun, args)
|
||||||
|
f_tokens = pytree.tree_map(to_fun, tokens)
|
||||||
|
|
||||||
|
# Populate the current FunctionalTensorMode with the tokens per
|
||||||
|
# operator. See Note [FunctionalTensorMode is Stateful]
|
||||||
|
functional_tensor_mode = (
|
||||||
|
torch.utils._python_dispatch._detect_functional_mode()
|
||||||
|
)
|
||||||
|
assert functional_tensor_mode is not None
|
||||||
|
for i, k in enumerate(meta.tokens.keys()):
|
||||||
|
functional_tensor_mode._tokens[k] = f_tokens[i]
|
||||||
|
|
||||||
# Run the joint
|
# Run the joint
|
||||||
f_outs = fn(*f_args)
|
f_outs = fn(*f_args)
|
||||||
|
|
||||||
|
# Return both the tokens and the outputs
|
||||||
|
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||||
|
f_outs = (*functional_tensor_mode._tokens.values(), *f_outs)
|
||||||
|
|
||||||
if trace_joint:
|
if trace_joint:
|
||||||
# We support a limited amount of mutation of graph inputs during the backward pass.
|
# We support a limited amount of mutation of graph inputs during the backward pass.
|
||||||
# (This is used e.g. by Float8, which needs to update buffers during the backward pass)
|
# (This is used e.g. by Float8, which needs to update buffers during the backward pass)
|
||||||
@ -470,6 +500,14 @@ def create_functionalized_fn(
|
|||||||
# Setup the wrapper for functionalization of rng ops
|
# Setup the wrapper for functionalization of rng ops
|
||||||
helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint)
|
helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint)
|
||||||
|
|
||||||
|
# Additionally pass in tokens as inputs
|
||||||
|
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||||
|
additional_token_inputs = [torch.tensor([])] * len(meta.tokens)
|
||||||
|
if trace_joint:
|
||||||
|
args = ([*additional_token_inputs, *args[0]], *args[1:])
|
||||||
|
else:
|
||||||
|
args = [*additional_token_inputs, *args]
|
||||||
|
|
||||||
return helper, args
|
return helper, args
|
||||||
|
|
||||||
|
|
||||||
|
@ -375,6 +375,31 @@ AOT_COUNTER = itertools.count()
|
|||||||
# To work around this, we view every forward output when creating out tangent
|
# To work around this, we view every forward output when creating out tangent
|
||||||
# tensors so that tangents can never be the same as forward inputs even if
|
# tensors so that tangents can never be the same as forward inputs even if
|
||||||
# forward inputs alias forward outputs.
|
# forward inputs alias forward outputs.
|
||||||
|
|
||||||
|
# Note [Side-Effectful Tokens in AOTAutograd]
|
||||||
|
#
|
||||||
|
# We allow some some side-effectful operators in
|
||||||
|
# the post-AOTAutograd (functional) graph, such as prints and torchbind operations.
|
||||||
|
# To ensure that these side-effects are compatible to future graph passes that
|
||||||
|
# assume that the graph is functional, we will thread "effect tokens" to show
|
||||||
|
# data dependence between these side-effectful operators. Practically speaking,
|
||||||
|
# effect tokens are just dummy values (torch.tensor([])). The graph would look
|
||||||
|
# like the following:
|
||||||
|
#
|
||||||
|
# def gm(self, token0, reader):
|
||||||
|
# token1, frame = with_token(ordered_effect_op, (reader,), token0)
|
||||||
|
# frame = frame * 2
|
||||||
|
# token2, frame2 = with_token(ordered_effect_op, (reader,), token1)
|
||||||
|
# frame2 = frame2 * 2
|
||||||
|
# return token2, frame, frame2
|
||||||
|
#
|
||||||
|
# We will pass the token as an input to the graph, thread it through
|
||||||
|
# side-effectful operators using the `with_effects` high order operator, and then
|
||||||
|
# return the updated token as an output.
|
||||||
|
# So the signature of the graph input would look something like
|
||||||
|
# (*tokens, *params_buffers, *user_inputs), and the signature of the graph
|
||||||
|
# output would look something like (*tokens, *outputs).
|
||||||
|
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
206
torch/_higher_order_ops/effects.py
Normal file
206
torch/_higher_order_ops/effects.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils._pytree as pytree
|
||||||
|
from torch._C import DispatchKey
|
||||||
|
from torch._ops import HigherOrderOperator
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
|
from torch.fx.experimental.proxy_tensor import (
|
||||||
|
disable_proxy_modes_tracing,
|
||||||
|
ProxyTorchDispatchMode,
|
||||||
|
track_tensor_tree,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _EffectType(Enum):
|
||||||
|
ORDERED = "Ordered"
|
||||||
|
|
||||||
|
|
||||||
|
SIDE_EFFECTS: Dict[torch._ops.OpOverload, _EffectType] = {
|
||||||
|
torch.ops.aten._print.default: _EffectType.ORDERED,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class WithEffects(HigherOrderOperator):
|
||||||
|
"""
|
||||||
|
with_effects(token, op, args, kwargs) -> (new_token, op_results)
|
||||||
|
|
||||||
|
This HOP helps ensure ordering between side effectful ops like prints or ops
|
||||||
|
using torchbind objects. This is needed to ensure a traced graph from
|
||||||
|
AOTAutograd is functional so that future optimization passes do not reorder
|
||||||
|
these operators. This is done through threading "effect tokens" through the
|
||||||
|
graph to enforce data dependence between side effectful ops.
|
||||||
|
|
||||||
|
The tokens are basically dummy values (torch.tensor([])). We create a token
|
||||||
|
per "effect type", which are enumerated in the _EffectType enum.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("with_effects")
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
token,
|
||||||
|
op: torch._ops.OpOverload,
|
||||||
|
*args: Tuple[Any, ...],
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> Tuple[Any, ...]:
|
||||||
|
assert isinstance(op, torch._ops.OpOverload)
|
||||||
|
assert not has_aliasing(op), "Ops with aliasing is not supported"
|
||||||
|
assert has_effects(op, args, kwargs)
|
||||||
|
assert isinstance(kwargs, dict)
|
||||||
|
return super().__call__(token, op, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
with_effects = WithEffects()
|
||||||
|
|
||||||
|
|
||||||
|
def has_aliasing(op: torch._ops.OpOverload):
|
||||||
|
for arg in op._schema.arguments:
|
||||||
|
if arg.alias_info is not None:
|
||||||
|
return True
|
||||||
|
for arg in op._schema.returns:
|
||||||
|
if arg.alias_info is not None:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def has_effects(op, args, kwargs) -> bool:
|
||||||
|
return (
|
||||||
|
isinstance(op, torch._ops.OpOverload)
|
||||||
|
and not has_aliasing(op)
|
||||||
|
and get_effect_key(op, args, kwargs) is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
|
||||||
|
if op in SIDE_EFFECTS:
|
||||||
|
return SIDE_EFFECTS[op]
|
||||||
|
|
||||||
|
# TODO(angelayi): Enable this when enabling tokens with export -- this will
|
||||||
|
# break some existing export tests right now
|
||||||
|
# for arg in args:
|
||||||
|
# if isinstance(arg, torch.ScriptObject):
|
||||||
|
# return _EffectType.ORDERED
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||||
|
def with_effects_dense(
|
||||||
|
token: torch.Tensor,
|
||||||
|
op: torch._ops.OpOverload,
|
||||||
|
*args: Tuple[Any, ...],
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
out = op(*args, **kwargs)
|
||||||
|
new_token = torch.tensor([])
|
||||||
|
if isinstance(out, tuple):
|
||||||
|
return (new_token, *out)
|
||||||
|
return (new_token, out)
|
||||||
|
|
||||||
|
|
||||||
|
@with_effects.py_impl(FakeTensorMode)
|
||||||
|
def with_effects_fake(
|
||||||
|
mode,
|
||||||
|
token: torch.Tensor,
|
||||||
|
op: torch._ops.OpOverload,
|
||||||
|
*args: Tuple[Any, ...],
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
with mode:
|
||||||
|
result = with_effects_dense(token, op, *args, **kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@with_effects.py_impl(ProxyTorchDispatchMode)
|
||||||
|
def with_effects_proxy(
|
||||||
|
mode,
|
||||||
|
token: torch.Tensor,
|
||||||
|
op: torch._ops.OpOverload,
|
||||||
|
*args: Tuple[Any, ...],
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
if not mode.enable_tracing:
|
||||||
|
return with_effects(token, op, *args, **kwargs)
|
||||||
|
|
||||||
|
with disable_proxy_modes_tracing():
|
||||||
|
out = with_effects(token, op, *args, **kwargs)
|
||||||
|
|
||||||
|
proxy_token = mode.tracer.unwrap_proxy(token)
|
||||||
|
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
|
||||||
|
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
|
||||||
|
|
||||||
|
out_proxy = mode.tracer.create_proxy(
|
||||||
|
"call_function",
|
||||||
|
with_effects,
|
||||||
|
(proxy_token, op, *proxy_args),
|
||||||
|
proxy_kwargs,
|
||||||
|
)
|
||||||
|
result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
with_effects.fallthrough(DispatchKey.AutogradCPU)
|
||||||
|
with_effects.fallthrough(DispatchKey.AutogradCUDA)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_effects(
|
||||||
|
allow_token_discovery: bool,
|
||||||
|
tokens: Dict[_EffectType, torch.Tensor],
|
||||||
|
op: torch._ops.OpOverload,
|
||||||
|
args: Tuple[Any, ...],
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
allow_token_discovery: Whether or not we are discovering tokens. If this
|
||||||
|
is true, we will create a token for every side effect type seen that
|
||||||
|
does not have a token assigned yet. If this is false, the tokens
|
||||||
|
should've all been created ahead of time, so we will error if there is
|
||||||
|
no token mapping to every effect type.
|
||||||
|
|
||||||
|
tokens: Map of effect type to tokens. This is to chain operators of the
|
||||||
|
same effects together so that they do not get reordered in later
|
||||||
|
optimization passes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
|
||||||
|
# this will create an empty tensor during proxy mode tracing if the token
|
||||||
|
# doesn't exist. But the tokens should always exist during proxy mode tracing.
|
||||||
|
key = get_effect_key(op, args, kwargs)
|
||||||
|
assert key is not None
|
||||||
|
if key not in tokens:
|
||||||
|
assert allow_token_discovery, f"Could not find a token for effect {key}"
|
||||||
|
tokens[key] = torch.tensor([])
|
||||||
|
token = tokens[key]
|
||||||
|
|
||||||
|
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
|
||||||
|
|
||||||
|
ctx = PythonFunctionalizeAPI()
|
||||||
|
|
||||||
|
unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type]
|
||||||
|
unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type]
|
||||||
|
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
|
||||||
|
with ctx.redispatch_to_next():
|
||||||
|
(new_token, *unwrapped_outs) = with_effects(
|
||||||
|
unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(op._schema.returns) == 0:
|
||||||
|
assert unwrapped_outs[0] is None
|
||||||
|
unwrapped_outs = None # type: ignore[assignment]
|
||||||
|
elif len(op._schema.returns) == 1:
|
||||||
|
assert len(unwrapped_outs) == 1
|
||||||
|
unwrapped_outs = unwrapped_outs[0]
|
||||||
|
else:
|
||||||
|
assert len(unwrapped_outs) == len(op._schema.returns)
|
||||||
|
|
||||||
|
# Add the newly created token into the tokens map for a following call to
|
||||||
|
# use this token.
|
||||||
|
wrapped_token = ctx.wrap_tensors(new_token)
|
||||||
|
assert isinstance(wrapped_token, torch.Tensor)
|
||||||
|
tokens[key] = wrapped_token
|
||||||
|
|
||||||
|
return ctx.wrap_tensors(unwrapped_outs) # type: ignore[arg-type]
|
@ -1,6 +1,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, ContextManager, Optional, Tuple
|
from typing import Any, Callable, ContextManager, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
@ -215,7 +215,7 @@ class FunctionalTensor(torch.Tensor):
|
|||||||
|
|
||||||
|
|
||||||
class FunctionalTensorMode(TorchDispatchMode):
|
class FunctionalTensorMode(TorchDispatchMode):
|
||||||
def __init__(self, pre_dispatch=False, export=False):
|
def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False):
|
||||||
self.export = export
|
self.export = export
|
||||||
self.is_on_stack = False
|
self.is_on_stack = False
|
||||||
self.enter_stack = []
|
self.enter_stack = []
|
||||||
@ -225,6 +225,18 @@ class FunctionalTensorMode(TorchDispatchMode):
|
|||||||
self.pre_dispatch = pre_dispatch
|
self.pre_dispatch = pre_dispatch
|
||||||
# This will be turned off later for pre-dispatch functionalization
|
# This will be turned off later for pre-dispatch functionalization
|
||||||
self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined]
|
self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined]
|
||||||
|
# Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep
|
||||||
|
# track of the ordering between side effectful operations.
|
||||||
|
self._tokens: Dict[Any, torch.Tensor] = {}
|
||||||
|
|
||||||
|
# Functionalization runs twice in AOTAutograd, once in
|
||||||
|
# `run_functionalized_fw_and_collect_metadata` to collect metadata to
|
||||||
|
# see which tensors need to be functionalized and discover how many
|
||||||
|
# tokens we need, and another time in `make_fx` which does the actual
|
||||||
|
# tracing to replace ops with their functional variants and handling
|
||||||
|
# side-effectful ops. In the second stage there should be no token
|
||||||
|
# discovery. This flag distinguishes between the two stages.
|
||||||
|
self._allow_token_discovery = _allow_token_discovery
|
||||||
|
|
||||||
# No-op if FunctionalTensorMode is already in use
|
# No-op if FunctionalTensorMode is already in use
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@ -338,6 +350,16 @@ class FunctionalTensorMode(TorchDispatchMode):
|
|||||||
)
|
)
|
||||||
return do_auto_functionalize(func, args, kwargs)
|
return do_auto_functionalize(func, args, kwargs)
|
||||||
|
|
||||||
|
from torch._higher_order_ops.effects import handle_effects, has_effects
|
||||||
|
|
||||||
|
if has_effects(func, args, kwargs):
|
||||||
|
assert not torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||||
|
func.name(), torch._C.DispatchKey.Functionalize
|
||||||
|
)
|
||||||
|
return handle_effects(
|
||||||
|
self._allow_token_discovery, self._tokens, func, args, kwargs
|
||||||
|
)
|
||||||
|
|
||||||
args_unwrapped, kwargs_unwrapped = pytree.tree_map_only(
|
args_unwrapped, kwargs_unwrapped = pytree.tree_map_only(
|
||||||
FunctionalTensor, unwrap, (args, kwargs)
|
FunctionalTensor, unwrap, (args, kwargs)
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user