mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchgen] Improve schema parsing with regex for numeric ranges (#140210)
Replaces the hardcoded string replacement for numeric ranges with a more robust regex pattern that handles any combination of positive and negative numbers in default value ranges. Fixes #135470 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140210 Approved by: https://github.com/ezyang
This commit is contained in:
@ -2740,6 +2740,15 @@ class TestWrapperSubclassAliasing(TestCase):
|
||||
kwargs = {"out": torch.empty(4)}
|
||||
self._test_wrapper_subclass_aliasing(torch.ops.aten.add.out, args, kwargs)
|
||||
|
||||
def test_wrapper_subclass_aliasing_fft_fft2(self, device):
|
||||
args = (torch.randn(4, 4),)
|
||||
kwargs = {}
|
||||
# fft_fft2 has a default arg 'int[1] dim=[-2,-1]',
|
||||
# Make sure that _return_and_correct_aliasing can handle this case
|
||||
# (I'm using inference_mode to make sure fft_fft2 doesn't decompose and goes to torch_dispatch)
|
||||
with torch.inference_mode():
|
||||
self._test_wrapper_subclass_aliasing(torch.ops.aten.fft_fft2, args, kwargs)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())
|
||||
|
||||
|
@ -549,8 +549,8 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
# which torchgen chokes on.
|
||||
torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str)
|
||||
torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str)
|
||||
# for aten::rot90
|
||||
torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]")
|
||||
# for aten::rot90 / aten:fft_*
|
||||
torchgen_schema_str = re.sub(r"=\[(-?[0-9]+), (-?[0-9]+)\]", r"=[\1,\2]", torchgen_schema_str)
|
||||
torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
|
||||
arg_schemas = [
|
||||
AliasInfo(
|
||||
|
Reference in New Issue
Block a user