Fix mm pad regresion - more conservative estimation of plannable inputs (#128909)

- More conservative estimation of plannable inputs
- Consider constant_pad_nd as pointwise node in concat lowering
- Use aten.cat instead of constant pad ndwhen padding just a single dimension because it can be memory-planned away

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128909
Approved by: https://github.com/Chillee
This commit is contained in:
eellison
2024-07-16 12:51:23 -07:00
committed by PyTorch MergeBot
parent 27ded03545
commit 16aaff7783
7 changed files with 97 additions and 8 deletions

View File

@ -422,6 +422,28 @@ class PadMMTest(TestCase):
repr(local_cache)
)
@fresh_inductor_cache()
@inductor_config.patch(max_pointwise_cat_inputs=2)
def test_exclude_cat_padding(self):
@torch.compile()
def mm(inps, b):
return torch.cat(inps) @ b
inp = torch.rand([2046, 2046], device="cuda")
inp2 = torch.rand([2046, 2046], device="cuda")
inps = inp.chunk(3)
mm(inps, inp2)
FileCheck().check_count("exclude_pad:False", 2, exactly=True).run(
repr(get_pad_cache().get_local_cache())
)
inps = inp.chunk(2)
mm(inps, inp2)
FileCheck().check_count("exclude_pad:False", 3, exactly=True).run(
repr(get_pad_cache().get_local_cache())
)
if __name__ == "__main__":
if HAS_CUDA:

View File

@ -265,6 +265,13 @@ class NumBytesMetricTests(TestCase):
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """400""")
def f(a, b):
a = a @ a
return torch.constant_pad_nd(torch.cat([a, b]), [2, 2], 0.5)
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """680""")
@patch.object(config, "split_cat_fx_passes", False)
@patch.object(
config,

View File

@ -1057,6 +1057,11 @@ class trace:
_save_config_ignore = [
# workaround: "Can't pickle <function ...>"
"trace.upload_tar",
"post_grad_custom_post_pass",
"post_grad_custom_pre_pass",
"joint_custom_pre_pass",
"joint_custom_post_pass",
"pre_grad_custom_pass",
]
_cache_config_ignore_prefix = [

View File

@ -318,6 +318,30 @@ def should_exclude_padding_time(match, arg_name):
if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous():
return False
# TODO - see issue https://githpub.com/pytorch/pytorch/issues/128889
# We would only able to completely plan these out if we were only doing
# first dimension padding. non-first we would still need a copy
# because these outputs are fixed dense.
cannot_plan_output = [
aten.mm.default,
aten.convolution.default,
aten.convolution_backward.default,
aten.bmm.default,
aten.addmm.default,
aten._scaled_dot_product_flash_attention.default,
aten._scaled_dot_product_efficient_attention.default,
]
if node_def.target in cannot_plan_output:
return False
if (
node_def.target == aten.cat.default
and len(node_def.all_input_nodes)
> torch._inductor.config.max_pointwise_cat_inputs
):
return False
# optimistically assume we should be able to memory plan away
# all non inputs
return node_def.op != "placeholder"
@ -427,6 +451,7 @@ def should_pad_bench(
mat2_pad = mat2
is_bmm = op is torch.ops.aten.bmm
mat1_pre_padded = should_exclude_padding_time(match, "mat1")
fns = []
if mat1_pre_padded and (m_padded_length or k_padded_length):
@ -664,24 +689,35 @@ def should_pad_mm(match: Match) -> bool:
def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False):
if k_padded_length != 0 or m_padded_length != 0:
if m_padded_length == 0 and k_padded_length == 0:
return mat1
elif k_padded_length != 0 and m_padded_length != 0:
# dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
pad_arg = [0, k_padded_length, 0, m_padded_length]
if is_bmm:
pad_arg.extend((0, 0))
return aten.constant_pad_nd(mat1, pad_arg)
return mat1
elif m_padded_length != 0:
return pad_dim(mat1, m_padded_length, 0 if not is_bmm else 1)
else:
assert k_padded_length != 0
return pad_dim(mat1, k_padded_length, 1 if not is_bmm else 2)
def pad_mat2(mat2, *, k_padded_length, n_padded_length, is_bmm=False):
if k_padded_length != 0 or n_padded_length != 0:
if k_padded_length == 0 and n_padded_length == 0:
return mat2
elif k_padded_length != 0 and n_padded_length != 0:
# dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
pad_arg = [0, n_padded_length, 0, k_padded_length]
if is_bmm:
pad_arg.extend((0, 0))
return aten.constant_pad_nd(mat2, pad_arg)
elif k_padded_length != 0:
return pad_dim(mat2, k_padded_length, 0 if not is_bmm else 1)
else:
return mat2
assert n_padded_length != 0
return pad_dim(mat2, n_padded_length, 1 if not is_bmm else 2)
def pad_mm(

View File

@ -1429,11 +1429,17 @@ def cat(inputs, dim=0):
MAX_COMPLEX_POINTWISE_CAT = 8
MAX_SIMPLE_OP_COUNT = 2
def additional_pointwise_ops(op: torch._ops.OpOverload):
return op in (aten.cat.default, aten.constant_pad_nd.default)
if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or (
(len(inputs) <= config.max_pointwise_cat_inputs)
and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs)
):
pointwise_uses = all(is_pointwise_use(use) for use in V.current_node.users)
pointwise_uses = all(
is_pointwise_use(use, additional_pointwise_ops)
for use in V.current_node.users
)
# fuse in case we will be used in a pointwise node, and there are any inputs we
# we can prevent materialization of.
fuse_pointwise_use = (

View File

@ -307,7 +307,15 @@ def is_view(op: torch._ops.OpOverload):
return any(a.alias_info is not None for a in op._schema.arguments)
def is_pointwise_use(use):
def is_pointwise_use(
use, is_pointwise_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None
):
"""
Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn`
Uses in views ops will follow the views uses
"""
if not use.op == "call_function":
return False
@ -317,9 +325,11 @@ def is_pointwise_use(use):
return False
if use.target is operator.getitem or is_view(use.target):
return all(is_pointwise_use(u) for u in use.users)
return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users)
return torch.Tag.pointwise in use.target.tags
return torch.Tag.pointwise in use.target.tags or (
is_pointwise_fn is not None and is_pointwise_fn(use.target)
)
def gen_gm_and_inputs(target, args, kwargs):

View File

@ -8,6 +8,7 @@ import io
import pickle
import tokenize
import unittest
import warnings
from types import FunctionType, ModuleType
from typing import Any, Dict, Optional, Set, Union
from typing_extensions import deprecated
@ -178,6 +179,8 @@ class ConfigModule(ModuleType):
mod = self.__name__
for k, v in self._config.items():
if k in self._config.get("_save_config_ignore", ()):
if v != self._default[k]:
warnings.warn(f"Skipping serialization of {k} value {v}")
continue
if v == self._default[k]:
continue