[dynamo] Allow guards to be dropped with custom filter functions. (#150936)

Summary: A follow up of https://github.com/pytorch/pytorch/pull/150689.

Test Plan: test_dynamo -k test_guard_filter_fn

Differential Revision: D72722322

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150936
Approved by: https://github.com/jansel
This commit is contained in:
Zhengxu Chen
2025-04-11 03:06:34 +00:00
committed by PyTorch MergeBot
parent 4b0cf9fc00
commit 86370fd658
8 changed files with 202 additions and 60 deletions

View File

@ -108,6 +108,8 @@ T = typing.TypeVar("T")
# Defined in CPython's Include/object.h
TPFLAGS_MAPPING = 1 << 6
GLOBAL_INT = 1
# Specializes a test to run only if translation validation is set.
def onlyIfTranslationValidation(fn: typing.Callable) -> typing.Callable:
@ -11942,6 +11944,57 @@ fn
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
def test_guard_filter_fn_by_id(self):
def guard_filter_fn(entries):
return [entry.guard_type != "ID_MATCH" for entry in entries]
@torch.compile(fullgraph=True, options={"guard_filter_fn": guard_filter_fn})
def fn(x):
return id(x)
inputs = (torch.randn(3, 2),)
fn(*inputs)
inputs_1 = (torch.randn(3, 2),)
with torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(fn(*inputs_1), id(inputs[0]))
def test_guard_filter_fn_by_is_global(self):
def guard_filter_fn(entries):
return [not entry.is_global for entry in entries]
global GLOBAL_INT
@torch.compile(fullgraph=True, options={"guard_filter_fn": guard_filter_fn})
def fn(x):
return x + GLOBAL_INT
GLOBAL_INT = 1
fn(torch.randn(3, 2))
GLOBAL_INT = 2
inputs = (torch.randn(3, 2),)
with torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(fn(*inputs), inputs[0] + 1)
def test_guard_filter_fn_by_name_and_value(self):
def guard_filter_fn(entries):
return [
not (entry.name == "y" and entry.value is None) for entry in entries
]
@torch.compile(fullgraph=True, options={"guard_filter_fn": guard_filter_fn})
def fn(x, y):
if y is not None:
x += y
return x
fn(torch.randn(3, 2), None)
inputs = (torch.randn(3, 2), torch.tensor(1))
with torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(fn(*inputs), inputs[0])
class TestTracer(JitTestCase):
def test_jit_save(self):

View File

@ -2424,7 +2424,9 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
] = None,
disable: builtins.bool = False,
) -> _Callable[_InputT, _RetT]: ...
@ -2437,7 +2439,9 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
] = None,
disable: builtins.bool = False,
) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
@ -2449,7 +2453,9 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
] = None,
disable: builtins.bool = False,
) -> _Union[
_Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
@ -2585,6 +2591,10 @@ def compile(
if bisect_backend := CompilerBisector.get_backend():
backend = bisect_backend
guard_filter_fn = None
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
@ -2595,6 +2605,7 @@ def compile(
nopython=fullgraph,
dynamic=dynamic,
disable=disable,
guard_filter_fn=guard_filter_fn,
)(model) # type: ignore[return-value]

View File

@ -926,6 +926,7 @@ def _compile(
output,
cache_entry,
hooks.guard_fail_fn if hooks else None,
hooks.guard_filter_fn if hooks else None,
)
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"

View File

@ -969,6 +969,7 @@ def _optimize(
nopython=False,
guard_export_fn=None,
guard_fail_fn=None,
guard_filter_fn=None,
disable=False,
dynamic=None,
) -> Union[OptimizeContext, _NullDecorator]:
@ -1004,7 +1005,11 @@ def _optimize(
# There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
# compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an
# easier to understand UX at the cost of a little more plumbing on our end.
hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn)
hooks = Hooks(
guard_export_fn=guard_export_fn,
guard_fail_fn=guard_fail_fn,
guard_filter_fn=guard_filter_fn,
)
torch._C._log_api_usage_once("torch._dynamo.optimize")
if (
disable
@ -1866,7 +1871,7 @@ def export(
def optimize_assert(
backend,
*,
hooks=Hooks(None, None),
hooks=Hooks(None, None, None),
export=False,
export_constraints=None,
dynamic=None,

View File

@ -59,6 +59,7 @@ from torch._C._dynamo.guards import (
from torch._dynamo.source import (
IndexedSource,
is_from_flatten_script_object_source,
is_from_global_source,
is_from_local_source,
is_from_optimizer_source,
TensorProperty,
@ -129,6 +130,7 @@ from .types import ( # noqa: F401
ExtraState,
GuardedCode,
GuardFail,
GuardFilterEntry,
GuardFn,
)
from .utils import (
@ -2223,9 +2225,7 @@ class GuardBuilder(GuardBuilderBase):
# We consider TENSOR_MATCH guard to be important enough to be
# included in diff guard manager by default.
if not isinstance(value, torch.nn.Parameter):
self.check_fn_manager.guard_manager.diff_guard_sources.add(
guard.name
)
self.guard_manager.diff_guard_sources.add(guard.name)
# A frame is valid for reuse with dynamic dimensions if the new
# (user-requested) dynamic dimensions are a subset of the old
@ -2464,6 +2464,9 @@ class CheckFunctionManager:
output_graph=None,
cache_entry=None,
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
guard_filter_fn: Optional[
Callable[[list[GuardFilterEntry]], list[bool]]
] = None,
):
guards = output_graph.guards if output_graph else None
self._weakrefs: dict[int, ReferenceType[object]] = {}
@ -2471,10 +2474,7 @@ class CheckFunctionManager:
existing_diff_guard_sources = (
update_diff_guard_managers_for_existing_cache_entries(cache_entry)
)
self.guard_manager = GuardManagerWrapper()
self.guard_manager.diff_guard_sources = existing_diff_guard_sources
self.output_graph = output_graph
w_builder = None
# NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
# in case a set default device call was made in the graph.
@ -2482,58 +2482,52 @@ class CheckFunctionManager:
output_graph.torch_function_mode_stack if output_graph else None
)
def source_ref(source):
guard_source = source.guard_source()
if guard_source is GuardSource.CONSTANT:
# No need to track constants
return source.name()
assert w_builder
r_builder = w_builder()
assert r_builder is not None
return r_builder.arg_ref(source.name())
builder = GuardBuilder(
f_code,
self.id_ref,
source_ref,
self.lookup_weakrefs,
output_graph.local_scope,
output_graph.global_scope,
self.guard_manager,
self,
)
# Break retain cycle. See test_release_scope_memory
def cleanup_builder(weak_b):
b = weak_b()
if b:
b.scope = None
# Break retain cycle. See test_release_input_memory
w_builder = weakref.ref(builder, cleanup_builder)
guard_on_nn_modules = config.guard_nn_modules and justknobs_check(
"pytorch/compiler:guard_nn_modules"
)
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
log.warning("guard_nn_modules is turned off using justknobs killswitch")
for guard in sorted(guards or (), key=Guard.sort_key):
if (
not guard_on_nn_modules
and guard.is_specialized_nn_module()
# Default func args must be guarded on.
# TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API
and "__defaults__" not in guard.name
and "__kwdefaults__" not in guard.name
and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name)
):
continue
sorted_guards = sorted(guards or (), key=Guard.sort_key)
builder, guard_manager = self.build_guards(
sorted_guards, existing_diff_guard_sources, f_code, output_graph
)
guard.create(builder)
if guard_filter_fn:
self.compile_check_fn(builder, guards, guard_fail_fn)
def make_guard_filter_entry(guard):
name = strip_local_scope(guard.name)
if name == "":
value = None
else:
value = builder.get(guard.name)
is_global = is_from_global_source(guard.originating_source)
guard_fn = guard.create_fn
if isinstance(guard_fn, functools.partial):
guard_fn = guard.create_fn.func
return GuardFilterEntry(
name=name,
value=value,
guard_type=guard_fn.__name__,
derived_guard_types=tuple(guard.guard_types)
if guard.guard_types
else (),
is_global=is_global,
orig_guard=guard,
)
filter_results = guard_filter_fn(
[make_guard_filter_entry(guard) for guard in sorted_guards]
)
assert len(filter_results) == len(sorted_guards)
assert all(type(x) == bool for x in filter_results)
sorted_guards = [
guard for i, guard in enumerate(sorted_guards) if filter_results[i]
]
# Redo the guards because filtering relies on the results from the last guard builder.
builder, guard_manager = self.build_guards(
sorted_guards, existing_diff_guard_sources, f_code, output_graph
)
self.guard_manager = guard_manager
self.compile_check_fn(builder, sorted_guards, guard_fail_fn)
# Keep track of weak references of objects with ID_MATCH guard. This
# info is stored alongside optimized_code and guard_manager and is used to
@ -2602,6 +2596,67 @@ class CheckFunctionManager:
self._weakrefs.clear()
self.output_graph = None
def build_guards(
self,
sorted_guards,
existing_diff_guard_sources,
f_code,
output_graph,
):
guard_manager = GuardManagerWrapper()
guard_manager.diff_guard_sources = existing_diff_guard_sources
w_builder = None
def source_ref(source):
guard_source = source.guard_source()
if guard_source is GuardSource.CONSTANT:
# No need to track constants
return source.name()
assert w_builder
r_builder = w_builder()
assert r_builder is not None
return r_builder.arg_ref(source.name())
builder = GuardBuilder(
f_code,
self.id_ref,
source_ref,
self.lookup_weakrefs,
output_graph.local_scope,
output_graph.global_scope,
guard_manager,
self,
)
# Break retain cycle. See test_release_scope_memory
def cleanup_builder(weak_b):
b = weak_b()
if b:
b.scope = None
# Break retain cycle. See test_release_input_memory
w_builder = weakref.ref(builder, cleanup_builder)
guard_on_nn_modules = config.guard_nn_modules and justknobs_check(
"pytorch/compiler:guard_nn_modules"
)
for guard in sorted_guards:
if (
not guard_on_nn_modules
and guard.is_specialized_nn_module()
# Default func args must be guarded on.
# TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API
and "__defaults__" not in guard.name
and "__kwdefaults__" not in guard.name
and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name)
):
continue
guard.create(builder)
return builder, guard_manager
def compile_check_fn(self, builder, guards_out, guard_fail_fn):
# see parallel handling of ".0" / "___implicit0" in _eval_frame.c
largs = builder.argnames

View File

@ -15,10 +15,11 @@ from typing import Callable, Optional
from torch._guards import GuardsSet
from .types import GuardFail
from .types import GuardFail, GuardFilterEntry
@dataclasses.dataclass
class Hooks:
guard_export_fn: Optional[Callable[[GuardsSet], None]] = None
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None
guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None

View File

@ -845,6 +845,12 @@ def is_from_local_source(source: Source, *, only_allow_input=False):
return True
def is_from_global_source(source: Source):
if isinstance(source, ChainedSource):
return is_from_global_source(source.base)
return isinstance(source, GlobalSource)
def is_from_source(source: Source, target: Source):
if isinstance(source, ChainedSource):
return is_from_source(source.base, target)

View File

@ -23,7 +23,7 @@ from torch._C._dynamo.eval_frame import (
_FrameExecStrategy as FrameExecStrategy,
_PyInterpreterFrame as DynamoFrameType,
)
from torch._guards import CompileId
from torch._guards import CompileId, Guard
# We use a dict to store additional data per frame.
@ -37,6 +37,16 @@ class GuardFail(NamedTuple):
orig_code: types.CodeType
@dataclasses.dataclass(frozen=True)
class GuardFilterEntry:
name: str
value: object
guard_type: str
derived_guard_types: tuple[str, ...]
is_global: bool
orig_guard: Guard
class GuardFn(Protocol):
closure_vars: dict[str, object]
args: list[str]