[aotd] Support HOP effects in backward (#132638)

Support of effectful operations in backward:

1/ AOTD collects metadata from forward fn only, so we can have usage of effectful ops in backward, that were not used in forward => Allowing tokens discovery during joint function .

FunctionalTensorMode holds _tokens, in Joint function after tracing forward we memoize _tokens as `_tokens_forward_output`.

2/ Tokens are added as primals inputs (forward) in EffectTokensWrapper.
Tokens that will be used in backward are in partitioner saved values. We do not have control on which positions they are saved in forward outputs.

2/ If new tokens discovered in backward after tracing joint_fn, the result graph will be manually added in the end of primals.
_aot_autograd/utils.py

3/ All effectful ops during backward are marked with 'must_be_in_backward' partitioner_tag, to prevent partiitoner to place them in forward.

For that functional_tensor_mode got new optional state `self._effects_partitioner_tag` for effectful ops, to set after tracing forward.

There are additional changes in partitioner to improve functionality of 'must_be_in_backward'

4/ Unlift tokens now should run for both forward and backward.
- As saved for backward tokens are placed on non static places - we identify input and output tokens to erase, by input and output of `with_effects` operation
- In forward we can have input tokens, discovered in backward, that are not used in with_effects ops in forward, but saved for backward. We identify them by position in forward inputs.

5/ Adding aot debug logging for graphs before unlifting and before adding additional primal for backward tokens.

Tests:
```
python test/higher_order_ops/test_with_effects.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132638
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev
2024-08-22 14:50:16 -07:00
committed by PyTorch MergeBot
parent 7fd3b69886
commit 8ae4f82243
12 changed files with 559 additions and 95 deletions

View File

@ -697,6 +697,7 @@ exclusions = {
"fusion",
"overlap",
"aot_graphs",
"aot_graphs_effects",
"post_grad_graphs",
"compiled_autograd",
"compiled_autograd_verbose",

View File

@ -30,6 +30,7 @@ from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TEST_CUDA,
TEST_WITH_ROCM,
TestCase,
@ -132,8 +133,8 @@ def forward(self, arg1_1):
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
_sink_tokens_default = torch.ops.prims._sink_tokens.default((getitem_2,)); getitem_2 = _sink_tokens_default = None
return (add,)""", # noqa: B950
_sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_2]); getitem_2 = _sink_tokens_default = None
return [add]""", # noqa: B950
)
def test_torchbind_custom_op(self):
@ -429,14 +430,17 @@ def forward(self, arg0_1, arg1_1, arg2_1):
def test_effectful_custom_op_with_subclasses(self):
with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
lib.define("zoo(Tensor x) -> Tensor")
lib.define("zoo2(Tensor x) -> Tensor")
d = {"fw": 0}
d = {"fw": 0, "bw": 0}
def reset_counter():
d["fw"] = 0
d["bw"] = 0
def assert_counter(fw):
def assert_counter(fw, bw):
self.assertEqual(d["fw"], fw)
self.assertEqual(d["bw"], bw)
def foo_impl(a):
d["fw"] = d["fw"] + 1
@ -445,22 +449,24 @@ def forward(self, arg0_1, arg1_1, arg2_1):
def foo_meta(a):
return a.clone()
def foo_bwd(ctx, grad):
return grad.clone()
def foo2_impl(x):
d["bw"] = d["bw"] + 1
return x.clone()
def foo2_meta(a):
return a.clone()
for backend in ["CPU", "CUDA"]:
lib.impl("zoo", foo_impl, backend)
lib.impl("zoo2", foo2_impl, backend)
lib.impl("zoo", foo_meta, "Meta")
lib.impl("zoo2", foo2_meta, "Meta")
if torch._C._dispatch_has_kernel_for_dispatch_key(
"_mylib::zoo", "Autograd"
):
self.skipTest(
"Double registration of Autograd kernel for test custom op"
)
def foo_bwd(ctx, grad):
torch.ops._mylib.zoo2(grad)
return grad.clone()
torch.library.register_autograd("_mylib::zoo", foo_bwd)
torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib)
from torch._higher_order_ops.effects import (
_EffectType,
@ -468,6 +474,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
)
_register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED)
_register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED)
def fn(x, y):
return torch.ops._mylib.zoo(x) + y
@ -488,13 +495,13 @@ def forward(self, arg0_1, arg1_1, arg2_1):
):
reset_counter()
ref_out = fn(*ins_fn())
assert_counter(expected_fw_count)
assert_counter(expected_fw_count, 0)
compiled_fn = torch.compile(fn, backend="aot_eager")
out = compiled_fn(*ins_fn())
reset_counter()
out = compiled_fn(*ins_fn())
assert_counter(expected_fw_count)
assert_counter(expected_fw_count, 0)
self.assertEqual(ref_out, out)
@ -521,16 +528,17 @@ def forward(self, arg0_1, arg1_1, arg2_1):
(
expected_fw_count,
expected_fw_count_after_bw,
expected_bw_count_after_bw,
),
) in enumerate(
zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1), (2, 2)])
zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1, 1), (2, 2, 2)])
):
ref_ins = ins_fn_req_grad()
reset_counter()
ref_out = fn(*ref_ins)
assert_counter(expected_fw_count)
assert_counter(expected_fw_count, 0)
ref_out.sum().backward()
assert_counter(expected_fw_count_after_bw)
assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)
compiled_fn = torch.compile(fn, fullgraph=True)
@ -538,10 +546,10 @@ def forward(self, arg0_1, arg1_1, arg2_1):
out = compiled_fn(*ins)
reset_counter()
out = compiled_fn(*ins)
assert_counter(expected_fw_count)
assert_counter(expected_fw_count, 0)
self.assertEqual(ref_out, out)
out.sum().backward()
assert_counter(expected_fw_count_after_bw)
assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)
self.assertEqual(ref_ins[1].grad, ins[1].grad)
self.assertEqual(ref_ins[0].grad, ins[0].grad)
@ -563,10 +571,14 @@ def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5):
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, tangents_1, tangents_2):
def forward(self, tangents_1, tangents_2, tangents_token):
with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.zoo2.default, tangents_1); tangents_token = None
getitem_4 = with_effects_2[0]; with_effects_2 = None
with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.zoo2.default, tangents_2); getitem_4 = None
getitem_6 = with_effects_3[0]; with_effects_3 = None
clone = torch.ops.aten.clone.default(tangents_1)
clone_1 = torch.ops.aten.clone.default(tangents_2)
return (clone, clone_1, tangents_1, tangents_2)""",
return (clone, clone_1, tangents_1, tangents_2, getitem_6)""",
)
def test_effects_and_input_mutation_return(self):
@ -662,6 +674,234 @@ def forward(self, arg0_1, arg1_1):
return (getitem, mul, mul)""",
)
@skipIfTorchDynamo()
def test_effectful_op_in_backward(self):
with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
lib.define("foo(Tensor x) -> Tensor")
def foo_impl(a):
return a.clone()
def foo_bwd(ctx, grad):
return torch.ops._mylib.foo(grad)
for backend in ["CPU", "CUDA", "Meta"]:
lib.impl("foo", foo_impl, backend)
torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib)
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
try:
def fn(x, y):
return torch.ops._mylib.foo(x) + y
def ins_dense_req_grad():
return (
torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
)
def ins_sc_req_grad():
return (
TwoTensor(
torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
),
torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
)
for i, ins_fn in enumerate([ins_dense_req_grad, ins_sc_req_grad]):
ref_ins = ins_fn()
ref_out = fn(*ref_ins)
ref_out.sum().backward()
compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True)
ins = ins_fn()
out = compiled_fn(*ins)
self.assertEqual(ref_out, out)
out.sum().backward()
self.assertEqual(ref_ins[1].grad, ins[1].grad)
self.assertEqual(ref_ins[0].grad, ins[0].grad)
fw_graph, bw_graph = get_fw_bw_graph(fn, ins)
if i == 0:
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2, primals_3):
with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2); primals_1 = primals_2 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
add = torch.ops.aten.add.Tensor(getitem_1, primals_3); getitem_1 = primals_3 = None
return (getitem, add)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, tangents_1, tangents_token):
with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1); tangents_token = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
return (getitem_3, tangents_1, getitem_2)""",
)
elif i == 1:
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2, primals_3, primals_4):
with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2); primals_1 = primals_2 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.foo.default, primals_3); getitem = primals_3 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
add = torch.ops.aten.add.Tensor(getitem_1, primals_4); getitem_1 = None
add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_4); getitem_3 = primals_4 = None
return (getitem_2, add, add_1)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, tangents_1, tangents_2, tangents_token):
with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1); tangents_token = None
getitem_4 = with_effects_2[0]
getitem_5 = with_effects_2[1]; with_effects_2 = None
with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.foo.default, tangents_2); getitem_4 = None
getitem_6 = with_effects_3[0]
getitem_7 = with_effects_3[1]; with_effects_3 = None
return (getitem_5, getitem_7, tangents_1, tangents_2, getitem_6)""",
)
else:
raise NotImplementedError
finally:
_deregister_effectful_op(torch.ops._mylib.foo.default)
@skipIfNoDynamoSupport
def test_regular_effectful_op_only_in_backward(self):
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
try:
def fn(x):
return x.sin()
def inps_fn():
return (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)
torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn())
fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn())
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1):
sin = torch.ops.aten.sin.default(primals_1)
return (sin, primals_1)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, primals_1, tangents_1, tangents_token):
with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1); tangents_token = primals_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1); tangents_1 = getitem_1 = None
return (mul, getitem)""",
)
def inps_fn_sc():
return (
TwoTensor(
torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
),
)
torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn_sc())
fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn_sc())
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2):
sin = torch.ops.aten.sin.default(primals_1)
sin_1 = torch.ops.aten.sin.default(primals_2)
return (sin, sin_1, primals_1, primals_2)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token):
with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1); tangents_token = primals_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten.cos.default, primals_2); getitem = primals_2 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1); tangents_1 = getitem_1 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_2, getitem_3); tangents_2 = getitem_3 = None
return (mul, mul_1, getitem_2)""",
)
finally:
_deregister_effectful_op(torch.ops.aten.cos.default)
@skipIfNoDynamoSupport
def test_regular_effectful_op_in_forward_and_backward(self):
from torch._higher_order_ops.effects import (
_deregister_effectful_op,
_EffectType,
_register_effectful_op,
)
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
try:
def fn(x):
x = x.cos()
return x.sin()
inps = (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)
torch.compile(fn, backend="inductor", fullgraph=True)(*inps)
fw_graph, bw_graph = get_fw_bw_graph(fn, inps)
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2):
with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.aten.cos.default, primals_2); primals_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
sin = torch.ops.aten.sin.default(getitem_1)
return (getitem, sin, primals_2, getitem_1)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, getitem_1); tangents_token = getitem_1 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_3); tangents_1 = getitem_3 = None
sin_1 = torch.ops.aten.sin.default(primals_2); primals_2 = None
neg = torch.ops.aten.neg.default(sin_1); sin_1 = None
mul_1 = torch.ops.aten.mul.Tensor(mul, neg); mul = neg = None
return (mul_1, getitem_2)""",
)
finally:
_deregister_effectful_op(torch.ops.aten.cos.default)
if __name__ == "__main__":
run_tests()

View File

@ -46,7 +46,10 @@ def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule:
# FunctionalTensorMode must be enabled here.
# See Note [Accessing .grad_fn on FunctionalTensor]
with enable_python_dispatcher(), FunctionalTensorMode(
pre_dispatch=aot_config.pre_dispatch, export=aot_config.is_export
pre_dispatch=aot_config.pre_dispatch,
export=aot_config.is_export,
# Allow token discovery for joint fn tracing as tokens can be used in backward.
_allow_token_discovery=True,
):
fx_g = make_fx(
f,
@ -191,7 +194,7 @@ def aot_dispatch_base_graph(
# See Note [Side-Effectful Tokens in AOTAutograd]
num_tokens = len(fw_metadata.tokens)
if num_tokens != 0 and config.unlift_effect_tokens:
unlift_tokens(fw_module, fw_metadata)
unlift_tokens(fw_module, fw_metadata, aot_config)
saved_updated_flat_args_subclasses_desugared = (
saved_updated_flat_args_subclasses_desugared[num_tokens:]
)

View File

@ -384,10 +384,16 @@ def aot_dispatch_autograd(
)
# See Note [Side-Effectful Tokens in AOTAutograd]
if num_tokens != 0 and config.unlift_effect_tokens:
unlift_tokens(fw_module, fw_metadata)
if config.unlift_effect_tokens and (
num_tokens > 0 or fw_metadata.num_backward_tokens > 0
):
unlift_tokens(fw_module, fw_metadata, aot_config, bw_module)
num_inner_fwd_outputs -= num_tokens
joint_inputs = (joint_inputs[0][num_tokens:], joint_inputs[1])
joint_inputs = (
joint_inputs[0][num_tokens:],
joint_inputs[1],
)
fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0]
# we only need to bookkeep the symints that are saved for bw, not any symints
@ -484,16 +490,21 @@ def aot_dispatch_autograd(
# (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors,
# so we need to figure out which subclass fw inputs they map to.
if maybe_subclass_meta is None:
num_backward_tokens: int = inner_meta.num_backward_tokens
assert (
len(bw_outs)
== len(fw_metadata.input_info) + inner_meta.num_outputs_rng_offset
== len(fw_metadata.input_info)
+ inner_meta.num_outputs_rng_offset
+ num_backward_tokens
)
bw_outs_no_rng = bw_outs
if inner_meta.num_outputs_rng_offset > 0:
bw_outs_no_rng = bw_outs[: -inner_meta.num_outputs_rng_offset]
assert len(bw_outs_no_rng) == len(fw_metadata.input_info)
bw_outs_no_rng_no_tokens = bw_outs
if (inner_meta.num_outputs_rng_offset + num_backward_tokens) > 0:
bw_outs_no_rng_no_tokens = bw_outs[
: -(inner_meta.num_outputs_rng_offset + num_backward_tokens)
]
assert len(bw_outs_no_rng_no_tokens) == len(fw_metadata.input_info)
for i, (bw_out) in enumerate(bw_outs_no_rng):
for i, (bw_out) in enumerate(bw_outs_no_rng_no_tokens):
# If our input experiences a metadata mutation inside the graph (e.g. set_()),
# we *must* not detach, otherwise it will be the detach'd input that gets the metadata mutation
metadata_mutation_in_graph = (

View File

@ -659,7 +659,7 @@ class EffectTokensWrapper(CompilerWrapper):
@wraps(compiled_fn)
def inner_fn(args: List[Any]):
if num_tokens > 0:
# Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
# Pass in forward effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
old_args = args
args = [*([None] * num_tokens), *args]
old_args.clear()
@ -1730,20 +1730,23 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
# Add the seed and offset to args
rng_args = CUDARngStateHelper.get_torch_state_as_tuple()
bw_tokens = [None] * CompiledFunction.metadata.num_backward_tokens
# - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first
# in the bw output order.
# Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls
# There are tests that count these calls, saving to var.
ctx_saved_tensors = ctx.saved_tensors
num_ctx_saved_tensors = len(ctx_saved_tensors)
all_args = [
*ctx.symints,
*ctx.saved_tensors,
*ctx_saved_tensors,
*flat_bw_args_with_grads,
*bw_tokens,
*rng_args,
]
del flat_bw_args_with_grads
tangents_start_idx = (
len(all_args) - num_flat_bw_args_with_grads - len(rng_args)
)
tangents_end_idx = len(all_args) - len(rng_args)
del ctx_saved_tensors
# Note: [AOTAutograd Backward Guards]
# During AOTDispatch, we eagerly create and trace out a joint fw-bw graph.
@ -1771,9 +1774,8 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
len(CompiledFunction.metadata.output_types)
== num_flat_bw_args_with_grads
)
grad_output_types = [
type(x) for x in all_args[-num_flat_bw_args_with_grads:]
]
grad_output_types = [type(x) for x in flat_bw_args_with_grads]
# In general, we can add more asserts/guards here for when we partitioned
# with incorrect assumptions about the grad_outputs.
# Normalize FakeTensor -> torch.Tensor
@ -1791,6 +1793,17 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
Expected grad_output types: {str(CompiledFunction.metadata.output_types)}
Got grad_output types: {str(grad_output_types)}"""
del flat_bw_args_with_grads
tangents_start_idx = (
len(all_args)
- num_flat_bw_args_with_grads
- len(rng_args)
- len(bw_tokens)
)
assert tangents_start_idx == len(ctx.symints) + num_ctx_saved_tensors
tangents_end_idx = len(all_args) - len(rng_args) - len(bw_tokens)
# TODO: figure out how to refactor the backward properly
# so I can use aot_dispatch_subclass_wrapper() here.
if CompiledFunction.maybe_subclass_metadata is not None:
@ -1855,7 +1868,9 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
all_args = unwrap_tensor_subclasses(
all_args, is_joint_structure=False
)
tangents_start_idx = len(all_args) - len_tangents - len(rng_args)
tangents_start_idx = (
len(all_args) - len_tangents - len(rng_args) - len(bw_tokens)
)
tangents_end_idx = tangents_start_idx + len_tangents
# Make the tangents contiguous. Note that we must do this after subclass desugaring
@ -1968,6 +1983,12 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
steal_args=True,
disable_amp=disable_amp,
)
# Toss out the backward output tokens
num_bw_tokens = CompiledFunction.metadata.num_backward_tokens
if num_bw_tokens > 0:
out = out[:-num_bw_tokens]
# TODO: replace this with FunctionalizedRngRuntimeWrapper.post_compile
out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue(
CompiledFunction.metadata, out, offset_index=len(out) - 1

View File

@ -351,6 +351,10 @@ class ViewAndMutationMeta:
# and backward output.
bw_donated_idxs: Optional[List[int]] = None
# Number of tokens used in backward, appended at the end of backward outputs.
# Filled after tracing joint function.
num_backward_tokens: int = 0
def __post_init__(self):
# pre-compute the indices of the inputs that are mutated.
# When keep_input_mutations is set, we don't need to worry about our epilogue
@ -566,6 +570,7 @@ class ViewAndMutationMeta:
x.shape == y.shape and x.dtype == y.dtype
for x, y, in zip(self.traced_tangents, other.traced_tangents)
)
and self.num_backward_tokens == other.num_backward_tokens
)

View File

@ -235,7 +235,24 @@ def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
backward_out: Tuple[Tensor, ...] = ()
# Call the backwards pass
if grad_primals:
with fx_traceback.preserve_node_meta():
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
if functional_tensor_mode is not None:
# Side-Effect Tokens:
# We want to have independent chains of tokens for forward and backward.
# functional_tensor_mode._tokens is used by both.
# We memoize the result tokens of forward in functional_tensor_mode._tokens_forward_output,
# to return them as joint graph outputs.
# We clean functional_tensor_mode._tokens before backward, to prevent reuse of forward tokens in backward.
# Joint graph tracing allows tokens discovery,
# So all the tokens in backward will be created and added as a graph inputs during tracing.
functional_tensor_mode._tokens_forward_output = (
functional_tensor_mode._tokens
)
functional_tensor_mode._tokens = {}
with set_partitioner_tag_is_backward(), fx_traceback.preserve_node_meta():
# for full graph export, we always export a joint graph where we assume no tangents are needed.
if aot_config.no_tangents:
assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1
@ -348,6 +365,14 @@ def set_partitioner_tag(tag: str):
fx_traceback.current_meta[meta_key] = original_val
def set_partitioner_tag_is_backward():
return set_partitioner_tag("is_backward")
def set_partitioner_tag_must_be_in_backward():
return set_partitioner_tag("must_be_in_backward")
# This creates the final function that we want to trace using make_fx(),
# in both aot_dispatch_autograd and aot_dispatch_base.
# Preconditions:
@ -439,9 +464,7 @@ def create_functionalized_fn(
# Not banning here mutations on inpt_info.requires_grad -
# we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph)
# Add node meta for copy_ for partitioner that this node should be in backward graph.
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag(
"must_be_in_backward"
):
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward():
before.copy_(after)
meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
idx
@ -649,13 +672,13 @@ def handle_effect_tokens_fn(
if trace_joint:
assert isinstance(args, tuple) and isinstance(args[0], (list, tuple))
tokens = args[0][:num_tokens]
assert all(token.numel() == 0 for token in tokens)
args = (args[0][num_tokens:], *args[1:])
else:
tokens = args[:num_tokens]
assert all(token.numel() == 0 for token in tokens)
args = args[num_tokens:]
assert all(token.numel() == 0 for token in tokens)
# Populate the current FunctionalTensorMode with the tokens per
# operator. See Note [FunctionalTensorMode is Stateful]
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
@ -671,17 +694,30 @@ def handle_effect_tokens_fn(
# Return both the tokens and the outputs
# See Note [Side-Effectful Tokens in AOTAutograd]
f_out_tokens = functional_tensor_mode._tokens.values()
out_tokens = [from_fun(t) for t in f_out_tokens]
if trace_joint:
assert len(outs) == 2
assert len(functional_tensor_mode._tokens_forward_output) == num_tokens
fwd_out_tokens = functional_tensor_mode._tokens_forward_output.values()
bwd_out_tokens = functional_tensor_mode._tokens.values()
f_fwd_out_tokens = [from_fun(t) for t in fwd_out_tokens]
f_bwd_out_tokens = [from_fun(t) for t in bwd_out_tokens]
meta.num_backward_tokens = len(bwd_out_tokens)
return ((*f_fwd_out_tokens, *outs[0]), (*outs[1], *f_bwd_out_tokens))
out_tokens = [from_fun(t) for t in functional_tensor_mode._tokens.values()]
return (*out_tokens, *outs)
# Additionally pass in tokens as inputs
# See Note [Side-Effectful Tokens in AOTAutograd]
additional_token_inputs = [torch.tensor([])] * len(meta.tokens)
additional_fwd_token_inputs = [torch.tensor([])] * num_tokens
if trace_joint:
args = ([*additional_token_inputs, *args[0]], *args[1:])
args = ([*additional_fwd_token_inputs, *args[0]], *args[1:])
else:
args = [*additional_token_inputs, *args]
args = [*additional_fwd_token_inputs, *args]
return inner_fn, args

View File

@ -13,6 +13,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.utils._pytree as pytree
from torch._library.fake_class_registry import FakeScriptObject
from torch._logging import getArtifactLogger
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import py_sym_types
@ -32,6 +33,8 @@ KNOWN_TYPES = [
original_zip = zip
aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects")
def strict_zip(*iterables, strict=True, **kwargs):
if not strict:
@ -234,57 +237,154 @@ def maybe_to_fresh_input(idx, t, meta):
return t
def unlift_tokens(fw_module, fw_metadata):
def is_with_effects(node):
return (
node.op == "call_function"
and node.target == torch.ops.higher_order.with_effects
)
def is_with_effects_op(node, op):
return is_with_effects(node) and node.args[1] == op
def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None):
# Remove the tokens from the inputs/outputs of the graph since inductor does
# not want these extra inputs/outputs, and replace them with
# _make_token() to create a token, and _sink_tokens() to collect the
# tokens. See Note [Side-Effectful Tokens in AOTAutograd]
num_tokens = len(fw_metadata.tokens)
# Logic:
# 1. Inputs identified as input tokens:
# - If used as a first argument in with_effects
#
# 2. Outputs identified as output tokens:
# - If Produced by getitem(with_effects, 0)
#
# 3. Checks invariants of number input output tokens:
# forward:
# expected_num_erased_inputs == len(fw_metadata.tokens)
# expected_num_erased_outputs == len(fw_metadata.tokens)
# backward:
# expected_num_erased_inputs == fw_metadata.num_backward_tokens
# expected_num_erased_outputs == fw_metadata.num_backward_tokens
num_forward_tokens = len(fw_metadata.tokens)
num_backward_tokens = fw_metadata.num_backward_tokens
input_token_nodes = []
for i, node in enumerate(fw_module.graph.nodes):
if i < num_tokens:
assert node.op == "placeholder"
input_token_nodes.append(node)
def rewrite_with_effects_input_token(module, node):
with module.graph.inserting_before(node):
new_token_node = module.graph.call_function(
torch.ops.prims._make_token.default, ()
)
new_token_node.meta["val"] = torch.tensor([])
new_token_node.meta["tensor_meta"] = torch.tensor([])
elif node.op == "call_function" and node.target.__name__ == "with_effects":
if node.args[0] in input_token_nodes:
with fw_module.graph.inserting_before(node):
new_token_node = fw_module.graph.call_function(
torch.ops.prims._make_token.default, ()
)
new_token_node.meta["val"] = torch.tensor([])
new_token_node.meta["tensor_meta"] = torch.tensor([])
args = list(node.args)
args[0] = new_token_node
node.args = tuple(args)
args = list(node.args)
args[0] = new_token_node
node.args = tuple(args)
def rewrite_output(module, node, output_token_nodes, other_output_args):
for output_token_node in output_token_nodes:
assert (
output_token_node.op == "call_function"
and output_token_node.target == operator.getitem
and output_token_node.args[1] == 0
)
with module.graph.inserting_before(node):
module.graph.call_function(
torch.ops.prims._sink_tokens.default,
(output_token_nodes,),
)
node.args = (other_output_args,)
elif node.op == "output":
output_token_nodes = node.args[0][:num_tokens]
other_output_args = node.args[0][num_tokens:]
def do(module, subgraph, expected_num_erased):
num_erased_inputs = 0
num_erased_outs = 0
input_nodes = []
input_token_nodes = set()
with_effect_nodes = []
output_token_nodes = []
other_output_nodes = []
for i, node in enumerate(module.graph.nodes):
if node.op == "placeholder":
input_nodes.append(node)
elif is_with_effects(node):
with_effect_nodes.append(node)
if node.args[0] in input_nodes:
input_token_nodes.add(node.args[0])
rewrite_with_effects_input_token(module, node)
elif node.op == "output":
outs = node.args[0]
for out in outs:
if (
isinstance(out, torch.fx.node.Node)
and out.op == "call_function"
and out.target == operator.getitem
and out.args[1] == 0
and out.args[0] in with_effect_nodes
):
output_token_nodes.append(out)
else:
other_output_nodes.append(out)
for output_token_node in output_token_nodes:
assert (
output_token_node.op == "call_function"
and output_token_node.target == operator.getitem
and output_token_node.args[1] == 0
)
with fw_module.graph.inserting_before(node):
sink_token_node = fw_module.graph.call_function(
torch.ops.prims._sink_tokens.default,
(output_token_nodes,),
)
node.args = (other_output_args,)
rewrite_output(module, node, output_token_nodes, other_output_nodes)
num_erased_outs = len(output_token_nodes)
for input_token_node in input_token_nodes:
fw_module.graph.erase_node(input_token_node)
for input_token_node in input_token_nodes:
module.graph.erase_node(input_token_node)
fw_module.recompile()
num_erased_inputs = len(input_token_nodes)
assert (
num_erased_inputs == expected_num_erased
), f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}"
assert (
num_erased_outs == expected_num_erased
), f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}"
module.recompile()
if num_forward_tokens > 0:
if aot_config.enable_log:
from torch._dynamo.utils import lazy_format_graph_code
aot_graphs_effects_log.debug(
"%s",
lazy_format_graph_code(
"Forward graph before unlifting tokens",
fw_module,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
do(
fw_module,
"forward",
num_forward_tokens,
)
if bw_module is not None and num_backward_tokens > 0:
if aot_config.enable_log:
from torch._dynamo.utils import lazy_format_graph_code
aot_graphs_effects_log.debug(
"%s",
lazy_format_graph_code(
"Backward graph before unlifting tokens",
bw_module,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
do(bw_module, "backward", num_backward_tokens)
# This is sad, but we need to update the metadata to get rid of
# the tokens.
fw_metadata.tokens = {}
fw_metadata.num_backward_tokens = 0
def root_module_when_exporting_non_strict(flat_fn):

View File

@ -29,6 +29,7 @@ from torch.utils.checkpoint import CheckpointPolicy
from . import config
from ._aot_autograd.logging_utils import get_aot_graph_name
from ._aot_autograd.utils import is_with_effects
from .compile_utils import fx_graph_cse, get_aten_target
@ -249,12 +250,18 @@ def _is_backward_state(node: fx.Node) -> bool:
return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState)
def _has_tag_is_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_backward"
def _has_tag_must_be_in_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "must_be_in_backward"
def _must_be_in_backward(node: fx.Node) -> bool:
return _has_tag_must_be_in_backward(node)
return _has_tag_must_be_in_backward(node) or (
_has_tag_is_backward(node) and is_with_effects(node)
)
def _extract_fwd_bwd_outputs(

View File

@ -23,6 +23,7 @@ class _EffectType(Enum):
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
# TODO(ivankobzarev): Make SIDE_EFFECTS dictionary WeakKeyDictionary as operator can go out of scope
SIDE_EFFECTS: Dict[OpType, _EffectType] = {
torch.ops.aten._print.default: _EffectType.ORDERED,
call_torchbind: _EffectType.ORDERED,
@ -41,6 +42,13 @@ def _register_effectful_op(op: OpType, effect: _EffectType):
SIDE_EFFECTS[op] = effect
def _deregister_effectful_op(op: OpType):
if op not in SIDE_EFFECTS:
raise RuntimeError(f"Op {op} is not registered as effectful")
del SIDE_EFFECTS[op]
class WithEffects(HigherOrderOperator):
"""
with_effects(token, op, args, kwargs) -> (new_token, op_results)
@ -221,7 +229,31 @@ def handle_effects(
assert (
allow_token_discovery
), f"Could not find a token for effect {key} which came from the function {op}"
tokens[key] = new_token_tensor()
proxy_tensor_mode = torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
)
if proxy_tensor_mode is not None:
# If we discovered a new token during tracing, we are in backward.
# Then we patch the graph, adding additional tangents_token as input to the joint graph.
tracer = proxy_tensor_mode.tracer
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
track_tensor_tree,
)
with disable_proxy_modes_tracing():
token_tensor = new_token_tensor()
token_proxy = proxy_tensor_mode.tracer.create_proxy(
"placeholder", "tangents_token", (), {}, name="tangents_token"
)
track_tensor_tree(token_tensor, token_proxy, constant=None, tracer=tracer)
tokens[key] = token_tensor
else:
tokens[key] = new_token_tensor()
token = tokens[key]
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI

View File

@ -81,6 +81,11 @@ register_artifact(
"aot_joint_graph",
"Print FX joint graph from AOTAutograd, prior to partitioning. Useful for debugging partitioning",
)
register_artifact(
"aot_graphs_effects",
"Prints the FX forward and backward graph generated by AOTDispatch, useful for debugging effects processing.",
visible=True,
)
register_artifact(
"post_grad_graphs",
"Prints the FX graph generated by post grad passes. Useful to understand what's being given to Inductor after post grad passes",

View File

@ -299,6 +299,9 @@ class FunctionalTensorMode(TorchDispatchMode):
# track of the ordering between side effectful operations.
self._tokens: Dict[Any, torch.Tensor] = {}
# Filled after forward tracing.
self._tokens_forward_output: Dict[Any, torch.Tensor] = {}
# Functionalization runs twice in AOTAutograd, once in
# `run_functionalized_fw_and_collect_metadata` to collect metadata to
# see which tensors need to be functionalized and discover how many