mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][sac] Support functools partial context_fn for sac (#164308)
Fixes https://github.com/pytorch/pytorch/issues/164300 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164308 Approved by: https://github.com/Lucaskabela, https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
e0f118585f
commit
c66d18d24d
@ -759,6 +759,38 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
),
|
||||
)
|
||||
|
||||
def test_sac_with_partial_context_fn(self):
|
||||
class CustomPolicy:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, ctx, out, func, *args, **kwargs):
|
||||
return CheckpointPolicy.MUST_SAVE
|
||||
|
||||
def f(x, y):
|
||||
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
||||
|
||||
context_fn1 = functools.partial(
|
||||
create_selective_checkpoint_contexts, CustomPolicy()
|
||||
)
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
f,
|
||||
x,
|
||||
y,
|
||||
use_reentrant=False,
|
||||
context_fn=context_fn1,
|
||||
)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="aot_eager_decomp_partition", fullgraph=True)
|
||||
a = torch.randn(4, 4, requires_grad=True, device="cpu")
|
||||
b = torch.randn(4, 4, requires_grad=True, device="cpu")
|
||||
|
||||
expected = fn(a, b)
|
||||
result = opt_fn(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
|
||||
|
@ -2583,7 +2583,7 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
|
||||
elif isinstance(
|
||||
ctx, torch._dynamo.variables.functions.FunctoolsPartialVariable
|
||||
):
|
||||
context_fn = ctx.as_python_constant()
|
||||
context_fn = ctx.guard_as_python_constant()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"checkpoint not implemented for {type(ctx)} context_fn"
|
||||
|
Reference in New Issue
Block a user