From 6b1211df294e57d59c1e1717b1fedc671ec5bd5a Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Fri, 6 Jun 2025 13:28:05 +0000 Subject: [PATCH] [BE]: Backport runtime_checkable perf improvements/behavior from 3.12 (#155130) Backports some behavior changes and performance improvements with runtime_checkable in 3.12 to older versions of Python. Should be free performance improvement on typing checking protocols since everything works on Python 3.12. The difference between the two versions of runtime_checkable is [these lines](https://github.com/python/typing_extensions/blob/40e22ebb2ca5747eaa9405b152c43a294ac3af37/src/typing_extensions.py#L800-L823). Pull Request resolved: https://github.com/pytorch/pytorch/pull/155130 Approved by: https://github.com/rec, https://github.com/aorenste --- torch/_C/__init__.pyi.in | 4 +--- torch/distributed/_checkpointable.py | 2 +- torch/distributed/checkpoint/staging.py | 4 ++-- torch/distributed/checkpoint/stateful.py | 4 ++-- torch/onnx/_internal/fx/type_utils.py | 3 ++- torch/onnx/_internal/io_adapter.py | 3 ++- torch/utils/benchmark/utils/_stubs.py | 3 ++- 7 files changed, 12 insertions(+), 11 deletions(-) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 344f18717a91..6a4650fd50d1 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -16,12 +16,10 @@ from typing import ( Literal, NamedTuple, overload, - Protocol, - runtime_checkable, SupportsIndex, TypeVar, ) -from typing_extensions import ParamSpec, Self, TypeAlias +from typing_extensions import ParamSpec, Protocol, runtime_checkable, Self, TypeAlias import numpy diff --git a/torch/distributed/_checkpointable.py b/torch/distributed/_checkpointable.py index bc0a288f1291..0594c20337b3 100644 --- a/torch/distributed/_checkpointable.py +++ b/torch/distributed/_checkpointable.py @@ -1,5 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Protocol, runtime_checkable +from typing_extensions import Protocol, runtime_checkable import torch diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index 9f3233ad06d5..05f38e34ae33 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -1,5 +1,5 @@ -from typing import Optional, runtime_checkable -from typing_extensions import Protocol +from typing import Optional +from typing_extensions import Protocol, runtime_checkable from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE diff --git a/torch/distributed/checkpoint/stateful.py b/torch/distributed/checkpoint/stateful.py index 95cbb1873d64..15e227d92fb5 100644 --- a/torch/distributed/checkpoint/stateful.py +++ b/torch/distributed/checkpoint/stateful.py @@ -1,5 +1,5 @@ -from typing import Any, runtime_checkable, TypeVar -from typing_extensions import Protocol +from typing import Any, TypeVar +from typing_extensions import Protocol, runtime_checkable __all__ = ["Stateful", "StatefulT"] diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index 4a6e508c1a38..968f69328011 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -4,7 +4,8 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from typing import Any, Optional, Protocol, runtime_checkable, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union +from typing_extensions import Protocol, runtime_checkable import onnx diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 1ed96e62ffad..6c414e8d54e7 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs from __future__ import annotations -from typing import Any, Callable, Protocol, runtime_checkable, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import Protocol, runtime_checkable import torch import torch.export as torch_export diff --git a/torch/utils/benchmark/utils/_stubs.py b/torch/utils/benchmark/utils/_stubs.py index 60861d1f412a..068e62ec87a3 100644 --- a/torch/utils/benchmark/utils/_stubs.py +++ b/torch/utils/benchmark/utils/_stubs.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Protocol, runtime_checkable +from typing import Any, Callable +from typing_extensions import Protocol, runtime_checkable class TimerClass(Protocol):