mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] Validate check_fn (#118448)
Fixes - https://github.com/pytorch/pytorch/issues/128090 Tracker issue here - https://github.com/pytorch/pytorch/issues/129937 Pull Request resolved: https://github.com/pytorch/pytorch/pull/118448 Approved by: https://github.com/jansel, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7192ee0735
commit
7ea8a3c9b8
@ -1363,6 +1363,9 @@ s1 > 3""",
|
||||
s = SubTensor(torch.randn(3, 10))
|
||||
f(s)
|
||||
|
||||
# Guard validation upsets the guard
|
||||
# https://github.com/pytorch/pytorch/issues/129936
|
||||
@unittest.expectedFailure
|
||||
def test_recompile_with_symbool_inputs(self):
|
||||
def f(pred: bool):
|
||||
if pred:
|
||||
|
@ -3601,6 +3601,8 @@ def forward(self, x):
|
||||
):
|
||||
torch.export.export(exported_v2.module(), (torch.randn(2, 2),))
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/129939
|
||||
@testing.expectedFailureNonStrict
|
||||
def test_export_cond(self):
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -4976,6 +4978,9 @@ graph():
|
||||
)
|
||||
)
|
||||
|
||||
# Guard validation upsets the guard
|
||||
# https://github.com/pytorch/pytorch/issues/129939
|
||||
@unittest.expectedFailure
|
||||
def test_cond_with_module_stack_export_with(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -2346,6 +2346,7 @@ known_failing_tests = {
|
||||
"test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
|
||||
"test_unpack_hooks_exec_count", # pack/unpack saved tensor hooks firing more than once
|
||||
"test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads
|
||||
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
|
||||
}
|
||||
|
||||
if not HAS_CUDA:
|
||||
|
@ -2103,6 +2103,7 @@ class CheckFunctionManager:
|
||||
guard.create(builder)
|
||||
|
||||
self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)
|
||||
|
||||
# Keep track of weak references of objects with ID_MATCH guard. This
|
||||
# info is stored alongside optimized_code and check_fn and is used to
|
||||
# limit the number of cache entries with same ID_MATCH'd object.
|
||||
@ -2123,6 +2124,18 @@ class CheckFunctionManager:
|
||||
self.guard_manager.id_matched_objs = builder.id_matched_objs
|
||||
self.check_fn = self.guard_manager
|
||||
|
||||
# Check that the guard returns True. False means that we will always
|
||||
# recompile.
|
||||
# TODO(anijain2305, ydwu4) - Skipping export because of following test
|
||||
# python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs
|
||||
if not output_graph.export:
|
||||
if not self.guard_manager.check(output_graph.local_scope):
|
||||
reasons = get_guard_fail_reason_helper(
|
||||
self.guard_manager, # type: ignore[arg-type]
|
||||
output_graph.local_scope,
|
||||
)
|
||||
raise AssertionError(f"Guard check failed: {reasons}")
|
||||
|
||||
# NB - We have to very careful of cleaning up here. Because of the
|
||||
# invalidate function, we can create a weakref finalizer that keeps
|
||||
# `self` alive for very long. Sometimes by mistake, we can run
|
||||
@ -2456,9 +2469,8 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
|
||||
return [f"Duplicate tensors found: {reason}"]
|
||||
|
||||
|
||||
def get_guard_fail_reason(
|
||||
def get_guard_fail_reason_helper(
|
||||
guard_fn: GuardFn,
|
||||
code: types.CodeType,
|
||||
f_locals: Dict[str, object],
|
||||
) -> str:
|
||||
"""
|
||||
@ -2525,6 +2537,15 @@ def get_guard_fail_reason(
|
||||
break
|
||||
|
||||
reason_str = "\n".join(reasons)
|
||||
return reason_str
|
||||
|
||||
|
||||
def get_guard_fail_reason(
|
||||
guard_fn: GuardFn,
|
||||
code: types.CodeType,
|
||||
f_locals: Dict[str, object],
|
||||
) -> str:
|
||||
reason_str = get_guard_fail_reason_helper(guard_fn, f_locals)
|
||||
guard_failures[orig_code_map[code]].append(reason_str)
|
||||
|
||||
try:
|
||||
|
Reference in New Issue
Block a user