mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make require_stride_order peek into AliasedLayout (#111681)
Summary: `require_stride_order` doesn't know how to handle storage with `AliasedLayout`. It always resorts to a copy even when the view refers to a storage with `FixedLayout`. This causes an unneccessary allocation + copy for collective outputs. Peeking into `AliasedLayout` in `require_stride_order` seems to be the proper way to address the issue. Original program: ```python import tempfile import torch import torch.distributed as dist from torch.distributed._functional_collectives import * # noqa from torch._inductor.utils import run_and_get_triton_code def func(arg: torch.Tensor) -> torch.Tensor: buf0 = arg + 42 out0 = torch.ops.c10d_functional.all_reduce(buf0, "avg", "default", [0], 1) out0 = torch.ops.c10d_functional.wait_tensor(out0) return out0 if __name__ == "__main__": with tempfile.NamedTemporaryFile(delete=False) as tmpf: dist.init_process_group( backend="nccl", init_method=f"file://{tmpf.name}", rank=0, world_size=1 ) device = torch.device("cuda:0") compiled = torch.compile(func) print(run_and_get_triton_code(compiled, torch.rand(4, 4, device=device))) torch.cuda.synchronize() dist.destroy_process_group() ``` Before: ```python def call(args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (4, 4), (4, 1)) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context buf0 = empty_strided((4, 4), (4, 1), device='cuda', dtype=torch.float32) # Source Nodes: [buf0], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(arg0_1, buf0, 16, grid=grid(16), stream=stream0) del arg0_1 buf1 = buf0; del buf0 # reuse buf2_pg = c10d._find_or_create_pg_by_ranks_and_tag('default', [0], 1) buf2 = buf1 buf2_work = dist.all_reduce(buf2, async_op=True, group=buf2_pg, op=fun_col_impl._str_to_reduce_op('avg')) fun_col_impl._register_tensor_work(buf2, buf2_work) buf1 = _wait_tensor(buf1) buf3 = buf1 buf4 = empty_strided((4, 4), (4, 1), device='cuda', dtype=torch.float32) # Source Nodes: [out0_1], Original ATen: [c10d_functional.wait_tensor] triton_poi_fused_wait_tensor_1.run(buf3, buf4, 16, grid=grid(16), stream=stream0) del buf1 del buf3 return (buf4, ) ``` After: ```python def call(args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (4, 4), (4, 1)) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context buf0 = empty_strided((4, 4), (4, 1), device='cuda', dtype=torch.float32) # Source Nodes: [buf0], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(arg0_1, buf0, 16, grid=grid(16), stream=stream0) del arg0_1 buf1 = buf0; del buf0 # reuse buf2_pg = c10d._find_or_create_pg_by_ranks_and_tag('default', [0], 1) buf2 = buf1 buf2_work = dist.all_reduce(buf2, async_op=True, group=buf2_pg, op=fun_col_impl._str_to_reduce_op('avg')) fun_col_impl._register_tensor_work(buf2, buf2_work) buf1 = _wait_tensor(buf1) buf3 = buf1 del buf3 return (buf1, ) ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/111681 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ac08b10d60
commit
6fd3659391
@ -478,14 +478,16 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
compiled = torch.compile(func)
|
||||
out = compiled(inputs, **self.get_world_trs())
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
# NOTE: Make sure we are not unneccessarily copying the outputs of
|
||||
# wait_tensors before they are returned from the graph.
|
||||
FileCheck() \
|
||||
.check("buf0 = empty(") \
|
||||
.check("buf0.copy_(arg0_1)") \
|
||||
.check("buf1 = buf0") \
|
||||
.check("buf1_work = dist.all_reduce(buf1") \
|
||||
.check("fun_col_impl._register_tensor_work(buf1, buf1_work)") \
|
||||
.check("_wait_tensor(buf0)") \
|
||||
.check("return (buf3, )") \
|
||||
.check("buf0 = _wait_tensor(buf0)") \
|
||||
.check("return (buf0, )") \
|
||||
.run(code)
|
||||
correct = func(inputs, **self.get_world_trs())
|
||||
self.assertTrue(same(out, correct))
|
||||
@ -510,16 +512,18 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
# NOTE: Make sure we are not unneccessarily copying the outputs of
|
||||
# wait_tensors before they are returned from the graph.
|
||||
FileCheck() \
|
||||
.check("buf1 = buf0; del buf0 # reuse") \
|
||||
.check_not("buf1.copy_(") \
|
||||
.check("buf2 = buf1") \
|
||||
.check("buf2_work = dist.all_reduce(buf2") \
|
||||
.check("fun_col_impl._register_tensor_work(buf2, buf2_work)") \
|
||||
.check("_wait_tensor(buf1)") \
|
||||
.check("buf1 = _wait_tensor(buf1)") \
|
||||
.check("buf3 = buf1") \
|
||||
.check("buf4 = empty(") \
|
||||
.check("return (buf4, buf5") \
|
||||
.check("buf4 = empty") \
|
||||
.check("return (buf1, buf4") \
|
||||
.run(code)
|
||||
out = compiled(inputs, **self.get_world_trs())
|
||||
correct = func(inputs, **self.get_world_trs())
|
||||
@ -546,18 +550,20 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
# NOTE: Make sure we are not unneccessarily copying the outputs of
|
||||
# wait_tensors before they are returned from the graph.
|
||||
FileCheck() \
|
||||
.check("buf0 = empty(") \
|
||||
.check("buf5 = empty(") \
|
||||
.check("triton_poi__0.run(arg0_1, buf0, buf5") \
|
||||
.check("buf4 = empty(") \
|
||||
.check("triton_poi__0.run(arg0_1, buf0, buf4") \
|
||||
.check_not("copy_(") \
|
||||
.check("buf1 = buf0; del buf0 # reuse") \
|
||||
.check("buf2 = buf1") \
|
||||
.check("buf2_work = dist.all_reduce(buf2") \
|
||||
.check("fun_col_impl._register_tensor_work(buf2, buf2_work)") \
|
||||
.check("_wait_tensor(buf1)") \
|
||||
.check("buf1 = _wait_tensor(buf1)") \
|
||||
.check("buf3 = buf1") \
|
||||
.check("return (buf4, buf5, buf6") \
|
||||
.check("return (buf1, buf4, buf5") \
|
||||
.run(code)
|
||||
out = compiled(inputs, **self.get_world_trs())
|
||||
correct = func(inputs, **self.get_world_trs())
|
||||
@ -787,10 +793,12 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
# NOTE: Make sure we are not unneccessarily copying the outputs of
|
||||
# wait_tensors before they are returned from the graph.
|
||||
FileCheck() \
|
||||
.check("buf0 = empty(") \
|
||||
.check("buf6 = empty(") \
|
||||
.check("triton_poi__0.run(arg0_1, buf0, buf6") \
|
||||
.check("buf5 = empty(") \
|
||||
.check("triton_poi__0.run(arg0_1, buf0, buf5") \
|
||||
.check("buf1 = empty(") \
|
||||
.check("buf2 = empty(") \
|
||||
.check_not("copy_(") \
|
||||
@ -799,12 +807,12 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
.check("buf3_work = fun_col_impl._all_gather_into_tensor_coalesced_fallback("
|
||||
"output_tensors=buf3, input_tensors=buf3_inputs") \
|
||||
.check("fun_col_impl._register_tensor_work(buf3, buf3_work)") \
|
||||
.check("_wait_tensor(buf1)") \
|
||||
.check("buf1 = _wait_tensor(buf1)") \
|
||||
.check("buf4 = buf1") \
|
||||
.check("buf5 = buf0; del buf0 # reuse") \
|
||||
.check("_wait_tensor(buf2)") \
|
||||
.check("buf8 = buf2") \
|
||||
.check("return (buf5, buf6, buf7, buf9") \
|
||||
.check("buf6 = buf0; del buf0 # reuse") \
|
||||
.check("buf2 = _wait_tensor(buf2)") \
|
||||
.check("buf7 = buf2") \
|
||||
.check("return (buf1, buf5, buf6, buf2") \
|
||||
.run(code)
|
||||
out = compiled(inputs, **self.get_world_trs())
|
||||
correct = func(inputs, **self.get_world_trs())
|
||||
@ -832,10 +840,12 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
# NOTE: The first return value should be the output of the first wait_tensor.
|
||||
# We want to make sure no unneccessary copy is made.
|
||||
FileCheck() \
|
||||
.check("buf0 = empty(") \
|
||||
.check("buf6 = empty(") \
|
||||
.check("triton_poi__0.run(arg0_1, buf0, buf6") \
|
||||
.check("buf5 = empty(") \
|
||||
.check("triton_poi__0.run(arg0_1, buf0, buf5") \
|
||||
.check("buf1 = empty(") \
|
||||
.check("buf2 = empty(") \
|
||||
.check_not("copy_(") \
|
||||
@ -843,12 +853,12 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
.check("buf3_work = fun_col_impl._reduce_scatter_tensor_coalesced_fallback("
|
||||
"output_tensors=buf3, input_tensors=buf3_inputs") \
|
||||
.check("fun_col_impl._register_tensor_work(buf3, buf3_work)") \
|
||||
.check("_wait_tensor(buf1)") \
|
||||
.check("buf1 = _wait_tensor(buf1)") \
|
||||
.check("buf4 = buf1") \
|
||||
.check("buf5 = buf0; del buf0 # reuse") \
|
||||
.check("_wait_tensor(buf2)") \
|
||||
.check("buf8 = buf2") \
|
||||
.check("return (buf5, buf6, buf7, buf9") \
|
||||
.check("buf6 = buf0; del buf0 # reuse") \
|
||||
.check("buf2 = _wait_tensor(buf2)") \
|
||||
.check("buf7 = buf2") \
|
||||
.check("return (buf1, buf5, buf6, buf2") \
|
||||
.run(code)
|
||||
out = compiled(inputs, **self.get_world_trs())
|
||||
correct = func(inputs, **self.get_world_trs())
|
||||
|
@ -3490,6 +3490,8 @@ class ExternKernel(InputsKernel):
|
||||
|
||||
# require x to have the layout as strided_ordered as order
|
||||
if is_storage_and_layout(x):
|
||||
while isinstance(x.get_layout(), AliasedLayout):
|
||||
x = x.get_layout().view
|
||||
if isinstance(x.get_layout(), FlexibleLayout):
|
||||
# fix flexiblelayout to be FixedLayout with stride_order
|
||||
as_storage_and_layout(
|
||||
|
Reference in New Issue
Block a user