mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Inductor][CPP] Select tiling factor for lower precision data types (#133830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133830 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
@ -4098,13 +4098,19 @@ class CPUReproTests(TestCase):
|
||||
|
||||
funcs.append(func2)
|
||||
|
||||
# test small shapes
|
||||
funcs.append(func2)
|
||||
small_size = cpu_vec_isa.pick_vec_isa().nelements(dtype=torch.bfloat16) // 2
|
||||
|
||||
example_shapes = [
|
||||
[(10, 32, 20, 20), (10, 32, 20, 20)],
|
||||
[(10, 32, 20, 20)],
|
||||
[(10, 32, 20, 20), (10, 32, 20, 20)],
|
||||
# test small shapes
|
||||
[(small_size), (small_size)],
|
||||
]
|
||||
mixed_types = [False, False, True]
|
||||
check_vecns = [True, True, True]
|
||||
mixed_types = [False, False, True, False]
|
||||
check_vecns = [True, True, True, False]
|
||||
|
||||
for dtype in [torch.bfloat16, torch.float16]:
|
||||
for func, shapes, mixed, check_vecn in zip(
|
||||
|
@ -3260,7 +3260,13 @@ class TilingSelect:
|
||||
tiling_indices = self._select_tiling_indices(
|
||||
fn_list, var_sizes_list, tiling_factor
|
||||
)
|
||||
|
||||
if tiling_indices:
|
||||
group, reduction_group = max(
|
||||
var_sizes_list, key=lambda sizes: len(sizes[1])
|
||||
)
|
||||
call_ranges = tuple(group) + tuple(reduction_group)
|
||||
|
||||
if config.cpp.enable_tiling_heuristics:
|
||||
|
||||
def _try_get_stride(
|
||||
@ -3296,10 +3302,6 @@ class TilingSelect:
|
||||
< len(itervars)
|
||||
)
|
||||
|
||||
group, reduction_group = max(
|
||||
var_sizes_list, key=lambda sizes: len(sizes[1])
|
||||
)
|
||||
call_ranges = tuple(group) + tuple(reduction_group)
|
||||
itervars = [
|
||||
sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
|
||||
for n in range(len(call_ranges))
|
||||
@ -3376,6 +3378,27 @@ class TilingSelect:
|
||||
# when needed.
|
||||
return [], []
|
||||
|
||||
if dtype in DTYPE_LOWP_FP:
|
||||
# For lower precision data type, if the call_range is not long enough,
|
||||
# use tiling_factor // 2 for better performance
|
||||
factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype)
|
||||
for tiling_indice in tiling_indices:
|
||||
if tiling_indice < 0:
|
||||
tiling_indice = tiling_indice + len(call_ranges)
|
||||
if tiling_indice < 0 or tiling_indice >= len(call_ranges):
|
||||
continue
|
||||
if has_free_symbols(call_ranges):
|
||||
call_range = V.graph.sizevars.size_hint(
|
||||
call_ranges[tiling_indice], fallback=0
|
||||
)
|
||||
if call_range < factor_lowp:
|
||||
V.graph.sizevars.guard_lt(call_range, factor_lowp)
|
||||
tiling_factor = factor_lowp // 2
|
||||
break
|
||||
elif call_ranges[tiling_indice] < factor_lowp:
|
||||
tiling_factor = factor_lowp // 2
|
||||
break
|
||||
|
||||
if len(tiling_indices) == 1:
|
||||
return [tiling_factor], tiling_indices
|
||||
if len(tiling_indices) == 2:
|
||||
|
Reference in New Issue
Block a user