mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[inductor] add a threshold for membw saving during fusion (#136782)
Fix https://github.com/pytorch/pytorch/issues/133242 . In that issue, inductor fuses 2 nodes because they access the same scalar tensor. This saving is very small (4 bytes), and if we ignore that, by default, we can not fuse. But if loop ordering after fusion get kicked in, we can reorder loops and fuse those 2 nodes. We get 33% memory bandwidth savings . I think adding a threshold for membw saving in general is not bad. I'll run a perf test. ( https://github.com/pytorch/pytorch/actions/runs/11375421752 ) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136782 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
e8b1409dcf
commit
6647320de2
@ -358,6 +358,7 @@ def make_test(
|
||||
device="cuda",
|
||||
**kwargs,
|
||||
):
|
||||
@config.patch("score_fusion_memory_threshold", 1)
|
||||
def test_fn(self):
|
||||
stack = ExitStack()
|
||||
try:
|
||||
@ -442,6 +443,7 @@ def make_test(
|
||||
|
||||
|
||||
def make_recompile_test(optim_cls, closure=None, kernel_count=2, **kwargs):
|
||||
@config.patch("score_fusion_memory_threshold", 1)
|
||||
@requires_gpu
|
||||
def test_fn(self):
|
||||
torch._dynamo.reset()
|
||||
|
||||
@ -412,6 +412,46 @@ class LoopOrderingTest(TestCase):
|
||||
self.do_acc_test(f, x, scale)
|
||||
self.assertEqual(1, metrics.generated_kernel_count)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
|
||||
def test_fp8_pattern_2(self):
|
||||
"""
|
||||
This test repros the fp8 fusion relation issue here:
|
||||
https://github.com/pytorch/pytorch/issues/133242
|
||||
"""
|
||||
ref_dtype = torch.bfloat16
|
||||
M, K = 4096, 4096
|
||||
|
||||
input_tensor = torch.randn(
|
||||
M, K, device="cuda", dtype=ref_dtype, requires_grad=False
|
||||
)
|
||||
scale = torch.Tensor([10.0]).to("cuda")
|
||||
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
||||
|
||||
def test_pattern2(tensor_x_inp, scale_x):
|
||||
tensor_x = tensor_x_inp * scale_x
|
||||
tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
||||
tensor_fp8 = tensor_x.to(torch.float8_e4m3fn)
|
||||
|
||||
tensor_x_t = (tensor_x_inp * scale_x).t()
|
||||
tensor_x_t = tensor_x_t.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
||||
tensor_fp8_t = tensor_x_t.to(torch.float8_e4m3fn)
|
||||
|
||||
tensor_fp8_t = tensor_fp8_t.contiguous().t()
|
||||
|
||||
return (tensor_fp8, tensor_fp8_t)
|
||||
|
||||
test_pattern = torch.compile(test_pattern2)
|
||||
tensor_fp8, tensor_fp8_t = test_pattern(input_tensor, scale)
|
||||
|
||||
self.assertEqual(1, metrics.generated_kernel_count)
|
||||
|
||||
expected_numbytes = scale.nbytes # scalar
|
||||
expected_numbytes += input_tensor.nbytes # input
|
||||
expected_numbytes += tensor_fp8.nbytes + tensor_fp8_t.nbytes # output
|
||||
self.assertEqual(expected_numbytes, metrics.num_bytes_accessed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_GPU:
|
||||
|
||||
@ -438,6 +438,17 @@ loop_ordering_after_fusion = (
|
||||
os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1"
|
||||
)
|
||||
|
||||
# If fusing two nodes only save less then score_fusion_memory_threshold memory,
|
||||
# we should not bother fusing the nodes.
|
||||
#
|
||||
# This is especially helpful to resolve https://github.com/pytorch/pytorch/issues/133242
|
||||
# Previously we fuse two nodes because of common read of a scalar tensor.
|
||||
# If we skip it, the loop ordering after fusion mechanism kicks in and can
|
||||
# brings more savings.
|
||||
#
|
||||
# For the cases loop ordering after fusion does not help, we don't lose much.
|
||||
score_fusion_memory_threshold = 10
|
||||
|
||||
# For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel
|
||||
benchmark_epilogue_fusion = (
|
||||
os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1"
|
||||
|
||||
@ -2922,7 +2922,10 @@ class Scheduler:
|
||||
node2.get_name(),
|
||||
)
|
||||
|
||||
return self.score_fusion_memory(node1, node2) > 0
|
||||
return (
|
||||
self.score_fusion_memory(node1, node2)
|
||||
>= config.score_fusion_memory_threshold
|
||||
)
|
||||
|
||||
def unfusable_node(self, node: BaseSchedulerNode) -> bool:
|
||||
"""
|
||||
@ -2990,7 +2993,10 @@ class Scheduler:
|
||||
return False
|
||||
del device2
|
||||
|
||||
no_shared_data = self.score_fusion_memory(node1, node2) == 0
|
||||
no_shared_data = (
|
||||
self.score_fusion_memory(node1, node2)
|
||||
< config.score_fusion_memory_threshold
|
||||
)
|
||||
if no_shared_data:
|
||||
no_shared_data = not self.has_shared_data_after_reordering_loop(
|
||||
node1, node2
|
||||
|
||||
Reference in New Issue
Block a user