add the option to disable functionalization in AOTDispatcher (#164577)

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: https://github.com/pytorch/pytorch/pull/164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164577
Approved by: https://github.com/ezyang
ghstack dependencies: #165372
This commit is contained in:
Brian Hirsh
2025-10-15 18:54:44 -07:00
committed by PyTorch MergeBot
parent f33c7e1a43
commit ed74dc054d
10 changed files with 403 additions and 78 deletions

View File

@ -838,6 +838,55 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
self, device
):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freq=1,
op=torch.ops.aten.sigmoid.default,
)
bw_compiler = functools.partial(
count_ops,
# Main check here is just that sigmoid is properly recomputed
# (we will see a sigmoid() and sigmoid_backward() in the bw graph)
freq=1,
op=torch.ops.aten.sigmoid.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
disable_functionalization=True,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_triton_kernel(self, device):

View File

@ -2605,6 +2605,170 @@ def forward(self, primals_1, primals_2):
]
self.verify_aot_autograd(f, inp_grad, test_mutation=True)
def test_fw_bw_mutation_no_functionalization1(self):
class FwBwMutation(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
# input mutation
torch._foreach_mul_([b], [2])
x = b + 1
# intermediate mutation
torch._foreach_mul_([x], [3])
ctx.save_for_backward(x)
return x * a
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
# bw mutation
torch._foreach_mul_([x], [4])
return grad_output * x, grad_output * x
def f(a, b):
return FwBwMutation.apply(a, b)
inps = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=False),
]
inps_ref = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=False),
]
fw_graph = [None]
bw_graph = [None]
def fw_compiler(gm, example_inputs):
fw_graph[0] = gm
return gm
def bw_compiler(gm, example_inputs):
bw_graph[0] = gm
return gm
compiled_f = compiled_function(
f,
fw_compiler,
bw_compiler,
dynamic=False,
partition_fn=default_partition,
keep_inference_input_mutations=True,
disable_functionalization=True,
)
out_ref = f(*inps_ref)
out = compiled_f(*inps)
self.assertEqual(out, out_ref)
out_ref.sum().backward()
out.sum().backward()
self.assertEqual(inps_ref[0].grad, inps[0].grad)
# important bit: there are 2 mutations in the fw
self.assertExpectedInline(
fw_graph[0].code.strip(),
"""\
def forward(self, primals_1, primals_2):
_foreach_mul_ = torch.ops.aten._foreach_mul_.ScalarList([primals_2], [2]); _foreach_mul_ = None
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
return (mul, add)""",
)
# important bit: there is 1 mutation in the bw
self.assertExpectedInline(
bw_graph[0].code.strip(),
"""\
def forward(self, add, tangents_1):
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
return (mul_1, None)""",
)
def test_fw_bw_mutation_no_functionalization2(self):
class FwBwMutation(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# input mutation
torch._foreach_mul_([x], [2])
x = x + 1
# intermediate mutation
torch._foreach_mul_([x], [3])
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
# bw mutation
torch._foreach_mul_([x], [4])
return grad_output * x
def f(a, b):
out = FwBwMutation.apply(b)
return out * a
inps = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=False),
]
inps_ref = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=False),
]
fw_graph = [None]
bw_graph = [None]
def fw_compiler(gm, example_inputs):
fw_graph[0] = gm
return gm
def bw_compiler(gm, example_inputs):
bw_graph[0] = gm
return gm
compiled_f = compiled_function(
f,
fw_compiler,
bw_compiler,
dynamic=False,
partition_fn=default_partition,
keep_inference_input_mutations=True,
disable_functionalization=True,
)
out_ref = f(*inps_ref)
out = compiled_f(*inps)
self.assertEqual(out, out_ref)
out_ref.sum().backward()
out.sum().backward()
self.assertEqual(inps_ref[0].grad, inps[0].grad)
# important bit: there are 2 mutations in the fw
# (the mutation on an activation doesn't get moved to bw)
self.assertExpectedInline(
fw_graph[0].code.strip(),
"""\
def forward(self, primals_1, primals_2):
_foreach_mul_ = torch.ops.aten._foreach_mul_.ScalarList([primals_2], [2]); _foreach_mul_ = None
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
return (mul, add)""",
)
self.assertExpectedInline(
bw_graph[0].code.strip(),
"""\
def forward(self, add, tangents_1):
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
return (mul_1, None)""",
)
def test_backward_mutation_metadata(self):
class BwMutation(torch.autograd.Function):
@staticmethod
@ -8046,7 +8210,14 @@ symbolic_aot_autograd_failures = {
}
def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False, use_min_cut=True):
def _test_aot_autograd_helper(
self,
device,
dtype,
op,
dynamic=False,
disable_functionalization=False,
):
if not op.supports_autograd:
self.skipTest("Op does not support autograd")
@ -8077,7 +8248,7 @@ def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False, use_min_cu
check_gradients=True,
try_check_data_specialization=try_check_data_specialization,
skip_correctness_check=op.skip_correctness_check_compile_vs_eager,
use_min_cut=use_min_cut,
disable_functionalization=disable_functionalization,
)
except DynamicOutputShapeException:
self.skipTest("Dynamic output shape operation in trace")
@ -8181,24 +8352,31 @@ class TestEagerFusionOpInfo(AOTTestCase):
@ops(op_db + hop_db, allowed_dtypes=(torch.float,))
@skipOps(
"TestEagerFusionOpInfo",
"test_aot_autograd_default_partition_exhaustive",
"test_aot_autograd_disable_functionalization_exhaustive",
aot_autograd_failures,
)
def test_aot_autograd_default_partition_exhaustive(self, device, dtype, op):
_test_aot_autograd_helper(self, device, dtype, op, use_min_cut=False)
def test_aot_autograd_disable_functionalization_exhaustive(self, device, dtype, op):
_test_aot_autograd_helper(
self, device, dtype, op, disable_functionalization=True
)
@ops(op_db + hop_db, allowed_dtypes=(torch.float,))
@patch("functorch.compile.config.debug_assert", True)
@skipOps(
"TestEagerFusionOpInfo",
"test_aot_autograd_symbolic_default_partition_exhaustive",
"test_aot_autograd_disable_functionalization_symbolic_exhaustive",
aot_autograd_failures | symbolic_aot_autograd_failures,
)
def test_aot_autograd_symbolic_default_partition_exhaustive(
def test_aot_autograd_disable_functionalization_symbolic_exhaustive(
self, device, dtype, op
):
_test_aot_autograd_helper(
self, device, dtype, op, dynamic=True, use_min_cut=False
self,
device,
dtype,
op,
dynamic=True,
disable_functionalization=True,
)

View File

@ -4,6 +4,7 @@ This module dispatches the graphs to either the forward-only or joint compilatio
pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata.
"""
import contextlib
import dataclasses
from typing import Any, Optional
@ -70,14 +71,19 @@ def _create_graph(
out, out_descs = call_and_expect_output_descs(f, args)
return out
with (
enable_python_dispatcher(),
FunctionalTensorMode(
if aot_config.disable_functionalization:
ctx = contextlib.nullcontext()
else:
ctx = FunctionalTensorMode( # type: ignore[assignment]
pre_dispatch=aot_config.pre_dispatch,
export=aot_config.is_export,
# Allow token discovery for joint fn tracing as tokens can be used in backward.
_allow_token_discovery=True,
),
)
with (
enable_python_dispatcher(),
ctx,
):
fx_g = make_fx(
inner_f,
@ -162,14 +168,22 @@ def aot_dispatch_base_graph(
keep_data_input_mutations=aot_config.keep_inference_input_mutations,
)
fn_to_trace, updated_flat_args, updated_flat_args_descs = create_functionalized_fn(
fn_to_trace,
flat_args,
flat_args_descs,
meta=fw_metadata,
aot_config=aot_config,
trace_joint=False,
)
if aot_config.disable_functionalization:
updated_flat_args, updated_flat_args_descs = (
flat_args,
flat_args_descs,
)
else:
fn_to_trace, updated_flat_args, updated_flat_args_descs = (
create_functionalized_fn(
fn_to_trace,
flat_args,
flat_args_descs,
meta=fw_metadata,
aot_config=aot_config,
trace_joint=False,
)
)
# TODO: replace with AOTDispatchSubclassWrapper once we refactor
# fn_input_mutations_to_outputs and create_functionalized_fn
@ -188,17 +202,18 @@ def aot_dispatch_base_graph(
fw_only=flat_fn,
)
(
fn_to_trace,
updated_flat_args_subclasses_desugared,
updated_flat_args_subclasses_desugared_descs,
) = handle_effect_tokens_fn(
fn_to_trace,
updated_flat_args_subclasses_desugared,
updated_flat_args_subclasses_desugared_descs,
meta=fw_metadata,
trace_joint=False,
)
if not aot_config.disable_functionalization:
(
fn_to_trace,
updated_flat_args_subclasses_desugared,
updated_flat_args_subclasses_desugared_descs,
) = handle_effect_tokens_fn(
fn_to_trace,
updated_flat_args_subclasses_desugared,
updated_flat_args_subclasses_desugared_descs,
meta=fw_metadata,
trace_joint=False,
)
aot_graphs_log.debug(
"aot_config id: %s, fw_metadata=%s,subclass_metadata=%s",
@ -265,12 +280,15 @@ def aot_dispatch_base_graph(
# As long as we opted to remove input mutations, then
# there should be *NO* mutating ops in the graph at this point.
copy_count = assert_functional_graph(fw_module.graph)
fw_module.graph.eliminate_dead_code()
fw_module.recompile()
copy_count2 = assert_functional_graph(fw_module.graph)
propagate_input_mutation_stacktraces(fw_module.graph)
if not aot_config.disable_functionalization:
copy_count = assert_functional_graph(fw_module.graph)
fw_module.graph.eliminate_dead_code()
fw_module.recompile()
copy_count2 = assert_functional_graph(fw_module.graph)
propagate_input_mutation_stacktraces(fw_module.graph)
assert copy_count == copy_count2
else:
fw_module.graph.eliminate_dead_code()
# See Note [Side-Effectful Tokens in AOTAutograd]
num_tokens = len(fw_metadata.tokens)
@ -283,8 +301,6 @@ def aot_dispatch_base_graph(
saved_updated_flat_args_subclasses_desugared_descs[num_tokens:]
)
assert copy_count == copy_count2
if aot_config.enable_log:
aot_graphs_log.info(
"%s",
@ -369,23 +385,30 @@ def aot_dispatch_autograd_graph(
flat_fn,
flat_args_descs,
fw_metadata,
aot_config,
)
joint_fn_to_trace = create_joint(
fn_prepared_for_autograd, flat_args_descs, aot_config=aot_config
)
joint_fn_handle = joint_fn_to_trace.handle
joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs = (
create_functionalized_fn(
joint_fn_to_trace,
if aot_config.disable_functionalization:
updated_joint_inputs, updated_joint_inputs_descs = (
joint_inputs,
joint_inputs_descs,
meta=fw_metadata,
aot_config=aot_config,
trace_joint=True,
joint_fn_handle=joint_fn_handle,
)
)
else:
joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs = (
create_functionalized_fn(
joint_fn_to_trace,
joint_inputs,
joint_inputs_descs,
meta=fw_metadata,
aot_config=aot_config,
trace_joint=True,
joint_fn_handle=joint_fn_handle,
)
)
# TODO: replace with AOTDispatchSubclassWrapper once we refactor
# fn_input_mutations_to_outputs and create_functionalized_fn
@ -403,15 +426,16 @@ def aot_dispatch_autograd_graph(
updated_joint_inputs = subclass_tracing_info.plain_tensor_args
updated_joint_inputs_descs = subclass_tracing_info.plain_tensor_args_descs
(joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs) = (
handle_effect_tokens_fn(
joint_fn_to_trace,
updated_joint_inputs,
updated_joint_inputs_descs,
meta=fw_metadata,
trace_joint=True,
if not aot_config.disable_functionalization:
(joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs) = (
handle_effect_tokens_fn(
joint_fn_to_trace,
updated_joint_inputs,
updated_joint_inputs_descs,
meta=fw_metadata,
trace_joint=True,
)
)
)
# When we call _create_graph, this may mutate the metadata of joint
# inputs. But callers are expecting to get the original joint inputs. So
@ -440,14 +464,15 @@ def aot_dispatch_autograd_graph(
aot_config=aot_config,
)
# There should be *NO* mutating ops in the graph at this point.
assert_functional_graph(fx_g.graph)
# Redundant with the check above, but worth having in case tracing introduced
# a fake tensor. Unlikely.
# See Note: [Fake Modules and AOTAutograd]
torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
fx_g.graph.eliminate_dead_code()
if not aot_config.disable_functionalization:
# There should be *NO* mutating ops in the graph at this point.
assert_functional_graph(fx_g.graph)
copy_fwd_metadata_to_bw_nodes(fx_g)
fx_g.recompile()

View File

@ -15,7 +15,7 @@ import warnings
from collections.abc import Callable
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass
from typing import Any, cast, Optional, TypeVar, Union
from typing import Any, Optional, TypeVar, Union
from unittest.mock import patch
import torch
@ -160,6 +160,7 @@ def fn_prepped_for_autograd(
fn: TraceFn,
args_descs: list[AOTInput],
meta: ViewAndMutationMeta,
aot_config: AOTConfig,
) -> PreppedForAutogradTraceFn:
@simple_wraps(fn)
def inner_fn(*args):
@ -240,10 +241,11 @@ def fn_prepped_for_autograd(
# This is annoying: our joint function needs to be aware of functionalization
# (syncing mutated inputs before calling autograd.grad())
# In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner.
for arg in args_maybe_cloned:
if not isinstance(arg, Tensor):
continue
sync_functional_tensor(arg)
if not aot_config.disable_functionalization:
for arg in args_maybe_cloned:
if not isinstance(arg, Tensor):
continue
sync_functional_tensor(arg)
return (fw_outs_to_return, out_grad_mask), (
fw_outs_to_return_descs,
@ -430,9 +432,12 @@ def create_joint(
with torch.autograd.detect_anomaly(check_nan=False):
return inner_fn(primals, tangents)
inner_fn_with_anomaly.handle = joint_fn_handle # type: ignore[attr-defined]
def joint_helper(primals, tangents):
return inner_fn_with_anomaly(primals, tangents)
return cast(JointTraceFn, inner_fn_with_anomaly) # deal with 'handle' property
joint_helper.handle = joint_fn_handle # type: ignore[attr-defined]
return joint_helper
def create_functionalized_rng_ops_wrapper(

View File

@ -973,6 +973,7 @@ class AOTConfig:
# This config makes sure to check certain things like
# mutating input with req_grad in export joint tracing.
export_trace_joint: bool = False
disable_functionalization: bool = False
def __post_init__(self):
if self.pre_dispatch:

View File

@ -723,6 +723,7 @@ def aot_function(
# Whether or not to trace with dynamic shapes
dynamic=False,
enable_log=True,
disable_functionalization=False,
) -> Callable:
"""
Traces the forward and backward graph of :attr:`fn` using torch dispatch
@ -790,6 +791,7 @@ def aot_function(
is_export=False,
no_tangents=False,
enable_log=enable_log,
disable_functionalization=disable_functionalization,
)
cached_res = None
@ -902,6 +904,7 @@ def prepare_aot_module_simplified(
flatten: bool,
*,
force_non_lazy_backward_lowering: bool = False,
disable_functionalization: bool = False,
):
if not flatten:
assert kwargs is None
@ -992,6 +995,7 @@ def prepare_aot_module_simplified(
ignore_shape_env=ignore_shape_env,
precompile_backend_id=getattr(mod, "_backend_id", None),
force_non_lazy_backward_lowering=force_non_lazy_backward_lowering,
disable_functionalization=False,
)
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
# NB: full_args_descs not needed here, fake_flat_args is 1:1 with full_args
@ -1028,6 +1032,7 @@ def aot_module_simplified(
cudagraphs: Optional[BoxedBool] = None,
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
ignore_shape_env: bool = False,
disable_functionalization: bool = False,
) -> nn.Module:
"""
This is the simplified or low overhead version of aot_module. For frontends
@ -1066,6 +1071,7 @@ def aot_module_simplified(
ignore_shape_env,
flatten=False,
force_non_lazy_backward_lowering=config.force_non_lazy_backward_lowering,
disable_functionalization=disable_functionalization,
)
compiled_fn = None
@ -1168,6 +1174,7 @@ def aot_export_joint_with_descriptors(
decompositions: Optional[dict] = None,
keep_inference_input_mutations=False,
ignore_shape_env=False,
disable_functionalization=False,
) -> JointWithDescriptors:
"""
This API captures the joint graph for an nn.Module. However, unlike
@ -1257,6 +1264,7 @@ def aot_export_joint_with_descriptors(
# Metric(s) {'is_forward'} have already been set in the current
# context.
force_non_lazy_backward_lowering=True,
disable_functionalization=disable_functionalization,
)
# TODO: Maybe this should be in create_aot_state? Not sure, that would

View File

@ -312,6 +312,9 @@ graphsafe_rng_functionalization = True
# through compile_fx, we can remove this
force_non_lazy_backward_lowering = False
# only for testing, used to turn functionalization off in AOTDispatcher
_test_disable_functionalization = True
# Error on BypassAOTAutogradCache instead of just a warning
# Used for tests
strict_autograd_cache = False

View File

@ -203,11 +203,11 @@ def _extract_graph_with_inputs_outputs(
env[node] = new_node
for node in joint_graph.nodes:
if _must_be_in_backward(node) and subgraph != "backward":
if _must_be_in_backward(node) and subgraph != "backward" and node not in inputs:
env[node] = InvalidNode # type: ignore[assignment]
continue
if _must_be_in_forward(node) and subgraph != "forward":
if _must_be_in_forward(node) and subgraph != "forward" and node not in inputs:
env[node] = InvalidNode # type: ignore[assignment]
continue
@ -296,13 +296,27 @@ def _has_tag_must_be_in_backward(node: fx.Node) -> bool:
def _must_be_in_forward(node: fx.Node) -> bool:
return _has_tag_must_be_in_forward(node)
if _has_tag_must_be_in_forward(node):
return True
is_mutable = is_with_effects(node) or (
isinstance(node.target, torch._ops.OpOverload)
and node.target._schema.is_mutable
)
return (
not _has_tag_is_backward(node)
and not _has_tag_must_be_in_backward(node)
and is_mutable
)
def _must_be_in_backward(node: fx.Node) -> bool:
return _has_tag_must_be_in_backward(node) or (
_has_tag_is_backward(node) and is_with_effects(node)
if _has_tag_must_be_in_backward(node):
return True
is_mutable = is_with_effects(node) or (
isinstance(node.target, torch._ops.OpOverload)
and node.target._schema.is_mutable
)
return _has_tag_is_backward(node) and is_mutable
def _extract_fwd_bwd_outputs(
@ -1015,11 +1029,50 @@ def default_partition(
forward_node_names = OrderedSet(
node.name for node in forward_only_graph.nodes if node.op != "output"
)
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
saved_values = []
saved_sym_nodes = []
def is_mutated_later_in_fw(node):
if _has_tag_is_backward(node):
return False
tensor_arg_aliases = [
x
for x in node.args
if isinstance(x, fx.Node)
and "val" in x.meta
and isinstance(x.meta["val"], torch.Tensor)
]
while len(tensor_arg_aliases) > 0:
a = tensor_arg_aliases.pop()
for u in a.users:
if not isinstance(u.target, torch._ops.OpOverload):
continue
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
if (
# one of the args was mutated
u.target._schema.is_mutable
# and the mutation happens "later"
and order[u] > order[node]
# and the mutation happened during the forward
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
):
for idx, alias_info in enumerate(u.target._schema.arguments):
if alias_info.is_write and u.args[idx] is a:
return True
elif u.target.is_view:
tensor_arg_aliases.append(u)
return False
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
# if a node isn't "required" to be in the forward, but any of its arguments
# are later mutated in the forward, then it must have been run in the forward
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
if is_mutated_later_in_fw(node):
saved_values.append(node)
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls

View File

@ -38,8 +38,8 @@ def aot_autograd_check(
assert_equals_fn=torch.testing.assert_close,
check_gradients=True,
try_check_data_specialization=False,
use_min_cut=True,
skip_correctness_check=False):
skip_correctness_check=False,
disable_functionalization=False):
"""Compares func(*args, **kwargs) in eager-mode to under AOTAutograd.
Compares outputs and (if check_gradients=True) gradients produced by
@ -64,14 +64,16 @@ def aot_autograd_check(
c_args, c_kwargs = pytree.tree_unflatten(reconstructed_flat_args, args_spec)
return func(*c_args, **c_kwargs)
if use_min_cut:
# cannot use the min cut partitioner without functionalization
if disable_functionalization:
compiled_f = compiled_function(
func_no_tensors,
nop,
nop,
dynamic=dynamic,
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True
partition_fn=default_partition,
keep_inference_input_mutations=True,
disable_functionalization=True
)
else:
compiled_f = compiled_function(
@ -79,8 +81,9 @@ def aot_autograd_check(
nop,
nop,
dynamic=dynamic,
partition_fn=default_partition,
keep_inference_input_mutations=True
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True,
disable_functionalization=False
)
out = wrapper_set_seed(func_no_tensors, args)

View File

@ -1187,7 +1187,7 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
def _is_compiling(func, args, kwargs):
# Check if we are under AOTAutograd tracing
# Checking that a functional mode is active should always do what we want
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) is not None
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) is not None
class _VersionWrapper: