Revert "add and fix OpInfo tests for the default partitioner (#165372)"

This reverts commit bcfea48ab7fd489218289693b98c1a6a6582d079.

Reverted https://github.com/pytorch/pytorch/pull/165372 on behalf of https://github.com/malfet due to Looks like it broke slow jobs, see 331b7cc054/1 ([comment](https://github.com/pytorch/pytorch/pull/165372#issuecomment-3407567748))
This commit is contained in:
PyTorch MergeBot
2025-10-15 17:38:52 +00:00
parent 331b7cc054
commit b509fb9b5d
3 changed files with 11 additions and 50 deletions

View File

@ -8044,7 +8044,7 @@ symbolic_aot_autograd_failures = {
}
def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False, use_min_cut=True):
def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False):
if not op.supports_autograd:
self.skipTest("Op does not support autograd")
@ -8075,7 +8075,6 @@ def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False, use_min_cu
check_gradients=True,
try_check_data_specialization=try_check_data_specialization,
skip_correctness_check=op.skip_correctness_check_compile_vs_eager,
use_min_cut=use_min_cut,
)
except DynamicOutputShapeException:
self.skipTest("Dynamic output shape operation in trace")
@ -8176,29 +8175,6 @@ class TestEagerFusionOpInfo(AOTTestCase):
def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
_test_aot_autograd_helper(self, device, dtype, op, dynamic=True)
@ops(op_db + hop_db, allowed_dtypes=(torch.float,))
@skipOps(
"TestEagerFusionOpInfo",
"test_aot_autograd_default_partition_exhaustive",
aot_autograd_failures,
)
def test_aot_autograd_default_partition_exhaustive(self, device, dtype, op):
_test_aot_autograd_helper(self, device, dtype, op, use_min_cut=False)
@ops(op_db + hop_db, allowed_dtypes=(torch.float,))
@patch("functorch.compile.config.debug_assert", True)
@skipOps(
"TestEagerFusionOpInfo",
"test_aot_autograd_symbolic_default_partition_exhaustive",
aot_autograd_failures | symbolic_aot_autograd_failures,
)
def test_aot_autograd_symbolic_default_partition_exhaustive(
self, device, dtype, op
):
_test_aot_autograd_helper(
self, device, dtype, op, dynamic=True, use_min_cut=False
)
aot_autograd_module_failures = set(
{

View File

@ -1025,11 +1025,7 @@ def default_partition(
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
elif (
"tensor_meta" not in node.meta
and node.op == "call_function"
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
):
elif "tensor_meta" not in node.meta and node.op == "call_function":
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target == operator.getitem for user in users)

View File

@ -3,7 +3,7 @@
import torch
import torch.utils._pytree as pytree
from torch.testing._utils import wrapper_set_seed
from functorch.compile import compiled_function, min_cut_rematerialization_partition, default_partition, nop
from functorch.compile import compiled_function, min_cut_rematerialization_partition, nop
from .make_fx import randomize
import re
@ -38,7 +38,6 @@ def aot_autograd_check(
assert_equals_fn=torch.testing.assert_close,
check_gradients=True,
try_check_data_specialization=False,
use_min_cut=True,
skip_correctness_check=False):
"""Compares func(*args, **kwargs) in eager-mode to under AOTAutograd.
@ -64,24 +63,14 @@ def aot_autograd_check(
c_args, c_kwargs = pytree.tree_unflatten(reconstructed_flat_args, args_spec)
return func(*c_args, **c_kwargs)
if use_min_cut:
compiled_f = compiled_function(
func_no_tensors,
nop,
nop,
dynamic=dynamic,
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True
)
else:
compiled_f = compiled_function(
func_no_tensors,
nop,
nop,
dynamic=dynamic,
partition_fn=default_partition,
keep_inference_input_mutations=True
)
compiled_f = compiled_function(
func_no_tensors,
nop,
nop,
dynamic=dynamic,
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True
)
out = wrapper_set_seed(func_no_tensors, args)
if check_gradients == "auto":