Compare commits

...

1 Commits

Author SHA1 Message Date
5bc8d92294 Enable PyTorch OSS numerics changes, inductor heuristics (#167799)
Summary:

Enable heuristic updates for performance improvements internally

Test Plan:
`buck2 test mode/opt aps_models/ads/launchers/fm/tests/ne/e2e_deterministic_tests:`

`buck2 test mode/opt aps_models/ads/launchers/fm/tests/ne/model_deterministic_tests:`

Reviewed By: njriasan, eellison

Differential Revision: D86211542
2025-11-15 22:02:49 -08:00

View File

@ -22,7 +22,6 @@ from typing import Any, Generic, Literal, TYPE_CHECKING, TypeVar, Union
import torch
from torch._dynamo.utils import counters, set_feature_use
from torch._environment import is_fbcode
from torch._inductor import metrics
from torch._prims_common import compute_required_storage_length
from torch.utils._debug_mode import get_active_debug_mode
@ -2470,9 +2469,8 @@ def triton_config_reduction(
rnumels[prefix] *= 2
if num_warps is None:
if reduction_hint == ReductionHint.INNER and not is_fbcode():
# r is contiguous, so ensure that each thread has 8 elements for
# vectorized loads, assuming bf16/fp16
if reduction_hint == ReductionHint.INNER:
# r is contiguous, ensure at least 8 elements per thread
# xblock is usually 1-2, default to giving each thread more work
num_warps = r // 128
else:
@ -2942,7 +2940,7 @@ def _reduction_configs(
)
contiguous_config = make_config(
2 if rnumel <= 2048 and not is_fbcode() else 1, # 1024 or less is persistent
2 if rnumel <= 2048 else 1, # 1024 or less is persistent
min(rnumel, MAX_R0_BLOCK),
register_intensive=register_intensive,
)
@ -2955,7 +2953,7 @@ def _reduction_configs(
outer_config = make_config(64, 8, register_intensive=register_intensive)
# TODO (paulzhan): Test heuristic on AMD and internal testing
# for correctness
if not torch.version.hip and not is_fbcode():
if not torch.version.hip:
outer_config = outer_config_opt()
configs = []