mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157580 Approved by: https://github.com/bdhirsh
73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
# Owner(s): ["module: pt2-dispatcher"]
|
|
import torch
|
|
from functorch.compile import min_cut_rematerialization_partition
|
|
from torch._C import FileCheck
|
|
from torch._inductor.custom_graph_pass import CustomPartitionerFn, get_hash_for_files
|
|
from torch._inductor.test_case import TestCase
|
|
from torch._inductor.utils import run_fw_bw_and_get_code
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
|
|
|
|
|
class MyCustomPartitionerFn(CustomPartitionerFn):
|
|
"""
|
|
A custom partitioner function with static_lifetime_input_indices overwrites.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.called = False
|
|
|
|
def __call__(self, gm, joint_inputs, **kwargs):
|
|
self.called = True
|
|
kwargs["static_lifetime_input_indices"] = [0, 1]
|
|
return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs)
|
|
|
|
def uuid(self):
|
|
return get_hash_for_files((__file__,))
|
|
|
|
|
|
class TestCustomPartitionerFn(TestCase):
|
|
def test_custom_partitioner_fn(self):
|
|
"""
|
|
For function f(a, b), with the partitioner in the compile_fx stack,
|
|
the addition `a+b` (equivalently `buf0`) is saved for backward.
|
|
With the custom partitioner function, we indicate that
|
|
`a` and `b` (equivalently `primals_1` and `primals_2`) do not take
|
|
additional memory and thus, they are saved for backward.
|
|
"""
|
|
|
|
# initialization
|
|
@torch.compile
|
|
def f(a, b):
|
|
return (a + b).cos().cos()
|
|
|
|
a = torch.randn((2, 2), requires_grad=True, device=GPU_TYPE)
|
|
b = torch.randn((2, 2), requires_grad=True, device=GPU_TYPE)
|
|
|
|
# CASE 1 -- default
|
|
# addition `a + b` (i.e, `buf0`) is saved for backward.
|
|
code_og = run_fw_bw_and_get_code(lambda: f(a, b))
|
|
fwd_code_og = code_og[1][0]
|
|
FileCheck().check("return (buf1, buf0, )").run(fwd_code_og)
|
|
|
|
# CASE 2 -- custom partitioner function
|
|
# `a` and `b` (i.e., `primals_1` and `primals_2`) are saved for backward.
|
|
custom_partitioner_fn = MyCustomPartitionerFn()
|
|
self.assertFalse(custom_partitioner_fn.called)
|
|
self.assertIsNotNone(custom_partitioner_fn.uuid())
|
|
|
|
with torch._inductor.config.patch(custom_partitioner_fn=custom_partitioner_fn):
|
|
code_cp = run_fw_bw_and_get_code(lambda: f(a, b))
|
|
fwd_code_cp = code_cp[1][0]
|
|
FileCheck().check("return (buf0, primals_1, primals_2, )").run(fwd_code_cp)
|
|
|
|
# make sure the custom partitioner function is indeed invoked
|
|
self.assertTrue(custom_partitioner_fn.called)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if HAS_GPU:
|
|
run_tests()
|