mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add and fix OpInfo tests for the default partitioner (#165372)
I noticed the default partitioner was breaking in some dynamic shape tests, so prior to turning off functionalization I want to tweak it to pass all of our OpInfo tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/165372 Approved by: https://github.com/ezyang ghstack dependencies: #165327
This commit is contained in:
committed by
PyTorch MergeBot
parent
d2e1dbc8f2
commit
bcfea48ab7
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.testing._utils import wrapper_set_seed
|
||||
from functorch.compile import compiled_function, min_cut_rematerialization_partition, nop
|
||||
from functorch.compile import compiled_function, min_cut_rematerialization_partition, default_partition, nop
|
||||
from .make_fx import randomize
|
||||
import re
|
||||
|
||||
@ -38,6 +38,7 @@ def aot_autograd_check(
|
||||
assert_equals_fn=torch.testing.assert_close,
|
||||
check_gradients=True,
|
||||
try_check_data_specialization=False,
|
||||
use_min_cut=True,
|
||||
skip_correctness_check=False):
|
||||
"""Compares func(*args, **kwargs) in eager-mode to under AOTAutograd.
|
||||
|
||||
@ -63,14 +64,24 @@ def aot_autograd_check(
|
||||
c_args, c_kwargs = pytree.tree_unflatten(reconstructed_flat_args, args_spec)
|
||||
return func(*c_args, **c_kwargs)
|
||||
|
||||
compiled_f = compiled_function(
|
||||
func_no_tensors,
|
||||
nop,
|
||||
nop,
|
||||
dynamic=dynamic,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
keep_inference_input_mutations=True
|
||||
)
|
||||
if use_min_cut:
|
||||
compiled_f = compiled_function(
|
||||
func_no_tensors,
|
||||
nop,
|
||||
nop,
|
||||
dynamic=dynamic,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
keep_inference_input_mutations=True
|
||||
)
|
||||
else:
|
||||
compiled_f = compiled_function(
|
||||
func_no_tensors,
|
||||
nop,
|
||||
nop,
|
||||
dynamic=dynamic,
|
||||
partition_fn=default_partition,
|
||||
keep_inference_input_mutations=True
|
||||
)
|
||||
|
||||
out = wrapper_set_seed(func_no_tensors, args)
|
||||
if check_gradients == "auto":
|
||||
|
Reference in New Issue
Block a user