Files
pytorch/torch/backends/cudnn/rnn.py
Maggie Moss 1051c1de5c Add pyrefly suppressions 2/n (#164513)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

---
step 1: uncomment lines in the `pyrefly.toml` file
before: https://gist.github.com/maggiemoss/911b4d0bc88bf8cf3ab91f67184e9d46

after:
```
 INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml`
 INFO 0 errors (1,152 ignored)
 ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164513
Approved by: https://github.com/oulgen
2025-10-03 02:46:13 +00:00

70 lines
2.3 KiB
Python

# mypy: allow-untyped-defs
import torch.cuda
try:
from torch._C import _cudnn
except ImportError:
# Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
# so it's safe to not emit any checks here.
_cudnn = None # type: ignore[assignment]
def get_cudnn_mode(mode):
if mode == "RNN_RELU":
# pyrefly: ignore # missing-attribute
return int(_cudnn.RNNMode.rnn_relu)
elif mode == "RNN_TANH":
# pyrefly: ignore # missing-attribute
return int(_cudnn.RNNMode.rnn_tanh)
elif mode == "LSTM":
# pyrefly: ignore # missing-attribute
return int(_cudnn.RNNMode.lstm)
elif mode == "GRU":
# pyrefly: ignore # missing-attribute
return int(_cudnn.RNNMode.gru)
else:
raise Exception(f"Unknown mode: {mode}") # noqa: TRY002
# NB: We don't actually need this class anymore (in fact, we could serialize the
# dropout state for even better reproducibility), but it is kept for backwards
# compatibility for old models.
class Unserializable:
def __init__(self, inner):
self.inner = inner
def get(self):
return self.inner
def __getstate__(self):
# Note: can't return {}, because python2 won't call __setstate__
# if the value evaluates to False
return "<unserializable>"
def __setstate__(self, state):
self.inner = None
def init_dropout_state(dropout, train, dropout_seed, dropout_state):
dropout_desc_name = "desc_" + str(torch.cuda.current_device())
dropout_p = dropout if train else 0
if (dropout_desc_name not in dropout_state) or (
dropout_state[dropout_desc_name].get() is None
):
if dropout_p == 0:
dropout_state[dropout_desc_name] = Unserializable(None)
else:
dropout_state[dropout_desc_name] = Unserializable(
torch._cudnn_init_dropout_state( # type: ignore[call-arg]
dropout_p,
train,
dropout_seed,
# pyrefly: ignore # unexpected-keyword
self_ty=torch.uint8,
device=torch.device("cuda"),
)
)
dropout_ts = dropout_state[dropout_desc_name].get()
return dropout_ts