Files
pytorch/test/inductor/test_custom_partitioner_fn.py

73 lines
2.6 KiB
Python

# Owner(s): ["module: pt2-dispatcher"]
import torch
from functorch.compile import min_cut_rematerialization_partition
from torch._C import FileCheck
from torch._inductor.custom_graph_pass import CustomPartitionerFn, get_hash_for_files
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_fw_bw_and_get_code
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
class MyCustomPartitionerFn(CustomPartitionerFn):
"""
A custom partitioner function with static_lifetime_input_indices overwrites.
"""
def __init__(self):
super().__init__()
self.called = False
def __call__(self, gm, joint_inputs, **kwargs):
self.called = True
kwargs["static_lifetime_input_indices"] = [0, 1]
return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs)
def uuid(self):
return get_hash_for_files((__file__,))
class TestCustomPartitionerFn(TestCase):
def test_custom_partitioner_fn(self):
"""
For function f(a, b), with the partitioner in the compile_fx stack,
the addition `a+b` (equivalently `buf0`) is saved for backward.
With the custom partitioner function, we indicate that
`a` and `b` (equivalently `primals_1` and `primals_2`) do not take
additional memory and thus, they are saved for backward.
"""
# initialization
@torch.compile
def f(a, b):
return (a + b).cos().cos()
a = torch.randn((2, 2), requires_grad=True, device=GPU_TYPE)
b = torch.randn((2, 2), requires_grad=True, device=GPU_TYPE)
# CASE 1 -- default
# addition `a + b` (i.e, `buf0`) is saved for backward.
code_og = run_fw_bw_and_get_code(lambda: f(a, b))
fwd_code_og = code_og[1][0]
FileCheck().check("return (buf1, buf0, )").run(fwd_code_og)
# CASE 2 -- custom partitioner function
# `a` and `b` (i.e., `primals_1` and `primals_2`) are saved for backward.
custom_partitioner_fn = MyCustomPartitionerFn()
self.assertFalse(custom_partitioner_fn.called)
self.assertIsNotNone(custom_partitioner_fn.uuid())
with torch._inductor.config.patch(custom_partitioner_fn=custom_partitioner_fn):
code_cp = run_fw_bw_and_get_code(lambda: f(a, b))
fwd_code_cp = code_cp[1][0]
FileCheck().check("return (buf0, primals_1, primals_2, )").run(fwd_code_cp)
# make sure the custom partitioner function is indeed invoked
self.assertTrue(custom_partitioner_fn.called)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
run_tests()