Make require_stride_order peek into AliasedLayout (#111681)

Summary:

`require_stride_order` doesn't know how to handle storage with `AliasedLayout`. It always resorts to a copy even when the view refers to a storage with `FixedLayout`. This causes an unneccessary allocation + copy for collective outputs. Peeking into `AliasedLayout` in `require_stride_order` seems to be the proper way to address the issue.

Original program:
```python
import tempfile

import torch
import torch.distributed as dist
from torch.distributed._functional_collectives import *  # noqa
from torch._inductor.utils import run_and_get_triton_code

def func(arg: torch.Tensor) -> torch.Tensor:
    buf0 = arg + 42
    out0 = torch.ops.c10d_functional.all_reduce(buf0, "avg", "default", [0], 1)
    out0 = torch.ops.c10d_functional.wait_tensor(out0)
    return out0

if __name__ == "__main__":
    with tempfile.NamedTemporaryFile(delete=False) as tmpf:
        dist.init_process_group(
            backend="nccl", init_method=f"file://{tmpf.name}", rank=0, world_size=1
        )
        device = torch.device("cuda:0")

        compiled = torch.compile(func)
        print(run_and_get_triton_code(compiled, torch.rand(4, 4, device=device)))

        torch.cuda.synchronize()
        dist.destroy_process_group()
```

Before:
```python
def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (4, 4), (4, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = empty_strided((4, 4), (4, 1), device='cuda', dtype=torch.float32)
        # Source Nodes: [buf0], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(arg0_1, buf0, 16, grid=grid(16), stream=stream0)
        del arg0_1
        buf1 = buf0; del buf0  # reuse
        buf2_pg = c10d._find_or_create_pg_by_ranks_and_tag('default', [0], 1)
        buf2 = buf1
        buf2_work = dist.all_reduce(buf2, async_op=True, group=buf2_pg, op=fun_col_impl._str_to_reduce_op('avg'))
        fun_col_impl._register_tensor_work(buf2, buf2_work)
        buf1 = _wait_tensor(buf1)
        buf3 = buf1
        buf4 = empty_strided((4, 4), (4, 1), device='cuda', dtype=torch.float32)
        # Source Nodes: [out0_1], Original ATen: [c10d_functional.wait_tensor]
        triton_poi_fused_wait_tensor_1.run(buf3, buf4, 16, grid=grid(16), stream=stream0)
        del buf1
        del buf3
        return (buf4, )
```

After:
```python
def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (4, 4), (4, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = empty_strided((4, 4), (4, 1), device='cuda', dtype=torch.float32)
        # Source Nodes: [buf0], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(arg0_1, buf0, 16, grid=grid(16), stream=stream0)
        del arg0_1
        buf1 = buf0; del buf0  # reuse
        buf2_pg = c10d._find_or_create_pg_by_ranks_and_tag('default', [0], 1)
        buf2 = buf1
        buf2_work = dist.all_reduce(buf2, async_op=True, group=buf2_pg, op=fun_col_impl._str_to_reduce_op('avg'))
        fun_col_impl._register_tensor_work(buf2, buf2_work)
        buf1 = _wait_tensor(buf1)
        buf3 = buf1
        del buf3
        return (buf1, )
```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111681
Approved by: https://github.com/jansel
This commit is contained in:
Yifu Wang
2023-10-25 00:51:52 -07:00
committed by PyTorch MergeBot
parent ac08b10d60
commit 6fd3659391
2 changed files with 35 additions and 23 deletions

View File

@ -478,14 +478,16 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
compiled = torch.compile(func)
out = compiled(inputs, **self.get_world_trs())
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: Make sure we are not unneccessarily copying the outputs of
# wait_tensors before they are returned from the graph.
FileCheck() \
.check("buf0 = empty(") \
.check("buf0.copy_(arg0_1)") \
.check("buf1 = buf0") \
.check("buf1_work = dist.all_reduce(buf1") \
.check("fun_col_impl._register_tensor_work(buf1, buf1_work)") \
.check("_wait_tensor(buf0)") \
.check("return (buf3, )") \
.check("buf0 = _wait_tensor(buf0)") \
.check("return (buf0, )") \
.run(code)
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@ -510,16 +512,18 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: Make sure we are not unneccessarily copying the outputs of
# wait_tensors before they are returned from the graph.
FileCheck() \
.check("buf1 = buf0; del buf0 # reuse") \
.check_not("buf1.copy_(") \
.check("buf2 = buf1") \
.check("buf2_work = dist.all_reduce(buf2") \
.check("fun_col_impl._register_tensor_work(buf2, buf2_work)") \
.check("_wait_tensor(buf1)") \
.check("buf1 = _wait_tensor(buf1)") \
.check("buf3 = buf1") \
.check("buf4 = empty(") \
.check("return (buf4, buf5") \
.check("buf4 = empty") \
.check("return (buf1, buf4") \
.run(code)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
@ -546,18 +550,20 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: Make sure we are not unneccessarily copying the outputs of
# wait_tensors before they are returned from the graph.
FileCheck() \
.check("buf0 = empty(") \
.check("buf5 = empty(") \
.check("triton_poi__0.run(arg0_1, buf0, buf5") \
.check("buf4 = empty(") \
.check("triton_poi__0.run(arg0_1, buf0, buf4") \
.check_not("copy_(") \
.check("buf1 = buf0; del buf0 # reuse") \
.check("buf2 = buf1") \
.check("buf2_work = dist.all_reduce(buf2") \
.check("fun_col_impl._register_tensor_work(buf2, buf2_work)") \
.check("_wait_tensor(buf1)") \
.check("buf1 = _wait_tensor(buf1)") \
.check("buf3 = buf1") \
.check("return (buf4, buf5, buf6") \
.check("return (buf1, buf4, buf5") \
.run(code)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
@ -787,10 +793,12 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: Make sure we are not unneccessarily copying the outputs of
# wait_tensors before they are returned from the graph.
FileCheck() \
.check("buf0 = empty(") \
.check("buf6 = empty(") \
.check("triton_poi__0.run(arg0_1, buf0, buf6") \
.check("buf5 = empty(") \
.check("triton_poi__0.run(arg0_1, buf0, buf5") \
.check("buf1 = empty(") \
.check("buf2 = empty(") \
.check_not("copy_(") \
@ -799,12 +807,12 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
.check("buf3_work = fun_col_impl._all_gather_into_tensor_coalesced_fallback("
"output_tensors=buf3, input_tensors=buf3_inputs") \
.check("fun_col_impl._register_tensor_work(buf3, buf3_work)") \
.check("_wait_tensor(buf1)") \
.check("buf1 = _wait_tensor(buf1)") \
.check("buf4 = buf1") \
.check("buf5 = buf0; del buf0 # reuse") \
.check("_wait_tensor(buf2)") \
.check("buf8 = buf2") \
.check("return (buf5, buf6, buf7, buf9") \
.check("buf6 = buf0; del buf0 # reuse") \
.check("buf2 = _wait_tensor(buf2)") \
.check("buf7 = buf2") \
.check("return (buf1, buf5, buf6, buf2") \
.run(code)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
@ -832,10 +840,12 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: The first return value should be the output of the first wait_tensor.
# We want to make sure no unneccessary copy is made.
FileCheck() \
.check("buf0 = empty(") \
.check("buf6 = empty(") \
.check("triton_poi__0.run(arg0_1, buf0, buf6") \
.check("buf5 = empty(") \
.check("triton_poi__0.run(arg0_1, buf0, buf5") \
.check("buf1 = empty(") \
.check("buf2 = empty(") \
.check_not("copy_(") \
@ -843,12 +853,12 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
.check("buf3_work = fun_col_impl._reduce_scatter_tensor_coalesced_fallback("
"output_tensors=buf3, input_tensors=buf3_inputs") \
.check("fun_col_impl._register_tensor_work(buf3, buf3_work)") \
.check("_wait_tensor(buf1)") \
.check("buf1 = _wait_tensor(buf1)") \
.check("buf4 = buf1") \
.check("buf5 = buf0; del buf0 # reuse") \
.check("_wait_tensor(buf2)") \
.check("buf8 = buf2") \
.check("return (buf5, buf6, buf7, buf9") \
.check("buf6 = buf0; del buf0 # reuse") \
.check("buf2 = _wait_tensor(buf2)") \
.check("buf7 = buf2") \
.check("return (buf1, buf5, buf6, buf2") \
.run(code)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())

View File

@ -3490,6 +3490,8 @@ class ExternKernel(InputsKernel):
# require x to have the layout as strided_ordered as order
if is_storage_and_layout(x):
while isinstance(x.get_layout(), AliasedLayout):
x = x.get_layout().view
if isinstance(x.get_layout(), FlexibleLayout):
# fix flexiblelayout to be FixedLayout with stride_order
as_storage_and_layout(