mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	When using dmypy, this setting is enabled and cannot be turned off. Force it for regular mypy too. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/118467 Approved by: https://github.com/Skylion007 ghstack dependencies: #118414, #118418, #118432
		
			
				
	
	
		
			92 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			92 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import Optional
 | 
						|
import torch
 | 
						|
from torch.overrides import TorchFunctionMode
 | 
						|
from torch.utils._contextlib import context_decorator
 | 
						|
import functools
 | 
						|
 | 
						|
CURRENT_DEVICE: Optional[torch.device] = None
 | 
						|
 | 
						|
@functools.lru_cache(1)
 | 
						|
def _device_constructors():
 | 
						|
    return {
 | 
						|
        # standard ones
 | 
						|
        torch.empty,
 | 
						|
        torch.empty_permuted,
 | 
						|
        torch.empty_strided,
 | 
						|
        torch.empty_quantized,
 | 
						|
        torch.ones,
 | 
						|
        torch.arange,
 | 
						|
        torch.bartlett_window,
 | 
						|
        torch.blackman_window,
 | 
						|
        torch.eye,
 | 
						|
        torch.fft.fftfreq,
 | 
						|
        torch.fft.rfftfreq,
 | 
						|
        torch.full,
 | 
						|
        torch.fill,
 | 
						|
        torch.hamming_window,
 | 
						|
        torch.hann_window,
 | 
						|
        torch.kaiser_window,
 | 
						|
        torch.linspace,
 | 
						|
        torch.logspace,
 | 
						|
        torch.nested.nested_tensor,
 | 
						|
        # This function doesn't actually take a device argument
 | 
						|
        # torch.normal,
 | 
						|
        torch.ones,
 | 
						|
        torch.rand,
 | 
						|
        torch.randn,
 | 
						|
        torch.randint,
 | 
						|
        torch.randperm,
 | 
						|
        torch.range,
 | 
						|
        torch.sparse_coo_tensor,
 | 
						|
        torch.sparse_compressed_tensor,
 | 
						|
        torch.sparse_csr_tensor,
 | 
						|
        torch.sparse_csc_tensor,
 | 
						|
        torch.sparse_bsr_tensor,
 | 
						|
        torch.sparse_bsc_tensor,
 | 
						|
        torch.tril_indices,
 | 
						|
        torch.triu_indices,
 | 
						|
        torch.vander,
 | 
						|
        torch.zeros,
 | 
						|
        torch.asarray,
 | 
						|
        # weird ones
 | 
						|
        torch.tensor,
 | 
						|
        torch.as_tensor,
 | 
						|
        torch.scalar_tensor,
 | 
						|
        torch.asarray,
 | 
						|
    }
 | 
						|
 | 
						|
# NB: This is directly called from C++ in torch/csrc/Device.cpp
 | 
						|
class DeviceContext(TorchFunctionMode):
 | 
						|
    def __init__(self, device):
 | 
						|
        self.device = torch.device(device)
 | 
						|
 | 
						|
    def __enter__(self):
 | 
						|
        global CURRENT_DEVICE
 | 
						|
        self.old_device = CURRENT_DEVICE
 | 
						|
        CURRENT_DEVICE = self.device
 | 
						|
        return super().__enter__()
 | 
						|
 | 
						|
    def __exit__(self, exc_type, exc_val, exc_tb):
 | 
						|
        global CURRENT_DEVICE
 | 
						|
        CURRENT_DEVICE = self.old_device
 | 
						|
        return super().__exit__(exc_type, exc_val, exc_tb)
 | 
						|
 | 
						|
    def __torch_function__(self, func, types, args=(), kwargs=None):
 | 
						|
        kwargs = kwargs or {}
 | 
						|
        if func in _device_constructors() and kwargs.get('device') is None:
 | 
						|
            kwargs['device'] = self.device
 | 
						|
        return func(*args, **kwargs)
 | 
						|
 | 
						|
# NB: This is directly called from C++ in torch/csrc/Device.cpp
 | 
						|
def device_decorator(device, func):
 | 
						|
    return context_decorator(lambda: device, func)
 | 
						|
 | 
						|
def set_device(device):
 | 
						|
    """
 | 
						|
    Set the default device inside of the wrapped function by decorating it with this function.
 | 
						|
 | 
						|
    If you would like to use this as a context manager, use device as a
 | 
						|
    context manager directly, e.g., ``with torch.device(device)``.
 | 
						|
    """
 | 
						|
    return lambda func: device_decorator(torch.device(device), func)
 |