mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Allow guards to be dropped with custom filter functions. (#150936)
Summary: A follow up of https://github.com/pytorch/pytorch/pull/150689. Test Plan: test_dynamo -k test_guard_filter_fn Differential Revision: D72722322 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150936 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
4b0cf9fc00
commit
86370fd658
@ -108,6 +108,8 @@ T = typing.TypeVar("T")
|
||||
# Defined in CPython's Include/object.h
|
||||
TPFLAGS_MAPPING = 1 << 6
|
||||
|
||||
GLOBAL_INT = 1
|
||||
|
||||
|
||||
# Specializes a test to run only if translation validation is set.
|
||||
def onlyIfTranslationValidation(fn: typing.Callable) -> typing.Callable:
|
||||
@ -11942,6 +11944,57 @@ fn
|
||||
x = torch.randn(4)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_guard_filter_fn_by_id(self):
|
||||
def guard_filter_fn(entries):
|
||||
return [entry.guard_type != "ID_MATCH" for entry in entries]
|
||||
|
||||
@torch.compile(fullgraph=True, options={"guard_filter_fn": guard_filter_fn})
|
||||
def fn(x):
|
||||
return id(x)
|
||||
|
||||
inputs = (torch.randn(3, 2),)
|
||||
fn(*inputs)
|
||||
|
||||
inputs_1 = (torch.randn(3, 2),)
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
self.assertEqual(fn(*inputs_1), id(inputs[0]))
|
||||
|
||||
def test_guard_filter_fn_by_is_global(self):
|
||||
def guard_filter_fn(entries):
|
||||
return [not entry.is_global for entry in entries]
|
||||
|
||||
global GLOBAL_INT
|
||||
|
||||
@torch.compile(fullgraph=True, options={"guard_filter_fn": guard_filter_fn})
|
||||
def fn(x):
|
||||
return x + GLOBAL_INT
|
||||
|
||||
GLOBAL_INT = 1
|
||||
fn(torch.randn(3, 2))
|
||||
|
||||
GLOBAL_INT = 2
|
||||
inputs = (torch.randn(3, 2),)
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
self.assertEqual(fn(*inputs), inputs[0] + 1)
|
||||
|
||||
def test_guard_filter_fn_by_name_and_value(self):
|
||||
def guard_filter_fn(entries):
|
||||
return [
|
||||
not (entry.name == "y" and entry.value is None) for entry in entries
|
||||
]
|
||||
|
||||
@torch.compile(fullgraph=True, options={"guard_filter_fn": guard_filter_fn})
|
||||
def fn(x, y):
|
||||
if y is not None:
|
||||
x += y
|
||||
return x
|
||||
|
||||
fn(torch.randn(3, 2), None)
|
||||
|
||||
inputs = (torch.randn(3, 2), torch.tensor(1))
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
self.assertEqual(fn(*inputs), inputs[0])
|
||||
|
||||
|
||||
class TestTracer(JitTestCase):
|
||||
def test_jit_save(self):
|
||||
|
@ -2424,7 +2424,9 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[
|
||||
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
|
||||
] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Callable[_InputT, _RetT]: ...
|
||||
|
||||
@ -2437,7 +2439,9 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[
|
||||
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
|
||||
] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
|
||||
|
||||
@ -2449,7 +2453,9 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[
|
||||
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
|
||||
] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Union[
|
||||
_Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
|
||||
@ -2585,6 +2591,10 @@ def compile(
|
||||
if bisect_backend := CompilerBisector.get_backend():
|
||||
backend = bisect_backend
|
||||
|
||||
guard_filter_fn = None
|
||||
if options and isinstance(options, dict):
|
||||
guard_filter_fn = options.pop("guard_filter_fn", None)
|
||||
|
||||
if backend == "inductor":
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
@ -2595,6 +2605,7 @@ def compile(
|
||||
nopython=fullgraph,
|
||||
dynamic=dynamic,
|
||||
disable=disable,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
)(model) # type: ignore[return-value]
|
||||
|
||||
|
||||
|
@ -926,6 +926,7 @@ def _compile(
|
||||
output,
|
||||
cache_entry,
|
||||
hooks.guard_fail_fn if hooks else None,
|
||||
hooks.guard_filter_fn if hooks else None,
|
||||
)
|
||||
|
||||
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
|
||||
|
@ -969,6 +969,7 @@ def _optimize(
|
||||
nopython=False,
|
||||
guard_export_fn=None,
|
||||
guard_fail_fn=None,
|
||||
guard_filter_fn=None,
|
||||
disable=False,
|
||||
dynamic=None,
|
||||
) -> Union[OptimizeContext, _NullDecorator]:
|
||||
@ -1004,7 +1005,11 @@ def _optimize(
|
||||
# There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
|
||||
# compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an
|
||||
# easier to understand UX at the cost of a little more plumbing on our end.
|
||||
hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn)
|
||||
hooks = Hooks(
|
||||
guard_export_fn=guard_export_fn,
|
||||
guard_fail_fn=guard_fail_fn,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
)
|
||||
torch._C._log_api_usage_once("torch._dynamo.optimize")
|
||||
if (
|
||||
disable
|
||||
@ -1866,7 +1871,7 @@ def export(
|
||||
def optimize_assert(
|
||||
backend,
|
||||
*,
|
||||
hooks=Hooks(None, None),
|
||||
hooks=Hooks(None, None, None),
|
||||
export=False,
|
||||
export_constraints=None,
|
||||
dynamic=None,
|
||||
|
@ -59,6 +59,7 @@ from torch._C._dynamo.guards import (
|
||||
from torch._dynamo.source import (
|
||||
IndexedSource,
|
||||
is_from_flatten_script_object_source,
|
||||
is_from_global_source,
|
||||
is_from_local_source,
|
||||
is_from_optimizer_source,
|
||||
TensorProperty,
|
||||
@ -129,6 +130,7 @@ from .types import ( # noqa: F401
|
||||
ExtraState,
|
||||
GuardedCode,
|
||||
GuardFail,
|
||||
GuardFilterEntry,
|
||||
GuardFn,
|
||||
)
|
||||
from .utils import (
|
||||
@ -2223,9 +2225,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
# We consider TENSOR_MATCH guard to be important enough to be
|
||||
# included in diff guard manager by default.
|
||||
if not isinstance(value, torch.nn.Parameter):
|
||||
self.check_fn_manager.guard_manager.diff_guard_sources.add(
|
||||
guard.name
|
||||
)
|
||||
self.guard_manager.diff_guard_sources.add(guard.name)
|
||||
|
||||
# A frame is valid for reuse with dynamic dimensions if the new
|
||||
# (user-requested) dynamic dimensions are a subset of the old
|
||||
@ -2464,6 +2464,9 @@ class CheckFunctionManager:
|
||||
output_graph=None,
|
||||
cache_entry=None,
|
||||
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
|
||||
guard_filter_fn: Optional[
|
||||
Callable[[list[GuardFilterEntry]], list[bool]]
|
||||
] = None,
|
||||
):
|
||||
guards = output_graph.guards if output_graph else None
|
||||
self._weakrefs: dict[int, ReferenceType[object]] = {}
|
||||
@ -2471,10 +2474,7 @@ class CheckFunctionManager:
|
||||
existing_diff_guard_sources = (
|
||||
update_diff_guard_managers_for_existing_cache_entries(cache_entry)
|
||||
)
|
||||
self.guard_manager = GuardManagerWrapper()
|
||||
self.guard_manager.diff_guard_sources = existing_diff_guard_sources
|
||||
self.output_graph = output_graph
|
||||
w_builder = None
|
||||
|
||||
# NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
|
||||
# in case a set default device call was made in the graph.
|
||||
@ -2482,58 +2482,52 @@ class CheckFunctionManager:
|
||||
output_graph.torch_function_mode_stack if output_graph else None
|
||||
)
|
||||
|
||||
def source_ref(source):
|
||||
guard_source = source.guard_source()
|
||||
if guard_source is GuardSource.CONSTANT:
|
||||
# No need to track constants
|
||||
return source.name()
|
||||
assert w_builder
|
||||
r_builder = w_builder()
|
||||
assert r_builder is not None
|
||||
return r_builder.arg_ref(source.name())
|
||||
|
||||
builder = GuardBuilder(
|
||||
f_code,
|
||||
self.id_ref,
|
||||
source_ref,
|
||||
self.lookup_weakrefs,
|
||||
output_graph.local_scope,
|
||||
output_graph.global_scope,
|
||||
self.guard_manager,
|
||||
self,
|
||||
)
|
||||
|
||||
# Break retain cycle. See test_release_scope_memory
|
||||
def cleanup_builder(weak_b):
|
||||
b = weak_b()
|
||||
if b:
|
||||
b.scope = None
|
||||
|
||||
# Break retain cycle. See test_release_input_memory
|
||||
w_builder = weakref.ref(builder, cleanup_builder)
|
||||
|
||||
guard_on_nn_modules = config.guard_nn_modules and justknobs_check(
|
||||
"pytorch/compiler:guard_nn_modules"
|
||||
)
|
||||
|
||||
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
|
||||
log.warning("guard_nn_modules is turned off using justknobs killswitch")
|
||||
|
||||
for guard in sorted(guards or (), key=Guard.sort_key):
|
||||
if (
|
||||
not guard_on_nn_modules
|
||||
and guard.is_specialized_nn_module()
|
||||
# Default func args must be guarded on.
|
||||
# TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API
|
||||
and "__defaults__" not in guard.name
|
||||
and "__kwdefaults__" not in guard.name
|
||||
and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name)
|
||||
):
|
||||
continue
|
||||
sorted_guards = sorted(guards or (), key=Guard.sort_key)
|
||||
builder, guard_manager = self.build_guards(
|
||||
sorted_guards, existing_diff_guard_sources, f_code, output_graph
|
||||
)
|
||||
|
||||
guard.create(builder)
|
||||
if guard_filter_fn:
|
||||
|
||||
self.compile_check_fn(builder, guards, guard_fail_fn)
|
||||
def make_guard_filter_entry(guard):
|
||||
name = strip_local_scope(guard.name)
|
||||
if name == "":
|
||||
value = None
|
||||
else:
|
||||
value = builder.get(guard.name)
|
||||
is_global = is_from_global_source(guard.originating_source)
|
||||
guard_fn = guard.create_fn
|
||||
if isinstance(guard_fn, functools.partial):
|
||||
guard_fn = guard.create_fn.func
|
||||
return GuardFilterEntry(
|
||||
name=name,
|
||||
value=value,
|
||||
guard_type=guard_fn.__name__,
|
||||
derived_guard_types=tuple(guard.guard_types)
|
||||
if guard.guard_types
|
||||
else (),
|
||||
is_global=is_global,
|
||||
orig_guard=guard,
|
||||
)
|
||||
|
||||
filter_results = guard_filter_fn(
|
||||
[make_guard_filter_entry(guard) for guard in sorted_guards]
|
||||
)
|
||||
assert len(filter_results) == len(sorted_guards)
|
||||
assert all(type(x) == bool for x in filter_results)
|
||||
sorted_guards = [
|
||||
guard for i, guard in enumerate(sorted_guards) if filter_results[i]
|
||||
]
|
||||
# Redo the guards because filtering relies on the results from the last guard builder.
|
||||
builder, guard_manager = self.build_guards(
|
||||
sorted_guards, existing_diff_guard_sources, f_code, output_graph
|
||||
)
|
||||
|
||||
self.guard_manager = guard_manager
|
||||
self.compile_check_fn(builder, sorted_guards, guard_fail_fn)
|
||||
|
||||
# Keep track of weak references of objects with ID_MATCH guard. This
|
||||
# info is stored alongside optimized_code and guard_manager and is used to
|
||||
@ -2602,6 +2596,67 @@ class CheckFunctionManager:
|
||||
self._weakrefs.clear()
|
||||
self.output_graph = None
|
||||
|
||||
def build_guards(
|
||||
self,
|
||||
sorted_guards,
|
||||
existing_diff_guard_sources,
|
||||
f_code,
|
||||
output_graph,
|
||||
):
|
||||
guard_manager = GuardManagerWrapper()
|
||||
guard_manager.diff_guard_sources = existing_diff_guard_sources
|
||||
|
||||
w_builder = None
|
||||
|
||||
def source_ref(source):
|
||||
guard_source = source.guard_source()
|
||||
if guard_source is GuardSource.CONSTANT:
|
||||
# No need to track constants
|
||||
return source.name()
|
||||
assert w_builder
|
||||
r_builder = w_builder()
|
||||
assert r_builder is not None
|
||||
return r_builder.arg_ref(source.name())
|
||||
|
||||
builder = GuardBuilder(
|
||||
f_code,
|
||||
self.id_ref,
|
||||
source_ref,
|
||||
self.lookup_weakrefs,
|
||||
output_graph.local_scope,
|
||||
output_graph.global_scope,
|
||||
guard_manager,
|
||||
self,
|
||||
)
|
||||
|
||||
# Break retain cycle. See test_release_scope_memory
|
||||
def cleanup_builder(weak_b):
|
||||
b = weak_b()
|
||||
if b:
|
||||
b.scope = None
|
||||
|
||||
# Break retain cycle. See test_release_input_memory
|
||||
w_builder = weakref.ref(builder, cleanup_builder)
|
||||
|
||||
guard_on_nn_modules = config.guard_nn_modules and justknobs_check(
|
||||
"pytorch/compiler:guard_nn_modules"
|
||||
)
|
||||
|
||||
for guard in sorted_guards:
|
||||
if (
|
||||
not guard_on_nn_modules
|
||||
and guard.is_specialized_nn_module()
|
||||
# Default func args must be guarded on.
|
||||
# TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API
|
||||
and "__defaults__" not in guard.name
|
||||
and "__kwdefaults__" not in guard.name
|
||||
and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name)
|
||||
):
|
||||
continue
|
||||
|
||||
guard.create(builder)
|
||||
return builder, guard_manager
|
||||
|
||||
def compile_check_fn(self, builder, guards_out, guard_fail_fn):
|
||||
# see parallel handling of ".0" / "___implicit0" in _eval_frame.c
|
||||
largs = builder.argnames
|
||||
|
@ -15,10 +15,11 @@ from typing import Callable, Optional
|
||||
|
||||
from torch._guards import GuardsSet
|
||||
|
||||
from .types import GuardFail
|
||||
from .types import GuardFail, GuardFilterEntry
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Hooks:
|
||||
guard_export_fn: Optional[Callable[[GuardsSet], None]] = None
|
||||
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None
|
||||
guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None
|
||||
|
@ -845,6 +845,12 @@ def is_from_local_source(source: Source, *, only_allow_input=False):
|
||||
return True
|
||||
|
||||
|
||||
def is_from_global_source(source: Source):
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_global_source(source.base)
|
||||
return isinstance(source, GlobalSource)
|
||||
|
||||
|
||||
def is_from_source(source: Source, target: Source):
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_source(source.base, target)
|
||||
|
@ -23,7 +23,7 @@ from torch._C._dynamo.eval_frame import (
|
||||
_FrameExecStrategy as FrameExecStrategy,
|
||||
_PyInterpreterFrame as DynamoFrameType,
|
||||
)
|
||||
from torch._guards import CompileId
|
||||
from torch._guards import CompileId, Guard
|
||||
|
||||
|
||||
# We use a dict to store additional data per frame.
|
||||
@ -37,6 +37,16 @@ class GuardFail(NamedTuple):
|
||||
orig_code: types.CodeType
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GuardFilterEntry:
|
||||
name: str
|
||||
value: object
|
||||
guard_type: str
|
||||
derived_guard_types: tuple[str, ...]
|
||||
is_global: bool
|
||||
orig_guard: Guard
|
||||
|
||||
|
||||
class GuardFn(Protocol):
|
||||
closure_vars: dict[str, object]
|
||||
args: list[str]
|
||||
|
Reference in New Issue
Block a user