add and fix OpInfo tests for the default partitioner (#165372)

I noticed the default partitioner was breaking in some dynamic shape tests, so prior to turning off functionalization I want to tweak it to pass all of our OpInfo tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165372
Approved by: https://github.com/ezyang
ghstack dependencies: #165327
This commit is contained in:
Brian Hirsh
2025-10-14 07:53:21 -07:00
committed by PyTorch MergeBot
parent d2e1dbc8f2
commit bcfea48ab7
3 changed files with 50 additions and 11 deletions

View File

@ -3,7 +3,7 @@
import torch
import torch.utils._pytree as pytree
from torch.testing._utils import wrapper_set_seed
from functorch.compile import compiled_function, min_cut_rematerialization_partition, nop
from functorch.compile import compiled_function, min_cut_rematerialization_partition, default_partition, nop
from .make_fx import randomize
import re
@ -38,6 +38,7 @@ def aot_autograd_check(
assert_equals_fn=torch.testing.assert_close,
check_gradients=True,
try_check_data_specialization=False,
use_min_cut=True,
skip_correctness_check=False):
"""Compares func(*args, **kwargs) in eager-mode to under AOTAutograd.
@ -63,14 +64,24 @@ def aot_autograd_check(
c_args, c_kwargs = pytree.tree_unflatten(reconstructed_flat_args, args_spec)
return func(*c_args, **c_kwargs)
compiled_f = compiled_function(
func_no_tensors,
nop,
nop,
dynamic=dynamic,
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True
)
if use_min_cut:
compiled_f = compiled_function(
func_no_tensors,
nop,
nop,
dynamic=dynamic,
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True
)
else:
compiled_f = compiled_function(
func_no_tensors,
nop,
nop,
dynamic=dynamic,
partition_fn=default_partition,
keep_inference_input_mutations=True
)
out = wrapper_set_seed(func_no_tensors, args)
if check_gradients == "auto":