[dynamo][sac] Support functools partial context_fn for sac (#164308)

Fixes https://github.com/pytorch/pytorch/issues/164300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164308
Approved by: https://github.com/Lucaskabela, https://github.com/soulitzer
This commit is contained in:
Animesh Jain
2025-09-30 16:21:15 -07:00
committed by PyTorch MergeBot
parent e0f118585f
commit c66d18d24d
2 changed files with 33 additions and 1 deletions

View File

@ -759,6 +759,38 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
),
)
def test_sac_with_partial_context_fn(self):
class CustomPolicy:
def __init__(self):
super().__init__()
def __call__(self, ctx, out, func, *args, **kwargs):
return CheckpointPolicy.MUST_SAVE
def f(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
context_fn1 = functools.partial(
create_selective_checkpoint_contexts, CustomPolicy()
)
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
f,
x,
y,
use_reentrant=False,
context_fn=context_fn1,
)
opt_fn = torch.compile(fn, backend="aot_eager_decomp_partition", fullgraph=True)
a = torch.randn(4, 4, requires_grad=True, device="cpu")
b = torch.randn(4, 4, requires_grad=True, device="cpu")
expected = fn(a, b)
result = opt_fn(a, b)
self.assertEqual(result, expected)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):

View File

@ -2583,7 +2583,7 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
elif isinstance(
ctx, torch._dynamo.variables.functions.FunctoolsPartialVariable
):
context_fn = ctx.as_python_constant()
context_fn = ctx.guard_as_python_constant()
else:
raise NotImplementedError(
f"checkpoint not implemented for {type(ctx)} context_fn"