Compare commits

...

2 Commits

2 changed files with 26 additions and 0 deletions

View File

@ -2905,6 +2905,22 @@ if HAS_CUDA_AND_TRITON:
# 2 graph partitions lead to 2 cudagraph
self.assertEqual(self.get_manager().new_graph_id().id, 2)
def test_graph_partition_view_fallback(self):
def f(x):
y = x + 1
z = torch.ops.aten.view.dtype(y, torch.float8_e4m3fn)
z_cpu = z.cpu()
u_cuda = z_cpu.cuda()
return u_cuda
compiled_f = torch.compile(f, mode="reduce-overhead")
for _ in range(3):
x = torch.ones(2, dtype=torch.int32, device="cuda")
eager_out = f(x)
compiled_out = compiled_f(x)
self.assertEqual(eager_out, compiled_out)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_log_message(self):
def foo(x, y):

View File

@ -5014,6 +5014,16 @@ class Scheduler:
for node in partition:
buffer_names_to_free.update(node.last_usage)
# buffer_names_to_free may contain buffers allocated in previous
# graph partitions. These buffers should also be a partition
# input.
extra_input_names = [
name
for name in (buffer_names_to_free - output_names)
if name in name_to_node
]
partition_input_names.update(extra_input_names)
input_nodes = {
name: name_to_node[name]
for name in partition_input_names