[hoo] Add with_effects to handle side effectful ops (#120296)

Proposal: https://docs.google.com/document/d/179QyhicGzTXJ5jvTAoAosP_Nzgf3PpgZwU_E3VV9PlM/edit#heading=h.bnm38nu3yfno
Implementation discussion: https://docs.google.com/document/d/179QyhicGzTXJ5jvTAoAosP_Nzgf3PpgZwU_E3VV9PlM/edit#heading=h.bj61609o1buq

Result with print:
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %with_effects : [num_users=1] = call_function[target=torch._higher_order_ops.effects.with_effects](args = (%arg0_1, aten.print.default, moo), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, %arg1_1), kwargs = {})
    return [getitem, add]
```

Follow ups:
* Add handling to auto_functionalize
* Add support for tokens on the export side
* Add support for tokens on the inductor side

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120296
Approved by: https://github.com/zou3519
This commit is contained in:
angelayi
2024-03-04 16:28:02 -08:00
committed by PyTorch MergeBot
parent 29976519a1
commit a7e93c341f
10 changed files with 566 additions and 20 deletions

View File

@ -0,0 +1,208 @@
# Owner(s): ["module: functorch"]
import unittest
import torch
import torch._dynamo
import torch._inductor
import torch._inductor.decomposition
from torch._functorch.aot_autograd import aot_export_module
from torch._higher_order_ops.effects import with_effects
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TestCase,
)
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
class TestWithEffects(TestCase):
def setUp(self):
if IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library(
"//caffe2/test/cpp/jit:test_custom_class_registrations"
)
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
def test_print(self):
class M(torch.nn.Module):
def forward(self, x):
torch.ops.aten._print("moo")
res = x + x
torch.ops.aten._print("moo")
return (res,)
inputs = (torch.randn(3),)
# Without functionalization, print should just appear in the graph directly
gm = make_fx(M())(*inputs)
FileCheck().check_count("torch.ops.aten._print.default", 2, exactly=True).run(
gm.code
)
# With functionalization, it should appear wrapped with with_effects()
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1, arg1_1):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
getitem = with_effects[0]; with_effects = None
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add)""",
)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
@unittest.expectedFailure # Will enable this once we enable tokens in export
def test_torchbind_custom_op(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
return (x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x),)
with enable_torchbind_tracing():
gm, gs = aot_export_module(M(), (torch.ones(2, 3),), trace_joint=False)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1):
_tensor_constant0 = self._tensor_constant0
takes_foo = torch.ops._TorchScriptTesting.takes_foo.default(_tensor_constant0, arg0_1); _tensor_constant0 = None
add = torch.ops.aten.add.Tensor(arg0_1, takes_foo); arg0_1 = takes_foo = None
return (add,)""", # noqa: B950
)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
def test_print_with_buffer_mutations(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.ones(3))
def forward(self, x):
torch.ops.aten._print("moo")
res = x + x
self.buf.add_(res)
res = self.buf + x
torch.ops.aten._print("moo")
return (res,)
inputs = (torch.randn(3),)
# With functionalization, it should appear wrapped with with_effects()
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
getitem = with_effects[0]; with_effects = None
add = torch.ops.aten.add.Tensor(arg2_1, arg2_1)
add_1 = torch.ops.aten.add.Tensor(arg1_1, add); arg1_1 = add = None
add_2 = torch.ops.aten.add.Tensor(add_1, arg2_1); arg2_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add_1, add_2)""",
)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
self.assertEqual(len(gs.buffers_to_mutate), 1)
def test_print_with_input_mutations(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
torch.ops.aten._print("moo")
res = x + x
x.add_(res)
res = x + x
torch.ops.aten._print("moo")
return (res,)
inputs = (torch.randn(3),)
# With functionalization, it should appear wrapped with with_effects()
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
self.assertEqual(len(gs.user_inputs_to_mutate), 1)
def test_alias_op(self):
def f(token, x):
token, out = with_effects(token, torch.ops.aten.absolute_.default, x)
return token, out
with self.assertRaisesRegex(
AssertionError, r"Ops with aliasing is not supported"
):
make_fx(f)(torch.tensor([]), torch.tensor(4))
def test_compile_aot_eager(self):
def f(x):
torch.ops.aten._print("moo")
res = x + x
torch.ops.aten._print("moo")
return res
inputs = (torch.randn(2, 3),)
res = torch.compile(f, backend="aot_eager")(*inputs)
self.assertTrue(torch.allclose(res, f(*inputs)))
@skipIfTorchDynamo(
"We're testing if the test works with inductor, which it currently"
"doesn't, so we expectedFailure-d the test, but the Dynamo tests"
"override the backend, causing an unexpected success"
)
@unittest.expectedFailure # NYI: AssertionError: with_effects is not an OpOverload
def test_compile_inductor(self):
def f(x):
torch.ops.aten._print("moo")
res = x + x
torch.ops.aten._print("moo")
return res
inputs = (torch.randn(2, 3),)
res = torch.compile(f, backend="inductor")(*inputs)
self.assertTrue(torch.allclose(res, f(*inputs)))
def test_compile_aot_eager_requires_grad(self):
def f(x):
torch.ops.aten._print("moo")
res = x + x
torch.ops.aten._print("moo")
return res
inputs = (torch.randn(2, 3, requires_grad=True),)
res = torch.compile(f, backend="aot_eager")(*inputs)
self.assertTrue(torch.allclose(res, f(*inputs)))
if __name__ == "__main__":
run_tests()

View File

@ -105,7 +105,8 @@ def run_functionalized_fw_and_collect_metadata(
# It doesn't matter if we run this under predispatch or not because it is # It doesn't matter if we run this under predispatch or not because it is
# only for figuring out metadata # only for figuring out metadata
with disable_above, FunctionalTensorMode(): mode = FunctionalTensorMode(_allow_token_discovery=True)
with disable_above, mode:
# precondition: The passed in function already handles unflattening inputs + flattening outputs # precondition: The passed in function already handles unflattening inputs + flattening outputs
flat_f_args = pytree.tree_map(_to_fun, flat_args) flat_f_args = pytree.tree_map(_to_fun, flat_args)
flat_f_outs = f(*flat_f_args) flat_f_outs = f(*flat_f_args)
@ -618,6 +619,7 @@ from a multi-output view call"
subclass_tangent_meta=create_subclass_meta(traced_tangents), subclass_tangent_meta=create_subclass_meta(traced_tangents),
is_train=is_train, is_train=is_train,
grad_enabled_mutation=grad_enabled_mutation, grad_enabled_mutation=grad_enabled_mutation,
tokens=mode._tokens,
) )
return metadata return metadata

View File

@ -374,9 +374,10 @@ def create_graph_signature(
graph_output_names = _graph_output_names(fx_g) graph_output_names = _graph_output_names(fx_g)
num_params_buffers = len(param_names) + len(buffer_names) num_params_buffers = len(param_names) + len(buffer_names)
num_tokens = len(fw_metadata.tokens)
# We have enough restrictions on the graph (no de-duping, synthetic bases, etc), # We have enough restrictions on the graph (no de-duping, synthetic bases, etc),
# Such that # graph inps = # user inps + # params + # buffers # Such that # graph inps = # user inps + # params + # buffers
num_user_args = len(graph_input_names) - num_params_buffers num_user_args = len(graph_input_names) - num_params_buffers - num_tokens
if trace_joint: if trace_joint:
assert num_user_fw_outs is not None assert num_user_fw_outs is not None
@ -411,7 +412,9 @@ def create_graph_signature(
else: else:
backward_signature = None backward_signature = None
num_user_fw_outs = ( num_user_fw_outs = (
len(graph_output_names) - fw_metadata.num_mutated_inp_runtime_indices len(graph_output_names)
- fw_metadata.num_mutated_inp_runtime_indices
- num_tokens
) )
return GraphSignature.from_tracing_metadata( return GraphSignature.from_tracing_metadata(

View File

@ -222,6 +222,9 @@ def aot_dispatch_autograd(
+ inner_meta.num_outputs + inner_meta.num_outputs
+ inner_meta.num_intermediate_bases + inner_meta.num_intermediate_bases
+ inner_meta.num_outputs_rng_offset + inner_meta.num_outputs_rng_offset
+ len(
fw_metadata.tokens
) # See Note [Side-Effectful Tokens in AOTAutograd]
) )
fw_module, bw_module = aot_config.partition_fn( fw_module, bw_module = aot_config.partition_fn(
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
@ -493,7 +496,7 @@ def aot_dispatch_autograd(
args = (*args, seed, offset) args = (*args, seed, offset)
# There is a pretty complicated calling convention around what the compiled fw returns. # There is a pretty complicated calling convention around what the compiled fw returns.
# The full list of outputs and their relative order is: # The full list of outputs and their relative order is:
# (*mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints) # (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
# - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version # - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version
# of the original view, and not the synthetic base # of the original view, and not the synthetic base
@ -514,6 +517,7 @@ def aot_dispatch_autograd(
num_mutated_runtime_inps = ( num_mutated_runtime_inps = (
CompiledFunction.metadata.num_mutated_inp_runtime_indices CompiledFunction.metadata.num_mutated_inp_runtime_indices
) )
num_tokens = len(CompiledFunction.metadata.tokens)
num_forward_returns = CompiledFunction.metadata.num_forward_returns num_forward_returns = CompiledFunction.metadata.num_forward_returns
num_forward = CompiledFunction.metadata.num_forward num_forward = CompiledFunction.metadata.num_forward
@ -538,7 +542,7 @@ def aot_dispatch_autograd(
), str([type(x) for x in symint_outs]) ), str([type(x) for x in symint_outs])
ctx.symints = symint_outs ctx.symints = symint_outs
raw_returns = fw_outs[0:num_forward_returns] raw_returns = fw_outs[0 : num_forward_returns + num_tokens]
# Wrap all autograd.Function.forward() outputs that are aliases # Wrap all autograd.Function.forward() outputs that are aliases
# so that autograd.Function doesn't treat them as tensors # so that autograd.Function doesn't treat them as tensors

View File

@ -69,10 +69,15 @@ def create_runtime_wrapper(
keep_input_mutations: bool, keep_input_mutations: bool,
disable_amp: bool, disable_amp: bool,
): ):
num_tokens = len(runtime_metadata.tokens)
if not hasattr(compiled_fn, "_boxed_call"): if not hasattr(compiled_fn, "_boxed_call"):
compiled_fn = make_boxed_func(compiled_fn) compiled_fn = make_boxed_func(compiled_fn)
def runtime_wrapper(*args): def runtime_wrapper(*args):
# Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
args = (*[torch.tensor([])] * num_tokens, *args)
if trace_joint: if trace_joint:
args_ = list(args) args_ = list(args)
# See Note [Detaching inputs that never need gradients] # See Note [Detaching inputs that never need gradients]
@ -120,8 +125,12 @@ def create_runtime_wrapper(
== num_mutated_runtime_inps == num_mutated_runtime_inps
+ runtime_metadata.num_outputs + runtime_metadata.num_outputs
+ num_intermediate_bases + num_intermediate_bases
+ num_tokens
) )
# Toss out the effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
all_outs = all_outs[num_tokens:]
# Step 3: After running the compiled fw, apply updates to mutated inputs # Step 3: After running the compiled fw, apply updates to mutated inputs
num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices
if num_mutations_to_apply > 0: if num_mutations_to_apply > 0:

View File

@ -5,7 +5,7 @@ input/output types, metadata, config, function signatures etc.
import collections import collections
import functools import functools
from dataclasses import dataclass from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, NewType, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, NewType, Optional, Set, Tuple, Union
@ -267,6 +267,11 @@ class ViewAndMutationMeta:
# raised # raised
deterministic: Optional[bool] = None deterministic: Optional[bool] = None
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
# side-effectful operators, FunctionalTensorMode will populate this
# dictionary telling us how many tokens we will need during tracing.
tokens: Dict[Any, torch.Tensor] = field(default_factory=dict)
def __post_init__(self): def __post_init__(self):
# pre-compute the indices of the inputs that are mutated. # pre-compute the indices of the inputs that are mutated.
# When keep_input_mutations is set, we don't need to worry about our epilogue # When keep_input_mutations is set, we don't need to worry about our epilogue
@ -549,6 +554,9 @@ class GraphSignature:
backward_signature: Optional[BackwardSignature] backward_signature: Optional[BackwardSignature]
input_tokens: List[GraphInputName]
output_tokens: List[GraphOutputName]
@classmethod @classmethod
def from_tracing_metadata( def from_tracing_metadata(
cls, cls,
@ -569,35 +577,54 @@ class GraphSignature:
graph_outputs = graph_output_names graph_outputs = graph_output_names
parameters = list(named_parameters) parameters = list(named_parameters)
buffers = list(named_buffers) buffers = list(named_buffers)
num_tokens = len(view_mutation_metadata.tokens)
# Calling convention assumptions: # Calling convention assumptions:
# (1) graph inputs = (params, buffers, user_inputs) # (1) graph inputs = (input_tokens, params, buffers, user_inputs)
# (2) graph outputs = (mutated_inputs, user_outs, param_gradients) # (2) graph outputs = (output_tokens, mutated_inputs, user_outs, param_gradients)
# (If we are capturing an inference graph, this convention is identical # (If we are capturing an inference graph, this convention is identical
# except that param_gradients is empty) # except that param_gradients is empty)
user_inputs = graph_inputs[len(parameters) + len(buffers) :] # See Note [Side-Effectful Tokens in AOTAutograd] for information on tokens
assert num_user_inputs == len(user_inputs)
assert len(graph_inputs) == (len(parameters) + len(buffers) + len(user_inputs))
inputs_to_parameters = dict(zip(graph_inputs[: len(parameters)], parameters)) # Address input calling conventions:
start, stop = 0, num_tokens
input_tokens = graph_inputs[start:stop]
start, stop = stop, stop + len(parameters)
inputs_to_parameters = dict(zip(graph_inputs[start:stop], parameters))
start, stop = stop, stop + len(buffers)
inputs_to_buffers = dict( inputs_to_buffers = dict(
zip( zip(
graph_inputs[len(parameters) : len(parameters) + len(buffers)], graph_inputs[start:stop],
buffers, buffers,
) )
) )
names = [*parameters, *buffers, *user_inputs] start, stop = stop, stop + num_user_inputs
user_inputs = graph_inputs[start:stop]
# We should've gone through all the inputs now
assert len(graph_inputs) - stop == 0
# Address output calling conventions:
start, stop = 0, num_tokens
output_tokens = graph_outputs[start:stop]
names = [*input_tokens, *parameters, *buffers, *user_inputs]
mutations = [] mutations = []
for idx, input_info in enumerate(view_mutation_metadata.input_info): for idx, input_info in enumerate(view_mutation_metadata.input_info):
if input_info.mutates_data: if input_info.mutates_data:
# Only buffers can be mutated, not parameters # Only buffers can be mutated, not parameters
assert idx >= len(parameters) assert idx >= len(parameters)
mutations.append(names[idx]) mutations.append(names[idx + num_tokens])
assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices
start, stop = 0, view_mutation_metadata.num_mutated_inp_runtime_indices start, stop = (
stop,
stop + view_mutation_metadata.num_mutated_inp_runtime_indices,
)
outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations)) outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations))
user_inputs_to_mutate = {} user_inputs_to_mutate = {}
@ -631,6 +658,8 @@ class GraphSignature:
in_spec=in_spec, in_spec=in_spec,
out_spec=out_spec, out_spec=out_spec,
backward_signature=backward_signature, backward_signature=backward_signature,
input_tokens=input_tokens, # type: ignore[arg-type]
output_tokens=output_tokens, # type: ignore[arg-type]
) )

View File

@ -23,7 +23,6 @@ from torch import Tensor
from torch._decomp.decompositions_for_rng import PhiloxStateTracker from torch._decomp.decompositions_for_rng import PhiloxStateTracker
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper from torch._prims_common import CUDARngStateHelper
from torch._subclasses.functional_tensor import FunctionalTensorMode
from torch.fx.experimental.symbolic_shapes import definitely_false, sym_eq from torch.fx.experimental.symbolic_shapes import definitely_false, sym_eq
from torch.nn.utils import stateless from torch.nn.utils import stateless
@ -350,12 +349,43 @@ def create_functionalized_fn(
disable_above = torch._C._ExcludeDispatchKeyGuard( disable_above = torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
) )
with disable_above, FunctionalTensorMode(aot_config.pre_dispatch):
# See Note [Side-Effectful Tokens in AOTAutograd]
if trace_joint:
assert (
isinstance(args, tuple)
and len(args) == 2
and isinstance(args[0], (list, tuple))
)
tokens = args[0][: len(meta.tokens)]
actual_args = args[0][len(meta.tokens) :]
args = (actual_args, args[1])
else:
tokens = args[: len(meta.tokens)]
args = args[len(meta.tokens) :]
assert all(token.numel() == 0 for token in tokens)
with disable_above:
# Wrap inputs into functional wrappers # Wrap inputs into functional wrappers
f_args = pytree.tree_map(to_fun, args) f_args = pytree.tree_map(to_fun, args)
f_tokens = pytree.tree_map(to_fun, tokens)
# Populate the current FunctionalTensorMode with the tokens per
# operator. See Note [FunctionalTensorMode is Stateful]
functional_tensor_mode = (
torch.utils._python_dispatch._detect_functional_mode()
)
assert functional_tensor_mode is not None
for i, k in enumerate(meta.tokens.keys()):
functional_tensor_mode._tokens[k] = f_tokens[i]
# Run the joint # Run the joint
f_outs = fn(*f_args) f_outs = fn(*f_args)
# Return both the tokens and the outputs
# See Note [Side-Effectful Tokens in AOTAutograd]
f_outs = (*functional_tensor_mode._tokens.values(), *f_outs)
if trace_joint: if trace_joint:
# We support a limited amount of mutation of graph inputs during the backward pass. # We support a limited amount of mutation of graph inputs during the backward pass.
# (This is used e.g. by Float8, which needs to update buffers during the backward pass) # (This is used e.g. by Float8, which needs to update buffers during the backward pass)
@ -470,6 +500,14 @@ def create_functionalized_fn(
# Setup the wrapper for functionalization of rng ops # Setup the wrapper for functionalization of rng ops
helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint) helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint)
# Additionally pass in tokens as inputs
# See Note [Side-Effectful Tokens in AOTAutograd]
additional_token_inputs = [torch.tensor([])] * len(meta.tokens)
if trace_joint:
args = ([*additional_token_inputs, *args[0]], *args[1:])
else:
args = [*additional_token_inputs, *args]
return helper, args return helper, args

View File

@ -375,6 +375,31 @@ AOT_COUNTER = itertools.count()
# To work around this, we view every forward output when creating out tangent # To work around this, we view every forward output when creating out tangent
# tensors so that tangents can never be the same as forward inputs even if # tensors so that tangents can never be the same as forward inputs even if
# forward inputs alias forward outputs. # forward inputs alias forward outputs.
# Note [Side-Effectful Tokens in AOTAutograd]
#
# We allow some some side-effectful operators in
# the post-AOTAutograd (functional) graph, such as prints and torchbind operations.
# To ensure that these side-effects are compatible to future graph passes that
# assume that the graph is functional, we will thread "effect tokens" to show
# data dependence between these side-effectful operators. Practically speaking,
# effect tokens are just dummy values (torch.tensor([])). The graph would look
# like the following:
#
# def gm(self, token0, reader):
# token1, frame = with_token(ordered_effect_op, (reader,), token0)
# frame = frame * 2
# token2, frame2 = with_token(ordered_effect_op, (reader,), token1)
# frame2 = frame2 * 2
# return token2, frame, frame2
#
# We will pass the token as an input to the graph, thread it through
# side-effectful operators using the `with_effects` high order operator, and then
# return the updated token as an output.
# So the signature of the graph input would look something like
# (*tokens, *params_buffers, *user_inputs), and the signature of the graph
# output would look something like (*tokens, *outputs).
# #
# #
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -0,0 +1,206 @@
from enum import Enum
from typing import Any, Dict, Optional, Tuple
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)
class _EffectType(Enum):
ORDERED = "Ordered"
SIDE_EFFECTS: Dict[torch._ops.OpOverload, _EffectType] = {
torch.ops.aten._print.default: _EffectType.ORDERED,
}
class WithEffects(HigherOrderOperator):
"""
with_effects(token, op, args, kwargs) -> (new_token, op_results)
This HOP helps ensure ordering between side effectful ops like prints or ops
using torchbind objects. This is needed to ensure a traced graph from
AOTAutograd is functional so that future optimization passes do not reorder
these operators. This is done through threading "effect tokens" through the
graph to enforce data dependence between side effectful ops.
The tokens are basically dummy values (torch.tensor([])). We create a token
per "effect type", which are enumerated in the _EffectType enum.
"""
def __init__(self):
super().__init__("with_effects")
def __call__(
self,
token,
op: torch._ops.OpOverload,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[Any, ...]:
assert isinstance(op, torch._ops.OpOverload)
assert not has_aliasing(op), "Ops with aliasing is not supported"
assert has_effects(op, args, kwargs)
assert isinstance(kwargs, dict)
return super().__call__(token, op, *args, **kwargs)
with_effects = WithEffects()
def has_aliasing(op: torch._ops.OpOverload):
for arg in op._schema.arguments:
if arg.alias_info is not None:
return True
for arg in op._schema.returns:
if arg.alias_info is not None:
return True
return False
def has_effects(op, args, kwargs) -> bool:
return (
isinstance(op, torch._ops.OpOverload)
and not has_aliasing(op)
and get_effect_key(op, args, kwargs) is not None
)
def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
if op in SIDE_EFFECTS:
return SIDE_EFFECTS[op]
# TODO(angelayi): Enable this when enabling tokens with export -- this will
# break some existing export tests right now
# for arg in args:
# if isinstance(arg, torch.ScriptObject):
# return _EffectType.ORDERED
return None
@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd)
def with_effects_dense(
token: torch.Tensor,
op: torch._ops.OpOverload,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
out = op(*args, **kwargs)
new_token = torch.tensor([])
if isinstance(out, tuple):
return (new_token, *out)
return (new_token, out)
@with_effects.py_impl(FakeTensorMode)
def with_effects_fake(
mode,
token: torch.Tensor,
op: torch._ops.OpOverload,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
with mode:
result = with_effects_dense(token, op, *args, **kwargs)
return result
@with_effects.py_impl(ProxyTorchDispatchMode)
def with_effects_proxy(
mode,
token: torch.Tensor,
op: torch._ops.OpOverload,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
if not mode.enable_tracing:
return with_effects(token, op, *args, **kwargs)
with disable_proxy_modes_tracing():
out = with_effects(token, op, *args, **kwargs)
proxy_token = mode.tracer.unwrap_proxy(token)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function",
with_effects,
(proxy_token, op, *proxy_args),
proxy_kwargs,
)
result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
return result
with_effects.fallthrough(DispatchKey.AutogradCPU)
with_effects.fallthrough(DispatchKey.AutogradCUDA)
def handle_effects(
allow_token_discovery: bool,
tokens: Dict[_EffectType, torch.Tensor],
op: torch._ops.OpOverload,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
"""
Args:
allow_token_discovery: Whether or not we are discovering tokens. If this
is true, we will create a token for every side effect type seen that
does not have a token assigned yet. If this is false, the tokens
should've all been created ahead of time, so we will error if there is
no token mapping to every effect type.
tokens: Map of effect type to tokens. This is to chain operators of the
same effects together so that they do not get reordered in later
optimization passes.
"""
# Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
# this will create an empty tensor during proxy mode tracing if the token
# doesn't exist. But the tokens should always exist during proxy mode tracing.
key = get_effect_key(op, args, kwargs)
assert key is not None
if key not in tokens:
assert allow_token_discovery, f"Could not find a token for effect {key}"
tokens[key] = torch.tensor([])
token = tokens[key]
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
ctx = PythonFunctionalizeAPI()
unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type]
unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type]
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
with ctx.redispatch_to_next():
(new_token, *unwrapped_outs) = with_effects(
unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs # type: ignore[arg-type]
)
if len(op._schema.returns) == 0:
assert unwrapped_outs[0] is None
unwrapped_outs = None # type: ignore[assignment]
elif len(op._schema.returns) == 1:
assert len(unwrapped_outs) == 1
unwrapped_outs = unwrapped_outs[0]
else:
assert len(unwrapped_outs) == len(op._schema.returns)
# Add the newly created token into the tokens map for a following call to
# use this token.
wrapped_token = ctx.wrap_tensors(new_token)
assert isinstance(wrapped_token, torch.Tensor)
tokens[key] = wrapped_token
return ctx.wrap_tensors(unwrapped_outs) # type: ignore[arg-type]

View File

@ -1,6 +1,6 @@
import contextlib import contextlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, ContextManager, Optional, Tuple from typing import Any, Callable, ContextManager, Dict, Optional, Tuple
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -215,7 +215,7 @@ class FunctionalTensor(torch.Tensor):
class FunctionalTensorMode(TorchDispatchMode): class FunctionalTensorMode(TorchDispatchMode):
def __init__(self, pre_dispatch=False, export=False): def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False):
self.export = export self.export = export
self.is_on_stack = False self.is_on_stack = False
self.enter_stack = [] self.enter_stack = []
@ -225,6 +225,18 @@ class FunctionalTensorMode(TorchDispatchMode):
self.pre_dispatch = pre_dispatch self.pre_dispatch = pre_dispatch
# This will be turned off later for pre-dispatch functionalization # This will be turned off later for pre-dispatch functionalization
self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined] self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined]
# Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep
# track of the ordering between side effectful operations.
self._tokens: Dict[Any, torch.Tensor] = {}
# Functionalization runs twice in AOTAutograd, once in
# `run_functionalized_fw_and_collect_metadata` to collect metadata to
# see which tensors need to be functionalized and discover how many
# tokens we need, and another time in `make_fx` which does the actual
# tracing to replace ops with their functional variants and handling
# side-effectful ops. In the second stage there should be no token
# discovery. This flag distinguishes between the two stages.
self._allow_token_discovery = _allow_token_discovery
# No-op if FunctionalTensorMode is already in use # No-op if FunctionalTensorMode is already in use
def __enter__(self): def __enter__(self):
@ -338,6 +350,16 @@ class FunctionalTensorMode(TorchDispatchMode):
) )
return do_auto_functionalize(func, args, kwargs) return do_auto_functionalize(func, args, kwargs)
from torch._higher_order_ops.effects import handle_effects, has_effects
if has_effects(func, args, kwargs):
assert not torch._C._dispatch_has_kernel_for_dispatch_key(
func.name(), torch._C.DispatchKey.Functionalize
)
return handle_effects(
self._allow_token_discovery, self._tokens, func, args, kwargs
)
args_unwrapped, kwargs_unwrapped = pytree.tree_map_only( args_unwrapped, kwargs_unwrapped = pytree.tree_map_only(
FunctionalTensor, unwrap, (args, kwargs) FunctionalTensor, unwrap, (args, kwargs)
) )