Enable epilogue fusion benchmarking internally (#125455)

Differential Revision: [D56920738](https://our.internmc.facebook.com/intern/diff/D56920738)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125455
Approved by: https://github.com/Chillee
This commit is contained in:
eellison
2024-05-13 21:15:40 +00:00
committed by PyTorch MergeBot
parent e046c59e5b
commit 328b75d1a0
3 changed files with 7 additions and 9 deletions

View File

@ -203,7 +203,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
{
"benchmark_kernel": True,
"benchmark_fusion": True,
"benchmark_multi_templates": True,
"benchmark_epilogue_fusion": True,
}
)
)
@ -231,7 +231,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
torch._dynamo.reset()
with unittest.mock.patch.object(
torch._inductor.config, "benchmark_multi_templates", False
torch._inductor.config, "benchmark_epilogue_fusion", False
):
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
with torch.no_grad():

View File

@ -318,15 +318,13 @@ debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
benchmark_multi_templates = (
os.environ.get(
"TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES", "0" if is_fbcode() else "1"
)
== "1"
# For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel
benchmark_epilogue_fusion = (
os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1"
)
# Take how many of the top triton kernels to benchmark epilogue
max_epilogue_benchmarked_choices = 3
max_epilogue_benchmarked_choices = 1
# how many nodes to allow into a single fusion
max_fusion_size = 64

View File

@ -1499,7 +1499,7 @@ def autotune_select_algorithm(*args, **kwargs):
if "return_multi_template" not in kwargs:
kwargs[
"return_multi_template"
] = torch._inductor.config.benchmark_multi_templates
] = torch._inductor.config.benchmark_epilogue_fusion
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)