From 16aaff77832526e913bfa3afdac7a16ff1341c39 Mon Sep 17 00:00:00 2001 From: eellison Date: Tue, 16 Jul 2024 12:51:23 -0700 Subject: [PATCH] Fix mm pad regresion - more conservative estimation of plannable inputs (#128909) - More conservative estimation of plannable inputs - Consider constant_pad_nd as pointwise node in concat lowering - Use aten.cat instead of constant pad ndwhen padding just a single dimension because it can be memory-planned away Pull Request resolved: https://github.com/pytorch/pytorch/pull/128909 Approved by: https://github.com/Chillee --- test/inductor/test_pad_mm.py | 22 +++++++++++++++ test/inductor/test_perf.py | 7 +++++ torch/_inductor/config.py | 5 ++++ torch/_inductor/fx_passes/pad_mm.py | 44 ++++++++++++++++++++++++++--- torch/_inductor/lowering.py | 8 +++++- torch/_inductor/utils.py | 16 +++++++++-- torch/utils/_config_module.py | 3 ++ 7 files changed, 97 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index f107f42a0459..c8b7bb4f4544 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -422,6 +422,28 @@ class PadMMTest(TestCase): repr(local_cache) ) + @fresh_inductor_cache() + @inductor_config.patch(max_pointwise_cat_inputs=2) + def test_exclude_cat_padding(self): + @torch.compile() + def mm(inps, b): + return torch.cat(inps) @ b + + inp = torch.rand([2046, 2046], device="cuda") + inp2 = torch.rand([2046, 2046], device="cuda") + + inps = inp.chunk(3) + mm(inps, inp2) + FileCheck().check_count("exclude_pad:False", 2, exactly=True).run( + repr(get_pad_cache().get_local_cache()) + ) + + inps = inp.chunk(2) + mm(inps, inp2) + FileCheck().check_count("exclude_pad:False", 3, exactly=True).run( + repr(get_pad_cache().get_local_cache()) + ) + if __name__ == "__main__": if HAS_CUDA: diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 1d8805c109bd..632b47082c1d 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -265,6 +265,13 @@ class NumBytesMetricTests(TestCase): inp = (T(10, 10), T(10, 10)) self.assertExpectedInline(count_numel(f, *inp), """400""") + def f(a, b): + a = a @ a + return torch.constant_pad_nd(torch.cat([a, b]), [2, 2], 0.5) + + inp = (T(10, 10), T(10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """680""") + @patch.object(config, "split_cat_fx_passes", False) @patch.object( config, diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 80630a6c8ce7..f378242305d3 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1057,6 +1057,11 @@ class trace: _save_config_ignore = [ # workaround: "Can't pickle " "trace.upload_tar", + "post_grad_custom_post_pass", + "post_grad_custom_pre_pass", + "joint_custom_pre_pass", + "joint_custom_post_pass", + "pre_grad_custom_pass", ] _cache_config_ignore_prefix = [ diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index ad8dcde91f72..f8e5f4f550d1 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -318,6 +318,30 @@ def should_exclude_padding_time(match, arg_name): if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous(): return False + # TODO - see issue https://githpub.com/pytorch/pytorch/issues/128889 + # We would only able to completely plan these out if we were only doing + # first dimension padding. non-first we would still need a copy + # because these outputs are fixed dense. + cannot_plan_output = [ + aten.mm.default, + aten.convolution.default, + aten.convolution_backward.default, + aten.bmm.default, + aten.addmm.default, + aten._scaled_dot_product_flash_attention.default, + aten._scaled_dot_product_efficient_attention.default, + ] + + if node_def.target in cannot_plan_output: + return False + + if ( + node_def.target == aten.cat.default + and len(node_def.all_input_nodes) + > torch._inductor.config.max_pointwise_cat_inputs + ): + return False + # optimistically assume we should be able to memory plan away # all non inputs return node_def.op != "placeholder" @@ -427,6 +451,7 @@ def should_pad_bench( mat2_pad = mat2 is_bmm = op is torch.ops.aten.bmm + mat1_pre_padded = should_exclude_padding_time(match, "mat1") fns = [] if mat1_pre_padded and (m_padded_length or k_padded_length): @@ -664,24 +689,35 @@ def should_pad_mm(match: Match) -> bool: def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False): - if k_padded_length != 0 or m_padded_length != 0: + if m_padded_length == 0 and k_padded_length == 0: + return mat1 + elif k_padded_length != 0 and m_padded_length != 0: # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding pad_arg = [0, k_padded_length, 0, m_padded_length] if is_bmm: pad_arg.extend((0, 0)) return aten.constant_pad_nd(mat1, pad_arg) - return mat1 + elif m_padded_length != 0: + return pad_dim(mat1, m_padded_length, 0 if not is_bmm else 1) + else: + assert k_padded_length != 0 + return pad_dim(mat1, k_padded_length, 1 if not is_bmm else 2) def pad_mat2(mat2, *, k_padded_length, n_padded_length, is_bmm=False): - if k_padded_length != 0 or n_padded_length != 0: + if k_padded_length == 0 and n_padded_length == 0: + return mat2 + elif k_padded_length != 0 and n_padded_length != 0: # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding pad_arg = [0, n_padded_length, 0, k_padded_length] if is_bmm: pad_arg.extend((0, 0)) return aten.constant_pad_nd(mat2, pad_arg) + elif k_padded_length != 0: + return pad_dim(mat2, k_padded_length, 0 if not is_bmm else 1) else: - return mat2 + assert n_padded_length != 0 + return pad_dim(mat2, n_padded_length, 1 if not is_bmm else 2) def pad_mm( diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 9fdf841f8bfb..f63d0d554154 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1429,11 +1429,17 @@ def cat(inputs, dim=0): MAX_COMPLEX_POINTWISE_CAT = 8 MAX_SIMPLE_OP_COUNT = 2 + def additional_pointwise_ops(op: torch._ops.OpOverload): + return op in (aten.cat.default, aten.constant_pad_nd.default) + if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or ( (len(inputs) <= config.max_pointwise_cat_inputs) and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs) ): - pointwise_uses = all(is_pointwise_use(use) for use in V.current_node.users) + pointwise_uses = all( + is_pointwise_use(use, additional_pointwise_ops) + for use in V.current_node.users + ) # fuse in case we will be used in a pointwise node, and there are any inputs we # we can prevent materialization of. fuse_pointwise_use = ( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 0d476c95d129..e2eac1aa144c 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -307,7 +307,15 @@ def is_view(op: torch._ops.OpOverload): return any(a.alias_info is not None for a in op._schema.arguments) -def is_pointwise_use(use): +def is_pointwise_use( + use, is_pointwise_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None +): + """ + Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn` + + Uses in views ops will follow the views uses + """ + if not use.op == "call_function": return False @@ -317,9 +325,11 @@ def is_pointwise_use(use): return False if use.target is operator.getitem or is_view(use.target): - return all(is_pointwise_use(u) for u in use.users) + return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users) - return torch.Tag.pointwise in use.target.tags + return torch.Tag.pointwise in use.target.tags or ( + is_pointwise_fn is not None and is_pointwise_fn(use.target) + ) def gen_gm_and_inputs(target, args, kwargs): diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 95b93df2c842..ab4e2034bfb9 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -8,6 +8,7 @@ import io import pickle import tokenize import unittest +import warnings from types import FunctionType, ModuleType from typing import Any, Dict, Optional, Set, Union from typing_extensions import deprecated @@ -178,6 +179,8 @@ class ConfigModule(ModuleType): mod = self.__name__ for k, v in self._config.items(): if k in self._config.get("_save_config_ignore", ()): + if v != self._default[k]: + warnings.warn(f"Skipping serialization of {k} value {v}") continue if v == self._default[k]: continue