mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix mm pad regresion - more conservative estimation of plannable inputs (#128909)
- More conservative estimation of plannable inputs - Consider constant_pad_nd as pointwise node in concat lowering - Use aten.cat instead of constant pad ndwhen padding just a single dimension because it can be memory-planned away Pull Request resolved: https://github.com/pytorch/pytorch/pull/128909 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
27ded03545
commit
16aaff7783
@ -422,6 +422,28 @@ class PadMMTest(TestCase):
|
||||
repr(local_cache)
|
||||
)
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@inductor_config.patch(max_pointwise_cat_inputs=2)
|
||||
def test_exclude_cat_padding(self):
|
||||
@torch.compile()
|
||||
def mm(inps, b):
|
||||
return torch.cat(inps) @ b
|
||||
|
||||
inp = torch.rand([2046, 2046], device="cuda")
|
||||
inp2 = torch.rand([2046, 2046], device="cuda")
|
||||
|
||||
inps = inp.chunk(3)
|
||||
mm(inps, inp2)
|
||||
FileCheck().check_count("exclude_pad:False", 2, exactly=True).run(
|
||||
repr(get_pad_cache().get_local_cache())
|
||||
)
|
||||
|
||||
inps = inp.chunk(2)
|
||||
mm(inps, inp2)
|
||||
FileCheck().check_count("exclude_pad:False", 3, exactly=True).run(
|
||||
repr(get_pad_cache().get_local_cache())
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_CUDA:
|
||||
|
@ -265,6 +265,13 @@ class NumBytesMetricTests(TestCase):
|
||||
inp = (T(10, 10), T(10, 10))
|
||||
self.assertExpectedInline(count_numel(f, *inp), """400""")
|
||||
|
||||
def f(a, b):
|
||||
a = a @ a
|
||||
return torch.constant_pad_nd(torch.cat([a, b]), [2, 2], 0.5)
|
||||
|
||||
inp = (T(10, 10), T(10, 10))
|
||||
self.assertExpectedInline(count_numel(f, *inp), """680""")
|
||||
|
||||
@patch.object(config, "split_cat_fx_passes", False)
|
||||
@patch.object(
|
||||
config,
|
||||
|
@ -1057,6 +1057,11 @@ class trace:
|
||||
_save_config_ignore = [
|
||||
# workaround: "Can't pickle <function ...>"
|
||||
"trace.upload_tar",
|
||||
"post_grad_custom_post_pass",
|
||||
"post_grad_custom_pre_pass",
|
||||
"joint_custom_pre_pass",
|
||||
"joint_custom_post_pass",
|
||||
"pre_grad_custom_pass",
|
||||
]
|
||||
|
||||
_cache_config_ignore_prefix = [
|
||||
|
@ -318,6 +318,30 @@ def should_exclude_padding_time(match, arg_name):
|
||||
if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous():
|
||||
return False
|
||||
|
||||
# TODO - see issue https://githpub.com/pytorch/pytorch/issues/128889
|
||||
# We would only able to completely plan these out if we were only doing
|
||||
# first dimension padding. non-first we would still need a copy
|
||||
# because these outputs are fixed dense.
|
||||
cannot_plan_output = [
|
||||
aten.mm.default,
|
||||
aten.convolution.default,
|
||||
aten.convolution_backward.default,
|
||||
aten.bmm.default,
|
||||
aten.addmm.default,
|
||||
aten._scaled_dot_product_flash_attention.default,
|
||||
aten._scaled_dot_product_efficient_attention.default,
|
||||
]
|
||||
|
||||
if node_def.target in cannot_plan_output:
|
||||
return False
|
||||
|
||||
if (
|
||||
node_def.target == aten.cat.default
|
||||
and len(node_def.all_input_nodes)
|
||||
> torch._inductor.config.max_pointwise_cat_inputs
|
||||
):
|
||||
return False
|
||||
|
||||
# optimistically assume we should be able to memory plan away
|
||||
# all non inputs
|
||||
return node_def.op != "placeholder"
|
||||
@ -427,6 +451,7 @@ def should_pad_bench(
|
||||
mat2_pad = mat2
|
||||
|
||||
is_bmm = op is torch.ops.aten.bmm
|
||||
|
||||
mat1_pre_padded = should_exclude_padding_time(match, "mat1")
|
||||
fns = []
|
||||
if mat1_pre_padded and (m_padded_length or k_padded_length):
|
||||
@ -664,24 +689,35 @@ def should_pad_mm(match: Match) -> bool:
|
||||
|
||||
|
||||
def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False):
|
||||
if k_padded_length != 0 or m_padded_length != 0:
|
||||
if m_padded_length == 0 and k_padded_length == 0:
|
||||
return mat1
|
||||
elif k_padded_length != 0 and m_padded_length != 0:
|
||||
# dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
|
||||
pad_arg = [0, k_padded_length, 0, m_padded_length]
|
||||
if is_bmm:
|
||||
pad_arg.extend((0, 0))
|
||||
return aten.constant_pad_nd(mat1, pad_arg)
|
||||
return mat1
|
||||
elif m_padded_length != 0:
|
||||
return pad_dim(mat1, m_padded_length, 0 if not is_bmm else 1)
|
||||
else:
|
||||
assert k_padded_length != 0
|
||||
return pad_dim(mat1, k_padded_length, 1 if not is_bmm else 2)
|
||||
|
||||
|
||||
def pad_mat2(mat2, *, k_padded_length, n_padded_length, is_bmm=False):
|
||||
if k_padded_length != 0 or n_padded_length != 0:
|
||||
if k_padded_length == 0 and n_padded_length == 0:
|
||||
return mat2
|
||||
elif k_padded_length != 0 and n_padded_length != 0:
|
||||
# dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
|
||||
pad_arg = [0, n_padded_length, 0, k_padded_length]
|
||||
if is_bmm:
|
||||
pad_arg.extend((0, 0))
|
||||
return aten.constant_pad_nd(mat2, pad_arg)
|
||||
elif k_padded_length != 0:
|
||||
return pad_dim(mat2, k_padded_length, 0 if not is_bmm else 1)
|
||||
else:
|
||||
return mat2
|
||||
assert n_padded_length != 0
|
||||
return pad_dim(mat2, n_padded_length, 1 if not is_bmm else 2)
|
||||
|
||||
|
||||
def pad_mm(
|
||||
|
@ -1429,11 +1429,17 @@ def cat(inputs, dim=0):
|
||||
MAX_COMPLEX_POINTWISE_CAT = 8
|
||||
MAX_SIMPLE_OP_COUNT = 2
|
||||
|
||||
def additional_pointwise_ops(op: torch._ops.OpOverload):
|
||||
return op in (aten.cat.default, aten.constant_pad_nd.default)
|
||||
|
||||
if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or (
|
||||
(len(inputs) <= config.max_pointwise_cat_inputs)
|
||||
and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs)
|
||||
):
|
||||
pointwise_uses = all(is_pointwise_use(use) for use in V.current_node.users)
|
||||
pointwise_uses = all(
|
||||
is_pointwise_use(use, additional_pointwise_ops)
|
||||
for use in V.current_node.users
|
||||
)
|
||||
# fuse in case we will be used in a pointwise node, and there are any inputs we
|
||||
# we can prevent materialization of.
|
||||
fuse_pointwise_use = (
|
||||
|
@ -307,7 +307,15 @@ def is_view(op: torch._ops.OpOverload):
|
||||
return any(a.alias_info is not None for a in op._schema.arguments)
|
||||
|
||||
|
||||
def is_pointwise_use(use):
|
||||
def is_pointwise_use(
|
||||
use, is_pointwise_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None
|
||||
):
|
||||
"""
|
||||
Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn`
|
||||
|
||||
Uses in views ops will follow the views uses
|
||||
"""
|
||||
|
||||
if not use.op == "call_function":
|
||||
return False
|
||||
|
||||
@ -317,9 +325,11 @@ def is_pointwise_use(use):
|
||||
return False
|
||||
|
||||
if use.target is operator.getitem or is_view(use.target):
|
||||
return all(is_pointwise_use(u) for u in use.users)
|
||||
return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users)
|
||||
|
||||
return torch.Tag.pointwise in use.target.tags
|
||||
return torch.Tag.pointwise in use.target.tags or (
|
||||
is_pointwise_fn is not None and is_pointwise_fn(use.target)
|
||||
)
|
||||
|
||||
|
||||
def gen_gm_and_inputs(target, args, kwargs):
|
||||
|
@ -8,6 +8,7 @@ import io
|
||||
import pickle
|
||||
import tokenize
|
||||
import unittest
|
||||
import warnings
|
||||
from types import FunctionType, ModuleType
|
||||
from typing import Any, Dict, Optional, Set, Union
|
||||
from typing_extensions import deprecated
|
||||
@ -178,6 +179,8 @@ class ConfigModule(ModuleType):
|
||||
mod = self.__name__
|
||||
for k, v in self._config.items():
|
||||
if k in self._config.get("_save_config_ignore", ()):
|
||||
if v != self._default[k]:
|
||||
warnings.warn(f"Skipping serialization of {k} value {v}")
|
||||
continue
|
||||
if v == self._default[k]:
|
||||
continue
|
||||
|
Reference in New Issue
Block a user