Compare commits

...

1 Commits

2 changed files with 62 additions and 0 deletions

View File

@ -3231,6 +3231,58 @@ if HAS_CUDA_AND_TRITON:
# splitting on 1 custom gives 2 cudagraphs
self.assertEqual(self.get_manager().new_graph_id().id, 2)
@config.patch(implicit_fallbacks=True)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_custom_op_mutation_late_free(self):
@torch.library.custom_op(
"mylib::op1",
mutates_args=["x"],
schema="(Tensor(a!)? x) -> (Tensor, Tensor)",
device_types="cuda",
)
def op1(x) -> tuple[torch.Tensor, torch.Tensor]:
x = x + 1
return (x + 1, x + 2)
@op1.register_fake
def _(x) -> tuple[torch.Tensor, torch.Tensor]:
return (torch.empty_like(x), torch.empty_like(x))
@torch.library.custom_op(
"mylib::cg_unsafe_op",
mutates_args=[],
schema="(Tensor x, Tensor y, Tensor x1, Tensor y1) -> Tensor",
device_types="cuda",
tags=(torch._C.Tag.cudagraph_unsafe,),
)
def cg_unsafe_op(x0, x1, y0, y1) -> torch.Tensor:
return x0 + x1 + y0 + y1
@cg_unsafe_op.register_fake
def _(x0, x1, y0, y1) -> torch.Tensor:
return torch.empty_like(x0)
def f(x):
x = x + 1
x = op1(x)
x0, x1 = x[0], x[1]
y0 = x0 + 1
y1 = x1 + 1
y = cg_unsafe_op(x0, x1, y0, y1)
z = y + x0 + x1
z0, z1 = op1(z)
z2 = z0 + z1
res = cg_unsafe_op(z2, z2, y, y)
return res
x = torch.randn(2, 2, device="cuda")
x_cloned = x.clone()
eager_out = f(x)
f_compiled = torch.compile(f, mode="reduce-overhead")
compiled_out = f_compiled(x_cloned)
self.assertEqual(eager_out, compiled_out)
@config.patch(implicit_fallbacks=True)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_custom_op_dynamoc_shapes(self):

View File

@ -4893,6 +4893,16 @@ class Scheduler:
for node in partition:
buffer_names_to_free.update(node.last_usage)
# buffer_names_to_free may contain buffers allocated in previous
# graph partitions. These buffers should also be a partition
# input.
extra_input_names = [
name
for name in (buffer_names_to_free - output_names)
if name in name_to_node
]
partition_input_names.update(extra_input_names)
input_nodes = {
name: name_to_node[name]
for name in partition_input_names