mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Replace manual cache in _python_dispatch.get_alias_info with functools.cache (#161286)
In addition to being more code, the manual cache was doing an extra dictionary lookup on each cache hit. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161286 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
9de9d25f8d
commit
80bf883d21
@ -1,10 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import functools
|
||||
import warnings
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, overload, Protocol, Union
|
||||
from typing import Optional, overload, Protocol, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
@ -527,15 +528,10 @@ class SchemaInfo:
|
||||
outs: list[AliasInfo]
|
||||
|
||||
|
||||
# Can't import torch._ops.OpOverload due to circular reference
|
||||
parsed_schema_map: dict[Any, SchemaInfo] = {}
|
||||
|
||||
|
||||
# Given an OpOverload, returns schema information on it.
|
||||
# This is cached for efficiency, since it can involve running torchgen
|
||||
@functools.cache
|
||||
def get_alias_info(func) -> SchemaInfo:
|
||||
if func in parsed_schema_map:
|
||||
return parsed_schema_map[func]
|
||||
# For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations
|
||||
# properly for some ops that output tensorlists)
|
||||
if func.namespace == "aten":
|
||||
@ -598,7 +594,6 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
for a in func._schema.returns
|
||||
]
|
||||
schema_info = SchemaInfo(args=arg_schemas, outs=out_schemas)
|
||||
parsed_schema_map[func] = schema_info
|
||||
return schema_info
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user