mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add the option to disable functionalization in AOTDispatcher (#164577)
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version: (1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: https://github.com/pytorch/pytorch/pull/164939) (2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup Pull Request resolved: https://github.com/pytorch/pytorch/pull/164577 Approved by: https://github.com/ezyang ghstack dependencies: #165372
This commit is contained in:
committed by
PyTorch MergeBot
parent
f33c7e1a43
commit
ed74dc054d
@ -838,6 +838,55 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
|
||||
self, device
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
]
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
)
|
||||
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
gn,
|
||||
x,
|
||||
y,
|
||||
use_reentrant=False,
|
||||
context_fn=selective_checkpointing_context_fn,
|
||||
)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True, device=device)
|
||||
y = torch.randn(4, 4, requires_grad=True, device=device)
|
||||
|
||||
fw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freq=1,
|
||||
op=torch.ops.aten.sigmoid.default,
|
||||
)
|
||||
bw_compiler = functools.partial(
|
||||
count_ops,
|
||||
# Main check here is just that sigmoid is properly recomputed
|
||||
# (we will see a sigmoid() and sigmoid_backward() in the bw graph)
|
||||
freq=1,
|
||||
op=torch.ops.aten.sigmoid.default,
|
||||
)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
disable_functionalization=True,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device):
|
||||
|
@ -2605,6 +2605,170 @@ def forward(self, primals_1, primals_2):
|
||||
]
|
||||
self.verify_aot_autograd(f, inp_grad, test_mutation=True)
|
||||
|
||||
def test_fw_bw_mutation_no_functionalization1(self):
|
||||
class FwBwMutation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
# input mutation
|
||||
torch._foreach_mul_([b], [2])
|
||||
x = b + 1
|
||||
# intermediate mutation
|
||||
torch._foreach_mul_([x], [3])
|
||||
ctx.save_for_backward(x)
|
||||
return x * a
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(x,) = ctx.saved_tensors
|
||||
# bw mutation
|
||||
torch._foreach_mul_([x], [4])
|
||||
return grad_output * x, grad_output * x
|
||||
|
||||
def f(a, b):
|
||||
return FwBwMutation.apply(a, b)
|
||||
|
||||
inps = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
torch.ones(3, 3, requires_grad=False),
|
||||
]
|
||||
inps_ref = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
torch.ones(3, 3, requires_grad=False),
|
||||
]
|
||||
|
||||
fw_graph = [None]
|
||||
bw_graph = [None]
|
||||
|
||||
def fw_compiler(gm, example_inputs):
|
||||
fw_graph[0] = gm
|
||||
return gm
|
||||
|
||||
def bw_compiler(gm, example_inputs):
|
||||
bw_graph[0] = gm
|
||||
return gm
|
||||
|
||||
compiled_f = compiled_function(
|
||||
f,
|
||||
fw_compiler,
|
||||
bw_compiler,
|
||||
dynamic=False,
|
||||
partition_fn=default_partition,
|
||||
keep_inference_input_mutations=True,
|
||||
disable_functionalization=True,
|
||||
)
|
||||
|
||||
out_ref = f(*inps_ref)
|
||||
out = compiled_f(*inps)
|
||||
self.assertEqual(out, out_ref)
|
||||
|
||||
out_ref.sum().backward()
|
||||
out.sum().backward()
|
||||
self.assertEqual(inps_ref[0].grad, inps[0].grad)
|
||||
|
||||
# important bit: there are 2 mutations in the fw
|
||||
self.assertExpectedInline(
|
||||
fw_graph[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, primals_1, primals_2):
|
||||
_foreach_mul_ = torch.ops.aten._foreach_mul_.ScalarList([primals_2], [2]); _foreach_mul_ = None
|
||||
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
|
||||
return (mul, add)""",
|
||||
)
|
||||
|
||||
# important bit: there is 1 mutation in the bw
|
||||
self.assertExpectedInline(
|
||||
bw_graph[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, add, tangents_1):
|
||||
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
|
||||
return (mul_1, None)""",
|
||||
)
|
||||
|
||||
def test_fw_bw_mutation_no_functionalization2(self):
|
||||
class FwBwMutation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
# input mutation
|
||||
torch._foreach_mul_([x], [2])
|
||||
x = x + 1
|
||||
# intermediate mutation
|
||||
torch._foreach_mul_([x], [3])
|
||||
ctx.save_for_backward(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(x,) = ctx.saved_tensors
|
||||
# bw mutation
|
||||
torch._foreach_mul_([x], [4])
|
||||
return grad_output * x
|
||||
|
||||
def f(a, b):
|
||||
out = FwBwMutation.apply(b)
|
||||
return out * a
|
||||
|
||||
inps = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
torch.ones(3, 3, requires_grad=False),
|
||||
]
|
||||
inps_ref = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
torch.ones(3, 3, requires_grad=False),
|
||||
]
|
||||
|
||||
fw_graph = [None]
|
||||
bw_graph = [None]
|
||||
|
||||
def fw_compiler(gm, example_inputs):
|
||||
fw_graph[0] = gm
|
||||
return gm
|
||||
|
||||
def bw_compiler(gm, example_inputs):
|
||||
bw_graph[0] = gm
|
||||
return gm
|
||||
|
||||
compiled_f = compiled_function(
|
||||
f,
|
||||
fw_compiler,
|
||||
bw_compiler,
|
||||
dynamic=False,
|
||||
partition_fn=default_partition,
|
||||
keep_inference_input_mutations=True,
|
||||
disable_functionalization=True,
|
||||
)
|
||||
|
||||
out_ref = f(*inps_ref)
|
||||
out = compiled_f(*inps)
|
||||
self.assertEqual(out, out_ref)
|
||||
|
||||
out_ref.sum().backward()
|
||||
out.sum().backward()
|
||||
self.assertEqual(inps_ref[0].grad, inps[0].grad)
|
||||
|
||||
# important bit: there are 2 mutations in the fw
|
||||
# (the mutation on an activation doesn't get moved to bw)
|
||||
self.assertExpectedInline(
|
||||
fw_graph[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, primals_1, primals_2):
|
||||
_foreach_mul_ = torch.ops.aten._foreach_mul_.ScalarList([primals_2], [2]); _foreach_mul_ = None
|
||||
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
|
||||
return (mul, add)""",
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
bw_graph[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, add, tangents_1):
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
|
||||
return (mul_1, None)""",
|
||||
)
|
||||
|
||||
def test_backward_mutation_metadata(self):
|
||||
class BwMutation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ -8046,7 +8210,14 @@ symbolic_aot_autograd_failures = {
|
||||
}
|
||||
|
||||
|
||||
def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False, use_min_cut=True):
|
||||
def _test_aot_autograd_helper(
|
||||
self,
|
||||
device,
|
||||
dtype,
|
||||
op,
|
||||
dynamic=False,
|
||||
disable_functionalization=False,
|
||||
):
|
||||
if not op.supports_autograd:
|
||||
self.skipTest("Op does not support autograd")
|
||||
|
||||
@ -8077,7 +8248,7 @@ def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False, use_min_cu
|
||||
check_gradients=True,
|
||||
try_check_data_specialization=try_check_data_specialization,
|
||||
skip_correctness_check=op.skip_correctness_check_compile_vs_eager,
|
||||
use_min_cut=use_min_cut,
|
||||
disable_functionalization=disable_functionalization,
|
||||
)
|
||||
except DynamicOutputShapeException:
|
||||
self.skipTest("Dynamic output shape operation in trace")
|
||||
@ -8181,24 +8352,31 @@ class TestEagerFusionOpInfo(AOTTestCase):
|
||||
@ops(op_db + hop_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps(
|
||||
"TestEagerFusionOpInfo",
|
||||
"test_aot_autograd_default_partition_exhaustive",
|
||||
"test_aot_autograd_disable_functionalization_exhaustive",
|
||||
aot_autograd_failures,
|
||||
)
|
||||
def test_aot_autograd_default_partition_exhaustive(self, device, dtype, op):
|
||||
_test_aot_autograd_helper(self, device, dtype, op, use_min_cut=False)
|
||||
def test_aot_autograd_disable_functionalization_exhaustive(self, device, dtype, op):
|
||||
_test_aot_autograd_helper(
|
||||
self, device, dtype, op, disable_functionalization=True
|
||||
)
|
||||
|
||||
@ops(op_db + hop_db, allowed_dtypes=(torch.float,))
|
||||
@patch("functorch.compile.config.debug_assert", True)
|
||||
@skipOps(
|
||||
"TestEagerFusionOpInfo",
|
||||
"test_aot_autograd_symbolic_default_partition_exhaustive",
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive",
|
||||
aot_autograd_failures | symbolic_aot_autograd_failures,
|
||||
)
|
||||
def test_aot_autograd_symbolic_default_partition_exhaustive(
|
||||
def test_aot_autograd_disable_functionalization_symbolic_exhaustive(
|
||||
self, device, dtype, op
|
||||
):
|
||||
_test_aot_autograd_helper(
|
||||
self, device, dtype, op, dynamic=True, use_min_cut=False
|
||||
self,
|
||||
device,
|
||||
dtype,
|
||||
op,
|
||||
dynamic=True,
|
||||
disable_functionalization=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ This module dispatches the graphs to either the forward-only or joint compilatio
|
||||
pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -70,14 +71,19 @@ def _create_graph(
|
||||
out, out_descs = call_and_expect_output_descs(f, args)
|
||||
return out
|
||||
|
||||
with (
|
||||
enable_python_dispatcher(),
|
||||
FunctionalTensorMode(
|
||||
if aot_config.disable_functionalization:
|
||||
ctx = contextlib.nullcontext()
|
||||
else:
|
||||
ctx = FunctionalTensorMode( # type: ignore[assignment]
|
||||
pre_dispatch=aot_config.pre_dispatch,
|
||||
export=aot_config.is_export,
|
||||
# Allow token discovery for joint fn tracing as tokens can be used in backward.
|
||||
_allow_token_discovery=True,
|
||||
),
|
||||
)
|
||||
|
||||
with (
|
||||
enable_python_dispatcher(),
|
||||
ctx,
|
||||
):
|
||||
fx_g = make_fx(
|
||||
inner_f,
|
||||
@ -162,14 +168,22 @@ def aot_dispatch_base_graph(
|
||||
keep_data_input_mutations=aot_config.keep_inference_input_mutations,
|
||||
)
|
||||
|
||||
fn_to_trace, updated_flat_args, updated_flat_args_descs = create_functionalized_fn(
|
||||
fn_to_trace,
|
||||
flat_args,
|
||||
flat_args_descs,
|
||||
meta=fw_metadata,
|
||||
aot_config=aot_config,
|
||||
trace_joint=False,
|
||||
)
|
||||
if aot_config.disable_functionalization:
|
||||
updated_flat_args, updated_flat_args_descs = (
|
||||
flat_args,
|
||||
flat_args_descs,
|
||||
)
|
||||
else:
|
||||
fn_to_trace, updated_flat_args, updated_flat_args_descs = (
|
||||
create_functionalized_fn(
|
||||
fn_to_trace,
|
||||
flat_args,
|
||||
flat_args_descs,
|
||||
meta=fw_metadata,
|
||||
aot_config=aot_config,
|
||||
trace_joint=False,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: replace with AOTDispatchSubclassWrapper once we refactor
|
||||
# fn_input_mutations_to_outputs and create_functionalized_fn
|
||||
@ -188,17 +202,18 @@ def aot_dispatch_base_graph(
|
||||
fw_only=flat_fn,
|
||||
)
|
||||
|
||||
(
|
||||
fn_to_trace,
|
||||
updated_flat_args_subclasses_desugared,
|
||||
updated_flat_args_subclasses_desugared_descs,
|
||||
) = handle_effect_tokens_fn(
|
||||
fn_to_trace,
|
||||
updated_flat_args_subclasses_desugared,
|
||||
updated_flat_args_subclasses_desugared_descs,
|
||||
meta=fw_metadata,
|
||||
trace_joint=False,
|
||||
)
|
||||
if not aot_config.disable_functionalization:
|
||||
(
|
||||
fn_to_trace,
|
||||
updated_flat_args_subclasses_desugared,
|
||||
updated_flat_args_subclasses_desugared_descs,
|
||||
) = handle_effect_tokens_fn(
|
||||
fn_to_trace,
|
||||
updated_flat_args_subclasses_desugared,
|
||||
updated_flat_args_subclasses_desugared_descs,
|
||||
meta=fw_metadata,
|
||||
trace_joint=False,
|
||||
)
|
||||
|
||||
aot_graphs_log.debug(
|
||||
"aot_config id: %s, fw_metadata=%s,subclass_metadata=%s",
|
||||
@ -265,12 +280,15 @@ def aot_dispatch_base_graph(
|
||||
|
||||
# As long as we opted to remove input mutations, then
|
||||
# there should be *NO* mutating ops in the graph at this point.
|
||||
copy_count = assert_functional_graph(fw_module.graph)
|
||||
fw_module.graph.eliminate_dead_code()
|
||||
fw_module.recompile()
|
||||
|
||||
copy_count2 = assert_functional_graph(fw_module.graph)
|
||||
propagate_input_mutation_stacktraces(fw_module.graph)
|
||||
if not aot_config.disable_functionalization:
|
||||
copy_count = assert_functional_graph(fw_module.graph)
|
||||
fw_module.graph.eliminate_dead_code()
|
||||
fw_module.recompile()
|
||||
copy_count2 = assert_functional_graph(fw_module.graph)
|
||||
propagate_input_mutation_stacktraces(fw_module.graph)
|
||||
assert copy_count == copy_count2
|
||||
else:
|
||||
fw_module.graph.eliminate_dead_code()
|
||||
|
||||
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
num_tokens = len(fw_metadata.tokens)
|
||||
@ -283,8 +301,6 @@ def aot_dispatch_base_graph(
|
||||
saved_updated_flat_args_subclasses_desugared_descs[num_tokens:]
|
||||
)
|
||||
|
||||
assert copy_count == copy_count2
|
||||
|
||||
if aot_config.enable_log:
|
||||
aot_graphs_log.info(
|
||||
"%s",
|
||||
@ -369,23 +385,30 @@ def aot_dispatch_autograd_graph(
|
||||
flat_fn,
|
||||
flat_args_descs,
|
||||
fw_metadata,
|
||||
aot_config,
|
||||
)
|
||||
joint_fn_to_trace = create_joint(
|
||||
fn_prepared_for_autograd, flat_args_descs, aot_config=aot_config
|
||||
)
|
||||
joint_fn_handle = joint_fn_to_trace.handle
|
||||
|
||||
joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs = (
|
||||
create_functionalized_fn(
|
||||
joint_fn_to_trace,
|
||||
if aot_config.disable_functionalization:
|
||||
updated_joint_inputs, updated_joint_inputs_descs = (
|
||||
joint_inputs,
|
||||
joint_inputs_descs,
|
||||
meta=fw_metadata,
|
||||
aot_config=aot_config,
|
||||
trace_joint=True,
|
||||
joint_fn_handle=joint_fn_handle,
|
||||
)
|
||||
)
|
||||
else:
|
||||
joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs = (
|
||||
create_functionalized_fn(
|
||||
joint_fn_to_trace,
|
||||
joint_inputs,
|
||||
joint_inputs_descs,
|
||||
meta=fw_metadata,
|
||||
aot_config=aot_config,
|
||||
trace_joint=True,
|
||||
joint_fn_handle=joint_fn_handle,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: replace with AOTDispatchSubclassWrapper once we refactor
|
||||
# fn_input_mutations_to_outputs and create_functionalized_fn
|
||||
@ -403,15 +426,16 @@ def aot_dispatch_autograd_graph(
|
||||
updated_joint_inputs = subclass_tracing_info.plain_tensor_args
|
||||
updated_joint_inputs_descs = subclass_tracing_info.plain_tensor_args_descs
|
||||
|
||||
(joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs) = (
|
||||
handle_effect_tokens_fn(
|
||||
joint_fn_to_trace,
|
||||
updated_joint_inputs,
|
||||
updated_joint_inputs_descs,
|
||||
meta=fw_metadata,
|
||||
trace_joint=True,
|
||||
if not aot_config.disable_functionalization:
|
||||
(joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs) = (
|
||||
handle_effect_tokens_fn(
|
||||
joint_fn_to_trace,
|
||||
updated_joint_inputs,
|
||||
updated_joint_inputs_descs,
|
||||
meta=fw_metadata,
|
||||
trace_joint=True,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# When we call _create_graph, this may mutate the metadata of joint
|
||||
# inputs. But callers are expecting to get the original joint inputs. So
|
||||
@ -440,14 +464,15 @@ def aot_dispatch_autograd_graph(
|
||||
aot_config=aot_config,
|
||||
)
|
||||
|
||||
# There should be *NO* mutating ops in the graph at this point.
|
||||
assert_functional_graph(fx_g.graph)
|
||||
|
||||
# Redundant with the check above, but worth having in case tracing introduced
|
||||
# a fake tensor. Unlikely.
|
||||
# See Note: [Fake Modules and AOTAutograd]
|
||||
torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
if not aot_config.disable_functionalization:
|
||||
# There should be *NO* mutating ops in the graph at this point.
|
||||
assert_functional_graph(fx_g.graph)
|
||||
|
||||
copy_fwd_metadata_to_bw_nodes(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
|
@ -15,7 +15,7 @@ import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, Optional, TypeVar, Union
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -160,6 +160,7 @@ def fn_prepped_for_autograd(
|
||||
fn: TraceFn,
|
||||
args_descs: list[AOTInput],
|
||||
meta: ViewAndMutationMeta,
|
||||
aot_config: AOTConfig,
|
||||
) -> PreppedForAutogradTraceFn:
|
||||
@simple_wraps(fn)
|
||||
def inner_fn(*args):
|
||||
@ -240,10 +241,11 @@ def fn_prepped_for_autograd(
|
||||
# This is annoying: our joint function needs to be aware of functionalization
|
||||
# (syncing mutated inputs before calling autograd.grad())
|
||||
# In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner.
|
||||
for arg in args_maybe_cloned:
|
||||
if not isinstance(arg, Tensor):
|
||||
continue
|
||||
sync_functional_tensor(arg)
|
||||
if not aot_config.disable_functionalization:
|
||||
for arg in args_maybe_cloned:
|
||||
if not isinstance(arg, Tensor):
|
||||
continue
|
||||
sync_functional_tensor(arg)
|
||||
|
||||
return (fw_outs_to_return, out_grad_mask), (
|
||||
fw_outs_to_return_descs,
|
||||
@ -430,9 +432,12 @@ def create_joint(
|
||||
with torch.autograd.detect_anomaly(check_nan=False):
|
||||
return inner_fn(primals, tangents)
|
||||
|
||||
inner_fn_with_anomaly.handle = joint_fn_handle # type: ignore[attr-defined]
|
||||
def joint_helper(primals, tangents):
|
||||
return inner_fn_with_anomaly(primals, tangents)
|
||||
|
||||
return cast(JointTraceFn, inner_fn_with_anomaly) # deal with 'handle' property
|
||||
joint_helper.handle = joint_fn_handle # type: ignore[attr-defined]
|
||||
|
||||
return joint_helper
|
||||
|
||||
|
||||
def create_functionalized_rng_ops_wrapper(
|
||||
|
@ -973,6 +973,7 @@ class AOTConfig:
|
||||
# This config makes sure to check certain things like
|
||||
# mutating input with req_grad in export joint tracing.
|
||||
export_trace_joint: bool = False
|
||||
disable_functionalization: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.pre_dispatch:
|
||||
|
@ -723,6 +723,7 @@ def aot_function(
|
||||
# Whether or not to trace with dynamic shapes
|
||||
dynamic=False,
|
||||
enable_log=True,
|
||||
disable_functionalization=False,
|
||||
) -> Callable:
|
||||
"""
|
||||
Traces the forward and backward graph of :attr:`fn` using torch dispatch
|
||||
@ -790,6 +791,7 @@ def aot_function(
|
||||
is_export=False,
|
||||
no_tangents=False,
|
||||
enable_log=enable_log,
|
||||
disable_functionalization=disable_functionalization,
|
||||
)
|
||||
cached_res = None
|
||||
|
||||
@ -902,6 +904,7 @@ def prepare_aot_module_simplified(
|
||||
flatten: bool,
|
||||
*,
|
||||
force_non_lazy_backward_lowering: bool = False,
|
||||
disable_functionalization: bool = False,
|
||||
):
|
||||
if not flatten:
|
||||
assert kwargs is None
|
||||
@ -992,6 +995,7 @@ def prepare_aot_module_simplified(
|
||||
ignore_shape_env=ignore_shape_env,
|
||||
precompile_backend_id=getattr(mod, "_backend_id", None),
|
||||
force_non_lazy_backward_lowering=force_non_lazy_backward_lowering,
|
||||
disable_functionalization=False,
|
||||
)
|
||||
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
|
||||
# NB: full_args_descs not needed here, fake_flat_args is 1:1 with full_args
|
||||
@ -1028,6 +1032,7 @@ def aot_module_simplified(
|
||||
cudagraphs: Optional[BoxedBool] = None,
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
|
||||
ignore_shape_env: bool = False,
|
||||
disable_functionalization: bool = False,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
This is the simplified or low overhead version of aot_module. For frontends
|
||||
@ -1066,6 +1071,7 @@ def aot_module_simplified(
|
||||
ignore_shape_env,
|
||||
flatten=False,
|
||||
force_non_lazy_backward_lowering=config.force_non_lazy_backward_lowering,
|
||||
disable_functionalization=disable_functionalization,
|
||||
)
|
||||
|
||||
compiled_fn = None
|
||||
@ -1168,6 +1174,7 @@ def aot_export_joint_with_descriptors(
|
||||
decompositions: Optional[dict] = None,
|
||||
keep_inference_input_mutations=False,
|
||||
ignore_shape_env=False,
|
||||
disable_functionalization=False,
|
||||
) -> JointWithDescriptors:
|
||||
"""
|
||||
This API captures the joint graph for an nn.Module. However, unlike
|
||||
@ -1257,6 +1264,7 @@ def aot_export_joint_with_descriptors(
|
||||
# Metric(s) {'is_forward'} have already been set in the current
|
||||
# context.
|
||||
force_non_lazy_backward_lowering=True,
|
||||
disable_functionalization=disable_functionalization,
|
||||
)
|
||||
|
||||
# TODO: Maybe this should be in create_aot_state? Not sure, that would
|
||||
|
@ -312,6 +312,9 @@ graphsafe_rng_functionalization = True
|
||||
# through compile_fx, we can remove this
|
||||
force_non_lazy_backward_lowering = False
|
||||
|
||||
# only for testing, used to turn functionalization off in AOTDispatcher
|
||||
_test_disable_functionalization = True
|
||||
|
||||
# Error on BypassAOTAutogradCache instead of just a warning
|
||||
# Used for tests
|
||||
strict_autograd_cache = False
|
||||
|
@ -203,11 +203,11 @@ def _extract_graph_with_inputs_outputs(
|
||||
env[node] = new_node
|
||||
|
||||
for node in joint_graph.nodes:
|
||||
if _must_be_in_backward(node) and subgraph != "backward":
|
||||
if _must_be_in_backward(node) and subgraph != "backward" and node not in inputs:
|
||||
env[node] = InvalidNode # type: ignore[assignment]
|
||||
continue
|
||||
|
||||
if _must_be_in_forward(node) and subgraph != "forward":
|
||||
if _must_be_in_forward(node) and subgraph != "forward" and node not in inputs:
|
||||
env[node] = InvalidNode # type: ignore[assignment]
|
||||
continue
|
||||
|
||||
@ -296,13 +296,27 @@ def _has_tag_must_be_in_backward(node: fx.Node) -> bool:
|
||||
|
||||
|
||||
def _must_be_in_forward(node: fx.Node) -> bool:
|
||||
return _has_tag_must_be_in_forward(node)
|
||||
if _has_tag_must_be_in_forward(node):
|
||||
return True
|
||||
is_mutable = is_with_effects(node) or (
|
||||
isinstance(node.target, torch._ops.OpOverload)
|
||||
and node.target._schema.is_mutable
|
||||
)
|
||||
return (
|
||||
not _has_tag_is_backward(node)
|
||||
and not _has_tag_must_be_in_backward(node)
|
||||
and is_mutable
|
||||
)
|
||||
|
||||
|
||||
def _must_be_in_backward(node: fx.Node) -> bool:
|
||||
return _has_tag_must_be_in_backward(node) or (
|
||||
_has_tag_is_backward(node) and is_with_effects(node)
|
||||
if _has_tag_must_be_in_backward(node):
|
||||
return True
|
||||
is_mutable = is_with_effects(node) or (
|
||||
isinstance(node.target, torch._ops.OpOverload)
|
||||
and node.target._schema.is_mutable
|
||||
)
|
||||
return _has_tag_is_backward(node) and is_mutable
|
||||
|
||||
|
||||
def _extract_fwd_bwd_outputs(
|
||||
@ -1015,11 +1029,50 @@ def default_partition(
|
||||
forward_node_names = OrderedSet(
|
||||
node.name for node in forward_only_graph.nodes if node.op != "output"
|
||||
)
|
||||
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
|
||||
saved_values = []
|
||||
saved_sym_nodes = []
|
||||
|
||||
def is_mutated_later_in_fw(node):
|
||||
if _has_tag_is_backward(node):
|
||||
return False
|
||||
tensor_arg_aliases = [
|
||||
x
|
||||
for x in node.args
|
||||
if isinstance(x, fx.Node)
|
||||
and "val" in x.meta
|
||||
and isinstance(x.meta["val"], torch.Tensor)
|
||||
]
|
||||
while len(tensor_arg_aliases) > 0:
|
||||
a = tensor_arg_aliases.pop()
|
||||
for u in a.users:
|
||||
if not isinstance(u.target, torch._ops.OpOverload):
|
||||
continue
|
||||
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
|
||||
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
|
||||
if (
|
||||
# one of the args was mutated
|
||||
u.target._schema.is_mutable
|
||||
# and the mutation happens "later"
|
||||
and order[u] > order[node]
|
||||
# and the mutation happened during the forward
|
||||
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
|
||||
):
|
||||
for idx, alias_info in enumerate(u.target._schema.arguments):
|
||||
if alias_info.is_write and u.args[idx] is a:
|
||||
return True
|
||||
elif u.target.is_view:
|
||||
tensor_arg_aliases.append(u)
|
||||
return False
|
||||
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.name not in forward_node_names:
|
||||
# if a node isn't "required" to be in the forward, but any of its arguments
|
||||
# are later mutated in the forward, then it must have been run in the forward
|
||||
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
|
||||
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
|
||||
if is_mutated_later_in_fw(node):
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if is_sym_node(node):
|
||||
# Symints must be kept separate from tensors so that PythonFunction only calls
|
||||
|
@ -38,8 +38,8 @@ def aot_autograd_check(
|
||||
assert_equals_fn=torch.testing.assert_close,
|
||||
check_gradients=True,
|
||||
try_check_data_specialization=False,
|
||||
use_min_cut=True,
|
||||
skip_correctness_check=False):
|
||||
skip_correctness_check=False,
|
||||
disable_functionalization=False):
|
||||
"""Compares func(*args, **kwargs) in eager-mode to under AOTAutograd.
|
||||
|
||||
Compares outputs and (if check_gradients=True) gradients produced by
|
||||
@ -64,14 +64,16 @@ def aot_autograd_check(
|
||||
c_args, c_kwargs = pytree.tree_unflatten(reconstructed_flat_args, args_spec)
|
||||
return func(*c_args, **c_kwargs)
|
||||
|
||||
if use_min_cut:
|
||||
# cannot use the min cut partitioner without functionalization
|
||||
if disable_functionalization:
|
||||
compiled_f = compiled_function(
|
||||
func_no_tensors,
|
||||
nop,
|
||||
nop,
|
||||
dynamic=dynamic,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
keep_inference_input_mutations=True
|
||||
partition_fn=default_partition,
|
||||
keep_inference_input_mutations=True,
|
||||
disable_functionalization=True
|
||||
)
|
||||
else:
|
||||
compiled_f = compiled_function(
|
||||
@ -79,8 +81,9 @@ def aot_autograd_check(
|
||||
nop,
|
||||
nop,
|
||||
dynamic=dynamic,
|
||||
partition_fn=default_partition,
|
||||
keep_inference_input_mutations=True
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
keep_inference_input_mutations=True,
|
||||
disable_functionalization=False
|
||||
)
|
||||
|
||||
out = wrapper_set_seed(func_no_tensors, args)
|
||||
|
@ -1187,7 +1187,7 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
def _is_compiling(func, args, kwargs):
|
||||
# Check if we are under AOTAutograd tracing
|
||||
# Checking that a functional mode is active should always do what we want
|
||||
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) is not None
|
||||
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) is not None
|
||||
|
||||
|
||||
class _VersionWrapper:
|
||||
|
Reference in New Issue
Block a user