Fix mm pad regresion - more conservative estimation of plannable inputs (#128909)

- More conservative estimation of plannable inputs
- Consider constant_pad_nd as pointwise node in concat lowering
- Use aten.cat instead of constant pad ndwhen padding just a single dimension because it can be memory-planned away

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128909
Approved by: https://github.com/Chillee
This commit is contained in:
eellison
2024-07-16 12:51:23 -07:00
committed by PyTorch MergeBot
parent 27ded03545
commit 16aaff7783
7 changed files with 97 additions and 8 deletions

View File

@ -8,6 +8,7 @@ import io
import pickle
import tokenize
import unittest
import warnings
from types import FunctionType, ModuleType
from typing import Any, Dict, Optional, Set, Union
from typing_extensions import deprecated
@ -178,6 +179,8 @@ class ConfigModule(ModuleType):
mod = self.__name__
for k, v in self._config.items():
if k in self._config.get("_save_config_ignore", ()):
if v != self._default[k]:
warnings.warn(f"Skipping serialization of {k} value {v}")
continue
if v == self._default[k]:
continue