mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
Type coverage for torch/_dynamo/variables/sdpa.py
This commit is contained in:
@ -1,7 +1,9 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
from inspect import getattr_static
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any, Iterable, TYPE_CHECKING, TypeGuard
|
||||
|
||||
from torch._guards import Source
|
||||
from torch.backends.cuda import SDPAParams
|
||||
from torch.fx.proxy import Proxy
|
||||
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..exc import Unsupported
|
||||
@ -29,9 +31,9 @@ class SDPAParamsVariable(VariableTracker):
|
||||
This is a read-only container."""
|
||||
|
||||
@staticmethod
|
||||
def create(tx: "InstructionTranslator", value, source):
|
||||
from torch.backends.cuda import SDPAParams
|
||||
|
||||
def create(
|
||||
tx: "InstructionTranslator", value: Any, source: Source
|
||||
) -> VariableTracker:
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
params = [
|
||||
@ -40,12 +42,14 @@ class SDPAParamsVariable(VariableTracker):
|
||||
]
|
||||
return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {})
|
||||
|
||||
def __init__(self, proxy, param_vars, **kwargs) -> None:
|
||||
def __init__(
|
||||
self, proxy: Proxy, param_vars: Iterable[VariableTracker], **kwargs: Any
|
||||
) -> None:
|
||||
self.proxy = proxy
|
||||
self.param_vars = param_vars
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
assert self.source is None
|
||||
assert self.param_vars is not None
|
||||
codegen.add_push_null(
|
||||
@ -54,7 +58,7 @@ class SDPAParamsVariable(VariableTracker):
|
||||
codegen.foreach(self.param_vars)
|
||||
codegen.extend_output(create_call_function(len(self.param_vars), False))
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Proxy:
|
||||
return self.proxy
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
@ -80,7 +84,5 @@ class SDPAParamsVariable(VariableTracker):
|
||||
return wrap_fx_proxy(tx=tx, proxy=proxy)
|
||||
|
||||
@staticmethod
|
||||
def is_sdpa_params(value):
|
||||
from torch.backends.cuda import SDPAParams
|
||||
|
||||
def is_sdpa_params(value: Any) -> TypeGuard["SDPAParams"]:
|
||||
return value is SDPAParams
|
||||
|
||||
Reference in New Issue
Block a user