Replace manual cache in _python_dispatch.get_alias_info with functools.cache (#161286)

In addition to being more code, the manual cache was doing an extra dictionary lookup on each cache hit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161286
Approved by: https://github.com/wconstab
This commit is contained in:
Scott Wolchok
2025-08-25 13:02:21 -07:00
committed by PyTorch MergeBot
parent 9de9d25f8d
commit 80bf883d21

View File

@ -1,10 +1,11 @@
# mypy: allow-untyped-defs
import contextlib
import functools
import warnings
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Optional, overload, Protocol, Union
from typing import Optional, overload, Protocol, Union
from typing_extensions import TypeIs
import torch
@ -527,15 +528,10 @@ class SchemaInfo:
outs: list[AliasInfo]
# Can't import torch._ops.OpOverload due to circular reference
parsed_schema_map: dict[Any, SchemaInfo] = {}
# Given an OpOverload, returns schema information on it.
# This is cached for efficiency, since it can involve running torchgen
@functools.cache
def get_alias_info(func) -> SchemaInfo:
if func in parsed_schema_map:
return parsed_schema_map[func]
# For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations
# properly for some ops that output tensorlists)
if func.namespace == "aten":
@ -598,7 +594,6 @@ def get_alias_info(func) -> SchemaInfo:
for a in func._schema.returns
]
schema_info = SchemaInfo(args=arg_schemas, outs=out_schemas)
parsed_schema_map[func] = schema_info
return schema_info