[Inductor][CPP] Select tiling factor for lower precision data types (#133830)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133830
Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
CaoE
2024-09-05 18:11:00 -07:00
committed by PyTorch MergeBot
parent 60d98b4cfb
commit 758d515d98
2 changed files with 35 additions and 6 deletions

View File

@ -4098,13 +4098,19 @@ class CPUReproTests(TestCase):
funcs.append(func2)
# test small shapes
funcs.append(func2)
small_size = cpu_vec_isa.pick_vec_isa().nelements(dtype=torch.bfloat16) // 2
example_shapes = [
[(10, 32, 20, 20), (10, 32, 20, 20)],
[(10, 32, 20, 20)],
[(10, 32, 20, 20), (10, 32, 20, 20)],
# test small shapes
[(small_size), (small_size)],
]
mixed_types = [False, False, True]
check_vecns = [True, True, True]
mixed_types = [False, False, True, False]
check_vecns = [True, True, True, False]
for dtype in [torch.bfloat16, torch.float16]:
for func, shapes, mixed, check_vecn in zip(

View File

@ -3260,7 +3260,13 @@ class TilingSelect:
tiling_indices = self._select_tiling_indices(
fn_list, var_sizes_list, tiling_factor
)
if tiling_indices:
group, reduction_group = max(
var_sizes_list, key=lambda sizes: len(sizes[1])
)
call_ranges = tuple(group) + tuple(reduction_group)
if config.cpp.enable_tiling_heuristics:
def _try_get_stride(
@ -3296,10 +3302,6 @@ class TilingSelect:
< len(itervars)
)
group, reduction_group = max(
var_sizes_list, key=lambda sizes: len(sizes[1])
)
call_ranges = tuple(group) + tuple(reduction_group)
itervars = [
sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
for n in range(len(call_ranges))
@ -3376,6 +3378,27 @@ class TilingSelect:
# when needed.
return [], []
if dtype in DTYPE_LOWP_FP:
# For lower precision data type, if the call_range is not long enough,
# use tiling_factor // 2 for better performance
factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype)
for tiling_indice in tiling_indices:
if tiling_indice < 0:
tiling_indice = tiling_indice + len(call_ranges)
if tiling_indice < 0 or tiling_indice >= len(call_ranges):
continue
if has_free_symbols(call_ranges):
call_range = V.graph.sizevars.size_hint(
call_ranges[tiling_indice], fallback=0
)
if call_range < factor_lowp:
V.graph.sizevars.guard_lt(call_range, factor_lowp)
tiling_factor = factor_lowp // 2
break
elif call_ranges[tiling_indice] < factor_lowp:
tiling_factor = factor_lowp // 2
break
if len(tiling_indices) == 1:
return [tiling_factor], tiling_indices
if len(tiling_indices) == 2: