mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update base for Update on "Solve for tilings"
Find variables that coalesce the reads and writes and score the total size. If uncoalesced memory expressions are found, look for additional tiling of variables which will coalesce memory accesses. For instance - for the following expression: `(32*p0) // 2048`, tiling p0 by 64 will make this expression coalesced. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
This commit is contained in:
@ -444,9 +444,9 @@ class LoopOrderingTest(TestCase):
|
||||
M, K = 4096, 4096
|
||||
|
||||
input_tensor = torch.randn(
|
||||
M, K, device="cuda", dtype=ref_dtype, requires_grad=False
|
||||
M, K, device=GPU_TYPE, dtype=ref_dtype, requires_grad=False
|
||||
)
|
||||
scale = torch.Tensor([10.0]).to("cuda")
|
||||
scale = torch.Tensor([10.0]).to(GPU_TYPE)
|
||||
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
@ -637,7 +637,10 @@ class MemoryCoalescingTest(MockSchedulerTest):
|
||||
def foo(x, y):
|
||||
return x + y
|
||||
|
||||
foo(torch.rand([4, 4], device="cuda"), torch.rand([4, 4], device="cuda").T)
|
||||
foo(
|
||||
torch.rand([4, 4], device=GPU_TYPE),
|
||||
torch.rand([4, 4], device=GPU_TYPE).T,
|
||||
)
|
||||
|
||||
def test_remapped_reads_split(self):
|
||||
from torch._inductor import tiling_utils
|
||||
@ -676,7 +679,10 @@ class MemoryCoalescingTest(MockSchedulerTest):
|
||||
(y.T + 1).flatten()
|
||||
)
|
||||
|
||||
foo(torch.rand([6, 6], device="cuda"), torch.rand([6, 6], device="cuda").T)
|
||||
foo(
|
||||
torch.rand([6, 6], device=GPU_TYPE),
|
||||
torch.rand([6, 6], device=GPU_TYPE).T,
|
||||
)
|
||||
|
||||
def test_reduction_pointwise(self):
|
||||
# test one pw var, one red var
|
||||
@ -717,7 +723,8 @@ class MemoryCoalescingTest(MockSchedulerTest):
|
||||
return out.sum(dim=1)
|
||||
|
||||
foo(
|
||||
torch.rand(256, 256, device="cuda"), torch.rand(256, 256, device="cuda")
|
||||
torch.rand(256, 256, device=GPU_TYPE),
|
||||
torch.rand(256, 256, device=GPU_TYPE),
|
||||
)
|
||||
|
||||
def test_reduction_no_pointwise(self):
|
||||
@ -740,7 +747,7 @@ class MemoryCoalescingTest(MockSchedulerTest):
|
||||
def foo(x):
|
||||
return x.sum()
|
||||
|
||||
foo(torch.rand(1024, device="cuda"))
|
||||
foo(torch.rand(1024, device=GPU_TYPE))
|
||||
|
||||
def test_coalescing(self):
|
||||
from torch._inductor import tiling_utils
|
||||
@ -805,8 +812,8 @@ class MemoryCoalescingTest(MockSchedulerTest):
|
||||
|
||||
y_dtype = torch.float if not downcast_transposed_v else torch.float64
|
||||
foo(
|
||||
torch.rand(256, 256, device="cuda"),
|
||||
torch.rand(256, 256, device="cuda", dtype=y_dtype).T,
|
||||
torch.rand(256, 256, device=GPU_TYPE),
|
||||
torch.rand(256, 256, device=GPU_TYPE, dtype=y_dtype).T,
|
||||
)
|
||||
|
||||
|
||||
|
@ -714,7 +714,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
"Fill in the reduction numel of lengths if missing"
|
||||
sizevars = V.graph.sizevars
|
||||
if len(lengths[1]) == 0 and (
|
||||
not sizevars.is_expr_static_and_true(reduction_numel == sympy.S.One)
|
||||
not sizevars.statically_known_equals(reduction_numel, sympy.S.One)
|
||||
and sizevars.statically_known_equals(
|
||||
sympy_product(groups),
|
||||
sympy_product(lengths[0]) * reduction_numel,
|
||||
|
@ -56,6 +56,7 @@ def find_coalesced_var(
|
||||
try:
|
||||
new_val = sympy_subs(index, variables)
|
||||
except ZeroDivisionError:
|
||||
loop_tiling_log.info("zero division error %s %s", index, variables)
|
||||
continue
|
||||
if new_val - zero_index == 1:
|
||||
return v
|
||||
@ -231,6 +232,7 @@ else:
|
||||
"""
|
||||
if len(it1) != len(it2):
|
||||
raise ValueError(f"Lengths differ: {len(it1)} != {len(it2)}")
|
||||
return zip(it1, it2)
|
||||
|
||||
|
||||
def apply_var_mapping(
|
||||
|
Reference in New Issue
Block a user