mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	This reverts commit d7e1bf9ff908d2a9c20d5354426d34c539fcb7a1. Reverted https://github.com/pytorch/pytorch/pull/123612 on behalf of https://github.com/jeffdaily due to This broke ROCm. see test_overrides.py ([comment](https://github.com/pytorch/pytorch/pull/123611#issuecomment-2067363780))
		
			
				
	
	
		
			963 lines
		
	
	
		
			35 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			963 lines
		
	
	
		
			35 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import copyreg
 | |
| import functools
 | |
| import logging
 | |
| import sys
 | |
| import traceback
 | |
| import warnings
 | |
| from collections import defaultdict
 | |
| from typing import Any, Callable, DefaultDict, Generic, List, Optional
 | |
| 
 | |
| from typing_extensions import ParamSpec
 | |
| 
 | |
| import torch
 | |
| 
 | |
| 
 | |
| def _type(self, dtype=None, non_blocking=False, **kwargs):
 | |
|     """Returns the type if `dtype` is not provided, else casts this object to
 | |
|     the specified type.
 | |
| 
 | |
|     If this is already of the correct type, no copy is performed and the
 | |
|     original object is returned.
 | |
| 
 | |
|     Args:
 | |
|         dtype (type or string): The desired type
 | |
|         non_blocking (bool): If ``True``, and the source is in pinned memory
 | |
|             and destination is on the GPU or vice versa, the copy is performed
 | |
|             asynchronously with respect to the host. Otherwise, the argument
 | |
|             has no effect.
 | |
|         **kwargs: For compatibility, may contain the key ``async`` in place of
 | |
|             the ``non_blocking`` argument. The ``async`` arg is deprecated.
 | |
|     """
 | |
|     non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs)
 | |
|     if dtype is None:
 | |
|         return self.__module__ + "." + self.__class__.__name__
 | |
| 
 | |
|     if isinstance(dtype, str):
 | |
|         dtype = _import_dotted_name(dtype)
 | |
|     if dtype == type(self):
 | |
|         return self
 | |
|     if self.is_sparse:
 | |
|         if not dtype.is_sparse:
 | |
|             raise RuntimeError("Cannot cast sparse tensor to dense tensor")
 | |
|         new_module_name = dtype.__module__.replace(".sparse", "")
 | |
|         new_values_type_name = new_module_name + "." + dtype.__name__
 | |
|         new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
 | |
|         new_indices_type_name = new_module_name + ".LongTensor"
 | |
|         new_indices = torch.Tensor._indices(self).type(
 | |
|             new_indices_type_name, non_blocking
 | |
|         )
 | |
|         return dtype(new_indices, new_values, self.size())
 | |
|     if dtype.is_sparse:
 | |
|         raise RuntimeError("Cannot cast dense tensor to sparse tensor")
 | |
|     return dtype(self.size()).copy_(self, non_blocking)
 | |
| 
 | |
| 
 | |
| def _hpu(self, device=None, non_blocking=False, **kwargs):
 | |
|     """Returns a copy of this object in HPU memory.
 | |
| 
 | |
|     If this object is already in HPU memory and on the correct device, then
 | |
|     no copy is performed and the original object is returned.
 | |
| 
 | |
|     Args:
 | |
|         device (int): The destination HPU id. Defaults to the current device.
 | |
|         non_blocking (bool): If ``True`` and the source is in pinned memory,
 | |
|             the copy will be asynchronous with respect to the host. Otherwise,
 | |
|             the argument has no effect.
 | |
|         **kwargs: For compatibility, may contain the key ``async`` in place of
 | |
|             the ``non_blocking`` argument.
 | |
|     """
 | |
|     non_blocking = _get_async_or_non_blocking("hpu", non_blocking, kwargs)
 | |
|     hpu = getattr(torch, "hpu", None)
 | |
|     assert hpu is not None, "HPU device module is not loaded"
 | |
|     if self.is_hpu:
 | |
|         if device is None:
 | |
|             device = hpu.current_device()
 | |
|         if self.get_device() == device:
 | |
|             return self
 | |
|     else:
 | |
|         if device is None:
 | |
|             device = -1
 | |
|     with hpu.device(device):
 | |
|         assert not self.is_sparse, "sparse storage is not supported for HPU tensors"
 | |
|         untyped_storage = torch.UntypedStorage(self.size(), device=torch.device("hpu"))
 | |
|         untyped_storage.copy_(self, non_blocking)
 | |
|         return untyped_storage
 | |
| 
 | |
| 
 | |
| def _cuda(self, device=None, non_blocking=False, **kwargs):
 | |
|     """Returns a copy of this object in CUDA memory.
 | |
| 
 | |
|     If this object is already in CUDA memory and on the correct device, then
 | |
|     no copy is performed and the original object is returned.
 | |
| 
 | |
|     Args:
 | |
|         device (int): The destination GPU id. Defaults to the current device.
 | |
|         non_blocking (bool): If ``True`` and the source is in pinned memory,
 | |
|             the copy will be asynchronous with respect to the host. Otherwise,
 | |
|             the argument has no effect.
 | |
|         **kwargs: For compatibility, may contain the key ``async`` in place of
 | |
|             the ``non_blocking`` argument.
 | |
|     """
 | |
|     non_blocking = _get_async_or_non_blocking("cuda", non_blocking, kwargs)
 | |
|     if self.is_cuda:
 | |
|         if device is None:
 | |
|             device = torch.cuda.current_device()
 | |
|         if self.get_device() == device:
 | |
|             return self
 | |
|     else:
 | |
|         if device is None:
 | |
|             device = -1
 | |
|     with torch.cuda.device(device):
 | |
|         if self.is_sparse:
 | |
|             new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
 | |
|             indices = torch.Tensor._indices(self).cuda(device, non_blocking)
 | |
|             values = torch.Tensor._values(self).cuda(device, non_blocking)
 | |
|             return new_type(indices, values, self.size())
 | |
|         else:
 | |
|             untyped_storage = torch.UntypedStorage(
 | |
|                 self.size(), device=torch.device("cuda")
 | |
|             )
 | |
|             untyped_storage.copy_(self, non_blocking)
 | |
|             return untyped_storage
 | |
| 
 | |
| 
 | |
| def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
 | |
|     """Return the non-blocking flag given the function name and kwargs.
 | |
| 
 | |
|     Args:
 | |
|         function_name (str): the name of the function being used.
 | |
|         non_blocking (bool): the default value.
 | |
|         **kwargs (dict): the kwargs passed to the function.
 | |
|     """
 | |
|     if not kwargs:
 | |
|         return non_blocking
 | |
|     if len(kwargs) != 1 or "async" not in kwargs:
 | |
|         message = "{}() got an unexpected keyword argument '{}'"
 | |
|         argument = list(kwargs.keys()).pop()
 | |
|         raise TypeError(message.format(function_name, argument))
 | |
|     warnings.warn("'async' is deprecated; use 'non_blocking'")
 | |
|     return kwargs["async"]
 | |
| 
 | |
| 
 | |
| # Note [Don't serialize hooks]
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | |
| # Since time immemorial, we have serialized the backward hooks associated with
 | |
| # variables.  This kind of half-worked--Python can pickle global functions
 | |
| # (but not closures!)--but there were problems.
 | |
| #
 | |
| #   - It's fragile.  If you serialize a backward hook into a saved
 | |
| #     model, and then you rename the function associated with the hook,
 | |
| #     now your saved model is broken and you can't load it anymore.
 | |
| #
 | |
| #   - It's not actually used.  The standard recommendation is to
 | |
| #     serialize the *state_dict* of a model, not the model itself
 | |
| #     (since this is more stable to code changes affecting the model
 | |
| #     serialization), and the state dict saves "data" only, thus
 | |
| #     stripping the backward hooks.  In some cases, hooks are
 | |
| #     essential to the well-functioning of a model (e.g., DDP),
 | |
| #     but DDP already manages readding the hooks!
 | |
| #
 | |
| #   - We didn't serialize them in many cases.  Prior to #10220, we
 | |
| #     were dropping backward hooks in ForkingPickler.  We "fixed" this
 | |
| #     to be convenient with other serialization sites, but lack of
 | |
| #     serializing backward hooks wasn't actually the root cause of
 | |
| #     the bug.
 | |
| #
 | |
| # With these cases in mind, we have decided that a better strategy
 | |
| # is to just NOT serialize hooks at all.
 | |
| #
 | |
| # Since this is a BC-breaking change, we should warn when we previously
 | |
| # serialized a hook, but no longer do so. This will be done by adding a special
 | |
| # sentinel property to hooks will be used to suppress this warning. If a hook
 | |
| # has the property _torch_serialize_ignore, we will not emit a warning if we
 | |
| # attempt to serialize a Tensor with this hook attached to it.
 | |
| #
 | |
| # By the way, when _backward_hooks is skipped, we must give an EMPTY
 | |
| # OrderedDict(), if you pass a None you'll run afoul #12219.
 | |
| 
 | |
| 
 | |
| # TODO: Once we decide to break serialization FC, `storage` no longer needs to
 | |
| # be a TypedStorage
 | |
| def _rebuild_tensor(storage, storage_offset, size, stride):
 | |
|     # first construct a tensor with the correct dtype/device
 | |
|     t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
 | |
|     return t.set_(storage._untyped_storage, storage_offset, size, stride)
 | |
| 
 | |
| 
 | |
| def get_tensor_metadata(tensor):
 | |
|     # Tensor's Metadata for serializing.
 | |
|     # Currently, this only returns a dict[string, bool] specifing whether
 | |
|     # `conj` or `neg` bit is set.
 | |
|     assert isinstance(tensor, torch.Tensor)
 | |
|     return torch._C._get_tensor_metadata(tensor)  # type: ignore[attr-defined]
 | |
| 
 | |
| 
 | |
| def set_tensor_metadata(tensor, metadata):
 | |
|     # See `get_tensor_metadata` above
 | |
|     assert isinstance(metadata, dict)
 | |
|     assert isinstance(tensor, torch.Tensor)
 | |
|     torch._C._set_tensor_metadata(tensor, metadata)  # type: ignore[attr-defined]
 | |
| 
 | |
| 
 | |
| def _rebuild_tensor_v2(
 | |
|     storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None
 | |
| ):
 | |
|     tensor = _rebuild_tensor(storage, storage_offset, size, stride)
 | |
|     tensor.requires_grad = requires_grad
 | |
|     if metadata:
 | |
|         set_tensor_metadata(tensor, metadata)
 | |
| 
 | |
|     # NB: This line exists only for backwards compatibility; the
 | |
|     # general expectation is that backward_hooks is an empty
 | |
|     # OrderedDict.  See Note [Don't serialize hooks]
 | |
|     tensor._backward_hooks = backward_hooks
 | |
|     return tensor
 | |
| 
 | |
| 
 | |
| def _rebuild_tensor_v3(
 | |
|     storage,
 | |
|     storage_offset,
 | |
|     size,
 | |
|     stride,
 | |
|     requires_grad,
 | |
|     backward_hooks,
 | |
|     dtype,
 | |
|     metadata=None,
 | |
| ):
 | |
|     t = torch.empty(
 | |
|         (0,),
 | |
|         dtype=dtype,
 | |
|         device=storage._untyped_storage.device,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
|     t.set_(storage._untyped_storage, storage_offset, size, stride)
 | |
|     if metadata:
 | |
|         set_tensor_metadata(t, metadata)
 | |
|     t._backward_hooks = backward_hooks
 | |
|     return t
 | |
| 
 | |
| 
 | |
| _sparse_tensors_to_validate: List["torch.Tensor"] = []
 | |
| 
 | |
| 
 | |
| # In _legacy_load() in serialization.py we unpickle storages after the sparse
 | |
| # tensors have been already unpickled. Those storages contain data necessary for
 | |
| # validating sparse tensors: indices and values. That's why sparse tensors are
 | |
| # first unpickled without any validation, and then this function is called just
 | |
| # before _legacy_load() returns, so that all the sparse tensors can be validated
 | |
| # in bulk.
 | |
| #
 | |
| # The same procedure must be followed by _load() in serialization.py because due
 | |
| # to Pickler semantics, we have to use the same (non-validating) function for
 | |
| # unpickling sparse tensors, regardless of the caller.
 | |
| def _validate_loaded_sparse_tensors():
 | |
|     try:
 | |
|         for t in _sparse_tensors_to_validate:
 | |
|             if t.layout is torch.sparse_coo:
 | |
|                 torch._validate_sparse_coo_tensor_args(
 | |
|                     t._indices(), t._values(), t.size(), t.is_coalesced()
 | |
|                 )
 | |
|             elif t.layout in {
 | |
|                 torch.sparse_csr,
 | |
|                 torch.sparse_csc,
 | |
|                 torch.sparse_bsr,
 | |
|                 torch.sparse_bsc,
 | |
|             }:
 | |
|                 # TODO: Validation currently involves an expensive traversal
 | |
|                 # on CPU, which may include a device transfer.
 | |
|                 if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
 | |
|                     compressed_indices, plain_indices = (
 | |
|                         t.crow_indices(),
 | |
|                         t.col_indices(),
 | |
|                     )
 | |
|                 else:
 | |
|                     compressed_indices, plain_indices = (
 | |
|                         t.ccol_indices(),
 | |
|                         t.row_indices(),
 | |
|                     )
 | |
|                 torch._validate_sparse_compressed_tensor_args(
 | |
|                     compressed_indices, plain_indices, t.values(), t.size(), t.layout
 | |
|                 )
 | |
|             else:
 | |
|                 raise NotImplementedError(
 | |
|                     f"_validate_loaded_sparse_tensors for layout `{t.layout}`"
 | |
|                 )
 | |
| 
 | |
|     finally:
 | |
|         _sparse_tensors_to_validate.clear()
 | |
| 
 | |
| 
 | |
| def _rebuild_sparse_tensor(layout, data):
 | |
|     """
 | |
|     Rebuilds a sparse tensor from its sparse storage representation.
 | |
| 
 | |
|     Args:
 | |
|         layout (str): The sparse storage layout of the tensor.
 | |
|         data (tuple): The tensor's sparse storage representation.
 | |
|     """
 | |
|     if layout == torch.sparse_coo:
 | |
|         if len(data) == 3:
 | |
|             # For BC:
 | |
|             indices, values, size = data
 | |
|             is_coalesced = None
 | |
|         else:
 | |
|             indices, values, size, is_coalesced = data
 | |
|         result = torch.sparse_coo_tensor(
 | |
|             indices, values, size, check_invariants=False, is_coalesced=is_coalesced
 | |
|         )
 | |
|         _sparse_tensors_to_validate.append(result)
 | |
|         return result
 | |
| 
 | |
|     elif layout in {
 | |
|         torch.sparse_csr,
 | |
|         torch.sparse_csc,
 | |
|         torch.sparse_bsr,
 | |
|         torch.sparse_bsc,
 | |
|     }:
 | |
|         compressed_indices, plain_indices, values, size = data
 | |
|         result = torch.sparse_compressed_tensor(
 | |
|             compressed_indices,
 | |
|             plain_indices,
 | |
|             values,
 | |
|             size,
 | |
|             layout=layout,
 | |
|             check_invariants=False,
 | |
|         )
 | |
|         _sparse_tensors_to_validate.append(result)
 | |
|         return result
 | |
| 
 | |
|     raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}")
 | |
| 
 | |
| 
 | |
| def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
 | |
|     return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets)
 | |
| 
 | |
| 
 | |
| def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
 | |
|     tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
 | |
|     tensor.requires_grad = requires_grad
 | |
|     return tensor
 | |
| 
 | |
| 
 | |
| # Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
 | |
| _rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
 | |
| 
 | |
| 
 | |
| def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
 | |
|     return torch.empty_strided(
 | |
|         size, stride, dtype=dtype, device="meta", requires_grad=requires_grad
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _rebuild_wrapper_subclass(
 | |
|     cls, dtype, size, stride, storage_offset, layout, device, requires_grad
 | |
| ):
 | |
|     return torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
 | |
|         cls,
 | |
|         size,
 | |
|         strides=stride,
 | |
|         storage_offset=storage_offset,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
| 
 | |
| 
 | |
| # TODO: Once we decide to break serialization FC, `storage` no longer needs to
 | |
| # be a TypedStorage
 | |
| def _rebuild_qtensor(
 | |
|     storage,
 | |
|     storage_offset,
 | |
|     size,
 | |
|     stride,
 | |
|     quantizer_params,
 | |
|     requires_grad,
 | |
|     backward_hooks,
 | |
| ):
 | |
|     qscheme = quantizer_params[0]
 | |
|     if qscheme == torch.per_tensor_affine:
 | |
|         _, scale, zero_point = quantizer_params
 | |
|         tensor = torch._empty_affine_quantized(
 | |
|             size,
 | |
|             scale=scale,
 | |
|             zero_point=zero_point,
 | |
|             dtype=storage.dtype,
 | |
|             device=storage.device,
 | |
|         )
 | |
|     elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
 | |
|         _, scales, zero_points, axis = quantizer_params
 | |
|         if type(scales) is list and type(zero_points) is list:
 | |
|             if qscheme == torch.per_channel_affine:
 | |
|                 scales = torch.tensor(scales, dtype=torch.double, device=storage.device)
 | |
|                 zero_points = torch.tensor(
 | |
|                     zero_points, dtype=torch.long, device=storage.device
 | |
|                 )
 | |
|             else:
 | |
|                 scales = torch.tensor(scales, dtype=torch.float, device=storage.device)
 | |
|                 zero_points = torch.tensor(
 | |
|                     zero_points, dtype=torch.float, device=storage.device
 | |
|                 )
 | |
|         tensor = torch._empty_per_channel_affine_quantized(
 | |
|             size,
 | |
|             scales=scales,
 | |
|             zero_points=zero_points,
 | |
|             axis=axis,
 | |
|             dtype=storage.dtype,
 | |
|             device=storage.device,
 | |
|         )
 | |
|     else:
 | |
|         raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}")
 | |
|     tensor.set_(storage, storage_offset, size, stride)
 | |
|     tensor.requires_grad = requires_grad
 | |
|     # NB: This line exists only for backwards compatibility; the
 | |
|     # general expectation is that backward_hooks is an empty
 | |
|     # OrderedDict.  See Note [Don't serialize hooks]
 | |
|     tensor._backward_hooks = backward_hooks
 | |
|     return tensor
 | |
| 
 | |
| 
 | |
| def _rebuild_parameter(data, requires_grad, backward_hooks):
 | |
|     param = torch.nn.Parameter(data, requires_grad)
 | |
|     # NB: This line exists only for backwards compatibility; the
 | |
|     # general expectation is that backward_hooks is an empty
 | |
|     # OrderedDict.  See Note [Don't serialize hooks]
 | |
|     param._backward_hooks = backward_hooks
 | |
| 
 | |
|     return param
 | |
| 
 | |
| 
 | |
| def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
 | |
|     param = torch.nn.Parameter(data, requires_grad)
 | |
|     # NB: This line exists only for backwards compatibility; the
 | |
|     # general expectation is that backward_hooks is an empty
 | |
|     # OrderedDict.  See Note [Don't serialize hooks]
 | |
|     param._backward_hooks = backward_hooks
 | |
| 
 | |
|     # Restore state on Parameter like python attr.
 | |
|     param = _set_obj_state(param, state)
 | |
|     return param
 | |
| 
 | |
| 
 | |
| def _get_obj_state(obj):
 | |
|     # Get the state of the python subclass
 | |
|     # This loosely mimicks the function on the object class but since Tensor do not inherit
 | |
|     # from it, we cannot call that function directly
 | |
|     # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891
 | |
|     # Note that starting with Python 3.11, this `__getstate__` is always defined and thus
 | |
|     # the else branch will never be taken.
 | |
|     getstate_fn = getattr(obj, "__getstate__", None)
 | |
|     if getstate_fn:
 | |
|         state = getstate_fn()
 | |
|     else:
 | |
|         slots_to_save = copyreg._slotnames(obj.__class__)  # type: ignore[attr-defined]
 | |
|         if slots_to_save:
 | |
|             state = (
 | |
|                 obj.__dict__,
 | |
|                 {
 | |
|                     name: getattr(obj, name)
 | |
|                     for name in slots_to_save
 | |
|                     if hasattr(obj, name)
 | |
|                 },
 | |
|             )
 | |
|         else:
 | |
|             state = obj.__dict__
 | |
| 
 | |
|     return state
 | |
| 
 | |
| 
 | |
| def _set_obj_state(obj, state):
 | |
|     if isinstance(state, tuple):
 | |
|         if not len(state) == 2:
 | |
|             raise RuntimeError(f"Invalid serialized state: {state}")
 | |
|         dict_state = state[0]
 | |
|         slots_state = state[1]
 | |
|     else:
 | |
|         dict_state = state
 | |
|         slots_state = None
 | |
| 
 | |
|     # Starting with Python 3.11, the __dict__ attribute is lazily created
 | |
|     # and is serialized as None when not needed.
 | |
|     if dict_state:
 | |
|         for k, v in dict_state.items():
 | |
|             setattr(obj, k, v)
 | |
| 
 | |
|     if slots_state:
 | |
|         for k, v in slots_state.items():
 | |
|             setattr(obj, k, v)
 | |
|     return obj
 | |
| 
 | |
| 
 | |
| def _import_dotted_name(name):
 | |
|     components = name.split(".")
 | |
|     obj = __import__(components[0])
 | |
|     for component in components[1:]:
 | |
|         obj = getattr(obj, component)
 | |
|     return obj
 | |
| 
 | |
| 
 | |
| def _flatten_dense_tensors(tensors):
 | |
|     """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
 | |
|     same dense type.
 | |
| 
 | |
|     Since inputs are dense, the resulting tensor will be a concatenated 1D
 | |
|     buffer. Element-wise operation on this buffer will be equivalent to
 | |
|     operating individually.
 | |
| 
 | |
|     Args:
 | |
|         tensors (Iterable[Tensor]): dense tensors to flatten.
 | |
| 
 | |
|     Returns:
 | |
|         A contiguous 1D buffer containing input tensors.
 | |
|     """
 | |
|     return torch._C._nn.flatten_dense_tensors(tensors)
 | |
| 
 | |
| 
 | |
| def _flatten_sparse_tensors(tensors):
 | |
|     """Flatten sparse tensors into two contiguous 1D buffers, one of indices and
 | |
|     one of values. Assume tensors are of same sparse type.
 | |
| 
 | |
|     Args:
 | |
|         tensors (Iterable[Tensor]): sparse tensors to flatten.
 | |
| 
 | |
|     Returns:
 | |
|         A tuple of two contiguous 1D buffers, one containing input tensors'
 | |
|         indices and the other containing the values.
 | |
|     """
 | |
|     flat_indices = torch._C._nn.flatten_dense_tensors(
 | |
|         [torch.Tensor._indices(t) for t in tensors]
 | |
|     )
 | |
|     flat_values = torch._C._nn.flatten_dense_tensors(
 | |
|         [torch.Tensor._values(t) for t in tensors]
 | |
|     )
 | |
|     return flat_indices, flat_values
 | |
| 
 | |
| 
 | |
| def _unflatten_dense_tensors(flat, tensors):
 | |
|     """View a flat buffer using the sizes of tensors. Assume that tensors are of
 | |
|     same dense type, and that flat is given by _flatten_dense_tensors.
 | |
| 
 | |
|     Args:
 | |
|         flat (Tensor): flattened dense tensors to unflatten.
 | |
|         tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
 | |
|           unflatten flat.
 | |
| 
 | |
|     Returns:
 | |
|         Unflattened dense tensors with sizes same as tensors and values from
 | |
|         flat.
 | |
|     """
 | |
|     return torch._C._nn.unflatten_dense_tensors(flat, tensors)
 | |
| 
 | |
| 
 | |
| def _unflatten_sparse_tensors(flat, tensors):
 | |
|     """View flat buffer (containing indices and values) using the sizes of
 | |
|     tensors. Assume that tensors are of same sparse type, and that flat is given
 | |
|     by _flatten_sparse_tensors.
 | |
| 
 | |
|     Args:
 | |
|         flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
 | |
|           tensors to unflatten.
 | |
|         tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
 | |
|           unflatten flat.
 | |
| 
 | |
|     Returns:
 | |
|         Unflattened sparse tensors with sizes same as tensors and values from
 | |
|         flat.
 | |
|     """
 | |
|     flat_indices, flat_values = flat
 | |
|     indices = torch._C._nn.unflatten_dense_tensors(
 | |
|         flat_indices, [torch.Tensor._indices(t) for t in tensors]
 | |
|     )
 | |
|     values = torch._C._nn.unflatten_dense_tensors(
 | |
|         flat_values, [torch.Tensor._values(t) for t in tensors]
 | |
|     )
 | |
|     outputs = []
 | |
|     for t, i, v in zip(tensors, indices, values):
 | |
|         outputs.append(t.new(i, v, t.size()))
 | |
|     return tuple(outputs)
 | |
| 
 | |
| 
 | |
| def _reorder_tensors_as(tensors, ordered_tensors):
 | |
|     """Assume that tensors are of same order as ordered_tensors within their
 | |
|     types, e.g., from _take_tensors. Reorder them to be of same order as
 | |
|     ordered_tensors.
 | |
| 
 | |
|     Args:
 | |
|         tensors (Iterable[Tensor]): tensors to be reordered. They should be of
 | |
|           the same order as ordered_tensors within their own types.
 | |
|         ordered_tensors (Iterable[Tensor]): tensors whose order will be the
 | |
|           reference.
 | |
| 
 | |
|     Returns:
 | |
|         Ordered tuple of tensors with contents from tensors and order of
 | |
|         ordered_tensors.
 | |
|     """
 | |
|     type_dict = defaultdict(list)
 | |
|     for tensor in tensors:
 | |
|         type_dict[tensor.type()].append(tensor)
 | |
|     type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
 | |
|     return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
 | |
| 
 | |
| 
 | |
| def _take_tensors(tensors, size_limit):
 | |
|     """Group tensors into chunks. This generator yields a chunk at each time,
 | |
|     each containing tensors of same type up to certain byte limit in total size.
 | |
| 
 | |
|     Args:
 | |
|         tensors (Sequence): A sequence of tensors to be separated into chunks.
 | |
|         size_limit (int): The limit of each chunk in bytes.
 | |
| 
 | |
|     Yields:
 | |
|         Blocks of tensors of same type and within size_limit. The yielded
 | |
|         tensors are only ordered as the original sequence within its types.
 | |
|     """
 | |
|     buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
 | |
|     for tensor in tensors:
 | |
|         t = tensor.type()
 | |
|         if tensor.is_sparse:
 | |
|             indices = torch.Tensor._indices(tensor)
 | |
|             values = torch.Tensor._values(tensor)
 | |
|             size = (
 | |
|                 indices.numel() * indices.element_size()
 | |
|                 + values.numel() * values.element_size()
 | |
|             )
 | |
|         else:
 | |
|             size = tensor.numel() * tensor.element_size()
 | |
|         buf_and_size = buf_dict[t]
 | |
|         if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
 | |
|             yield buf_and_size[0]
 | |
|             buf_and_size = buf_dict[t] = [[], 0]
 | |
|         buf_and_size[0].append(tensor)
 | |
|         buf_and_size[1] += size
 | |
|     for buf, _ in buf_dict.values():
 | |
|         if len(buf) > 0:
 | |
|             yield buf
 | |
| 
 | |
| 
 | |
| # annotation decorator to get annotations in a way that is compatible
 | |
| # with both Python 2 and 3
 | |
| def annotate(ret, **kwargs):
 | |
|     def dec(fun):
 | |
|         fun.__annotations__ = dict(kwargs)
 | |
|         fun.__annotations__["return"] = ret
 | |
|         return fun
 | |
| 
 | |
|     return dec
 | |
| 
 | |
| 
 | |
| def render_call(fn, args, kwargs):
 | |
|     str_fn = torch.overrides.resolve_name(fn)
 | |
|     if str_fn is None:
 | |
|         str_fn = str(fn)
 | |
| 
 | |
|     str_args: List[str] = []
 | |
|     with torch._tensor_str.printoptions(threshold=0, edgeitems=0):
 | |
|         str_args.extend(repr(a) for a in args)
 | |
|         str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items())
 | |
|         r = f"{str_fn}({', '.join(str_args)})"
 | |
|     return r
 | |
| 
 | |
| 
 | |
| # NOTE [ Python Traceback Reference Cycle Problem ]
 | |
| #
 | |
| # When using sys.exc_info(), it is important to **not** store the exc_info[2],
 | |
| # which is the traceback, because otherwise you will run into the traceback
 | |
| # reference cycle problem, i.e., the traceback holding reference to the frame,
 | |
| # and the frame (which holds reference to all the object in its temporary scope)
 | |
| # holding reference the traceback.
 | |
| 
 | |
| 
 | |
| class KeyErrorMessage(str):
 | |
|     r"""str subclass that returns itself in repr"""
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return self
 | |
| 
 | |
| 
 | |
| class ExceptionWrapper:
 | |
|     r"""Wraps an exception plus traceback to communicate across threads"""
 | |
| 
 | |
|     def __init__(self, exc_info=None, where="in background"):
 | |
|         # It is important that we don't store exc_info, see
 | |
|         # NOTE [ Python Traceback Reference Cycle Problem ]
 | |
|         if exc_info is None:
 | |
|             exc_info = sys.exc_info()
 | |
|         self.exc_type = exc_info[0]
 | |
|         self.exc_msg = "".join(traceback.format_exception(*exc_info))
 | |
|         self.where = where
 | |
| 
 | |
|     def reraise(self):
 | |
|         r"""Reraises the wrapped exception in the current thread"""
 | |
|         # Format a message such as: "Caught ValueError in DataLoader worker
 | |
|         # process 2. Original Traceback:", followed by the traceback.
 | |
|         msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
 | |
|         if self.exc_type == KeyError:
 | |
|             # KeyError calls repr() on its argument (usually a dict key). This
 | |
|             # makes stack traces unreadable. It will not be changed in Python
 | |
|             # (https://bugs.python.org/issue2651), so we work around it.
 | |
|             msg = KeyErrorMessage(msg)
 | |
|         elif getattr(self.exc_type, "message", None):
 | |
|             # Some exceptions have first argument as non-str but explicitly
 | |
|             # have message field
 | |
|             raise self.exc_type(message=msg)
 | |
|         try:
 | |
|             exception = self.exc_type(msg)
 | |
|         except TypeError:
 | |
|             # If the exception takes multiple arguments, don't try to
 | |
|             # instantiate since we don't know how to
 | |
|             raise RuntimeError(msg) from None
 | |
|         raise exception
 | |
| 
 | |
| 
 | |
| def _get_available_device_type():
 | |
|     if torch.cuda.is_available():
 | |
|         return "cuda"
 | |
|     if hasattr(torch, "xpu") and torch.xpu.is_available():  # type: ignore[attr-defined]
 | |
|         return "xpu"
 | |
|     custom_backend_name = torch._C._get_privateuse1_backend_name()
 | |
|     custom_device_mod = getattr(torch, custom_backend_name, None)
 | |
|     if custom_device_mod and custom_device_mod.is_available():
 | |
|         return custom_backend_name
 | |
|     # add more available device types here
 | |
|     return None
 | |
| 
 | |
| 
 | |
| def _get_device_attr(get_member):
 | |
|     device_type = _get_available_device_type()
 | |
|     if device_type and device_type.lower() == "cuda":
 | |
|         return get_member(torch.cuda)
 | |
|     if device_type and device_type.lower() == "xpu":
 | |
|         return get_member(torch.xpu)  # type: ignore[attr-defined]
 | |
|     if device_type == torch._C._get_privateuse1_backend_name():
 | |
|         return get_member(getattr(torch, device_type))
 | |
|     # add more available device types here
 | |
|     return None
 | |
| 
 | |
| 
 | |
| def _get_current_device_index():
 | |
|     # current device index
 | |
|     return _get_device_attr(lambda m: m.current_device())
 | |
| 
 | |
| 
 | |
| def _get_all_device_indices():
 | |
|     # all device index
 | |
|     return _get_device_attr(lambda m: list(range(m.device_count())))
 | |
| 
 | |
| 
 | |
| def _get_devices_properties(device_ids):
 | |
|     # all device properties
 | |
|     return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
 | |
| 
 | |
| 
 | |
| def get_current_device_index() -> int:
 | |
|     r"""Checks if there are CUDA devices available and
 | |
|     returns the device index of the current default CUDA device.
 | |
|     Returns -1 in case there are no CUDA devices available.
 | |
|     Arguments: ``None``
 | |
|     """
 | |
|     if torch.cuda.device_count() > 0:
 | |
|         return torch.cuda.current_device()
 | |
|     return -1
 | |
| 
 | |
| 
 | |
| def _get_device_index(
 | |
|     device: Any, optional: bool = False, allow_cpu: bool = False
 | |
| ) -> int:
 | |
|     r"""Gets the device index from :attr:`device`, which can be a torch.device
 | |
|     object, a Python integer, or ``None``.
 | |
| 
 | |
|     If :attr:`device` is a torch.device object, returns the device index if it
 | |
|     has index. Note that for a device without a specified index,
 | |
|     i.e., ``torch.device('xxx')``, this will return the current default
 | |
|     device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
 | |
|     CPU devices will be accepted and ``-1`` will be returned in this case.
 | |
| 
 | |
|     If :attr:`device` is a Python integer, it is returned as is.
 | |
| 
 | |
|     If :attr:`device` is ``None``, this will return the current default
 | |
|     device of the supported runtime platform if :attr:`optional` is ``True``.
 | |
|     i.e., the current default CUDA device will be returned if CUDA runtime is supported.
 | |
|     """
 | |
|     if isinstance(device, str):
 | |
|         device = torch.device(device)
 | |
|     device_idx: Optional[int] = None
 | |
|     if isinstance(device, torch.device):
 | |
|         if not allow_cpu and device.type == "cpu":
 | |
|             raise ValueError(f"Expected a non cpu device, but got: {device}")
 | |
|         device_idx = -1 if device.type == "cpu" else device.index
 | |
|     if isinstance(device, int):
 | |
|         device_idx = device
 | |
|     if device_idx is None:
 | |
|         if optional:
 | |
|             # The eager API _get_current_device_index uses `lambda` functions which are
 | |
|             # not supported in JIT and hence not scriptable. The JIT equivalent API to get
 | |
|             # the current device index is `get_current_device_index()` which can
 | |
|             # be scripted. We use is_scripting to check the mode we are in and call the
 | |
|             # appropriate API.
 | |
|             if torch.jit.is_scripting():
 | |
|                 device_idx = get_current_device_index()
 | |
|             else:
 | |
|                 device_idx = _get_current_device_index()
 | |
|         else:
 | |
|             raise ValueError(
 | |
|                 f"Expected a torch.device with a specified index or an integer, but got:{device}"
 | |
|             )
 | |
|     return device_idx
 | |
| 
 | |
| 
 | |
| def _handle_complex(tensor):
 | |
|     """
 | |
|     Returns a real view of a tensor if complex dtype else just the tensor
 | |
|     need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
 | |
|     """
 | |
|     return (
 | |
|         torch.view_as_real(tensor)
 | |
|         if not isinstance(tensor, torch.nn.UninitializedParameter)
 | |
|         and tensor.is_complex()
 | |
|         else tensor
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _element_size(dtype):
 | |
|     """
 | |
|     Returns the element size for a dtype, in bytes
 | |
|     """
 | |
|     if not isinstance(dtype, torch.dtype):
 | |
|         raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}")
 | |
| 
 | |
|     if dtype.is_complex:
 | |
|         return torch.finfo(dtype).bits >> 2
 | |
|     elif dtype.is_floating_point:
 | |
|         return torch.finfo(dtype).bits >> 3
 | |
|     elif dtype == torch.bool:
 | |
|         # NOTE: torch.bool is not supported in torch.iinfo()
 | |
|         return 1
 | |
|     else:
 | |
|         return torch.iinfo(dtype).bits >> 3
 | |
| 
 | |
| 
 | |
| class _ClassPropertyDescriptor:
 | |
|     def __init__(self, fget, fset=None):
 | |
|         self.fget = fget
 | |
| 
 | |
|     def __get__(self, instance, owner=None):
 | |
|         if owner is None:
 | |
|             owner = type(instance)
 | |
|         return self.fget.__get__(instance, owner)()
 | |
| 
 | |
| 
 | |
| def classproperty(func):
 | |
|     if not isinstance(func, (classmethod, staticmethod)):
 | |
|         func = classmethod(func)
 | |
|     return _ClassPropertyDescriptor(func)
 | |
| 
 | |
| 
 | |
| def is_compiling() -> bool:
 | |
|     """
 | |
|     Indicates whether we are tracing/compiling with torch.compile() or torch.export().
 | |
| 
 | |
|     TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling().
 | |
|     """
 | |
|     return torch.compiler.is_compiling()
 | |
| 
 | |
| 
 | |
| def _functionalize_sync(t):
 | |
|     # This code lives in python instead of C++ since conditioning on a certain python subclass
 | |
|     # is much more of a pain in C++.
 | |
|     from torch._subclasses.functional_tensor import FunctionalTensor
 | |
| 
 | |
|     if isinstance(t, FunctionalTensor):
 | |
|         # If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
 | |
|         # when we sync our inner tensor.
 | |
|         # Why?
 | |
|         # (1) If there are input mutations in the graph, then they will be re-applied during
 | |
|         #     AOTAutograd when we call _sync() from inside of our functionalization kernels.
 | |
|         # (2) _sync() causes us to regenerate our updated the tensor from the updated base,
 | |
|         #     which dispatches to a bunch of view ops
 | |
|         # (3) The input to these view ops is our inner FunctionalTensorWrapper
 | |
|         #     (since the sync was called from C++), not the python FunctionalTensor
 | |
|         # (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
 | |
|         #     the view op, since it will see an input that is a C++ FunctionalTensorWrapper
 | |
|         #     (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
 | |
|         maybe_functional_mode = torch._C._unset_dispatch_mode(
 | |
|             torch._C._TorchDispatchModeKey.FUNCTIONAL
 | |
|         )
 | |
|         try:
 | |
|             torch._functionalize_sync(t.elem)  # type: ignore[attr-defined]
 | |
|         finally:
 | |
|             if maybe_functional_mode is not None:
 | |
|                 torch._C._set_dispatch_mode(maybe_functional_mode)
 | |
|     else:
 | |
|         torch._functionalize_sync(t)  # type: ignore[attr-defined]
 | |
| 
 | |
| 
 | |
| @functools.lru_cache(2)
 | |
| def _get_device_module(device_type: str):
 | |
|     device_module = getattr(torch, device_type, None)
 | |
|     if device_module is None:
 | |
|         raise RuntimeError(
 | |
|             f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
 | |
|         )
 | |
|     return device_module
 | |
| 
 | |
| 
 | |
| def _dummy_type(name: str) -> type:
 | |
|     def get_err_fn(is_init: bool):
 | |
|         def err_fn(obj, *args, **kwargs):
 | |
|             if is_init:
 | |
|                 class_name = obj.__class__.__name__
 | |
|             else:
 | |
|                 class_name = obj.__name__
 | |
|             raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
 | |
| 
 | |
|         return err_fn
 | |
| 
 | |
|     return type(
 | |
|         name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
 | |
|     )
 | |
| 
 | |
| 
 | |
| class _LazySeedTracker:
 | |
|     # Since seeding is memory-less, only track the latest seed.
 | |
|     # Note: `manual_seed_all` followed by `manual_seed` overwrites
 | |
|     # the seed on current device. We track the order of **latest**
 | |
|     # calls between these two API.
 | |
|     def __init__(self):
 | |
|         self.manual_seed_all_cb = None
 | |
|         self.manual_seed_cb = None
 | |
|         self.call_order = []
 | |
| 
 | |
|     def queue_seed_all(self, cb, traceback):
 | |
|         self.manual_seed_all_cb = (cb, traceback)
 | |
|         # update seed_all to be latest
 | |
|         self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
 | |
| 
 | |
|     def queue_seed(self, cb, traceback):
 | |
|         self.manual_seed_cb = (cb, traceback)
 | |
|         # update seed to be latest
 | |
|         self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
 | |
| 
 | |
|     def get_calls(self) -> List:
 | |
|         return self.call_order
 | |
| 
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| P = ParamSpec("P")
 | |
| 
 | |
| 
 | |
| class CallbackRegistry(Generic[P]):
 | |
|     def __init__(self, name: str):
 | |
|         self.name = name
 | |
|         self.callback_list: List[Callable[P, None]] = []
 | |
| 
 | |
|     def add_callback(self, cb: Callable[P, None]) -> None:
 | |
|         self.callback_list.append(cb)
 | |
| 
 | |
|     def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
 | |
|         for cb in self.callback_list:
 | |
|             try:
 | |
|                 cb(*args, **kwargs)
 | |
|             except Exception as e:
 | |
|                 logger.exception(
 | |
|                     "Exception in callback for %s registered with gpu trace", self.name
 | |
|                 )
 |