mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check --- step 1: uncomment lines in the `pyrefly.toml` file before: https://gist.github.com/maggiemoss/911b4d0bc88bf8cf3ab91f67184e9d46 after: ``` INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml` INFO 0 errors (1,152 ignored) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164513 Approved by: https://github.com/oulgen
		
			
				
	
	
		
			70 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			70 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# mypy: allow-untyped-defs
 | 
						|
import torch.cuda
 | 
						|
 | 
						|
 | 
						|
try:
 | 
						|
    from torch._C import _cudnn
 | 
						|
except ImportError:
 | 
						|
    # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
 | 
						|
    # so it's safe to not emit any checks here.
 | 
						|
    _cudnn = None  # type: ignore[assignment]
 | 
						|
 | 
						|
 | 
						|
def get_cudnn_mode(mode):
 | 
						|
    if mode == "RNN_RELU":
 | 
						|
        # pyrefly: ignore  # missing-attribute
 | 
						|
        return int(_cudnn.RNNMode.rnn_relu)
 | 
						|
    elif mode == "RNN_TANH":
 | 
						|
        # pyrefly: ignore  # missing-attribute
 | 
						|
        return int(_cudnn.RNNMode.rnn_tanh)
 | 
						|
    elif mode == "LSTM":
 | 
						|
        # pyrefly: ignore  # missing-attribute
 | 
						|
        return int(_cudnn.RNNMode.lstm)
 | 
						|
    elif mode == "GRU":
 | 
						|
        # pyrefly: ignore  # missing-attribute
 | 
						|
        return int(_cudnn.RNNMode.gru)
 | 
						|
    else:
 | 
						|
        raise Exception(f"Unknown mode: {mode}")  # noqa: TRY002
 | 
						|
 | 
						|
 | 
						|
# NB: We don't actually need this class anymore (in fact, we could serialize the
 | 
						|
# dropout state for even better reproducibility), but it is kept for backwards
 | 
						|
# compatibility for old models.
 | 
						|
class Unserializable:
 | 
						|
    def __init__(self, inner):
 | 
						|
        self.inner = inner
 | 
						|
 | 
						|
    def get(self):
 | 
						|
        return self.inner
 | 
						|
 | 
						|
    def __getstate__(self):
 | 
						|
        # Note: can't return {}, because python2 won't call __setstate__
 | 
						|
        # if the value evaluates to False
 | 
						|
        return "<unserializable>"
 | 
						|
 | 
						|
    def __setstate__(self, state):
 | 
						|
        self.inner = None
 | 
						|
 | 
						|
 | 
						|
def init_dropout_state(dropout, train, dropout_seed, dropout_state):
 | 
						|
    dropout_desc_name = "desc_" + str(torch.cuda.current_device())
 | 
						|
    dropout_p = dropout if train else 0
 | 
						|
    if (dropout_desc_name not in dropout_state) or (
 | 
						|
        dropout_state[dropout_desc_name].get() is None
 | 
						|
    ):
 | 
						|
        if dropout_p == 0:
 | 
						|
            dropout_state[dropout_desc_name] = Unserializable(None)
 | 
						|
        else:
 | 
						|
            dropout_state[dropout_desc_name] = Unserializable(
 | 
						|
                torch._cudnn_init_dropout_state(  # type: ignore[call-arg]
 | 
						|
                    dropout_p,
 | 
						|
                    train,
 | 
						|
                    dropout_seed,
 | 
						|
                    # pyrefly: ignore  # unexpected-keyword
 | 
						|
                    self_ty=torch.uint8,
 | 
						|
                    device=torch.device("cuda"),
 | 
						|
                )
 | 
						|
            )
 | 
						|
    dropout_ts = dropout_state[dropout_desc_name].get()
 | 
						|
    return dropout_ts
 |