mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176 Approved by: https://github.com/bobrenjc93
41 lines
889 B
Python
41 lines
889 B
Python
from enum import Enum
|
|
from typing import TypeVar
|
|
from typing_extensions import assert_never, assert_type, ParamSpec
|
|
|
|
import pytest
|
|
|
|
from torch import jit, nn, ScriptDict, ScriptFunction, ScriptList
|
|
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R", covariant=True)
|
|
|
|
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
BLUE = 3
|
|
|
|
|
|
# Script Enum
|
|
assert_type(jit.script(Color), type[Color])
|
|
|
|
# ScriptDict
|
|
assert_type(jit.script({1: 1}), ScriptDict)
|
|
|
|
# ScriptList
|
|
assert_type(jit.script([0]), ScriptList)
|
|
|
|
# ScriptModule
|
|
scripted_module = jit.script(nn.Linear(2, 2))
|
|
assert_type(scripted_module, jit.RecursiveScriptModule)
|
|
|
|
# ScripFunction
|
|
# NOTE: can't use assert_type because of parameter names
|
|
# NOTE: Generic usage only possible with Python 3.9
|
|
relu: ScriptFunction = jit.script(nn.functional.relu)
|
|
|
|
# can't script nn.Module class
|
|
with pytest.raises(RuntimeError):
|
|
assert_never(jit.script(nn.Linear))
|