[inductor] add a threshold for membw saving during fusion (#136782)

Fix https://github.com/pytorch/pytorch/issues/133242 . In that issue, inductor fuses 2 nodes because they access the same scalar tensor. This saving is very small (4 bytes), and if we ignore that, by default, we can not fuse. But if loop ordering after fusion get kicked in, we can reorder loops and fuse those 2 nodes. We get 33% memory bandwidth savings .

I think adding a threshold for membw saving in general is not bad.

I'll run a perf test. ( https://github.com/pytorch/pytorch/actions/runs/11375421752 )

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136782
Approved by: https://github.com/jansel
This commit is contained in:
Shunting Zhang
2024-10-16 16:22:38 -07:00
committed by PyTorch MergeBot
parent e8b1409dcf
commit 6647320de2
4 changed files with 61 additions and 2 deletions

View File

@ -358,6 +358,7 @@ def make_test(
device="cuda",
**kwargs,
):
@config.patch("score_fusion_memory_threshold", 1)
def test_fn(self):
stack = ExitStack()
try:
@ -442,6 +443,7 @@ def make_test(
def make_recompile_test(optim_cls, closure=None, kernel_count=2, **kwargs):
@config.patch("score_fusion_memory_threshold", 1)
@requires_gpu
def test_fn(self):
torch._dynamo.reset()

View File

@ -412,6 +412,46 @@ class LoopOrderingTest(TestCase):
self.do_acc_test(f, x, scale)
self.assertEqual(1, metrics.generated_kernel_count)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
def test_fp8_pattern_2(self):
"""
This test repros the fp8 fusion relation issue here:
https://github.com/pytorch/pytorch/issues/133242
"""
ref_dtype = torch.bfloat16
M, K = 4096, 4096
input_tensor = torch.randn(
M, K, device="cuda", dtype=ref_dtype, requires_grad=False
)
scale = torch.Tensor([10.0]).to("cuda")
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
def test_pattern2(tensor_x_inp, scale_x):
tensor_x = tensor_x_inp * scale_x
tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
tensor_fp8 = tensor_x.to(torch.float8_e4m3fn)
tensor_x_t = (tensor_x_inp * scale_x).t()
tensor_x_t = tensor_x_t.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
tensor_fp8_t = tensor_x_t.to(torch.float8_e4m3fn)
tensor_fp8_t = tensor_fp8_t.contiguous().t()
return (tensor_fp8, tensor_fp8_t)
test_pattern = torch.compile(test_pattern2)
tensor_fp8, tensor_fp8_t = test_pattern(input_tensor, scale)
self.assertEqual(1, metrics.generated_kernel_count)
expected_numbytes = scale.nbytes # scalar
expected_numbytes += input_tensor.nbytes # input
expected_numbytes += tensor_fp8.nbytes + tensor_fp8_t.nbytes # output
self.assertEqual(expected_numbytes, metrics.num_bytes_accessed)
if __name__ == "__main__":
if HAS_GPU:

View File

@ -438,6 +438,17 @@ loop_ordering_after_fusion = (
os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1"
)
# If fusing two nodes only save less then score_fusion_memory_threshold memory,
# we should not bother fusing the nodes.
#
# This is especially helpful to resolve https://github.com/pytorch/pytorch/issues/133242
# Previously we fuse two nodes because of common read of a scalar tensor.
# If we skip it, the loop ordering after fusion mechanism kicks in and can
# brings more savings.
#
# For the cases loop ordering after fusion does not help, we don't lose much.
score_fusion_memory_threshold = 10
# For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel
benchmark_epilogue_fusion = (
os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1"

View File

@ -2922,7 +2922,10 @@ class Scheduler:
node2.get_name(),
)
return self.score_fusion_memory(node1, node2) > 0
return (
self.score_fusion_memory(node1, node2)
>= config.score_fusion_memory_threshold
)
def unfusable_node(self, node: BaseSchedulerNode) -> bool:
"""
@ -2990,7 +2993,10 @@ class Scheduler:
return False
del device2
no_shared_data = self.score_fusion_memory(node1, node2) == 0
no_shared_data = (
self.score_fusion_memory(node1, node2)
< config.score_fusion_memory_threshold
)
if no_shared_data:
no_shared_data = not self.has_shared_data_after_reordering_loop(
node1, node2