[torchgen] Improve schema parsing with regex for numeric ranges (#140210)

Replaces the hardcoded string replacement for numeric ranges with a more robust regex pattern that handles any combination of positive and negative numbers in default value ranges.
Fixes #135470

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140210
Approved by: https://github.com/ezyang
This commit is contained in:
Aki
2024-11-14 23:28:25 +00:00
committed by PyTorch MergeBot
parent e90888a93d
commit 9c818c880f
2 changed files with 11 additions and 2 deletions

View File

@ -549,8 +549,8 @@ def get_alias_info(func) -> SchemaInfo:
# which torchgen chokes on.
torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str)
torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str)
# for aten::rot90
torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]")
# for aten::rot90 / aten:fft_*
torchgen_schema_str = re.sub(r"=\[(-?[0-9]+), (-?[0-9]+)\]", r"=[\1,\2]", torchgen_schema_str)
torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
arg_schemas = [
AliasInfo(