[torchgen] Improve schema parsing with regex for numeric ranges (#140210)

Replaces the hardcoded string replacement for numeric ranges with a more robust regex pattern that handles any combination of positive and negative numbers in default value ranges.
Fixes #135470

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140210
Approved by: https://github.com/ezyang
This commit is contained in:
Aki
2024-11-14 23:28:25 +00:00
committed by PyTorch MergeBot
parent e90888a93d
commit 9c818c880f
2 changed files with 11 additions and 2 deletions

View File

@ -2740,6 +2740,15 @@ class TestWrapperSubclassAliasing(TestCase):
kwargs = {"out": torch.empty(4)}
self._test_wrapper_subclass_aliasing(torch.ops.aten.add.out, args, kwargs)
def test_wrapper_subclass_aliasing_fft_fft2(self, device):
args = (torch.randn(4, 4),)
kwargs = {}
# fft_fft2 has a default arg 'int[1] dim=[-2,-1]',
# Make sure that _return_and_correct_aliasing can handle this case
# (I'm using inference_mode to make sure fft_fft2 doesn't decompose and goes to torch_dispatch)
with torch.inference_mode():
self._test_wrapper_subclass_aliasing(torch.ops.aten.fft_fft2, args, kwargs)
instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())

View File

@ -549,8 +549,8 @@ def get_alias_info(func) -> SchemaInfo:
# which torchgen chokes on.
torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str)
torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str)
# for aten::rot90
torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]")
# for aten::rot90 / aten:fft_*
torchgen_schema_str = re.sub(r"=\[(-?[0-9]+), (-?[0-9]+)\]", r"=[\1,\2]", torchgen_schema_str)
torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
arg_schemas = [
AliasInfo(