diff --git a/.lintrunner.toml b/.lintrunner.toml index ed18baffbd8d..9f62931ce57c 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1549,8 +1549,6 @@ exclude_patterns = [ 'torch/_custom_op/autograd.py', 'torch/_custom_op/functional.py', 'torch/_custom_op/impl.py', - 'torch/_dispatch/__init__.py', - 'torch/_dispatch/python.py', 'torch/_export/__init__.py', 'torch/_export/constraints.py', 'torch/_export/db/__init__.py', @@ -1629,10 +1627,6 @@ exclude_patterns = [ 'torch/_higher_order_ops/__init__.py', 'torch/_higher_order_ops/out_dtype.py', 'torch/_higher_order_ops/wrap.py', - 'torch/_prims_common/__init__.py', - 'torch/_prims_common/wrappers.py', - 'torch/amp/__init__.py', - 'torch/amp/autocast_mode.py', 'torch/ao/__init__.py', 'torch/ao/nn/__init__.py', 'torch/ao/nn/intrinsic/__init__.py', @@ -1823,68 +1817,10 @@ exclude_patterns = [ 'torch/ao/quantization/quantize_pt2e.py', 'torch/ao/quantization/stubs.py', 'torch/ao/quantization/utils.py', - 'torch/autograd/__init__.py', - 'torch/autograd/_functions/__init__.py', - 'torch/autograd/_functions/tensor.py', - 'torch/autograd/_functions/utils.py', - 'torch/autograd/anomaly_mode.py', - 'torch/autograd/forward_ad.py', - 'torch/autograd/function.py', - 'torch/autograd/functional.py', - 'torch/autograd/grad_mode.py', - 'torch/autograd/gradcheck.py', - 'torch/autograd/graph.py', - 'torch/autograd/profiler.py', - 'torch/autograd/profiler_legacy.py', - 'torch/autograd/profiler_util.py', - 'torch/autograd/variable.py', - 'torch/backends/__init__.py', - 'torch/backends/_coreml/__init__.py', - 'torch/backends/_coreml/preprocess.py', - 'torch/backends/_nnapi/__init__.py', - 'torch/backends/_nnapi/prepare.py', - 'torch/backends/_nnapi/serializer.py', - 'torch/backends/cpu/__init__.py', - 'torch/backends/cuda/__init__.py', - 'torch/backends/cudnn/__init__.py', - 'torch/backends/cudnn/rnn.py', - 'torch/backends/mkl/__init__.py', - 'torch/backends/mkldnn/__init__.py', - 'torch/backends/mps/__init__.py', - 'torch/backends/openmp/__init__.py', - 'torch/backends/opt_einsum/__init__.py', - 'torch/backends/quantized/__init__.py', - 'torch/backends/xeon/__init__.py', - 'torch/backends/xeon/run_cpu.py', - 'torch/backends/xnnpack/__init__.py', 'torch/compiler/__init__.py', 'torch/contrib/__init__.py', 'torch/contrib/_tensorboard_vis.py', - 'torch/cpu/__init__.py', - 'torch/cpu/amp/__init__.py', - 'torch/cpu/amp/autocast_mode.py', - 'torch/csrc/jit/tensorexpr/codegen_external.py', - 'torch/csrc/jit/tensorexpr/scripts/bisect.py', - 'torch/csrc/lazy/test_mnist.py', - 'torch/cuda/__init__.py', - 'torch/cuda/_memory_viz.py', - 'torch/cuda/_sanitizer.py', - 'torch/cuda/_utils.py', - 'torch/cuda/amp/__init__.py', - 'torch/cuda/amp/autocast_mode.py', - 'torch/cuda/amp/common.py', - 'torch/cuda/amp/grad_scaler.py', - 'torch/cuda/comm.py', - 'torch/cuda/error.py', - 'torch/cuda/graphs.py', - 'torch/cuda/jiterator.py', - 'torch/cuda/memory.py', - 'torch/cuda/nccl.py', - 'torch/cuda/nvtx.py', - 'torch/cuda/profiler.py', - 'torch/cuda/random.py', - 'torch/cuda/sparse.py', - 'torch/cuda/streams.py', + 'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable 'torch/distributed/__init__.py', 'torch/distributed/_composable_state.py', 'torch/distributed/_functional_collectives.py', @@ -2092,52 +2028,7 @@ exclude_patterns = [ 'torch/distributed/tensor/parallel/multihead_attention_tp.py', 'torch/distributed/tensor/parallel/style.py', 'torch/distributed/utils.py', - 'torch/distributions/__init__.py', - 'torch/distributions/bernoulli.py', - 'torch/distributions/beta.py', - 'torch/distributions/binomial.py', - 'torch/distributions/categorical.py', - 'torch/distributions/cauchy.py', - 'torch/distributions/chi2.py', - 'torch/distributions/constraint_registry.py', - 'torch/distributions/constraints.py', - 'torch/distributions/continuous_bernoulli.py', - 'torch/distributions/dirichlet.py', - 'torch/distributions/distribution.py', - 'torch/distributions/exp_family.py', - 'torch/distributions/exponential.py', - 'torch/distributions/fishersnedecor.py', - 'torch/distributions/gamma.py', - 'torch/distributions/geometric.py', - 'torch/distributions/gumbel.py', - 'torch/distributions/half_cauchy.py', - 'torch/distributions/half_normal.py', - 'torch/distributions/independent.py', - 'torch/distributions/kl.py', - 'torch/distributions/kumaraswamy.py', - 'torch/distributions/laplace.py', - 'torch/distributions/lkj_cholesky.py', - 'torch/distributions/log_normal.py', - 'torch/distributions/logistic_normal.py', - 'torch/distributions/lowrank_multivariate_normal.py', - 'torch/distributions/mixture_same_family.py', - 'torch/distributions/multinomial.py', - 'torch/distributions/multivariate_normal.py', - 'torch/distributions/negative_binomial.py', - 'torch/distributions/normal.py', - 'torch/distributions/one_hot_categorical.py', - 'torch/distributions/pareto.py', - 'torch/distributions/poisson.py', - 'torch/distributions/relaxed_bernoulli.py', - 'torch/distributions/relaxed_categorical.py', - 'torch/distributions/studentT.py', - 'torch/distributions/transformed_distribution.py', - 'torch/distributions/transforms.py', - 'torch/distributions/uniform.py', - 'torch/distributions/utils.py', - 'torch/distributions/von_mises.py', - 'torch/distributions/weibull.py', - 'torch/distributions/wishart.py', + 'torch/distributions/distribution.py', # Use f-string instead of `format` call. 'torch/fft/__init__.py', 'torch/func/__init__.py', 'torch/functional.py', @@ -2229,37 +2120,11 @@ exclude_patterns = [ 'torch/fx/tensor_type.py', 'torch/fx/traceback.py', 'torch/hub.py', - 'torch/jit/__init__.py', - 'torch/jit/_async.py', - 'torch/jit/_await.py', - 'torch/jit/_builtins.py', - 'torch/jit/_check.py', - 'torch/jit/_dataclass_impls.py', - 'torch/jit/_decomposition_utils.py', - 'torch/jit/_decompositions.py', - 'torch/jit/_freeze.py', - 'torch/jit/_fuser.py', - 'torch/jit/_ir_utils.py', - 'torch/jit/_logging.py', - 'torch/jit/_monkeytype_config.py', - 'torch/jit/_passes/__init__.py', - 'torch/jit/_passes/_property_propagation.py', - 'torch/jit/_pickle.py', - 'torch/jit/_recursive.py', - 'torch/jit/_script.py', - 'torch/jit/_serialization.py', - 'torch/jit/_shape_functions.py', - 'torch/jit/_state.py', - 'torch/jit/_trace.py', - 'torch/jit/annotations.py', - 'torch/jit/frontend.py', - 'torch/jit/generate_bytecode.py', - 'torch/jit/mobile/__init__.py', - 'torch/jit/quantized.py', - 'torch/jit/supported_ops.py', - 'torch/jit/unsupported_tensor_ops.py', + 'torch/jit/_script.py', # "Callable[[], Any]" has no attribute "__func__" + 'torch/jit/frontend.py', # "expr" has no attribute "id" 'torch/library.py', 'torch/linalg/__init__.py', + # UFMT causes import cycle on masked 'torch/masked/__init__.py', 'torch/masked/_docs.py', 'torch/masked/_ops.py', @@ -2272,14 +2137,6 @@ exclude_patterns = [ 'torch/masked/maskedtensor/reductions.py', 'torch/masked/maskedtensor/unary.py', 'torch/monitor/__init__.py', - 'torch/mps/__init__.py', - 'torch/mps/profiler.py', - 'torch/multiprocessing/__init__.py', - 'torch/multiprocessing/_atfork.py', - 'torch/multiprocessing/pool.py', - 'torch/multiprocessing/queue.py', - 'torch/multiprocessing/reductions.py', - 'torch/multiprocessing/spawn.py', 'torch/nested/__init__.py', 'torch/nn/__init__.py', 'torch/nn/_reduction.py', @@ -2421,40 +2278,6 @@ exclude_patterns = [ 'torch/optim/sparse_adam.py', 'torch/optim/swa_utils.py', 'torch/overrides.py', - 'torch/profiler/__init__.py', - 'torch/profiler/_memory_profiler.py', - 'torch/profiler/_pattern_matcher.py', - 'torch/profiler/_utils.py', - 'torch/profiler/itt.py', - 'torch/profiler/profiler.py', - 'torch/profiler/python_tracer.py', - 'torch/quantization/__init__.py', - 'torch/quantization/_numeric_suite.py', - 'torch/quantization/_numeric_suite_fx.py', - 'torch/quantization/fake_quantize.py', - 'torch/quantization/fuse_modules.py', - 'torch/quantization/fuser_method_mappings.py', - 'torch/quantization/fx/__init__.py', - 'torch/quantization/fx/_equalize.py', - 'torch/quantization/fx/convert.py', - 'torch/quantization/fx/fuse.py', - 'torch/quantization/fx/fusion_patterns.py', - 'torch/quantization/fx/graph_module.py', - 'torch/quantization/fx/match_utils.py', - 'torch/quantization/fx/pattern_utils.py', - 'torch/quantization/fx/prepare.py', - 'torch/quantization/fx/quantization_patterns.py', - 'torch/quantization/fx/quantization_types.py', - 'torch/quantization/fx/utils.py', - 'torch/quantization/observer.py', - 'torch/quantization/qconfig.py', - 'torch/quantization/quant_type.py', - 'torch/quantization/quantization_mappings.py', - 'torch/quantization/quantize.py', - 'torch/quantization/quantize_fx.py', - 'torch/quantization/quantize_jit.py', - 'torch/quantization/stubs.py', - 'torch/quantization/utils.py', 'torch/quasirandom.py', 'torch/random.py', 'torch/return_types.py', diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py index cc1e53371353..ef770d601e75 100644 --- a/torch/_dispatch/python.py +++ b/torch/_dispatch/python.py @@ -1,13 +1,14 @@ -import torch._C -from contextlib import contextmanager -import unittest.mock -import torch -import torch.utils._pytree as pytree import itertools +import unittest.mock +from contextlib import contextmanager from typing import Iterator -import torch._ops -__all__ = ['enable_python_dispatcher', 'no_python_dispatcher', 'enable_pre_dispatch'] +import torch +import torch._C +import torch._ops +import torch.utils._pytree as pytree + +__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"] no_python_dispatcher = torch._C._DisablePythonDispatcher enable_python_dispatcher = torch._C._EnablePythonDispatcher @@ -15,6 +16,7 @@ enable_pre_dispatch = torch._C._EnablePreDispatch CROSSREF_FUNCTIONALIZE = False + def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]: """ Warning: the set of overloads this will report is very subtle. It is precisely @@ -40,9 +42,12 @@ def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]: for overload in packet: yield getattr(packet, overload) + @contextmanager def suspend_functionalization(): - f_tls = torch._C._dispatch_tls_is_dispatch_key_included(torch._C.DispatchKey.Functionalize) + f_tls = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) f_rv = torch._C._functionalization_reapply_views_tls() if f_tls: torch._disable_functionalization() @@ -52,12 +57,18 @@ def suspend_functionalization(): if f_tls: torch._enable_functionalization(reapply_views=f_rv) + def check_tensor_metadata_matches(nv, rv, desc): assert callable(desc) assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}" assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}" - same_strides, idx = torch._prims_common.check_significant_strides(nv, rv, only_cuda=False) - assert same_strides, f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})" + same_strides, idx = torch._prims_common.check_significant_strides( + nv, rv, only_cuda=False + ) + assert ( + same_strides + ), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})" + def check_metadata_matches(n, r, desc): assert callable(desc) @@ -71,6 +82,7 @@ def check_metadata_matches(n, r, desc): continue check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}") + class Lit: def __init__(self, s): self.s = s @@ -78,14 +90,19 @@ class Lit: def __repr__(self): return self.s + def _fmt(a: object) -> object: if isinstance(a, torch.Tensor): - return Lit(f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})") + return Lit( + f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})" + ) else: return a + def make_crossref_functionalize(op, final_key): from torch._subclasses.fake_tensor import FakeTensorMode + # This case is pretty weird, suppress it for now if op == torch.ops.aten.lift_fresh.default: return final_key @@ -117,7 +134,9 @@ def make_crossref_functionalize(op, final_key): with suspend_functionalization(): f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs)) - orig_f_args, orig_f_kwargs = pytree.tree_map(maybe_detach, (f_args, f_kwargs)) + orig_f_args, orig_f_kwargs = pytree.tree_map( + maybe_detach, (f_args, f_kwargs) + ) with fake_mode: f_r = op(*f_args, **f_kwargs) r = op._op_dk(final_key, *args, **kwargs) @@ -126,14 +145,20 @@ def make_crossref_functionalize(op, final_key): fmt_args = ", ".join( itertools.chain( (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args), - (f"{k}={pytree.tree_map(_fmt, v)}" for k, v in orig_f_kwargs.items()), + ( + f"{k}={pytree.tree_map(_fmt, v)}" + for k, v in orig_f_kwargs.items() + ), ) ) return f"{op}({fmt_args})" + check_metadata_matches(f_r, r, desc) return r + return handler + # NB: enabling this is slow, don't do it in a hot loop. This is purely # for debugging purposes. @contextmanager @@ -142,7 +167,8 @@ def enable_crossref_functionalize(): op._uncache_dispatch(torch._C.DispatchKey.Functionalize) try: with enable_python_dispatcher(), unittest.mock.patch( - 'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True): + "torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True + ): yield finally: for op in all_py_loaded_overloads(): diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 67f82934b234..dd45b0f1e51f 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -1,13 +1,27 @@ from __future__ import annotations -from contextlib import nullcontext -from typing import Any, Union, Sequence, Optional, Tuple, List, Callable, Type, overload, cast -from enum import Enum -from functools import reduce, cmp_to_key import operator -import sympy -import weakref import warnings +import weakref + +from contextlib import nullcontext +from enum import Enum +from functools import cmp_to_key, reduce +from typing import ( + Any, + Callable, + cast, + List, + Optional, + overload, + Sequence, + Tuple, + Type, + Union, +) + +import sympy + import torch from torch import sym_float, sym_int, sym_max @@ -142,15 +156,11 @@ def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType, check_strides=Fals if check_strides: same_strides, idx = check_significant_strides(a, b) if not same_strides: - msg = ( - f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!" - ) + msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!" raise RuntimeError(msg) if a.storage_offset() != b.storage_offset(): - msg = ( - f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!" - ) + msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!" raise RuntimeError(msg) if a.is_conj() != b.is_conj(): @@ -171,7 +181,9 @@ def _check_strides_helper( # See https://github.com/pytorch/pytorch/issues/77553 # Only compares strides that are "meaningful" -- strides for dimensions with length > 1 # and for tensors with more than one element - if (not only_cuda or a.device.type == "cuda" or b.device.type == "cuda") and a.numel() > 0: + if ( + not only_cuda or a.device.type == "cuda" or b.device.type == "cuda" + ) and a.numel() > 0: for idx in range(a.ndim): check = not significant_only or a.shape[idx] > 1 if a.stride()[idx] != b.stride()[idx] and check: @@ -179,11 +191,13 @@ def _check_strides_helper( return True, None + def check_significant_strides( a: TensorLikeType, b: TensorLikeType, *, only_cuda=True ) -> Tuple[bool, Optional[int]]: return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True) + def check_all_strides( a: TensorLikeType, b: TensorLikeType, *, only_cuda=True ) -> Tuple[bool, Optional[int]]: @@ -222,7 +236,6 @@ def is_channels_last_contiguous_2d(a: Tensor) -> bool: expected_stride = 1 for idx in (1, 3, 2, 0): - length = a.shape[idx] if length == 1: continue @@ -243,7 +256,6 @@ def is_channels_last_contiguous_3d(a: Tensor) -> bool: expected_stride = 1 for idx in (1, 4, 3, 2, 0): - length = a.shape[idx] if length == 1: continue @@ -331,13 +343,10 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous # Sorts (length, stride) pairs by stride - lengths_and_strides = sorted( - zip(a.shape, a.stride()), key=operator.itemgetter(1) - ) + lengths_and_strides = sorted(zip(a.shape, a.stride()), key=operator.itemgetter(1)) expected_stride = 1 for length, stride in lengths_and_strides: - if length == 1: continue @@ -357,7 +366,9 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: # non overlapping and dense strides. # This is also INCORRECT because it does not model TensorIterator's # short-circuit, which can cause different strides. -def compute_elementwise_output_logical_to_physical_perm(*tensors, _skip_checks=False) -> List[int]: +def compute_elementwise_output_logical_to_physical_perm( + *tensors, _skip_checks=False +) -> List[int]: if not _skip_checks and len(tensors) == 0: msg = "Can't compute elementwise output strides for zero tensors!" raise ValueError(msg) @@ -368,7 +379,9 @@ def compute_elementwise_output_logical_to_physical_perm(*tensors, _skip_checks=F # Filters the tensors to actual tensors if not _skip_checks: tensors = tuple( - a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) + a + for a in tensors + if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) ) # Short-circuits for CPU scalar case @@ -388,7 +401,9 @@ def compute_elementwise_output_logical_to_physical_perm(*tensors, _skip_checks=F # TODO: do channels last too is_contiguous = True for t in tensors: - is_contiguous = is_contiguous and t.is_contiguous(memory_format=torch.contiguous_format) + is_contiguous = is_contiguous and t.is_contiguous( + memory_format=torch.contiguous_format + ) if is_contiguous: return list(range(ndim)) @@ -471,7 +486,9 @@ def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]: permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical new_strides = make_contiguous_strides_for(permuted_shape) - permuted_strides = apply_perm(new_strides, invert_perm(logical_to_physical_perm)) # to logical + permuted_strides = apply_perm( + new_strides, invert_perm(logical_to_physical_perm) + ) # to logical return tuple(permuted_strides) @@ -589,7 +606,9 @@ def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int: # Takes a dimension or sequence of dimensions and "wraps" them, # mapping negative offsets to positive ones @overload -def canonicalize_dims(rank: int, indices: Sequence[int], wrap_scalar: bool = True) -> Tuple[int, ...]: +def canonicalize_dims( + rank: int, indices: Sequence[int], wrap_scalar: bool = True +) -> Tuple[int, ...]: pass @@ -740,7 +759,9 @@ def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]: # Extracts dimensions that might be passed either as a list/tuple or as varargs. # A typical case is Tensor.permute . -def extract_dims_from_varargs(dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]]) -> DimsSequenceType: +def extract_dims_from_varargs( + dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]] +) -> DimsSequenceType: if dims and isinstance(dims[0], Sequence): assert len(dims) == 1 dims = cast(Tuple[DimsSequenceType], dims) @@ -805,8 +826,10 @@ def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]: shape = list(shape) torch._check( newsize != 0, - lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the " - f"unspecified dimension size -1 can be any value and is ambiguous"), + lambda: ( + f"cannot reshape tensor of 0 elements into shape {shape} because the " + f"unspecified dimension size -1 can be any value and is ambiguous" + ), ) shape[dim] = numel // newsize return tuple(shape) @@ -1054,11 +1077,15 @@ def get_higher_dtype( def check_pin_memory(pin_memory: bool): - torch._check_not_implemented(not pin_memory, lambda: "PrimTorch does not support pinned memory") + torch._check_not_implemented( + not pin_memory, lambda: "PrimTorch does not support pinned memory" + ) def check_layout(layout: torch.layout): - torch._check_not_implemented(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}") + torch._check_not_implemented( + layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}" + ) # TODO: maybe unify with can_cast_to? @@ -1172,6 +1199,7 @@ _computation_dtype_map = { def get_computation_dtype(dtype: torch.dtype) -> torch.dtype: return _computation_dtype_map.get(dtype, dtype) + _cpu_acc_type_map = { torch.bfloat16: torch.float64, torch.float16: torch.float64, @@ -1180,6 +1208,7 @@ _cpu_acc_type_map = { torch.complex64: torch.complex128, } + def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype: # Equivalent to at::toAccumulateType, prefer computation_dtype where possible if device.type == "cpu": @@ -1331,9 +1360,7 @@ def elementwise_dtypes( highest_type: type = bool for x in args: if not isinstance(x, (Number, TensorLike, sympy.Symbol)): - msg = ( - f"Unexpected type {str(type(x))} when computing elementwise type promotion!" - ) + msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!" raise ValueError(msg) if isinstance(x, Number): @@ -1413,9 +1440,7 @@ def elementwise_dtypes( elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL: return get_computation_dtype(result_dtype), torch.bool else: - raise ValueError( - f"Unknown type promotion kind {str(type_promotion_kind)}" - ) + raise ValueError(f"Unknown type promotion kind {str(type_promotion_kind)}") def reduction_dtypes( @@ -1444,6 +1469,7 @@ def reduction_dtypes( result_dtype = torch.bool return computation_dtype, result_dtype + # This function's logic is borrowed from the following functions defined in C++: # batched_matrix_contiguous_strides and contiguous_strides def make_contiguous_strides_for( @@ -1661,9 +1687,12 @@ def check( .. note:: This function is planned for removal in the future. Please use `torch._check*` functions instead. """ - warnings.warn(DeprecationWarning( - "'torch._prims_common.check' will be removed in the future. Please use " - "'torch._check*' functions instead")) + warnings.warn( + DeprecationWarning( + "'torch._prims_common.check' will be removed in the future. Please use " + "'torch._check*' functions instead" + ) + ) torch._check_with(exc_type, b, s) @@ -1752,8 +1781,8 @@ def get_aten_op(fn: Callable, name: str): """ module = fn.__module__ prefix = "torch._refs" - assert(module.startswith(prefix)) - module = module[len(prefix):] + assert module.startswith(prefix) + module = module[len(prefix) :] # We want to go from .special / .nn.functional # to special and special_ / nn_functional_ if module: @@ -1786,12 +1815,18 @@ def clone_preserve_strides(x): # We should revisit this when we add a compositional as_strided op, # and also as part of https://github.com/pytorch/pytorch/issues/90507 try: - old = torch._C._dispatch_tls_is_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView) - torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView, True) + old = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, True + ) buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone() return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) finally: - torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView, old) + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, old + ) def alert_not_deterministic(caller: str): @@ -1801,16 +1836,20 @@ def alert_not_deterministic(caller: str): f"{caller} does not have a deterministic implementation, but you set " f"'torch.use_deterministic_algorithms(True, warn_only=True)'. " f"You can file an issue at https://github.com/pytorch/pytorch/issues " - f"to help us prioritize adding deterministic support for this operation.") + f"to help us prioritize adding deterministic support for this operation." + ) else: torch._check( False, - lambda: (f"{caller} does not have a deterministic implementation, but you set " - f"'torch.use_deterministic_algorithms(True)'. You can turn off " - f"determinism just for this operation, or you can use the " - f"'warn_only=True' option, if that's acceptable for your application. " - f"You can also file an issue at https://github.com/pytorch/pytorch/issues " - f"to help us prioritize adding deterministic support for this operation.")) + lambda: ( + f"{caller} does not have a deterministic implementation, but you set " + f"'torch.use_deterministic_algorithms(True)'. You can turn off " + f"determinism just for this operation, or you can use the " + f"'warn_only=True' option, if that's acceptable for your application. " + f"You can also file an issue at https://github.com/pytorch/pytorch/issues " + f"to help us prioritize adding deterministic support for this operation." + ), + ) class CUDARngStateHelper: diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index c696bb2085c4..f3c9ce12eba2 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -1,37 +1,43 @@ +import inspect +import warnings +from functools import wraps +from itertools import chain + +from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple + import torch +import torch._prims_common as utils from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, Number, NumberType, + ShapeType, TensorLike, TensorLikeType, - ShapeType, - ELEMENTWISE_TYPE_PROMOTION_KIND, ) -import torch._prims_common as utils from torch.utils._pytree import tree_flatten, tree_unflatten -from typing import Callable, Sequence, Tuple, NamedTuple, Optional, overload -import inspect -from functools import wraps -import warnings -from itertools import chain @overload def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: pass + @overload def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType: pass + @overload def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence: pass + @overload def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None: pass + # TODO: implement ref.cast with an option to enforce safe casting def _maybe_convert_to_dtype(a, dtype): if isinstance(a, TensorLike): @@ -47,9 +53,7 @@ def _maybe_convert_to_dtype(a, dtype): if a is None: return None - raise ValueError( - f"Received type {type(a)} that is neither a tensor or a number!" - ) + raise ValueError(f"Received type {type(a)} that is neither a tensor or a number!") def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: @@ -289,7 +293,9 @@ def out_wrapper(*out_names: str, exact_dtype: bool = False): def backwards_not_supported(prim): def redispatch_prim(args, kwargs): with torch._C._AutoDispatchBelowAutograd(): - old = torch._C._dispatch_tls_is_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView) + old = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) return prim(*args, **kwargs) class BackwardsNotSupported(torch.autograd.Function): @@ -305,7 +311,9 @@ def backwards_not_supported(prim): @wraps(prim) def _autograd_impl(*args, **kwargs): flat_args, args_spec = tree_flatten((args, kwargs)) - if torch.is_grad_enabled() and any(a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)): + if torch.is_grad_enabled() and any( + a.requires_grad for a in flat_args if isinstance(a, torch.Tensor) + ): # TODO: There is a subtle bug here: prims like copy_to # return their input argument after mutating it; and custom # autograd function will incorrectly turn the result into diff --git a/torch/amp/__init__.py b/torch/amp/__init__.py index 955505b6033b..f080d3a978d3 100644 --- a/torch/amp/__init__.py +++ b/torch/amp/__init__.py @@ -1 +1 @@ -from .autocast_mode import autocast, _enter_autocast, _exit_autocast +from .autocast_mode import _enter_autocast, _exit_autocast, autocast diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 52cbe7a823eb..018108818040 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -1,20 +1,24 @@ -import torch import functools import warnings from typing import Any, Optional + +import torch from torch.types import _dtype -__all__ = ['autocast_decorator', 'autocast'] +__all__ = ["autocast_decorator", "autocast"] + def autocast_decorator(autocast_instance, func): @functools.wraps(func) def decorate_autocast(*args, **kwargs): with autocast_instance: return func(*args, **kwargs) - decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in script mode' # type: ignore[attr-defined] + + decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode" # type: ignore[attr-defined] return decorate_autocast + class autocast: r""" Instances of :class:`autocast` serve as context managers or decorators that @@ -179,10 +183,14 @@ class autocast: cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled. Default: ``True`` """ - def __init__(self, device_type : str, - dtype : Optional[_dtype] = None, - enabled : bool = True, - cache_enabled : Optional[bool] = None): + + def __init__( + self, + device_type: str, + dtype: Optional[_dtype] = None, + enabled: bool = True, + cache_enabled: Optional[bool] = None, + ): if torch._jit_internal.is_scripting(): self._enabled = enabled self.device = device_type @@ -192,71 +200,90 @@ class autocast: return self.device = device_type self.custom_backend_name = torch._C._get_privateuse1_backend_name() - if self.device == 'cuda': + if self.device == "cuda": self.fast_dtype = torch.get_autocast_gpu_dtype() - elif self.device == 'cpu': + elif self.device == "cpu": self.fast_dtype = torch.get_autocast_cpu_dtype() - elif self.device == 'xpu': + elif self.device == "xpu": self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined] - elif self.device == 'ipu': + elif self.device == "ipu": self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined] - elif self.device == 'hpu': + elif self.device == "hpu": self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined] - elif self.device == 'xla': + elif self.device == "xla": self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined] elif self.device == self.custom_backend_name: - necessary_funcs = ['is_autocast_enabled', 'set_autocast_enabled', 'get_autocast_dtype', - 'set_autocast_dtype', 'get_amp_supported_dtype'] + necessary_funcs = [ + "is_autocast_enabled", + "set_autocast_enabled", + "get_autocast_dtype", + "set_autocast_dtype", + "get_amp_supported_dtype", + ] message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not " message += "registered a module or the module miss some necessary funcs. The backend should register " message += "a module by `torch._register_device_module`, and the module must have these funcs: \n" message += "`is_autocast_enabled() -> bool`, `set_autocast_enabled(bool) -> None`, " message += "`get_autocast_dtype() -> torch.dtype`, `set_autocast_dtype(torch.dtype) " - message += "-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n" + message += ( + "-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n" + ) assert hasattr(torch, self.custom_backend_name), message self.custom_device_mod = getattr(torch, self.custom_backend_name) for func in necessary_funcs: - assert hasattr(self.custom_device_mod, func), message + f"But the func `{func}` is missing. \n" + assert hasattr(self.custom_device_mod, func), ( + message + f"But the func `{func}` is missing. \n" + ) self.fast_dtype = self.custom_device_mod.get_autocast_dtype() else: - raise RuntimeError(f'User specified an unsupported autocast device_type \'{self.device}\'') + raise RuntimeError( + f"User specified an unsupported autocast device_type '{self.device}'" + ) self._cache_enabled = torch.is_autocast_cache_enabled() - if enabled and torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda': - warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling') + if ( + enabled + and torch.cuda.amp.common.amp_definitely_not_available() + and self.device == "cuda" + ): + warnings.warn( + "User provided device_type of 'cuda', but CUDA is not available. Disabling" + ) enabled = False if dtype is not None: self.fast_dtype = dtype if cache_enabled is not None: self._cache_enabled = cache_enabled - if self.device == 'cpu': + if self.device == "cpu": supported_dtype = [torch.bfloat16] if self.fast_dtype not in supported_dtype: - error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n' - error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.' + error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += ( + "CPU Autocast only supports dtype of torch.bfloat16 currently." + ) warnings.warn(error_message) enabled = False - elif self.device == 'xpu': + elif self.device == "xpu": supported_dtype = [torch.bfloat16, torch.float16] if self.fast_dtype not in supported_dtype: - error_message = 'In XPU autocast, but the target dtype is not supported. Disabling autocast.\n' - error_message += 'XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.' + error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." warnings.warn(error_message) enabled = False - elif self.device == 'ipu': + elif self.device == "ipu": supported_dtypes = [torch.bfloat16, torch.float16] if self.fast_dtype not in supported_dtypes: - error_message = 'In IPU autocast, but the target dtype is not supported. Disabling autocast.\n' - error_message += 'IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.' + error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." warnings.warn(error_message) enabled = False - elif self.device == 'hpu': + elif self.device == "hpu": supported_dtype = [torch.bfloat16, torch.float16] if self.fast_dtype not in supported_dtype: - error_message = 'In HPU autocast, but the target dtype is not supported. Disabling autocast.\n' - error_message += 'HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.' + error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." warnings.warn(error_message) enabled = False elif self.device == self.custom_backend_name: @@ -264,17 +291,27 @@ class autocast: if self.fast_dtype not in supported_dtype: error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. " error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of " - error_message += ", ".join(str(dtype) for dtype in supported_dtype) + " currently." + error_message += ( + ", ".join(str(dtype) for dtype in supported_dtype) + " currently." + ) warnings.warn(error_message) enabled = False - elif self.device == 'cuda': - if enabled and self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): - raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.') - elif self.device == 'xla': + elif self.device == "cuda": + if ( + enabled + and self.fast_dtype == torch.bfloat16 + and not torch.cuda.is_bf16_supported() + ): + raise RuntimeError( + "Current CUDA Device does not support bfloat16. Please switch dtype to float16." + ) + elif self.device == "xla": supported_dtype = [torch.bfloat16] if self.fast_dtype not in supported_dtype: - error_message = 'In XLA autocast, but the target dtype is not supported. Disabling autocast.\n' - error_message += 'XLA Autocast only supports dtype of torch.bfloat16 currently.' + error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += ( + "XLA Autocast only supports dtype of torch.bfloat16 currently." + ) warnings.warn(error_message) enabled = False self._enabled = enabled @@ -285,31 +322,31 @@ class autocast: return self self.prev_cache_enabled = torch.is_autocast_cache_enabled() - if self.device == 'cpu': + if self.device == "cpu": self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type] torch.autocast_increment_nesting() - elif self.device == 'xpu': - self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined] + elif self.device == "xpu": + self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined] self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined] torch.autocast_increment_nesting() - elif self.device == 'ipu': - self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined] + elif self.device == "ipu": + self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined] self.prev_fastdtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined] torch.set_autocast_ipu_enabled(self._enabled) # type: ignore[attr-defined] torch.set_autocast_ipu_dtype(self.fast_dtype) # type: ignore[attr-defined] torch.autocast_increment_nesting() - elif self.device == 'hpu': - self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined] + elif self.device == "hpu": + self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined] self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined] torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined] torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined] torch.autocast_increment_nesting() - elif self.device == 'xla': + elif self.device == "xla": self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined] self.prev_fastdtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined] torch.set_autocast_xla_enabled(self._enabled) # type: ignore[attr-defined] @@ -334,31 +371,31 @@ class autocast: return # Drop the cache when we exit to a nesting level that's outside any instance of autocast. - if self.device == 'cpu': + if self.device == "cpu": if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_cpu_enabled(self.prev) torch.set_autocast_cpu_dtype(self.prev_fastdtype) - elif self.device == 'xpu': + elif self.device == "xpu": if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() - torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined] - torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] - elif self.device == 'ipu': + torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined] + torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] + elif self.device == "ipu": if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() - torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined] - torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] - elif self.device == 'hpu': + torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined] + torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] + elif self.device == "hpu": if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() - torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined] - torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] - elif self.device == 'xla': + torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined] + torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] + elif self.device == "xla": if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() - torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined] - torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined] + torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined] + torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined] elif self.device == self.custom_backend_name: if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() @@ -377,13 +414,16 @@ class autocast: return func return autocast_decorator(self, func) + # These functions aren't meant for public usage. # They are what we trace into a graph during pre_dispatch tracing # when we encounter an autocast context manager. def _enter_autocast(*vals): # For pre-dispatch tracing, if a TorchFunction mode is active, we'll want to trace this into a graph. if torch._C._is_torch_function_mode_enabled(): - return torch.overrides.handle_torch_function(torch.amp._enter_autocast, [], *vals) + return torch.overrides.handle_torch_function( + torch.amp._enter_autocast, [], *vals + ) mode = torch.amp.autocast(*vals) mode.__enter__() return mode diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index dd814990e6fc..269cb724121a 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -6,34 +6,39 @@ for which gradients should be computed with the ``requires_grad=True`` keyword. As of now, we only support autograd for floating point :class:`Tensor` types ( half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble). """ -import torch import warnings +from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union -from torch.types import _TensorOrTensors, _size -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast +import torch + +from torch.types import _size, _TensorOrTensors +from .. import _vmap_internals +from ..overrides import handle_torch_function, has_torch_function, is_tensor_like +from . import forward_ad, functional, graph +from .anomaly_mode import detect_anomaly, set_detect_anomaly +from .function import Function, NestedIOFunction +from .grad_mode import ( + _force_original_view_tracking, + _unsafe_preserve_version_counter, + enable_grad, + inference_mode, + no_grad, + set_grad_enabled, + set_multithreading_enabled, +) +from .gradcheck import gradcheck, gradgradcheck from .variable import Variable -from .function import Function, NestedIOFunction -from .gradcheck import gradcheck, gradgradcheck -from .grad_mode import ( - no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking, - _unsafe_preserve_version_counter -) -from .anomaly_mode import detect_anomaly, set_detect_anomaly -from ..overrides import has_torch_function, handle_torch_function, is_tensor_like -from . import functional -from . import forward_ad -from . import graph -from .. import _vmap_internals -__all__ = ['Variable', 'Function', 'backward', 'grad_mode'] +__all__ = ["Variable", "Function", "backward", "grad_mode"] _OptionalTensor = Optional[torch.Tensor] _ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor] -def _calculate_shape(output: torch.Tensor, grad: torch.Tensor, - is_grads_batched: bool) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]: +def _calculate_shape( + output: torch.Tensor, grad: torch.Tensor, is_grads_batched: bool +) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]: # is_same_size ensures that both tensors are either nested or non nested if output.is_nested: if is_grads_batched: @@ -47,63 +52,97 @@ def _calculate_shape(output: torch.Tensor, grad: torch.Tensor, reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:] return reg_out_shape, reg_grad_shape -def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor], - is_grads_batched: bool) -> Tuple[_OptionalTensor, ...]: + +def _make_grads( + outputs: Sequence[torch.Tensor], + grads: Sequence[_OptionalTensor], + is_grads_batched: bool, +) -> Tuple[_OptionalTensor, ...]: new_grads: List[_OptionalTensor] = [] for out, grad in zip(outputs, grads): if isinstance(grad, torch.Tensor): first_grad = grad if not is_grads_batched else grad[0] if not torch.is_same_size(out, first_grad): - out_shape, grad_shape = _calculate_shape(out, first_grad, is_grads_batched) + out_shape, grad_shape = _calculate_shape( + out, first_grad, is_grads_batched + ) if is_grads_batched: - raise RuntimeError("If `is_grads_batched=True`, we interpret the first " - "dimension of each grad_output as the batch dimension. " - "The sizes of the remaining dimensions are expected to match " - "the shape of corresponding output, but a mismatch " - "was detected: grad_output[" - + str(grads.index(grad)) + "] has a shape of " - + str(grad_shape) + " and output[" - + str(outputs.index(out)) + "] has a shape of " - + str(out_shape) + ". " - "If you only want some tensors in `grad_output` to be considered " - "batched, consider using vmap.") + raise RuntimeError( + "If `is_grads_batched=True`, we interpret the first " + "dimension of each grad_output as the batch dimension. " + "The sizes of the remaining dimensions are expected to match " + "the shape of corresponding output, but a mismatch " + "was detected: grad_output[" + + str(grads.index(grad)) + + "] has a shape of " + + str(grad_shape) + + " and output[" + + str(outputs.index(out)) + + "] has a shape of " + + str(out_shape) + + ". " + "If you only want some tensors in `grad_output` to be considered " + "batched, consider using vmap." + ) else: - raise RuntimeError("Mismatch in shape: grad_output[" - + str(grads.index(grad)) + "] has a shape of " - + str(grad_shape) + " and output[" - + str(outputs.index(out)) + "] has a shape of " - + str(out_shape) + ".") + raise RuntimeError( + "Mismatch in shape: grad_output[" + + str(grads.index(grad)) + + "] has a shape of " + + str(grad_shape) + + " and output[" + + str(outputs.index(out)) + + "] has a shape of " + + str(out_shape) + + "." + ) if out.dtype.is_complex != grad.dtype.is_complex: - raise RuntimeError("For complex Tensors, both grad_output and output" - " are required to have the same dtype." - " Mismatch in dtype: grad_output[" - + str(grads.index(grad)) + "] has a dtype of " - + str(grad.dtype) + " and output[" - + str(outputs.index(out)) + "] has a dtype of " - + str(out.dtype) + ".") + raise RuntimeError( + "For complex Tensors, both grad_output and output" + " are required to have the same dtype." + " Mismatch in dtype: grad_output[" + + str(grads.index(grad)) + + "] has a dtype of " + + str(grad.dtype) + + " and output[" + + str(outputs.index(out)) + + "] has a dtype of " + + str(out.dtype) + + "." + ) new_grads.append(grad) elif grad is None: if out.requires_grad: if out.numel() != 1: - raise RuntimeError("grad can be implicitly created only for scalar outputs") + raise RuntimeError( + "grad can be implicitly created only for scalar outputs" + ) if not out.dtype.is_floating_point: - msg = ("grad can be implicitly created only for real scalar outputs" - f" but got {out.dtype}") + msg = ( + "grad can be implicitly created only for real scalar outputs" + f" but got {out.dtype}" + ) raise RuntimeError(msg) - new_grads.append(torch.ones_like(out, memory_format=torch.preserve_format)) + new_grads.append( + torch.ones_like(out, memory_format=torch.preserve_format) + ) else: new_grads.append(None) else: - raise TypeError("gradients can be either Tensors or None, but got " + - type(grad).__name__) + raise TypeError( + "gradients can be either Tensors or None, but got " + + type(grad).__name__ + ) return tuple(new_grads) -def _tensor_or_tensors_to_tuple(tensors: Optional[_TensorOrTensors], length: int) -> Tuple[_OptionalTensor, ...]: +def _tensor_or_tensors_to_tuple( + tensors: Optional[_TensorOrTensors], length: int +) -> Tuple[_OptionalTensor, ...]: if tensors is None: - return (None, ) * length + return (None,) * length if isinstance(tensors, torch.Tensor): - return (tensors, ) + return (tensors,) return tuple(tensors) @@ -176,22 +215,30 @@ def backward( raise RuntimeError( "backward() called inside a functorch transform. This is not " "supported, please use functorch.grad or functorch.vjp instead " - "or call backward() outside of functorch transforms.") + "or call backward() outside of functorch transforms." + ) if grad_variables is not None: warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.") if grad_tensors is None: grad_tensors = grad_variables else: - raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) " - "arguments both passed to backward(). Please only " - "use 'grad_tensors'.") + raise RuntimeError( + "'grad_tensors' and 'grad_variables' (deprecated) " + "arguments both passed to backward(). Please only " + "use 'grad_tensors'." + ) if inputs is not None and len(inputs) == 0: raise RuntimeError("'inputs' argument to backward() cannot be empty.") tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors) - inputs = (inputs,) if isinstance(inputs, torch.Tensor) else \ - tuple(inputs) if inputs is not None else tuple() + inputs = ( + (inputs,) + if isinstance(inputs, torch.Tensor) + else tuple(inputs) + if inputs is not None + else tuple() + ) grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors)) grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False) @@ -202,8 +249,15 @@ def backward( # some Python versions print out the first line of a multi-line function # calls in the traceback and some print out the last line Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - tensors, grad_tensors_, retain_graph, create_graph, inputs, - allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass + tensors, + grad_tensors_, + retain_graph, + create_graph, + inputs, + allow_unreachable=True, + accumulate_grad=True, + ) # Calls into the C++ engine to run the backward pass + def grad( outputs: _TensorOrTensors, @@ -273,12 +327,19 @@ def grad( """ if materialize_grads and allow_unused is False: - raise ValueError("Expected allow_unused to be True or not passed when materialize_grads=True, " - "but got: allow_unused=False.") + raise ValueError( + "Expected allow_unused to be True or not passed when materialize_grads=True, " + "but got: allow_unused=False." + ) if allow_unused is None: allow_unused = materialize_grads - t_outputs = cast(Tuple[torch.Tensor, ...], (outputs,) if is_tensor_like(outputs) else tuple(outputs)) - t_inputs = cast(Tuple[torch.Tensor, ...], (inputs,) if is_tensor_like(inputs) else tuple(inputs)) + t_outputs = cast( + Tuple[torch.Tensor, ...], + (outputs,) if is_tensor_like(outputs) else tuple(outputs), + ) + t_inputs = cast( + Tuple[torch.Tensor, ...], (inputs,) if is_tensor_like(inputs) else tuple(inputs) + ) overridable_args = t_outputs + t_inputs if has_torch_function(overridable_args): return handle_torch_function( @@ -296,12 +357,16 @@ def grad( ) if not only_inputs: - warnings.warn("only_inputs argument is deprecated and is ignored now " - "(defaults to True). To accumulate gradient for other " - "parts of the graph, please use torch.autograd.backward.") + warnings.warn( + "only_inputs argument is deprecated and is ignored now " + "(defaults to True). To accumulate gradient for other " + "parts of the graph, please use torch.autograd.backward." + ) grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(t_outputs)) - grad_outputs_ = _make_grads(t_outputs, grad_outputs_, is_grads_batched=is_grads_batched) + grad_outputs_ = _make_grads( + t_outputs, grad_outputs_, is_grads_batched=is_grads_batched + ) if retain_graph is None: retain_graph = create_graph @@ -310,18 +375,38 @@ def grad( # some Python versions print out the first line of multi-line function # calls in the traceback and some print out the last line if is_grads_batched: + def vjp(gO): return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, gO, retain_graph, create_graph, t_inputs, - allow_unused, accumulate_grad=False) # Calls into the C++ engine to run the backward pass - result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs_) + t_outputs, + gO, + retain_graph, + create_graph, + t_inputs, + allow_unused, + accumulate_grad=False, + ) # Calls into the C++ engine to run the backward pass + + result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)( + grad_outputs_ + ) else: result = Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, grad_outputs_, retain_graph, create_graph, t_inputs, - allow_unused, accumulate_grad=False) # Calls into the C++ engine to run the backward pass + t_outputs, + grad_outputs_, + retain_graph, + create_graph, + t_inputs, + allow_unused, + accumulate_grad=False, + ) # Calls into the C++ engine to run the backward pass if materialize_grads: - result = tuple(output if output is not None else torch.zeros_like(input, requires_grad=True) - for (output, input) in zip(result, t_inputs)) + result = tuple( + output + if output is not None + else torch.zeros_like(input, requires_grad=True) + for (output, input) in zip(result, t_inputs) + ) return result @@ -343,7 +428,10 @@ def _is_checkpoint_valid(): def variable(*args, **kwargs): - raise RuntimeError("torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead") + raise RuntimeError( + "torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead" + ) + # Monkey patching variable.Variable to fix FX codegen. FX generates a call by roughly doing # f"{fn.__module__}.{fn.__name__}(...). This yields torch.autograd.variable.Variable(...) in the @@ -383,6 +471,7 @@ from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState from . import profiler + def _register_py_tensor_class_for_device(device, cls): if not isinstance(cls, type): raise RuntimeError("cls isn't a typeinfo object") @@ -390,4 +479,6 @@ def _register_py_tensor_class_for_device(device, cls): is_multithreading_enabled = torch._C._is_multithreading_enabled -torch._C._add_docstr(is_multithreading_enabled, "Returns True if multithreading is currently enabled.") +torch._C._add_docstr( + is_multithreading_enabled, "Returns True if multithreading is currently enabled." +) diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py index 8e334d5dc40e..6f3f45c5ad65 100644 --- a/torch/autograd/_functions/tensor.py +++ b/torch/autograd/_functions/tensor.py @@ -1,16 +1,18 @@ -from functools import reduce import warnings +from functools import reduce + import torch import torch._utils from ..function import Function class Type(Function): - @staticmethod def forward(ctx, i, dest_type): - warnings.warn("torch.autograd._functions.Type is deprecated as of PyTorch 2.1, please use " - "torch.tensor.to(dtype=dtype) instead.") + warnings.warn( + "torch.autograd._functions.Type is deprecated as of PyTorch 2.1, please use " + "torch.tensor.to(dtype=dtype) instead." + ) ctx.input_type = type(i) ctx.input_device = -1 if not i.is_cuda else i.get_device() return i.type(dest_type) @@ -26,18 +28,24 @@ class Type(Function): # TODO: deprecate this class Resize(Function): - @staticmethod def forward(ctx, tensor, sizes): ctx.sizes = sizes ctx.numel = reduce(lambda x, y: x * y, sizes, 1) if tensor.numel() != ctx.numel: - raise RuntimeError(("requested resize to {} ({} elements in total), " - "but the given tensor has a size of {} ({} elements). " - "autograd's resize can only change the shape of a given " - "tensor, while preserving the number of elements. ").format( - 'x'.join(map(str, sizes)), ctx.numel, - 'x'.join(map(str, tensor.size())), tensor.numel())) + raise RuntimeError( + ( + "requested resize to {} ({} elements in total), " + "but the given tensor has a size of {} ({} elements). " + "autograd's resize can only change the shape of a given " + "tensor, while preserving the number of elements. " + ).format( + "x".join(map(str, sizes)), + ctx.numel, + "x".join(map(str, tensor.size())), + tensor.numel(), + ) + ) ctx.input_sizes = tensor.size() if tensor.is_quantized: tensor.copy_(tensor) diff --git a/torch/autograd/_functions/utils.py b/torch/autograd/_functions/utils.py index 89e88e4af39a..735b6240a49b 100644 --- a/torch/autograd/_functions/utils.py +++ b/torch/autograd/_functions/utils.py @@ -11,9 +11,13 @@ def maybe_unexpand(tensor, old_size, check_same_size=True): if check_same_size and tensor.size() == old_size: return tensor num_unsqueezed = tensor.dim() - len(old_size) - expanded_dims = [dim for dim, (expanded, original) - in enumerate(zip(tensor.size()[num_unsqueezed:], old_size)) - if expanded != original] + expanded_dims = [ + dim + for dim, (expanded, original) in enumerate( + zip(tensor.size()[num_unsqueezed:], old_size) + ) + if expanded != original + ] for _ in range(num_unsqueezed): tensor = tensor.sum(0, keepdim=False) @@ -42,7 +46,7 @@ def check_onnx_broadcast(dims1, dims2): supported = False elif len1 > len2: broadcast = True - if numel2 != 1 and dims1[len1 - len2:] != dims2: + if numel2 != 1 and dims1[len1 - len2 :] != dims2: supported = False else: if dims1 != dims2: @@ -51,5 +55,7 @@ def check_onnx_broadcast(dims1, dims2): supported = False if not supported: - raise ValueError(f"Numpy style broadcasting is not supported in ONNX. Input dims are: {dims1}, {dims2}") + raise ValueError( + f"Numpy style broadcasting is not supported in ONNX. Input dims are: {dims1}, {dims2}" + ) return broadcast diff --git a/torch/autograd/anomaly_mode.py b/torch/autograd/anomaly_mode.py index c0eb56f234f2..5bfd4ec8227a 100644 --- a/torch/autograd/anomaly_mode.py +++ b/torch/autograd/anomaly_mode.py @@ -1,8 +1,9 @@ -import torch import warnings from typing import Any +import torch + __all__ = ["detect_anomaly", "set_detect_anomaly"] @@ -77,9 +78,12 @@ class detect_anomaly: self.prev = torch.is_anomaly_enabled() self.check_nan = check_nan self.prev_check_nan = torch.is_anomaly_check_nan_enabled() - warnings.warn('Anomaly Detection has been enabled. ' - 'This mode will increase the runtime ' - 'and should only be enabled for debugging.', stacklevel=2) + warnings.warn( + "Anomaly Detection has been enabled. " + "This mode will increase the runtime " + "and should only be enabled for debugging.", + stacklevel=2, + ) def __enter__(self) -> None: torch.set_anomaly_enabled(True, self.check_nan) diff --git a/torch/autograd/forward_ad.py b/torch/autograd/forward_ad.py index 440497bea35f..5c0be8e89d7f 100644 --- a/torch/autograd/forward_ad.py +++ b/torch/autograd/forward_ad.py @@ -1,11 +1,19 @@ -import torch import os -from .grad_mode import _DecoratorContextManager from collections import namedtuple from typing import Any -__all__ = ["UnpackedDualTensor", "enter_dual_level", "exit_dual_level", "make_dual", "unpack_dual", "dual_level"] +import torch +from .grad_mode import _DecoratorContextManager + +__all__ = [ + "UnpackedDualTensor", + "enter_dual_level", + "exit_dual_level", + "make_dual", + "unpack_dual", + "dual_level", +] # Global variable used to make the python API simpler to use _current_level = -1 @@ -22,8 +30,10 @@ def enter_dual_level(): global _current_level new_level = torch._C._enter_dual_level() if new_level != _current_level + 1: - raise RuntimeError("Entering a new forward AD level but the current level " - "is not valid. Make sure you did not modified it directly.") + raise RuntimeError( + "Entering a new forward AD level but the current level " + "is not valid. Make sure you did not modified it directly." + ) _current_level = new_level return new_level @@ -40,8 +50,10 @@ def exit_dual_level(*, level=None): if level is None: level = _current_level if level != _current_level: - raise RuntimeError("Trying to exit a forward AD level that was not the last one " - "that was created. This is not supported.") + raise RuntimeError( + "Trying to exit a forward AD level that was not the last one " + "that was created. This is not supported." + ) torch._C._exit_dual_level(level=level) _current_level = level - 1 @@ -93,16 +105,23 @@ def make_dual(tensor, tangent, *, level=None): level = _current_level if level < 0: - raise RuntimeError("Trying to create a dual Tensor for forward AD but no level " - "exists, make sure to enter_dual_level() first.") + raise RuntimeError( + "Trying to create a dual Tensor for forward AD but no level " + "exists, make sure to enter_dual_level() first." + ) if not (tensor.is_floating_point() or tensor.is_complex()): - raise ValueError(f"Expected primal to be floating point or complex, but got: {tensor.dtype}") + raise ValueError( + f"Expected primal to be floating point or complex, but got: {tensor.dtype}" + ) if not (tangent.is_floating_point() or tangent.is_complex()): - raise ValueError(f"Expected tangent to be floating point or complex, but got: {tangent.dtype}") + raise ValueError( + f"Expected tangent to be floating point or complex, but got: {tangent.dtype}" + ) return torch._VF._make_dual(tensor, tangent, level=level) -_UnpackedDualTensor = namedtuple('_UnpackedDualTensor', ['primal', 'tangent']) + +_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"]) class UnpackedDualTensor(_UnpackedDualTensor): @@ -176,15 +195,18 @@ class dual_level(_DecoratorContextManager): Please see the `forward-mode AD tutorial `__ for detailed steps on how to use this API. """ + def __enter__(self): return enter_dual_level() def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: exit_dual_level() + # Private helper functions _is_fwd_grad_enabled = torch._C._is_fwd_grad_enabled + # Private helper function to enable or disable fwd grad. # If you're a user and want to use this, please file an issue to discuss the use case. class _set_fwd_grad_enabled(_DecoratorContextManager): diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 954dee2e591f..77e8e56763be 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -1,20 +1,29 @@ -import torch -import torch._C as _C -from torch._C import _functions -import torch._functorch as _functorch -import torch.utils.hooks as hooks import functools import warnings from collections import OrderedDict from typing import Any, List, Optional, Tuple + +import torch +import torch._C as _C +import torch._functorch as _functorch +import torch.utils.hooks as hooks +from torch._C import _functions from torch._functorch.autograd_function import custom_function_call -__all__ = ["FunctionCtx", "BackwardCFunction", "FunctionMeta", "Function", "once_differentiable", "traceable", - "InplaceFunction", "NestedIOFunction"] +__all__ = [ + "FunctionCtx", + "BackwardCFunction", + "FunctionMeta", + "Function", + "once_differentiable", + "traceable", + "InplaceFunction", + "NestedIOFunction", +] + # Formerly known as: _ContextMethodMixin class FunctionCtx: - def save_for_backward(self, *tensors: torch.Tensor): r"""Saves given tensors for a future call to :func:`~Function.backward`. @@ -122,7 +131,8 @@ class FunctionCtx: for tensor in tensors: assert isinstance(tensor, torch.Tensor) or tensor is None, ( "save_for_forward expects all arguments to be tensors; you should " - "save non-tensors as attributes on ctx.") + "save non-tensors as attributes on ctx." + ) self.saved_for_forward = tensors @@ -165,9 +175,10 @@ class FunctionCtx: def mark_shared_storage(self, *pairs): warnings.warn( - 'mark_shared_storage is deprecated. ' - 'Tensors with shared storages are automatically tracked. Note ' - 'that calls to `set_()` are not tracked') + "mark_shared_storage is deprecated. " + "Tensors with shared storages are automatically tracked. Note " + "that calls to `set_()` are not tracked" + ) def mark_non_differentiable(self, *args: torch.Tensor): r"""Marks outputs as non-differentiable. @@ -246,11 +257,12 @@ class FunctionCtx: """ self.materialize_grads = value + # DO NOT USE: This is only defined to be able to load old serialized models _ContextMethodMixin = FunctionCtx -class _HookMixin: +class _HookMixin: @staticmethod def _register_hook(backward_hooks, hook): if backward_hooks is None: @@ -267,9 +279,11 @@ class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin): backward_fn = self._forward_cls.backward # type: ignore[attr-defined] vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined] if backward_fn is not Function.backward and vjp_fn is not Function.vjp: - raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom " - "Function is not allowed. You should only implement one " - "of them.") + raise RuntimeError( + "Implementing both 'backward' and 'vjp' for a custom " + "Function is not allowed. You should only implement one " + "of them." + ) user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn return user_fn(self, *args) @@ -289,14 +303,19 @@ class FunctionMeta(type): version of this function (which is generated on the fly by this metaclass). """ + def __init__(cls, name, bases, attrs): - backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls}) + backward_fn = type( + name + "Backward", (BackwardCFunction,), {"_forward_cls": cls} + ) cls._backward_cls = backward_fn super().__init__(name, bases, attrs) -class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta): +class _SingleLevelFunction( + _C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta +): @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: r""" @@ -338,8 +357,9 @@ class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass= ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward` if they are intended to be used for in ``jvp``. """ - raise NotImplementedError("You must implement the forward function for custom" - " autograd.Function.") + raise NotImplementedError( + "You must implement the forward function for custom" " autograd.Function." + ) @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any: @@ -381,9 +401,11 @@ class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass= first input to :func:`forward` needs gradient computed w.r.t. the output. """ - raise NotImplementedError("You must implement either the backward or vjp method for " - "your custom autograd.Function to use it with backward " - "mode AD.") + raise NotImplementedError( + "You must implement either the backward or vjp method for " + "your custom autograd.Function to use it with backward " + "mode AD." + ) # vjp and backward are alias of each other vjp = backward @@ -406,8 +428,10 @@ class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass= You can use the :attr:`ctx` object to pass any value from the forward to this functions. """ - raise NotImplementedError("You must implement the jvp function for custom " - "autograd.Function to use it with forward mode AD.") + raise NotImplementedError( + "You must implement the jvp function for custom " + "autograd.Function to use it with forward mode AD." + ) class Function(_SingleLevelFunction): @@ -443,18 +467,23 @@ class Function(_SingleLevelFunction): >>> # xdoctest: +SKIP >>> output = Exp.apply(input) """ + def __init__(self, *args, **kwargs): cls = self.__class__ - warnings.warn(f"{cls} should not be instantiated. Methods on autograd functions" - "are all static, so you should invoke them on the class itself. " - "Instantiating an autograd function will raise an " - "error in a future version of PyTorch.", DeprecationWarning) + warnings.warn( + f"{cls} should not be instantiated. Methods on autograd functions" + "are all static, so you should invoke them on the class itself. " + "Instantiating an autograd function will raise an " + "error in a future version of PyTorch.", + DeprecationWarning, + ) def __call__(self, *args, **kwargs): raise RuntimeError( "Legacy autograd function with non-static forward method is deprecated. " "Please use new-style autograd function with static forward method. " - "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)") + "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)" + ) # for the tracer is_traceable = False @@ -499,7 +528,8 @@ class Function(_SingleLevelFunction): """ raise NotImplementedError( "To use autograd.Function with vmap, you must either override the " - "vmap staticmethod or set generate_vmap_rule=True.") + "vmap staticmethod or set generate_vmap_rule=True." + ) @classmethod def apply(cls, *args, **kwargs): @@ -510,15 +540,16 @@ class Function(_SingleLevelFunction): if cls.setup_context == _SingleLevelFunction.setup_context: raise RuntimeError( - 'In order to use an autograd.Function with functorch transforms ' - '(vmap, grad, jvp, jacrev, ...), it must override the setup_context ' - 'staticmethod. For more details, please see ' - 'https://pytorch.org/docs/master/notes/extending.func.html') + "In order to use an autograd.Function with functorch transforms " + "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " + "staticmethod. For more details, please see " + "https://pytorch.org/docs/master/notes/extending.func.html" + ) return custom_function_call(cls, *args, **kwargs) -def once_differentiable(fn): +def once_differentiable(fn): @functools.wraps(fn) def wrapper(ctx, *args): with torch.no_grad(): @@ -536,8 +567,9 @@ def once_differentiable(fn): # Unfortunately, this leads to unexpected error messages ("no nodes # require computing gradients"), but I don't have a better idea. # These functions would raise an error in backward anyway. - requires_grad = any(isinstance(arg, torch.Tensor) and arg.requires_grad - for arg in args) + requires_grad = any( + isinstance(arg, torch.Tensor) and arg.requires_grad for arg in args + ) if not requires_grad: return outputs @@ -546,7 +578,9 @@ def once_differentiable(fn): err_fn = _functions.DelayedError( b"trying to differentiate twice a function that was marked " - b"with @once_differentiable", len(outputs)) + b"with @once_differentiable", + len(outputs), + ) # Create aliases of each output that has requires_grad=True. We need # at least one of the inputs to err_fn to require grad so that the @@ -558,6 +592,7 @@ def once_differentiable(fn): return var return err_fn(*[fake_requires_grad(v) for v in outputs]) + return wrapper @@ -577,7 +612,6 @@ def traceable(fn_cls): class InplaceFunction(Function): - def __init__(self, inplace=False): super().__init__() self.inplace = inplace @@ -591,18 +625,23 @@ def _nested_map(condition, fn, condition_msg=None): return None elif isinstance(obj, (list, tuple)): mapped = (_map(x) for x in obj) - if hasattr(obj, '_fields'): + if hasattr(obj, "_fields"): # obj is namedtuple return type(obj)(*mapped) return type(obj)(mapped) elif isinstance(obj, dict): - return {x : _map(obj[x]) for x in obj} + return {x: _map(obj[x]) for x in obj} else: - raise ValueError("Auto nesting doesn't know how to process " - "an input object of type " + torch.typename(obj) + - (". Accepted types: " + condition_msg + - ", or lists/tuples of them" - if condition_msg else "")) + raise ValueError( + "Auto nesting doesn't know how to process " + "an input object of type " + + torch.typename(obj) + + ( + ". Accepted types: " + condition_msg + ", or lists/tuples of them" + if condition_msg + else "" + ) + ) return _map @@ -613,8 +652,7 @@ def _jit_unwrap_structured(obj): return obj -def _iter_filter(condition, allow_unknown=False, condition_msg=None, - conversion=None): +def _iter_filter(condition, allow_unknown=False, condition_msg=None, conversion=None): def _iter(obj): if conversion is not None: obj = conversion(obj) @@ -632,11 +670,16 @@ def _iter_filter(condition, allow_unknown=False, condition_msg=None, elif allow_unknown: yield obj else: - raise ValueError("Auto nesting doesn't know how to process " - "an input object of type " + torch.typename(obj) + - (". Accepted types: " + condition_msg + - ", or lists/tuples of them" - if condition_msg else "")) + raise ValueError( + "Auto nesting doesn't know how to process " + "an input object of type " + + torch.typename(obj) + + ( + ". Accepted types: " + condition_msg + ", or lists/tuples of them" + if condition_msg + else "" + ) + ) return _iter @@ -661,17 +704,26 @@ def _unflatten(input, proto): return unflatten_helper(input, proto)[0] -_iter_jit_values = _iter_filter(lambda o: o is None or isinstance(o, torch._C.Value), - condition_msg="jit's Values or None") -_iter_tensors = _iter_filter(lambda x: isinstance(x, torch.Tensor), condition_msg="Tensors", - conversion=_jit_unwrap_structured) -_iter_tensors_permissive = _iter_filter(lambda x: isinstance(x, torch.Tensor), - allow_unknown=True, - condition_msg="Tensors (permissive)") -_iter_None_tensors = _iter_filter(lambda o: o is None or isinstance(o, torch.Tensor), - condition_msg="Tensors or None") -_map_tensor_data = _nested_map(lambda x: isinstance(x, torch.Tensor), lambda o: o.data, - condition_msg="Tensors") +_iter_jit_values = _iter_filter( + lambda o: o is None or isinstance(o, torch._C.Value), + condition_msg="jit's Values or None", +) +_iter_tensors = _iter_filter( + lambda x: isinstance(x, torch.Tensor), + condition_msg="Tensors", + conversion=_jit_unwrap_structured, +) +_iter_tensors_permissive = _iter_filter( + lambda x: isinstance(x, torch.Tensor), + allow_unknown=True, + condition_msg="Tensors (permissive)", +) +_iter_None_tensors = _iter_filter( + lambda o: o is None or isinstance(o, torch.Tensor), condition_msg="Tensors or None" +) +_map_tensor_data = _nested_map( + lambda x: isinstance(x, torch.Tensor), lambda o: o.data, condition_msg="Tensors" +) class NestedIOFunction(Function): diff --git a/torch/autograd/functional.py b/torch/autograd/functional.py index 109da01fb755..755494a88ade 100644 --- a/torch/autograd/functional.py +++ b/torch/autograd/functional.py @@ -1,7 +1,8 @@ +from typing import List, Tuple + import torch -from typing import Tuple, List -from . import forward_ad as fwAD from torch._vmap_internals import _vmap +from . import forward_ad as fwAD __all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"] @@ -14,7 +15,7 @@ def _as_tuple_nocheck(x): elif isinstance(x, list): return tuple(x) else: - return x, + return (x,) def _as_tuple(inp, arg_name=None, fn_name=None): @@ -31,11 +32,19 @@ def _as_tuple(inp, arg_name=None, fn_name=None): for i, el in enumerate(inp): if not isinstance(el, torch.Tensor): if is_inp_tuple: - raise TypeError("The {} given to {} must be either a Tensor or a tuple of Tensors but the" - " value at index {} has type {}.".format(arg_name, fn_name, i, type(el))) + raise TypeError( + "The {} given to {} must be either a Tensor or a tuple of Tensors but the" + " value at index {} has type {}.".format( + arg_name, fn_name, i, type(el) + ) + ) else: - raise TypeError("The {} given to {} must be either a Tensor or a tuple of Tensors but the" - " given {} has type {}.".format(arg_name, fn_name, arg_name, type(el))) + raise TypeError( + "The {} given to {} must be either a Tensor or a tuple of Tensors but the" + " given {} has type {}.".format( + arg_name, fn_name, arg_name, type(el) + ) + ) return is_inp_tuple, inp @@ -98,7 +107,9 @@ def _validate_v(v, other, is_other_tuple): # Both are assumed to be tuples of Tensors if len(other) != len(v): if is_other_tuple: - raise RuntimeError(f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}.") + raise RuntimeError( + f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}." + ) else: raise RuntimeError("The given v should contain a single Tensor.") @@ -107,7 +118,9 @@ def _validate_v(v, other, is_other_tuple): prepend = "" if is_other_tuple: prepend = f"Entry {idx} in " - raise RuntimeError(f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}.") + raise RuntimeError( + f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}." + ) def _check_requires_grad(inputs, input_type, strict): @@ -120,30 +133,49 @@ def _check_requires_grad(inputs, input_type, strict): for i, inp in enumerate(inputs): if inp is None: # This can only be reached for grad_inputs. - raise RuntimeError("The output of the user-provided function is independent of input {}." - " This is not allowed in strict mode.".format(i)) + raise RuntimeError( + "The output of the user-provided function is independent of input {}." + " This is not allowed in strict mode.".format(i) + ) if not inp.requires_grad: if input_type == "hessian": - raise RuntimeError("The hessian of the user-provided function with respect to input {}" - " is independent of the input. This is not allowed in strict mode." - " You should ensure that your function is thrice differentiable and that" - " the hessian depends on the inputs.".format(i)) + raise RuntimeError( + "The hessian of the user-provided function with respect to input {}" + " is independent of the input. This is not allowed in strict mode." + " You should ensure that your function is thrice differentiable and that" + " the hessian depends on the inputs.".format(i) + ) elif input_type == "jacobian": - raise RuntimeError("While computing the hessian, found that the jacobian of the user-provided" - " function with respect to input {} is independent of the input. This is not" - " allowed in strict mode. You should ensure that your function is twice" - " differentiable and that the jacobian depends on the inputs (this would be" - " violated by a linear function for example).".format(i)) + raise RuntimeError( + "While computing the hessian, found that the jacobian of the user-provided" + " function with respect to input {} is independent of the input. This is not" + " allowed in strict mode. You should ensure that your function is twice" + " differentiable and that the jacobian depends on the inputs (this would be" + " violated by a linear function for example).".format(i) + ) elif input_type == "grad_inputs": - raise RuntimeError("The gradient with respect to input {} is independent of the inputs of the" - " user-provided function. This is not allowed in strict mode.".format(i)) + raise RuntimeError( + "The gradient with respect to input {} is independent of the inputs of the" + " user-provided function. This is not allowed in strict mode.".format( + i + ) + ) else: - raise RuntimeError("Output {} of the user-provided function does not require gradients." - " The outputs must be computed in a differentiable manner from the input" - " when running in strict mode.".format(i)) + raise RuntimeError( + "Output {} of the user-provided function does not require gradients." + " The outputs must be computed in a differentiable manner from the input" + " when running in strict mode.".format(i) + ) -def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retain_graph=None, is_grads_batched=False): +def _autograd_grad( + outputs, + inputs, + grad_outputs=None, + create_graph=False, + retain_graph=None, + is_grads_batched=False, +): # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them. # This has the extra constraint that inputs has to be a tuple assert isinstance(outputs, tuple) @@ -163,9 +195,15 @@ def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retai # No differentiable output, we don't need to call the autograd engine return (None,) * len(inputs) else: - return torch.autograd.grad(new_outputs, inputs, new_grad_outputs, allow_unused=True, - create_graph=create_graph, retain_graph=retain_graph, - is_grads_batched=is_grads_batched) + return torch.autograd.grad( + new_outputs, + inputs, + new_grad_outputs, + allow_unused=True, + create_graph=create_graph, + retain_graph=retain_graph, + is_grads_batched=is_grads_batched, + ) def _fill_in_zeros(grads, refs, strict, create_graph, stage): @@ -181,30 +219,48 @@ def _fill_in_zeros(grads, refs, strict, create_graph, stage): if grads_i is None: if strict: if stage == "back": - raise RuntimeError("The output of the user-provided function is independent of " - "input {}. This is not allowed in strict mode.".format(i)) + raise RuntimeError( + "The output of the user-provided function is independent of " + "input {}. This is not allowed in strict mode.".format(i) + ) elif stage == "back_trick": - raise RuntimeError("The gradient with respect to the input is independent of entry {}" - " in the grad_outputs when using the double backward trick to compute" - " forward mode gradients. This is not allowed in strict mode.".format(i)) + raise RuntimeError( + "The gradient with respect to the input is independent of entry {}" + " in the grad_outputs when using the double backward trick to compute" + " forward mode gradients. This is not allowed in strict mode.".format( + i + ) + ) elif stage == "double_back": - raise RuntimeError("The jacobian of the user-provided function is independent of " - "input {}. This is not allowed in strict mode.".format(i)) + raise RuntimeError( + "The jacobian of the user-provided function is independent of " + "input {}. This is not allowed in strict mode.".format(i) + ) else: - raise RuntimeError("The hessian of the user-provided function is independent of " - "entry {} in the grad_jacobian. This is not allowed in strict " - "mode as it prevents from using the double backward trick to " - "replace forward mode AD.".format(i)) + raise RuntimeError( + "The hessian of the user-provided function is independent of " + "entry {} in the grad_jacobian. This is not allowed in strict " + "mode as it prevents from using the double backward trick to " + "replace forward mode AD.".format(i) + ) grads_i = torch.zeros_like(refs[i]) else: if strict and create_graph and not grads_i.requires_grad: if "double" not in stage: - raise RuntimeError("The jacobian of the user-provided function is independent of " - "input {}. This is not allowed in strict mode when create_graph=True.".format(i)) + raise RuntimeError( + "The jacobian of the user-provided function is independent of " + "input {}. This is not allowed in strict mode when create_graph=True.".format( + i + ) + ) else: - raise RuntimeError("The hessian of the user-provided function is independent of " - "input {}. This is not allowed in strict mode when create_graph=True.".format(i)) + raise RuntimeError( + "The hessian of the user-provided function is independent of " + "input {}. This is not allowed in strict mode when create_graph=True.".format( + i + ) + ) res += (grads_i,) @@ -213,6 +269,7 @@ def _fill_in_zeros(grads, refs, strict, create_graph, stage): # Public API + def vjp(func, inputs, v=None, create_graph=False, strict=False): r"""Function that computes the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs. @@ -279,7 +336,9 @@ def vjp(func, inputs, v=None, create_graph=False, strict=False): inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) outputs = func(*inputs) - is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "vjp") + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "vjp" + ) _check_requires_grad(outputs, "outputs", strict=strict) if v is not None: @@ -288,9 +347,11 @@ def vjp(func, inputs, v=None, create_graph=False, strict=False): _validate_v(v, outputs, is_outputs_tuple) else: if len(outputs) != 1 or outputs[0].nelement() != 1: - raise RuntimeError("The vector v can only be None if the " - "user-provided function returns " - "a single Tensor with a single element.") + raise RuntimeError( + "The vector v can only be None if the " + "user-provided function returns " + "a single Tensor with a single element." + ) enable_grad = True if create_graph else torch.is_grad_enabled() with torch.set_grad_enabled(enable_grad): @@ -301,7 +362,9 @@ def vjp(func, inputs, v=None, create_graph=False, strict=False): outputs = _grad_postprocess(outputs, create_graph) vjp = _grad_postprocess(vjp, create_graph) - return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(vjp, is_inputs_tuple) + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + vjp, is_inputs_tuple + ) def jvp(func, inputs, v=None, create_graph=False, strict=False): @@ -377,37 +440,51 @@ def jvp(func, inputs, v=None, create_graph=False, strict=False): _validate_v(v, inputs, is_inputs_tuple) else: if len(inputs) != 1 or inputs[0].nelement() != 1: - raise RuntimeError("The vector v can only be None if the input to " - "the user-provided function is a single Tensor " - "with a single element.") + raise RuntimeError( + "The vector v can only be None if the input to " + "the user-provided function is a single Tensor " + "with a single element." + ) outputs = func(*inputs) - is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "jvp") + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "jvp" + ) _check_requires_grad(outputs, "outputs", strict=strict) # The backward is linear so the value of grad_outputs is not important as # it won't appear in the double backward graph. We only need to ensure that # it does not contain inf or nan. - grad_outputs = tuple(torch.zeros_like(out, requires_grad=True) for out in outputs) + grad_outputs = tuple( + torch.zeros_like(out, requires_grad=True) for out in outputs + ) grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True) _check_requires_grad(grad_inputs, "grad_inputs", strict=strict) if create_graph: with torch.enable_grad(): - grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph) + grad_res = _autograd_grad( + grad_inputs, grad_outputs, v, create_graph=create_graph + ) jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") else: - grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph) + grad_res = _autograd_grad( + grad_inputs, grad_outputs, v, create_graph=create_graph + ) jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") # Cleanup objects and return them to the user outputs = _grad_postprocess(outputs, create_graph) jvp = _grad_postprocess(jvp, create_graph) - return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(jvp, is_outputs_tuple) + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + jvp, is_outputs_tuple + ) -def _construct_standard_basis_for(tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...]) -> Tuple[torch.Tensor, ...]: +def _construct_standard_basis_for( + tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...] +) -> Tuple[torch.Tensor, ...]: # This function: # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix. # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`. @@ -429,8 +506,10 @@ def _construct_standard_basis_for(tensors: Tuple[torch.Tensor, ...], tensor_nume assert len(tensors) == len(tensor_numels) assert len(tensors) > 0 total_numel = sum(tensor_numels) - chunks = tuple(tensor.new_zeros(total_numel, tensor_numel) - for tensor, tensor_numel in zip(tensors, tensor_numels)) + chunks = tuple( + tensor.new_zeros(total_numel, tensor_numel) + for tensor, tensor_numel in zip(tensors, tensor_numels) + ) diag_start_idx = 0 for chunk, numel in zip(chunks, tensor_numels): chunk.diagonal(diag_start_idx).fill_(1) @@ -440,10 +519,12 @@ def _construct_standard_basis_for(tensors: Tuple[torch.Tensor, ...], tensor_nume def _jacfwd(func, inputs, strict=False, vectorize=False): if strict: - raise RuntimeError('torch.autograd.functional.jacobian: `strict=True` ' - 'and `strategy="forward-mode"` are not supported together (yet). ' - 'Please either set `strict=False` or ' - '`strategy="reverse-mode"`.') + raise RuntimeError( + "torch.autograd.functional.jacobian: `strict=True` " + 'and `strategy="forward-mode"` are not supported together (yet). ' + "Please either set `strict=False` or " + '`strategy="reverse-mode"`.' + ) is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") output_info = [] @@ -458,8 +539,12 @@ def _jacfwd(func, inputs, strict=False, vectorize=False): def jvp(tangents): with fwAD.dual_level(): dual_inputs = tuple( - fwAD.make_dual(input, tangent.view_as(input)) for input, tangent in zip(inputs, tangents)) - _is_outputs_tuple, dual_outputs = _as_tuple(func(*dual_inputs), "outputs") + fwAD.make_dual(input, tangent.view_as(input)) + for input, tangent in zip(inputs, tangents) + ) + _is_outputs_tuple, dual_outputs = _as_tuple( + func(*dual_inputs), "outputs" + ) output_info.append(_is_outputs_tuple) jv = [] primal_outs = [] @@ -482,20 +567,32 @@ def _jacfwd(func, inputs, strict=False, vectorize=False): for jac, input_j in zip(jac.split(input_numels, dim=0), inputs): # We need to transpose the Jacobian because in forward AD, the # batch dimension represents that of the inputs - jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0) \ - .reshape(tuple([*output_i.shape, *input_j.shape])) # noqa: C409 + jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape( + (*output_i.shape, *input_j.shape) + ) # noqa: C409 jacobian_output_i_output.append(jacobian_input_i_output_j) jacobian_input_output.append(jacobian_output_i_output) # Omit [Step 4] because everything is already transposed w/ forward AD - return _tuple_postprocess(jacobian_input_output, (is_outputs_tuple, is_inputs_tuple)) + return _tuple_postprocess( + jacobian_input_output, (is_outputs_tuple, is_inputs_tuple) + ) else: - raise NotImplementedError("Computing Jacobian using forward-AD or forward-over-reverse Hessian is" - "only implemented for `vectorize=True`.") + raise NotImplementedError( + "Computing Jacobian using forward-AD or forward-over-reverse Hessian is" + "only implemented for `vectorize=True`." + ) -def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False, strategy="reverse-mode"): +def jacobian( + func, + inputs, + create_graph=False, + strict=False, + vectorize=False, + strategy="reverse-mode", +): r"""Function that computes the Jacobian of a given function. Args: @@ -574,13 +671,16 @@ def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False, st assert strategy in ("forward-mode", "reverse-mode"), ( 'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your ' 'function has more outputs than inputs, "forward-mode" tends to be more performant. ' - 'Otherwise, prefer to use "reverse-mode".') + 'Otherwise, prefer to use "reverse-mode".' + ) if strategy == "forward-mode": if create_graph: - raise NotImplementedError('torch.autograd.functional.jacobian: `create_graph=True` ' - 'and `strategy="forward-mode"` are not supported together (yet). ' - 'Please either set `create_graph=False` or ' - '`strategy="reverse-mode"`.') + raise NotImplementedError( + "torch.autograd.functional.jacobian: `create_graph=True` " + 'and `strategy="forward-mode"` are not supported together (yet). ' + "Please either set `create_graph=False` or " + '`strategy="reverse-mode"`.' + ) return _jacfwd(func, inputs, strict, vectorize) with torch.enable_grad(): @@ -588,17 +688,19 @@ def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False, st inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) outputs = func(*inputs) - is_outputs_tuple, outputs = _as_tuple(outputs, - "outputs of the user-provided function", - "jacobian") + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "jacobian" + ) _check_requires_grad(outputs, "outputs", strict=strict) if vectorize: if strict: - raise RuntimeError('torch.autograd.functional.jacobian: `strict=True` ' - 'and `vectorized=True` are not supported together. ' - 'Please either set `strict=False` or ' - '`vectorize=False`.') + raise RuntimeError( + "torch.autograd.functional.jacobian: `strict=True` " + "and `vectorized=True` are not supported together. " + "Please either set `strict=False` or " + "`vectorize=False`." + ) # NOTE: [Computing jacobian with vmap and grad for multiple outputs] # # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3). @@ -646,11 +748,21 @@ def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False, st # Step 2: Call vmap + autograd.grad def vjp(grad_output): - vj = list(_autograd_grad(flat_outputs, inputs, grad_output, create_graph=create_graph, is_grads_batched=True)) + vj = list( + _autograd_grad( + flat_outputs, + inputs, + grad_output, + create_graph=create_graph, + is_grads_batched=True, + ) + ) for el_idx, vj_el in enumerate(vj): if vj_el is not None: continue - vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand((sum(output_numels),) + inputs[el_idx].shape) + vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand( + (sum(output_numels),) + inputs[el_idx].shape + ) return tuple(vj) jacobians_of_flat_output = vjp(grad_outputs) @@ -672,44 +784,70 @@ def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False, st # before returning. jacobian_output_input = tuple(zip(*jacobian_input_output)) - jacobian_output_input = _grad_postprocess(jacobian_output_input, create_graph) - return _tuple_postprocess(jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)) + jacobian_output_input = _grad_postprocess( + jacobian_output_input, create_graph + ) + return _tuple_postprocess( + jacobian_output_input, (is_outputs_tuple, is_inputs_tuple) + ) jacobian: Tuple[torch.Tensor, ...] = tuple() for i, out in enumerate(outputs): - # mypy complains that expression and variable have different types due to the empty list jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment] for j in range(out.nelement()): - vj = _autograd_grad((out.reshape(-1)[j],), inputs, - retain_graph=True, create_graph=create_graph) + vj = _autograd_grad( + (out.reshape(-1)[j],), + inputs, + retain_graph=True, + create_graph=create_graph, + ) - for el_idx, (jac_i_el, vj_el, inp_el) in enumerate(zip(jac_i, vj, inputs)): + for el_idx, (jac_i_el, vj_el, inp_el) in enumerate( + zip(jac_i, vj, inputs) + ): if vj_el is not None: if strict and create_graph and not vj_el.requires_grad: - msg = ("The jacobian of the user-provided function is " - "independent of input {}. This is not allowed in " - "strict mode when create_graph=True.".format(i)) + msg = ( + "The jacobian of the user-provided function is " + "independent of input {}. This is not allowed in " + "strict mode when create_graph=True.".format(i) + ) raise RuntimeError(msg) jac_i_el.append(vj_el) else: if strict: - msg = ("Output {} of the user-provided function is " - "independent of input {}. This is not allowed in " - "strict mode.".format(i, el_idx)) + msg = ( + "Output {} of the user-provided function is " + "independent of input {}. This is not allowed in " + "strict mode.".format(i, el_idx) + ) raise RuntimeError(msg) jac_i_el.append(torch.zeros_like(inp_el)) - jacobian += (tuple(torch.stack(jac_i_el, dim=0).view(out.size() # type: ignore[operator] - + inputs[el_idx].size()) for (el_idx, jac_i_el) in enumerate(jac_i)), ) + jacobian += ( + tuple( + torch.stack(jac_i_el, dim=0).view( + out.size() + inputs[el_idx].size() # type: ignore[operator] + ) + for (el_idx, jac_i_el) in enumerate(jac_i) + ), + ) jacobian = _grad_postprocess(jacobian, create_graph) return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple)) -def hessian(func, inputs, create_graph=False, strict=False, vectorize=False, outer_jacobian_strategy="reverse-mode"): +def hessian( + func, + inputs, + create_graph=False, + strict=False, + vectorize=False, + outer_jacobian_strategy="reverse-mode", +): r"""Function that computes the Hessian of a given scalar function. Args: @@ -797,19 +935,27 @@ def hessian(func, inputs, create_graph=False, strict=False, vectorize=False, out """ is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian") - assert outer_jacobian_strategy in ("forward-mode", "reverse-mode"), ( - 'Expected strategy to be either "forward-mode" or "reverse-mode".') + assert outer_jacobian_strategy in ( + "forward-mode", + "reverse-mode", + ), 'Expected strategy to be either "forward-mode" or "reverse-mode".' def ensure_single_output_function(*inp): out = func(*inp) - is_out_tuple, t_out = _as_tuple(out, "outputs of the user-provided function", "hessian") + is_out_tuple, t_out = _as_tuple( + out, "outputs of the user-provided function", "hessian" + ) _check_requires_grad(t_out, "outputs", strict=strict) if is_out_tuple or not isinstance(out, torch.Tensor): - raise RuntimeError("The function given to hessian should return a single Tensor") + raise RuntimeError( + "The function given to hessian should return a single Tensor" + ) if out.nelement() != 1: - raise RuntimeError("The Tensor returned by the function given to hessian should contain a single element") + raise RuntimeError( + "The Tensor returned by the function given to hessian should contain a single element" + ) return out.squeeze() @@ -822,8 +968,14 @@ def hessian(func, inputs, create_graph=False, strict=False, vectorize=False, out _check_requires_grad(jac, "jacobian", strict=strict) return jac - res = jacobian(jac_func, inputs, create_graph=create_graph, strict=strict, vectorize=vectorize, - strategy=outer_jacobian_strategy) + res = jacobian( + jac_func, + inputs, + create_graph=create_graph, + strict=strict, + vectorize=vectorize, + strategy=outer_jacobian_strategy, + ) return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple)) @@ -894,17 +1046,25 @@ def vhp(func, inputs, v=None, create_graph=False, strict=False): _validate_v(v, inputs, is_inputs_tuple) else: if len(inputs) != 1 or inputs[0].nelement() != 1: - raise RuntimeError("The vector v can only be None if the input to the user-provided function " - "is a single Tensor with a single element.") + raise RuntimeError( + "The vector v can only be None if the input to the user-provided function " + "is a single Tensor with a single element." + ) outputs = func(*inputs) - is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "vhp") + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "vhp" + ) _check_requires_grad(outputs, "outputs", strict=strict) if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): - raise RuntimeError("The function given to vhp should return a single Tensor") + raise RuntimeError( + "The function given to vhp should return a single Tensor" + ) if outputs[0].nelement() != 1: - raise RuntimeError("The Tensor returned by the function given to vhp should contain a single element") + raise RuntimeError( + "The Tensor returned by the function given to vhp should contain a single element" + ) jac = _autograd_grad(outputs, inputs, create_graph=True) _check_requires_grad(jac, "jacobian", strict=strict) @@ -917,7 +1077,9 @@ def vhp(func, inputs, v=None, create_graph=False, strict=False): outputs = _grad_postprocess(outputs, create_graph) vhp = _grad_postprocess(vhp, create_graph) - return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(vhp, is_inputs_tuple) + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + vhp, is_inputs_tuple + ) def hvp(func, inputs, v=None, create_graph=False, strict=False): @@ -996,17 +1158,25 @@ def hvp(func, inputs, v=None, create_graph=False, strict=False): _validate_v(v, inputs, is_inputs_tuple) else: if len(inputs) != 1 or inputs[0].nelement() != 1: - raise RuntimeError("The vector v can only be None if the input to the user-provided function " - "is a single Tensor with a single element.") + raise RuntimeError( + "The vector v can only be None if the input to the user-provided function " + "is a single Tensor with a single element." + ) outputs = func(*inputs) - is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "hvp") + is_outputs_tuple, outputs = _as_tuple( + outputs, "outputs of the user-provided function", "hvp" + ) _check_requires_grad(outputs, "outputs", strict=strict) if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): - raise RuntimeError("The function given to hvp should return a single Tensor") + raise RuntimeError( + "The function given to hvp should return a single Tensor" + ) if outputs[0].nelement() != 1: - raise RuntimeError("The Tensor returned by the function given to hvp should contain a single element") + raise RuntimeError( + "The Tensor returned by the function given to hvp should contain a single element" + ) jac = _autograd_grad(outputs, inputs, create_graph=True) _check_requires_grad(jac, "jacobian", strict=strict) @@ -1019,9 +1189,13 @@ def hvp(func, inputs, v=None, create_graph=False, strict=False): enable_grad = True if create_graph else torch.is_grad_enabled() with torch.set_grad_enabled(enable_grad): grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph) - hvp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back_trick") + hvp = _fill_in_zeros( + grad_res, inputs, strict, create_graph, "double_back_trick" + ) outputs = _grad_postprocess(outputs, create_graph) hvp = _grad_postprocess(hvp, create_graph) - return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(hvp, is_inputs_tuple) + return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( + hvp, is_inputs_tuple + ) diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index 1b4980ff04bf..22a09dc9ce39 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -1,10 +1,17 @@ -import torch from typing import Any, Optional +import torch + from torch.utils._contextlib import _DecoratorContextManager -__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled', - 'inference_mode', 'set_multithreading_enabled'] +__all__ = [ + "no_grad", + "enable_grad", + "set_grad_enabled", + "inference_mode", + "set_multithreading_enabled", +] + class no_grad(_DecoratorContextManager): r"""Context-manager that disables gradient calculation. @@ -53,6 +60,7 @@ class no_grad(_DecoratorContextManager): >>> a.requires_grad True """ + def __init__(self) -> None: if not torch._jit_internal.is_scripting(): super().__init__() @@ -105,6 +113,7 @@ class enable_grad(_DecoratorContextManager): True """ + def __enter__(self) -> None: self.prev = torch.is_grad_enabled() torch._C._set_grad_enabled(True) @@ -213,6 +222,7 @@ class inference_mode(_DecoratorContextManager): False """ + def __init__(self, mode: bool = True) -> None: if not torch._jit_internal.is_scripting(): super().__init__() @@ -290,7 +300,9 @@ class _force_original_view_tracking(_DecoratorContextManager): self.mode = mode def __enter__(self) -> None: - self._force_original_view_tracking_context = torch._C._ViewReplayEnabled(self.mode) + self._force_original_view_tracking_context = torch._C._ViewReplayEnabled( + self.mode + ) self._force_original_view_tracking_context.__enter__() def __exit__(self, *args) -> None: @@ -299,6 +311,7 @@ class _force_original_view_tracking(_DecoratorContextManager): def clone(self): return self.__class__(self.mode) + class _unsafe_preserve_version_counter(_DecoratorContextManager): r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING! diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index bf605a561e34..954490d8141b 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -1,28 +1,40 @@ -import torch -from torch.types import _TensorOrTensors -import torch.testing -from torch.overrides import is_tensor_like import collections -from itertools import product -import warnings -from typing import Callable, Union, Optional, Iterable, List, Tuple, Dict -from torch._vmap_internals import vmap, _vmap import functools +import warnings +from itertools import product +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.testing +from torch._vmap_internals import _vmap, vmap +from torch.overrides import is_tensor_like +from torch.types import _TensorOrTensors # Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public # since they have been exposed from before we added `__all__` and we already maintain BC for them # We should eventually deprecate them and remove them from `__all__` -__all__ = ["gradcheck", "gradgradcheck", "GradcheckError", "get_numerical_jacobian", - "get_analytical_jacobian", "get_numerical_jacobian_wrt_specific_input"] +__all__ = [ + "gradcheck", + "gradgradcheck", + "GradcheckError", + "get_numerical_jacobian", + "get_analytical_jacobian", + "get_numerical_jacobian_wrt_specific_input", +] + class GradcheckError(RuntimeError): r"""Error raised by :func:`gradcheck` and :func:`gradgradcheck`""" pass - def _is_sparse_compressed_tensor(obj: torch.Tensor): - return obj.layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} + return obj.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } def _is_sparse_any_tensor(obj: torch.Tensor): @@ -33,7 +45,9 @@ def _is_float_or_complex_tensor(obj): return is_tensor_like(obj) and (obj.is_floating_point() or obj.is_complex()) -def _allocate_jacobians_with_inputs(input_tensors: Tuple, numel_output) -> Tuple[torch.Tensor, ...]: +def _allocate_jacobians_with_inputs( + input_tensors: Tuple, numel_output +) -> Tuple[torch.Tensor, ...]: # Makes zero-filled tensors from inputs. If `numel_output` is not None, for # each tensor in `input_tensors`, returns a new zero-filled tensor with height # of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns @@ -46,8 +60,9 @@ def _allocate_jacobians_with_inputs(input_tensors: Tuple, numel_output) -> Tuple return tuple(out) -def _allocate_jacobians_with_outputs(output_tensors: Tuple, numel_input, dtype=None, - device=None) -> Tuple[torch.Tensor, ...]: +def _allocate_jacobians_with_outputs( + output_tensors: Tuple, numel_input, dtype=None, device=None +) -> Tuple[torch.Tensor, ...]: # Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor # in `output_tensors`, returns a new zero-filled tensor with height of `dim` and # width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size @@ -60,8 +75,9 @@ def _allocate_jacobians_with_outputs(output_tensors: Tuple, numel_input, dtype=N return tuple(out) -def _iter_tensors(x: Union[torch.Tensor, Iterable[torch.Tensor]], - only_requiring_grad: bool = False) -> Iterable[torch.Tensor]: +def _iter_tensors( + x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False +) -> Iterable[torch.Tensor]: if is_tensor_like(x): # mypy doesn't narrow type of `x` to torch.Tensor if x.requires_grad or not only_requiring_grad: # type: ignore[union-attr] @@ -81,27 +97,52 @@ def _densify(x): elif x.layout is torch.sparse_coo: device = x.device indices_dtype = x._indices().dtype - tmp = torch.ones(x.shape[:x.sparse_dim()], dtype=torch.int8, device=device) + tmp = torch.ones(x.shape[: x.sparse_dim()], dtype=torch.int8, device=device) indices = tmp.nonzero().t().to(dtype=indices_dtype) - values = torch.zeros((tmp.numel(), *x.shape[x.sparse_dim():]), dtype=x.dtype, device=device) + values = torch.zeros( + (tmp.numel(), *x.shape[x.sparse_dim() :]), dtype=x.dtype, device=device + ) x_coalesced = x.detach().coalesce() if x_coalesced.numel() > 0: stride = tmp.stride() - flat_indices = x_coalesced.indices().mul( - torch.tensor(stride, dtype=indices_dtype, device=device).unsqueeze(1)).sum(0) + flat_indices = ( + x_coalesced.indices() + .mul( + torch.tensor(stride, dtype=indices_dtype, device=device).unsqueeze( + 1 + ) + ) + .sum(0) + ) values[flat_indices] = x_coalesced.values() - return torch.sparse_coo_tensor(indices, values, x.shape)._coalesced_(True).requires_grad_(x.requires_grad) + return ( + torch.sparse_coo_tensor(indices, values, x.shape) + ._coalesced_(True) + .requires_grad_(x.requires_grad) + ) elif _is_sparse_compressed_tensor(x): - blocksize = x.values().shape[1:3] if x.layout in {torch.sparse_bsr, torch.sparse_bsc} else None - compressed_indices = x.crow_indices() if x.layout in {torch.sparse_csr, torch.sparse_bsr} else x.ccol_indices() + blocksize = ( + x.values().shape[1:3] + if x.layout in {torch.sparse_bsr, torch.sparse_bsc} + else None + ) + compressed_indices = ( + x.crow_indices() + if x.layout in {torch.sparse_csr, torch.sparse_bsr} + else x.ccol_indices() + ) # We'll use intermediate sparse COO for simplicity - r = _densify(x.detach().to_sparse(layout=torch.sparse_coo)).to_sparse(layout=x.layout, blocksize=blocksize) + r = _densify(x.detach().to_sparse(layout=torch.sparse_coo)).to_sparse( + layout=x.layout, blocksize=blocksize + ) # Check that all elements are specified also after `to_sparse` op: dense_numel = r.values().numel() // max(1, r.values().shape[0]) batch_numel = compressed_indices.numel() // compressed_indices.shape[-1] sparse_numel = r.numel() // max(1, dense_numel * batch_numel) if sparse_numel != r._nnz(): - raise AssertionError(f'{x.layout} densify failed: expected nnz={sparse_numel} but got {r._nnz()}') + raise AssertionError( + f"{x.layout} densify failed: expected nnz={sparse_numel} but got {r._nnz()}" + ) return r.requires_grad_(x.requires_grad) elif _is_sparse_any_tensor(x): raise NotImplementedError(x.layout) @@ -126,6 +167,7 @@ def _iter_tensor(x_tensor): # where x is the t.data of the original tensor. Perturbing the entry of x # at index (1, 1) yields the 3rd column of the overall Jacobian matrix. if _is_sparse_any_tensor(x_tensor): + def get_stride(size): dim = len(size) tmp = 1 @@ -134,37 +176,60 @@ def _iter_tensor(x_tensor): stride[i] = tmp tmp *= size[i] return stride + x_nnz = x_tensor._nnz() x_size = list(x_tensor.size()) if x_tensor.layout is torch.sparse_coo: x_indices = x_tensor._indices().t() x_values = x_tensor._values() elif x_tensor.layout is torch.sparse_csr: - x_indices = torch._convert_indices_from_csr_to_coo(x_tensor.crow_indices(), x_tensor.col_indices()).t() + x_indices = torch._convert_indices_from_csr_to_coo( + x_tensor.crow_indices(), x_tensor.col_indices() + ).t() x_values = x_tensor.values() elif x_tensor.layout is torch.sparse_csc: - x_indices = torch._convert_indices_from_csr_to_coo(x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True).t() + x_indices = torch._convert_indices_from_csr_to_coo( + x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True + ).t() x_values = x_tensor.values() elif x_tensor.layout is torch.sparse_bsr: x_block_values = x_tensor.values() x_blocksize = x_block_values.size()[1:3] - x_indices = torch._convert_indices_from_csr_to_coo(x_tensor.crow_indices(), x_tensor.col_indices()) \ - .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) \ - .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) \ - .add_(torch.stack(torch.where(torch.ones(x_blocksize, device=x_tensor.device))).repeat(1, x_nnz)).t() + x_indices = ( + torch._convert_indices_from_csr_to_coo( + x_tensor.crow_indices(), x_tensor.col_indices() + ) + .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) + .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) + .add_( + torch.stack( + torch.where(torch.ones(x_blocksize, device=x_tensor.device)) + ).repeat(1, x_nnz) + ) + .t() + ) x_values = x_block_values.flatten(0, 2) x_nnz = x_values.size(0) elif x_tensor.layout is torch.sparse_bsc: x_block_values = x_tensor.values() x_blocksize = x_block_values.size()[1:3] - x_indices = torch._convert_indices_from_csr_to_coo(x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True) \ - .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) \ - .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) \ - .add_(torch.stack(torch.where(torch.ones(x_blocksize, device=x_tensor.device))).repeat(1, x_nnz)).t() + x_indices = ( + torch._convert_indices_from_csr_to_coo( + x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True + ) + .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) + .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) + .add_( + torch.stack( + torch.where(torch.ones(x_blocksize, device=x_tensor.device)) + ).repeat(1, x_nnz) + ) + .t() + ) x_values = x_block_values.flatten(0, 2) x_nnz = x_values.size(0) else: - raise NotImplementedError(f'_iter_tensor for {x_tensor.layout} input') + raise NotImplementedError(f"_iter_tensor for {x_tensor.layout} input") x_stride = get_stride(x_size) # Use .data here to get around the version check x_values = x_values.data @@ -187,8 +252,9 @@ def _iter_tensor(x_tensor): yield x_tensor, x_idx, d_idx -def _get_numerical_jacobian(fn, inputs, outputs=None, target=None, eps=1e-3, - is_forward_ad=False) -> List[Tuple[torch.Tensor, ...]]: +def _get_numerical_jacobian( + fn, inputs, outputs=None, target=None, eps=1e-3, is_forward_ad=False +) -> List[Tuple[torch.Tensor, ...]]: """Computes the numerical Jacobian of `fn(inputs)` with respect to `target`. If not specified, targets are the input. Returns M * N Jacobians where N is the number of tensors in target that require grad and M is the number of non-integral @@ -214,14 +280,27 @@ def _get_numerical_jacobian(fn, inputs, outputs=None, target=None, eps=1e-3, if outputs is None: outputs = _as_tuple(fn(*_as_tuple(inputs))) if not is_forward_ad and any(o.is_complex() for o in outputs): - raise ValueError("Expected output to be non-complex. get_numerical_jacobian no " - "longer supports functions that return complex outputs.") + raise ValueError( + "Expected output to be non-complex. get_numerical_jacobian no " + "longer supports functions that return complex outputs." + ) if target is None: target = inputs - inp_indices = [i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad] + inp_indices = [ + i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad + ] for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)): - jacobians += [get_numerical_jacobian_wrt_specific_input(fn, inp_idx, inputs, outputs, eps, - input=inp, is_forward_ad=is_forward_ad)] + jacobians += [ + get_numerical_jacobian_wrt_specific_input( + fn, + inp_idx, + inputs, + outputs, + eps, + input=inp, + is_forward_ad=is_forward_ad, + ) + ] return jacobians @@ -242,17 +321,24 @@ def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0): Note that `target` may not even be part of `input` to `fn`, so please be **very careful** in this to not clone `target`. """ - warnings.warn("get_numerical_jacobian was part of PyTorch's private API and not " - "meant to be exposed. We are deprecating it and it will be removed " - "in a future version of PyTorch. If you have a specific use for " - "this or feature request for this to be a stable API, please file " - "us an issue at https://github.com/pytorch/pytorch/issues/new") - if grad_out != 1.0: # grad_out param is only kept for backward compatibility reasons - raise ValueError("Expected grad_out to be 1.0. get_numerical_jacobian no longer " - "supports values of grad_out != 1.0.") + warnings.warn( + "get_numerical_jacobian was part of PyTorch's private API and not " + "meant to be exposed. We are deprecating it and it will be removed " + "in a future version of PyTorch. If you have a specific use for " + "this or feature request for this to be a stable API, please file " + "us an issue at https://github.com/pytorch/pytorch/issues/new" + ) + if ( + grad_out != 1.0 + ): # grad_out param is only kept for backward compatibility reasons + raise ValueError( + "Expected grad_out to be 1.0. get_numerical_jacobian no longer " + "supports values of grad_out != 1.0." + ) def fn_pack_inps(*inps): return fn(inps) + jacobians = _get_numerical_jacobian(fn_pack_inps, inputs, None, target, eps) return tuple(jacobian_for_each_output[0] for jacobian_for_each_output in jacobians) @@ -289,8 +375,9 @@ def _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn): return tuple(compute(a, b) for (a, b) in zip(outa, outb)) -def _compute_numerical_jvps_wrt_specific_input(jvp_fn, delta, input_is_complex, - is_forward_ad=False) -> List[torch.Tensor]: +def _compute_numerical_jvps_wrt_specific_input( + jvp_fn, delta, input_is_complex, is_forward_ad=False +) -> List[torch.Tensor]: # Computing the jacobian only works for real delta # For details on the algorithm used here, refer: # Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf @@ -300,32 +387,38 @@ def _compute_numerical_jvps_wrt_specific_input(jvp_fn, delta, input_is_complex, ds_dx_tup = jvp_fn(delta[0] if isinstance(delta, tuple) else delta) if input_is_complex: # C -> R - ds_dy_tup = jvp_fn(delta[1] * 1j) if isinstance(delta, tuple) else jvp_fn(delta * 1j) + ds_dy_tup = ( + jvp_fn(delta[1] * 1j) if isinstance(delta, tuple) else jvp_fn(delta * 1j) + ) for ds_dx, ds_dy in zip(ds_dx_tup, ds_dy_tup): - assert(not ds_dx.is_complex()) + assert not ds_dx.is_complex() # conjugate wirtinger derivative conj_w_d = ds_dx + ds_dy * 1j jvps.append(conj_w_d) else: for ds_dx in ds_dx_tup: # R -> R or (R -> C for the forward AD case) - assert(is_forward_ad or not ds_dx.is_complex()) + assert is_forward_ad or not ds_dx.is_complex() jvps.append(ds_dx) return jvps -def _combine_jacobian_cols(jacobians_cols: Dict[int, List[torch.Tensor]], outputs, input, - numel) -> Tuple[torch.Tensor, ...]: +def _combine_jacobian_cols( + jacobians_cols: Dict[int, List[torch.Tensor]], outputs, input, numel +) -> Tuple[torch.Tensor, ...]: # jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor # we return a list that maps output_idx -> full jacobian Tensor - jacobians = _allocate_jacobians_with_outputs(outputs, numel, dtype=input.dtype if input.dtype.is_complex else None) + jacobians = _allocate_jacobians_with_outputs( + outputs, numel, dtype=input.dtype if input.dtype.is_complex else None + ) for i, jacobian in enumerate(jacobians): for k, v in jacobians_cols.items(): jacobian[k] = v[i] return jacobians -def _prepare_input(input: torch.Tensor, maybe_perturbed_input: Optional[torch.Tensor], - fast_mode=False) -> torch.Tensor: +def _prepare_input( + input: torch.Tensor, maybe_perturbed_input: Optional[torch.Tensor], fast_mode=False +) -> torch.Tensor: # Prepares the inputs to be passed into the function while including the new # modified input. if input.layout == torch._mkldnn: # type: ignore[attr-defined] # no attr _mkldnn @@ -351,18 +444,21 @@ def _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None # Check that the returned outputs don't have different dtype or shape when you # perturb the input on_index = "on index {idx} " if idx is not None else "" - assert output1.shape == output2.shape, \ - (f"Expected `func` to return outputs with the same shape" - f" when inputs are perturbed {on_index}by {eps}, but got:" - f" shapes {output1.shape} and {output2.shape}.") - assert output1.dtype == output2.dtype, \ - (f"Expected `func` to return outputs with the same dtype" - f" when inputs are perturbed {on_index}by {eps}, but got:" - f" dtypes {output1.dtype} and {output2.dtype}.") + assert output1.shape == output2.shape, ( + f"Expected `func` to return outputs with the same shape" + f" when inputs are perturbed {on_index}by {eps}, but got:" + f" shapes {output1.shape} and {output2.shape}." + ) + assert output1.dtype == output2.dtype, ( + f"Expected `func` to return outputs with the same dtype" + f" when inputs are perturbed {on_index}by {eps}, but got:" + f" dtypes {output1.dtype} and {output2.dtype}." + ) -def get_numerical_jacobian_wrt_specific_input(fn, input_idx, inputs, outputs, eps, - input=None, is_forward_ad=False) -> Tuple[torch.Tensor, ...]: +def get_numerical_jacobian_wrt_specific_input( + fn, input_idx, inputs, outputs, eps, input=None, is_forward_ad=False +) -> Tuple[torch.Tensor, ...]: # Computes the numerical jacobians wrt to a single input. Returns N jacobian # tensors, where N is the number of outputs. We use a dictionary for # jacobian_cols because indices aren't necessarily consecutive for sparse inputs @@ -374,13 +470,21 @@ def get_numerical_jacobian_wrt_specific_input(fn, input_idx, inputs, outputs, ep for x, idx, d_idx in _iter_tensor(input): wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, x) input_to_perturb = x[idx] - nbhd_checks_fn = functools.partial(_check_outputs_same_dtype_and_shape, idx=idx, eps=eps) - jvp_fn = _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn) - jacobian_cols[d_idx] = _compute_numerical_jvps_wrt_specific_input(jvp_fn, eps, x.is_complex(), is_forward_ad) + nbhd_checks_fn = functools.partial( + _check_outputs_same_dtype_and_shape, idx=idx, eps=eps + ) + jvp_fn = _get_numerical_jvp_fn( + wrapped_fn, input_to_perturb, eps, nbhd_checks_fn + ) + jacobian_cols[d_idx] = _compute_numerical_jvps_wrt_specific_input( + jvp_fn, eps, x.is_complex(), is_forward_ad + ) return _combine_jacobian_cols(jacobian_cols, outputs, input, input.numel()) -def _get_analytical_jacobian_forward_ad(fn, inputs, outputs, *, check_grad_dtypes=False, - all_u=None) -> Tuple[Tuple[torch.Tensor, ...], ...]: + +def _get_analytical_jacobian_forward_ad( + fn, inputs, outputs, *, check_grad_dtypes=False, all_u=None +) -> Tuple[Tuple[torch.Tensor, ...], ...]: """Computes the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect to `target`. Returns N * M Jacobians where N is the number of tensors in target that require grad and M is the number of non-integral outputs. @@ -404,12 +508,18 @@ def _get_analytical_jacobian_forward_ad(fn, inputs, outputs, *, check_grad_dtype tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad) if any(i.is_complex() for i in tensor_inputs): - raise ValueError("Expected inputs to be non-complex for _get_analytical_jacobian_forward_ad.") + raise ValueError( + "Expected inputs to be non-complex for _get_analytical_jacobian_forward_ad." + ) if all_u: - jacobians = tuple(_allocate_jacobians_with_outputs(outputs, 1) for i in tensor_inputs) + jacobians = tuple( + _allocate_jacobians_with_outputs(outputs, 1) for i in tensor_inputs + ) else: - jacobians = tuple(_allocate_jacobians_with_outputs(outputs, i.numel()) for i in tensor_inputs) + jacobians = tuple( + _allocate_jacobians_with_outputs(outputs, i.numel()) for i in tensor_inputs + ) with fwAD.dual_level(): fw_grads = [] @@ -417,7 +527,9 @@ def _get_analytical_jacobian_forward_ad(fn, inputs, outputs, *, check_grad_dtype for i, inp in enumerate(inputs): if is_tensor_like(inp) and inp.requires_grad: if inp.layout == torch._mkldnn: # type: ignore[attr-defined] - raise ValueError("MKLDNN inputs are not support for forward AD gradcheck.") + raise ValueError( + "MKLDNN inputs are not support for forward AD gradcheck." + ) inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) # If inp is a differentiable view, the dual might not be the tangent given to @@ -434,8 +546,12 @@ def _get_analytical_jacobian_forward_ad(fn, inputs, outputs, *, check_grad_dtype dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs) for index_o, d_o in enumerate(dual_outputs): val, res = fwAD.unpack_dual(d_o) - if check_grad_dtypes and res is not None and val.is_complex() != res.is_complex(): - raise GradcheckError('Forward AD gradient has dtype mismatch.') + if ( + check_grad_dtypes + and res is not None + and val.is_complex() != res.is_complex() + ): + raise GradcheckError("Forward AD gradient has dtype mismatch.") # Remove extra dimension of size 1 corresponding to the reduced input jacobians[i][index_o].squeeze_(0) @@ -447,23 +563,32 @@ def _get_analytical_jacobian_forward_ad(fn, inputs, outputs, *, check_grad_dtype else: # Reconstruct the full Jacobian column by column for i, fw_grad in enumerate(fw_grads): - for lin_idx, grad_idx in enumerate(product(*[range(m) for m in fw_grad.size()])): - fw_grad[grad_idx] = 1. + for lin_idx, grad_idx in enumerate( + product(*[range(m) for m in fw_grad.size()]) + ): + fw_grad[grad_idx] = 1.0 raw_outputs = _as_tuple(fn(*dual_inputs)) dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs) for index_o, d_o in enumerate(dual_outputs): val, res = fwAD.unpack_dual(d_o) - if check_grad_dtypes and res is not None and val.is_complex() != res.is_complex(): - raise GradcheckError('Forward AD gradient has dtype mismatch.') + if ( + check_grad_dtypes + and res is not None + and val.is_complex() != res.is_complex() + ): + raise GradcheckError( + "Forward AD gradient has dtype mismatch." + ) if res is None: jacobians[i][index_o][lin_idx].zero_() else: jacobians[i][index_o][lin_idx].copy_(res.reshape(-1)) - fw_grad[grad_idx] = 0. + fw_grad[grad_idx] = 0.0 return jacobians + def _get_input_to_perturb(input): # Prepare the input so that it can be modified in-place and do certain # operations that require the tensor to have strides. If fast_mode=False, @@ -483,16 +608,24 @@ def _get_input_to_perturb(input): def _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, fast_mode=False): # Wraps `fn` so that its inputs are already supplied def wrapped_fn(): - inp = tuple(_prepare_input(a, input_to_perturb if i == input_idx else None, fast_mode) - if is_tensor_like(a) else a for i, a in enumerate(_as_tuple(inputs))) + inp = tuple( + _prepare_input(a, input_to_perturb if i == input_idx else None, fast_mode) + if is_tensor_like(a) + else a + for i, a in enumerate(_as_tuple(inputs)) + ) return tuple(a.clone() for a in _as_tuple(fn(*inp))) + return wrapped_fn def _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn): # Wraps jvp_fn so that certain arguments are already supplied def jvp_fn(delta): - return _compute_numerical_gradient(wrapped_fn, input_to_perturb, delta, eps, nbhd_checks_fn) + return _compute_numerical_gradient( + wrapped_fn, input_to_perturb, delta, eps, nbhd_checks_fn + ) + return jvp_fn @@ -514,7 +647,9 @@ def _mul_tensor_or_tuple(u, k): return k * u -def _get_numerical_jvp_wrt_specific_input(fn, input_idx, inputs, u, eps, is_forward_ad=False) -> List[torch.Tensor]: +def _get_numerical_jvp_wrt_specific_input( + fn, input_idx, inputs, u, eps, is_forward_ad=False +) -> List[torch.Tensor]: input = inputs[input_idx] input_to_perturb = _get_input_to_perturb(input) wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, True) @@ -522,14 +657,20 @@ def _get_numerical_jvp_wrt_specific_input(fn, input_idx, inputs, u, eps, is_forw jvp_fn = _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn) u = _reshape_tensor_or_tuple(u, input_to_perturb.shape) u = _mul_tensor_or_tuple(u, eps) - return _compute_numerical_jvps_wrt_specific_input(jvp_fn, u, input.is_complex(), is_forward_ad) + return _compute_numerical_jvps_wrt_specific_input( + jvp_fn, u, input.is_complex(), is_forward_ad + ) -def _get_numerical_vJu(fn, inputs, inp_indices, func_out, all_u, all_v, eps, is_forward_ad): +def _get_numerical_vJu( + fn, inputs, inp_indices, func_out, all_u, all_v, eps, is_forward_ad +): # Note that all_v can also be None, in that case, this function only computes Ju. reduced_jacobians: List[List[torch.Tensor]] = [] for i, (inp_idx, u) in enumerate(zip(inp_indices, all_u)): - all_Ju = _get_numerical_jvp_wrt_specific_input(fn, inp_idx, inputs, u, eps, is_forward_ad) + all_Ju = _get_numerical_jvp_wrt_specific_input( + fn, inp_idx, inputs, u, eps, is_forward_ad + ) # Filter out the Ju for non floating point outputs filtered_Ju = [] func_out = _as_tuple(func_out) @@ -559,8 +700,9 @@ def _check_jacobians_equal(j1, j2, atol): return True -def _stack_and_check_tensors(list_of_list_of_tensors, inputs, - numel_outputs) -> Tuple[Tuple[torch.Tensor, ...], bool, bool]: +def _stack_and_check_tensors( + list_of_list_of_tensors, inputs, numel_outputs +) -> Tuple[Tuple[torch.Tensor, ...], bool, bool]: # For the ith tensor in the inner list checks whether it has the same size and # dtype as the ith differentiable input. out_jacobians = _allocate_jacobians_with_inputs(inputs, numel_outputs) @@ -578,7 +720,9 @@ def _stack_and_check_tensors(list_of_list_of_tensors, inputs, if tensor is None: out_jacobian[:, j].zero_() else: - dense = tensor.to_dense() if not tensor.layout == torch.strided else tensor + dense = ( + tensor.to_dense() if not tensor.layout == torch.strided else tensor + ) assert out_jacobian[:, j].numel() == dense.numel() out_jacobian[:, j] = dense.reshape(-1) return out_jacobians, correct_grad_sizes, correct_grad_types @@ -601,8 +745,9 @@ If the test """ -def _check_analytical_jacobian_attributes(inputs, output, nondet_tol, check_grad_dtypes, - fast_mode=False, v=None) -> Tuple[torch.Tensor, ...]: +def _check_analytical_jacobian_attributes( + inputs, output, nondet_tol, check_grad_dtypes, fast_mode=False, v=None +) -> Tuple[torch.Tensor, ...]: # This is used by both fast and slow mode: # - For slow mode, vjps[i][j] is the jth row the Jacobian wrt the ith # input. @@ -611,8 +756,10 @@ def _check_analytical_jacobian_attributes(inputs, output, nondet_tol, check_grad diff_input_list = list(_iter_tensors(inputs, True)) def vjp_fn(grad_output): - return torch.autograd.grad(output, diff_input_list, grad_output, - retain_graph=True, allow_unused=True) + return torch.autograd.grad( + output, diff_input_list, grad_output, retain_graph=True, allow_unused=True + ) + # Compute everything twice to check for nondeterminism (which we call reentrancy) if fast_mode: vjps1 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v) @@ -622,28 +769,34 @@ def _check_analytical_jacobian_attributes(inputs, output, nondet_tol, check_grad vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) output_numel = output.numel() if not fast_mode else 1 - jacobians1, types_ok, sizes_ok = _stack_and_check_tensors(vjps1, inputs, output_numel) + jacobians1, types_ok, sizes_ok = _stack_and_check_tensors( + vjps1, inputs, output_numel + ) jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel) reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol) if not types_ok and check_grad_dtypes: - raise GradcheckError('Gradient has dtype mismatch') + raise GradcheckError("Gradient has dtype mismatch") if not sizes_ok: - raise GradcheckError('Analytical gradient has incorrect size') + raise GradcheckError("Analytical gradient has incorrect size") if not reentrant: - raise GradcheckError('Backward is not reentrant, i.e., running backward with ' - 'same input and grad_output multiple times gives different values, ' - 'although analytical gradient matches numerical gradient.' - f'The tolerance for nondeterminism was {nondet_tol}.' + - FAILED_NONDET_MSG) + raise GradcheckError( + "Backward is not reentrant, i.e., running backward with " + "same input and grad_output multiple times gives different values, " + "although analytical gradient matches numerical gradient." + f"The tolerance for nondeterminism was {nondet_tol}." + FAILED_NONDET_MSG + ) return jacobians1 -def _get_analytical_vJu_backward_mode(inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u): +def _get_analytical_vJu_backward_mode( + inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u +): reduced_jacobians: List[List[torch.Tensor]] = [] for output, v in zip(outputs, all_v): - all_vJ = _check_analytical_jacobian_attributes(inputs, output, nondet_tol, check_grad_dtypes, - fast_mode=True, v=v) + all_vJ = _check_analytical_jacobian_attributes( + inputs, output, nondet_tol, check_grad_dtypes, fast_mode=True, v=v + ) jacobian_scalars: List[torch.Tensor] = [] for vJ, u in zip(all_vJ, all_u): # Why do we need squeeze here? vJ is a 2-d tensor so that we can reuse @@ -659,31 +812,44 @@ def _get_analytical_vJu_backward_mode(inputs, outputs, nondet_tol, check_grad_dt reduced_jacobians.append(jacobian_scalars) return reduced_jacobians + def get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0): # Replicates the behavior of the old get_analytical_jacobian before the refactor # This shares much of its code with _check_analytical_jacobian_attributes - warnings.warn("get_analytical_jacobian was part of PyTorch's private API and not " - "meant to be exposed. We are deprecating it and it will be removed " - "in a future version of PyTorch. If you have a specific use for " - "this or feature request for this to be a stable API, please file " - "us an issue at https://github.com/pytorch/pytorch/issues/new") - if grad_out != 1.0: # grad_out param is only kept for backward compatibility reasons - raise ValueError("Expected grad_out to be 1.0. get_analytical_jacobian no longer " - "supports values of grad_out != 1.0.") + warnings.warn( + "get_analytical_jacobian was part of PyTorch's private API and not " + "meant to be exposed. We are deprecating it and it will be removed " + "in a future version of PyTorch. If you have a specific use for " + "this or feature request for this to be a stable API, please file " + "us an issue at https://github.com/pytorch/pytorch/issues/new" + ) + if ( + grad_out != 1.0 + ): # grad_out param is only kept for backward compatibility reasons + raise ValueError( + "Expected grad_out to be 1.0. get_analytical_jacobian no longer " + "supports values of grad_out != 1.0." + ) if output.is_complex(): - raise ValueError("Expected output to be non-complex. get_analytical_jacobian no " - "longer supports functions that return complex outputs.") + raise ValueError( + "Expected output to be non-complex. get_analytical_jacobian no " + "longer supports functions that return complex outputs." + ) diff_input_list = list(_iter_tensors(inputs, True)) def vjp_fn(grad_output): - return torch.autograd.grad(output, diff_input_list, grad_output, - retain_graph=True, allow_unused=True) + return torch.autograd.grad( + output, diff_input_list, grad_output, retain_graph=True, allow_unused=True + ) + # Compute everything twice to check for nondeterminism (which we call reentrancy) vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) output_numel = output.numel() - jacobians1, types_ok, sizes_ok = _stack_and_check_tensors(vjps1, inputs, output_numel) + jacobians1, types_ok, sizes_ok = _stack_and_check_tensors( + vjps1, inputs, output_numel + ) jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel) reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol) @@ -693,17 +859,22 @@ def get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0): def _get_analytical_jacobian(inputs, outputs, input_idx, output_idx): # Computes the analytical Jacobian in slow mode for a single input-output pair. # Forgoes performing checks on dtype, shape, and reentrancy. - jacobians = _check_analytical_jacobian_attributes(inputs, outputs[output_idx], - nondet_tol=float('inf'), check_grad_dtypes=False) + jacobians = _check_analytical_jacobian_attributes( + inputs, outputs[output_idx], nondet_tol=float("inf"), check_grad_dtypes=False + ) return jacobians[input_idx] -def _compute_analytical_jacobian_rows(vjp_fn, sample_output) -> List[List[Optional[torch.Tensor]]]: +def _compute_analytical_jacobian_rows( + vjp_fn, sample_output +) -> List[List[Optional[torch.Tensor]]]: # Computes Jacobian row-by-row using backward function `vjp_fn` = v^T J # NB: this function does not assume vjp_fn(v) to return tensors with the same # number of elements for different v. This is checked when we later combine the # rows into a single tensor. - grad_out_base = torch.zeros_like(sample_output, memory_format=torch.legacy_contiguous_format) + grad_out_base = torch.zeros_like( + sample_output, memory_format=torch.legacy_contiguous_format + ) flat_grad_out = grad_out_base.view(-1) # jacobians_rows[i][j] represents the jth row of the ith input jacobians_rows: List[List[Optional[torch.Tensor]]] = [] @@ -714,11 +885,15 @@ def _compute_analytical_jacobian_rows(vjp_fn, sample_output) -> List[List[Option for i, d_x in enumerate(grad_inputs): if j == 0: jacobians_rows.append([]) - jacobians_rows[i] += [d_x.clone() if isinstance(d_x, torch.Tensor) else None] + jacobians_rows[i] += [ + d_x.clone() if isinstance(d_x, torch.Tensor) else None + ] return jacobians_rows -def _get_analytical_vjps_wrt_specific_output(vjp_fn, sample_output, v) -> List[List[Optional[torch.Tensor]]]: +def _get_analytical_vjps_wrt_specific_output( + vjp_fn, sample_output, v +) -> List[List[Optional[torch.Tensor]]]: vjps: List[List[Optional[torch.Tensor]]] = [] grad_inputs = vjp_fn(v.reshape(sample_output.shape)) for vjp in grad_inputs: @@ -733,10 +908,11 @@ def _check_inputs(tupled_inputs) -> bool: if is_tensor_like(inp) and inp.requires_grad: if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128): warnings.warn( - f'Input #{idx} requires gradient and ' - 'is not a double precision floating point or complex. ' - 'This check will likely fail if all the inputs are ' - 'not of double precision floating point or complex. ') + f"Input #{idx} requires gradient and " + "is not a double precision floating point or complex. " + "This check will likely fail if all the inputs are " + "not of double precision floating point or complex. " + ) if inp.is_sparse: content = inp._values() elif _is_sparse_compressed_tensor(inp): @@ -746,18 +922,23 @@ def _check_inputs(tupled_inputs) -> bool: # TODO: To cover more problematic cases, replace stride = 0 check with # "any overlap in memory" once we have a proper function to check it. if content.layout is not torch._mkldnn: # type: ignore[attr-defined] - if not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())): + if not all( + st > 0 or sz <= 1 + for st, sz in zip(content.stride(), content.size()) + ): raise RuntimeError( - f'The {idx}th input has a dimension with stride 0. gradcheck only ' - 'supports inputs that are non-overlapping to be able to ' - 'compute the numerical gradients correctly. You should call ' - '.contiguous on the input before passing it to gradcheck.') + f"The {idx}th input has a dimension with stride 0. gradcheck only " + "supports inputs that are non-overlapping to be able to " + "compute the numerical gradients correctly. You should call " + ".contiguous on the input before passing it to gradcheck." + ) any_input_requiring_grad = True if not any_input_requiring_grad: raise ValueError( - 'gradcheck expects at least one input tensor to require gradient, ' - 'but none of the them have requires_grad=True.') + "gradcheck expects at least one input tensor to require gradient, " + "but none of the them have requires_grad=True." + ) return True @@ -765,34 +946,46 @@ def _check_outputs(outputs) -> None: if any(_is_sparse_any_tensor(t) for t in outputs if isinstance(t, torch.Tensor)): # it is easier to call to_dense() on the sparse output than # to modify analytical jacobian - raise ValueError('Sparse output is not supported at gradcheck yet. ' - 'Please call to_dense(masked_grad=...) on the output of fn for gradcheck.') + raise ValueError( + "Sparse output is not supported at gradcheck yet. " + "Please call to_dense(masked_grad=...) on the output of fn for gradcheck." + ) if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)): # type: ignore[attr-defined] - raise ValueError('MKLDNN output is not supported at gradcheck yet. ' - 'Please call to_dense(masked_grad=...) on the output of fn for gradcheck.') + raise ValueError( + "MKLDNN output is not supported at gradcheck yet. " + "Please call to_dense(masked_grad=...) on the output of fn for gradcheck." + ) -def _check_no_differentiable_outputs(func, inputs, func_out, eps, *, is_forward_ad) -> bool: +def _check_no_differentiable_outputs( + func, inputs, func_out, eps, *, is_forward_ad +) -> bool: # When there are no differentiable outputs, numerical gradient for a function is # expected to be zero. - jacobians_all_inputs_outputs = _get_numerical_jacobian(func, inputs, func_out, - eps=eps, is_forward_ad=is_forward_ad) + jacobians_all_inputs_outputs = _get_numerical_jacobian( + func, inputs, func_out, eps=eps, is_forward_ad=is_forward_ad + ) for jacobians_all_outputs_and_fixed_input in jacobians_all_inputs_outputs: for jacobian in jacobians_all_outputs_and_fixed_input: if torch.ne(jacobian, 0).sum() > 0: - raise GradcheckError('Numerical gradient for function expected to be zero') + raise GradcheckError( + "Numerical gradient for function expected to be zero" + ) return True -def _check_no_differentiable_outputs_fast(func, func_out, all_inputs, inputs_indices, - all_u, eps, nondet_tol): +def _check_no_differentiable_outputs_fast( + func, func_out, all_inputs, inputs_indices, all_u, eps, nondet_tol +): for inp_idx, u in zip(inputs_indices, all_u): jvps = _get_numerical_jvp_wrt_specific_input(func, inp_idx, all_inputs, u, eps) for jvp in jvps: if jvp.numel() == 0: continue if (jvp - torch.zeros_like(jvp)).abs().max() > nondet_tol: - raise GradcheckError('Numerical gradient for function expected to be zero') + raise GradcheckError( + "Numerical gradient for function expected to be zero" + ) return True @@ -836,7 +1029,10 @@ If the test to have `check_batched_forward_grad=False` """ -def _get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp, is_forward_ad=False): + +def _get_failed_batched_grad_test_msg( + output_idx, input_idx, res, exp, is_forward_ad=False +): return f""" For output {output_idx} and input {input_idx}: @@ -849,8 +1045,9 @@ Expected: {exp} """.strip() + def _test_batched_grad_forward_ad(func, inputs) -> bool: - fwAD = torch.autograd.forward_ad # To avoid early import issues (do we need this?) + fwAD = torch.autograd.forward_ad # To avoid early import issues (do we need this?) assert isinstance(inputs, tuple) for input_idx, current_input in enumerate(inputs): @@ -860,8 +1057,12 @@ def _test_batched_grad_forward_ad(func, inputs) -> bool: def jvp(tangent: torch.Tensor): with fwAD.dual_level(): dual = fwAD.make_dual(current_input.detach(), tangent) - inputs_with_dual = tuple(dual if idx == input_idx else (inp.detach() if is_tensor_like(inp) else inp) - for idx, inp in enumerate(inputs)) + inputs_with_dual = tuple( + dual + if idx == input_idx + else (inp.detach() if is_tensor_like(inp) else inp) + for idx, inp in enumerate(inputs) + ) dual_outputs = _as_tuple(func(*inputs_with_dual)) ret = [] for dual_output in dual_outputs: @@ -871,8 +1072,13 @@ def _test_batched_grad_forward_ad(func, inputs) -> bool: if tangent_out is not None: ret.append(tangent_out) else: - ret.append(torch.zeros([], dtype=primal_out.dtype, device=primal_out.device).expand(primal_out.shape)) + ret.append( + torch.zeros( + [], dtype=primal_out.dtype, device=primal_out.device + ).expand(primal_out.shape) + ) return tuple(ret) + if not _is_float_or_complex_tensor(current_input): continue @@ -885,14 +1091,20 @@ def _test_batched_grad_forward_ad(func, inputs) -> bool: except RuntimeError as ex: # Rethrow to provide a better error message raise GradcheckError( - f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}') from ex + f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}" + ) from ex for input_idx, (res, exp) in enumerate(zip(result, expected)): if torch.allclose(res, exp): continue - raise GradcheckError(_get_failed_batched_grad_test_msg(input_idx, input_idx, res, exp, is_forward_ad=True)) + raise GradcheckError( + _get_failed_batched_grad_test_msg( + input_idx, input_idx, res, exp, is_forward_ad=True + ) + ) return True + def _test_batched_grad(input, output, output_idx) -> bool: # NB: _test_batched_grad compares two autograd.grad invocations with a single # vmap(autograd.grad) invocation. It's not exactly a "gradcheck" in the @@ -900,13 +1112,22 @@ def _test_batched_grad(input, output, output_idx) -> bool: # but it is morally similar (we could have computed a full analytic jac # via vmap, but that is potentially slow) diff_input_list = list(_iter_tensors(input, True)) - grad = functools.partial(torch.autograd.grad, output, diff_input_list, retain_graph=True, allow_unused=True) + grad = functools.partial( + torch.autograd.grad, + output, + diff_input_list, + retain_graph=True, + allow_unused=True, + ) def vjp(v): results = grad(v) - results = tuple(grad if grad is not None else - torch.zeros([], dtype=inp.dtype, device=inp.device).expand(inp.shape) - for grad, inp in zip(results, diff_input_list)) + results = tuple( + grad + if grad is not None + else torch.zeros([], dtype=inp.dtype, device=inp.device).expand(inp.shape) + for grad, inp in zip(results, diff_input_list) + ) return results grad_outputs = [torch.randn_like(output) for _ in range(2)] @@ -927,12 +1148,15 @@ def _test_batched_grad(input, output, output_idx) -> bool: # autograd.grad instead of the C++ traceback of what line in the # backward formula raise GradcheckError( - f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}') from ex + f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}" + ) from ex for input_idx, (res, exp) in enumerate(zip(result, expected)): if torch.allclose(res, exp): continue - raise GradcheckError(_get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp)) + raise GradcheckError( + _get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp) + ) return True @@ -941,38 +1165,55 @@ def _test_backward_mul_by_grad_output(outputs, inputs, masked) -> bool: diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True)) if not diff_input_list: raise GradcheckError("no Tensors requiring grad found in input") - grads_input = torch.autograd.grad(outputs, diff_input_list, - [torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in outputs], - allow_unused=True) + grads_input = torch.autograd.grad( + outputs, + diff_input_list, + [ + torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) + for o in outputs + ], + allow_unused=True, + ) for gi, di in zip(grads_input, diff_input_list): if gi is None: continue if isinstance(gi, torch.Tensor) and gi.layout != torch.strided: if gi.layout != di.layout: - raise GradcheckError('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(di.layout) + ')') + raise GradcheckError( + "grad is incorrect layout (" + + str(gi.layout) + + " is not " + + str(di.layout) + + ")" + ) if _is_sparse_any_tensor(gi): - sparse_kind = str(gi.layout).replace('torch.', '').replace('_coo', '') + sparse_kind = str(gi.layout).replace("torch.", "").replace("_coo", "") if gi.sparse_dim() != di.sparse_dim(): - raise GradcheckError(f'grad is {sparse_kind} tensor, but has incorrect sparse_dim' - f' {gi.sparse_dim()}, expected {di.sparse_dim()}') + raise GradcheckError( + f"grad is {sparse_kind} tensor, but has incorrect sparse_dim" + f" {gi.sparse_dim()}, expected {di.sparse_dim()}" + ) if gi.dense_dim() != di.dense_dim(): - raise GradcheckError(f'grad is {sparse_kind} tensor, but has incorrect dense_dim' - f' {gi.dense_dim()}, expected {di.dense_dim()}') + raise GradcheckError( + f"grad is {sparse_kind} tensor, but has incorrect dense_dim" + f" {gi.dense_dim()}, expected {di.dense_dim()}" + ) gi = gi.to_dense() di = di.to_dense() if masked: if not torch.allclose(gi, torch.zeros_like(gi)): - raise GradcheckError('backward not multiplied by grad_output') + raise GradcheckError("backward not multiplied by grad_output") elif not gi.eq(0).all(): - raise GradcheckError('backward not multiplied by grad_output') + raise GradcheckError("backward not multiplied by grad_output") if gi.dtype != di.dtype: raise GradcheckError("grad is incorrect type") if gi.device != di.device: raise GradcheckError("grad is incorrect device") if gi.size() != di.size(): - raise GradcheckError('grad is incorrect size') + raise GradcheckError("grad is incorrect size") return True + def _test_undefined_forward_mode(func, outputs, inputs): fwAD = torch.autograd.forward_ad @@ -988,7 +1229,9 @@ def _test_undefined_forward_mode(func, outputs, inputs): for i, inp in enumerate(inputs): if is_tensor_like(inp) and inp.requires_grad: if inp.layout == torch._mkldnn: # type: ignore[attr-defined] - raise ValueError("MKLDNN inputs are not support for forward AD gradcheck.") + raise ValueError( + "MKLDNN inputs are not support for forward AD gradcheck." + ) inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) # If inp is a differentiable view, the dual might not be the tangent given to @@ -1024,11 +1267,20 @@ def _test_undefined_forward_mode(func, outputs, inputs): if not (res1 is None or res2 is None): if not torch.allclose(res1, res2): - raise GradcheckError("Mismatch in tangent values for output with index: ", index_o, - " when input: ", inp, " has an undefined tangent value. ", - " Got: ", res1, " but expected: ", res2) + raise GradcheckError( + "Mismatch in tangent values for output with index: ", + index_o, + " when input: ", + inp, + " has an undefined tangent value. ", + " Got: ", + res1, + " but expected: ", + res2, + ) return True + def _test_undefined_backward_mode(func, outputs, inputs) -> bool: diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True)) if not diff_input_list: @@ -1036,46 +1288,61 @@ def _test_undefined_backward_mode(func, outputs, inputs) -> bool: def warn_bc_breaking(): warnings.warn( - 'Backwards compatibility: New undefined gradient support checking ' - 'feature is enabled by default, but it may break existing callers ' - 'of this function. If this is true for you, you can call this ' - 'function with "check_undefined_grad=False" to disable the feature') + "Backwards compatibility: New undefined gradient support checking " + "feature is enabled by default, but it may break existing callers " + "of this function. If this is true for you, you can call this " + 'function with "check_undefined_grad=False" to disable the feature' + ) def check_undefined_grad_support(output_to_check): - grads_output = [torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output_to_check] + grads_output = [ + torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) + for o in output_to_check + ] try: - grads_input = torch.autograd.grad(output_to_check, diff_input_list, - grads_output, allow_unused=True) + grads_input = torch.autograd.grad( + output_to_check, diff_input_list, grads_output, allow_unused=True + ) except RuntimeError as e: warn_bc_breaking() raise GradcheckError( - 'Expected backward function to handle undefined output grads. ' + "Expected backward function to handle undefined output grads. " 'Please look at "Notes about undefined output gradients" in ' - '"tools/autograd/derivatives.yaml"') from e + '"tools/autograd/derivatives.yaml"' + ) from e for gi, i in zip(grads_input, diff_input_list): if (gi is not None) and (not gi.eq(0).all()): warn_bc_breaking() raise GradcheckError( - 'Expected all input grads to be undefined or zero when all output grads are undefined ' + "Expected all input grads to be undefined or zero when all output grads are undefined " 'or zero. Please look at "Notes about undefined output gradients" in ' - '"tools/autograd/derivatives.yaml"') + '"tools/autograd/derivatives.yaml"' + ) return True # All backward functions must work properly if all output grads are undefined - outputs_to_check = [[ - torch._C._functions.UndefinedGrad()(o) for o in _differentiable_outputs(func(*inputs)) - # This check filters out Tensor-likes that aren't instances of Tensor. - if isinstance(o, torch.Tensor) - ]] + outputs_to_check = [ + [ + torch._C._functions.UndefinedGrad()(o) + for o in _differentiable_outputs(func(*inputs)) + # This check filters out Tensor-likes that aren't instances of Tensor. + if isinstance(o, torch.Tensor) + ] + ] # If there are multiple output grads, we should be able to undef one at a time without error if len(outputs_to_check[0]) > 1: for undef_grad_idx in range(len(outputs)): output_to_check = _differentiable_outputs(func(*inputs)) - outputs_to_check.append([ - torch._C._functions.UndefinedGrad()(o) if idx == undef_grad_idx else o - for idx, o in enumerate(output_to_check)]) + outputs_to_check.append( + [ + torch._C._functions.UndefinedGrad()(o) + if idx == undef_grad_idx + else o + for idx, o in enumerate(output_to_check) + ] + ) return all(check_undefined_grad_support(output) for output in outputs_to_check) @@ -1086,24 +1353,39 @@ def _as_tuple(x): elif isinstance(x, list): return tuple(x) else: - return x, + return (x,) def _differentiable_outputs(x): return tuple(o for o in _as_tuple(x) if o.requires_grad) -def _get_notallclose_msg(analytical, numerical, output_idx, input_idx, complex_indices, - test_imag=False, is_forward_ad=False) -> str: - out_is_complex = (not is_forward_ad) and complex_indices and output_idx in complex_indices +def _get_notallclose_msg( + analytical, + numerical, + output_idx, + input_idx, + complex_indices, + test_imag=False, + is_forward_ad=False, +) -> str: + out_is_complex = ( + (not is_forward_ad) and complex_indices and output_idx in complex_indices + ) inp_is_complex = is_forward_ad and complex_indices and input_idx in complex_indices part = "imaginary" if test_imag else "real" element = "inputs" if is_forward_ad else "outputs" - prefix = "" if not (out_is_complex or inp_is_complex) else \ - f"While considering the {part} part of complex {element} only, " + prefix = ( + "" + if not (out_is_complex or inp_is_complex) + else f"While considering the {part} part of complex {element} only, " + ) mode = "computed with forward mode " if is_forward_ad else "" - return prefix + 'Jacobian %smismatch for output %d with respect to input %d,\n' \ - 'numerical:%s\nanalytical:%s\n' % (mode, output_idx, input_idx, numerical, analytical) + return ( + prefix + "Jacobian %smismatch for output %d with respect to input %d,\n" + "numerical:%s\nanalytical:%s\n" + % (mode, output_idx, input_idx, numerical, analytical) + ) def _transpose(matrix_of_tensors): @@ -1118,10 +1400,12 @@ def _real_and_imag_output(fn): def wrapped_fn(*inputs): outs = _as_tuple(fn(*inputs)) return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs) + return wrapped_fn return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag) + def _real_and_imag_input(fn, complex_inp_indices, tupled_inputs): # returns new functions that take real inputs instead of complex inputs as # (x, y) -> fn(x + y * 1j). And it computes: inp -> fn(inp + y * 1j) and inp -> fn(x + inp * 1j). @@ -1131,18 +1415,33 @@ def _real_and_imag_input(fn, complex_inp_indices, tupled_inputs): def wrapped_fn(*inputs): new_inputs = list(inputs) for should_be_complex in complex_inp_indices: - new_inputs[should_be_complex] = fn_to_apply(new_inputs[should_be_complex], - tupled_inputs[should_be_complex]) + new_inputs[should_be_complex] = fn_to_apply( + new_inputs[should_be_complex], tupled_inputs[should_be_complex] + ) return _as_tuple(fn(*new_inputs)) + return wrapped_fn + real_fn = apply_to_c_inps(fn, lambda inp, orig: inp + orig.imag * 1j) imag_fn = apply_to_c_inps(fn, lambda inp, orig: orig.real + inp * 1j) return real_fn, imag_fn -def _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps, rtol, - atol, check_grad_dtypes, check_forward_ad, check_backward_ad, nondet_tol, - check_undefined_grad): +def _gradcheck_real_imag( + gradcheck_fn, + func, + func_out, + tupled_inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + check_forward_ad, + check_backward_ad, + nondet_tol, + check_undefined_grad, +): complex_out_indices = [i for i, o in enumerate(outputs) if o.is_complex()] has_any_complex_output = any(o.is_complex() for o in _as_tuple(func_out)) if check_backward_ad: @@ -1151,75 +1450,179 @@ def _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, e imag_func_out = imag_fn(*tupled_inputs) imag_outputs = _differentiable_outputs(imag_func_out) - gradcheck_fn(imag_fn, imag_func_out, tupled_inputs, imag_outputs, eps, - rtol, atol, check_grad_dtypes, nondet_tol, - complex_indices=complex_out_indices, test_imag=True) + gradcheck_fn( + imag_fn, + imag_func_out, + tupled_inputs, + imag_outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + complex_indices=complex_out_indices, + test_imag=True, + ) real_func_out = real_fn(*tupled_inputs) real_outputs = _differentiable_outputs(real_func_out) - gradcheck_fn(real_fn, real_func_out, tupled_inputs, real_outputs, eps, - rtol, atol, check_grad_dtypes, nondet_tol, complex_indices=complex_out_indices) + gradcheck_fn( + real_fn, + real_func_out, + tupled_inputs, + real_outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + complex_indices=complex_out_indices, + ) else: - gradcheck_fn(func, func_out, tupled_inputs, outputs, eps, - rtol, atol, check_grad_dtypes, nondet_tol) + gradcheck_fn( + func, + func_out, + tupled_inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + ) if check_forward_ad: - complex_inp_indices = [i for i, inp in enumerate(tupled_inputs) if is_tensor_like(inp) and inp.is_complex()] + complex_inp_indices = [ + i + for i, inp in enumerate(tupled_inputs) + if is_tensor_like(inp) and inp.is_complex() + ] if complex_inp_indices: - real_fn, imag_fn = _real_and_imag_input(func, complex_inp_indices, tupled_inputs) + real_fn, imag_fn = _real_and_imag_input( + func, complex_inp_indices, tupled_inputs + ) - imag_inputs = [inp.imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs] + imag_inputs = [ + inp.imag if is_tensor_like(inp) and inp.is_complex() else inp + for inp in tupled_inputs + ] imag_func_out = imag_fn(*imag_inputs) diff_imag_func_out = _differentiable_outputs(imag_func_out) - gradcheck_fn(imag_fn, imag_func_out, imag_inputs, diff_imag_func_out, eps, - rtol, atol, check_grad_dtypes, nondet_tol, - complex_indices=complex_inp_indices, test_imag=True, use_forward_ad=True) + gradcheck_fn( + imag_fn, + imag_func_out, + imag_inputs, + diff_imag_func_out, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + complex_indices=complex_inp_indices, + test_imag=True, + use_forward_ad=True, + ) - real_inputs = [inp.real if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs] + real_inputs = [ + inp.real if is_tensor_like(inp) and inp.is_complex() else inp + for inp in tupled_inputs + ] real_func_out = real_fn(*real_inputs) diff_real_func_out = _differentiable_outputs(real_func_out) - gradcheck_fn(real_fn, real_func_out, real_inputs, diff_real_func_out, eps, - rtol, atol, check_grad_dtypes, nondet_tol, complex_indices=complex_inp_indices, - use_forward_ad=True) + gradcheck_fn( + real_fn, + real_func_out, + real_inputs, + diff_real_func_out, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + complex_indices=complex_inp_indices, + use_forward_ad=True, + ) if check_undefined_grad: _test_undefined_forward_mode(imag_fn, imag_func_out, imag_inputs) _test_undefined_forward_mode(real_fn, real_func_out, real_inputs) else: - gradcheck_fn(func, func_out, tupled_inputs, outputs, eps, - rtol, atol, check_grad_dtypes, nondet_tol, use_forward_ad=True) + gradcheck_fn( + func, + func_out, + tupled_inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + use_forward_ad=True, + ) if check_undefined_grad: _test_undefined_forward_mode(func, outputs, tupled_inputs) -def _slow_gradcheck(func, func_out, tupled_inputs, outputs, eps, rtol, atol, check_grad_dtypes, - nondet_tol, *, use_forward_ad=False, complex_indices=None, test_imag=False, masked=False): + +def _slow_gradcheck( + func, + func_out, + tupled_inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + *, + use_forward_ad=False, + complex_indices=None, + test_imag=False, + masked=False, +): func_out = _as_tuple(func_out) if not outputs: - return _check_no_differentiable_outputs(func, tupled_inputs, func_out, - eps=eps, is_forward_ad=use_forward_ad) + return _check_no_differentiable_outputs( + func, tupled_inputs, func_out, eps=eps, is_forward_ad=use_forward_ad + ) tupled_inputs_numerical = tupled_inputs if masked else _densify(tupled_inputs) - numerical = _transpose(_get_numerical_jacobian(func, tupled_inputs_numerical, func_out, - eps=eps, is_forward_ad=use_forward_ad)) + numerical = _transpose( + _get_numerical_jacobian( + func, + tupled_inputs_numerical, + func_out, + eps=eps, + is_forward_ad=use_forward_ad, + ) + ) # Note: [numerical vs analytical output length] # The numerical path returns jacobian quantity for all outputs, even if requires_grad of that # output is False. This behavior is necessary for _check_no_differentiable_outputs to work. numerical = [nj for o, nj in zip(func_out, numerical) if o.requires_grad] if use_forward_ad: - analytical_forward = _get_analytical_jacobian_forward_ad(func, tupled_inputs, func_out, check_grad_dtypes=check_grad_dtypes) + analytical_forward = _get_analytical_jacobian_forward_ad( + func, tupled_inputs, func_out, check_grad_dtypes=check_grad_dtypes + ) for i, n_per_out in enumerate(numerical): for j, n in enumerate(n_per_out): a = analytical_forward[j][i] if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol): - raise GradcheckError(_get_notallclose_msg(a, n, i, j, complex_indices, test_imag, - is_forward_ad=True)) + raise GradcheckError( + _get_notallclose_msg( + a, n, i, j, complex_indices, test_imag, is_forward_ad=True + ) + ) else: for i, o in enumerate(outputs): - analytical = _check_analytical_jacobian_attributes(tupled_inputs, o, nondet_tol, check_grad_dtypes) + analytical = _check_analytical_jacobian_attributes( + tupled_inputs, o, nondet_tol, check_grad_dtypes + ) for j, (a, n) in enumerate(zip(analytical, numerical[i])): if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol): - raise GradcheckError(_get_notallclose_msg(a, n, i, j, complex_indices, test_imag)) + raise GradcheckError( + _get_notallclose_msg(a, n, i, j, complex_indices, test_imag) + ) return True @@ -1244,6 +1647,7 @@ def _to_real_dtype(dtype): else: return dtype + def _vec_from_tensor(x, generator, downcast_complex=False): # Create a random vector with the same number of elements as x and the same # dtype/device. If x is complex and downcast_complex is False, we create a @@ -1253,9 +1657,11 @@ def _vec_from_tensor(x, generator, downcast_complex=False): # indices. Make sure size is set so that it isn't inferred to be smaller. x_values = x._values() dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype - values = torch.rand(x_values.numel(), generator=generator) \ - .to(dtype=dtype, device=x.device) \ + values = ( + torch.rand(x_values.numel(), generator=generator) + .to(dtype=dtype, device=x.device) .view(x_values.shape) + ) values /= values.norm() vec = torch.sparse_coo_tensor(x._indices(), values, x.size()) elif _is_sparse_compressed_tensor(x): @@ -1265,20 +1671,30 @@ def _vec_from_tensor(x, generator, downcast_complex=False): compressed_indices, plain_indices = x.ccol_indices(), x.row_indices() x_values = x.values() dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype - values = torch.rand(x_values.numel(), generator=generator) \ - .to(dtype=dtype, device=x.device) \ + values = ( + torch.rand(x_values.numel(), generator=generator) + .to(dtype=dtype, device=x.device) .view(x_values.shape) + ) values /= values.norm() - vec = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, x.size(), layout=x.layout) + vec = torch.sparse_compressed_tensor( + compressed_indices, plain_indices, values, x.size(), layout=x.layout + ) else: dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype - vec = torch.rand(x.numel(), generator=generator).to(dtype=dtype, device=x.device) + vec = torch.rand(x.numel(), generator=generator).to( + dtype=dtype, device=x.device + ) vec /= vec.norm() return vec def _get_inp_tensors(tupled_inputs): - inp_idx_tup = [(i, t) for i, t in enumerate(tupled_inputs) if is_tensor_like(t) and t.requires_grad] + inp_idx_tup = [ + (i, t) + for i, t in enumerate(tupled_inputs) + if is_tensor_like(t) and t.requires_grad + ] return [tup[0] for tup in inp_idx_tup], [tup[1] for tup in inp_idx_tup] @@ -1293,7 +1709,7 @@ def _adjusted_atol(atol, u, v): # TODO: properly handle case when u is tuple instead of only taking first element u = u[0] if isinstance(u, tuple) else u sum_u = u.sum() - sum_v = 1. if v is None else v.sum() + sum_v = 1.0 if v is None else v.sum() return atol * float(sum_u) * float(sum_v) @@ -1314,29 +1730,40 @@ If the test """.strip() -def _run_slow_mode_and_get_error(func, tupled_inputs, outputs, input_idx, output_idx, rtol, atol, is_forward_ad): +def _run_slow_mode_and_get_error( + func, tupled_inputs, outputs, input_idx, output_idx, rtol, atol, is_forward_ad +): # Compute jacobians in slow mode for better error message - slow_numerical = _get_numerical_jacobian(func, tupled_inputs, outputs, is_forward_ad=is_forward_ad)[input_idx][output_idx] + slow_numerical = _get_numerical_jacobian( + func, tupled_inputs, outputs, is_forward_ad=is_forward_ad + )[input_idx][output_idx] if is_forward_ad: + def new_fn(inp): new_inputs = list(tupled_inputs) new_inputs[input_idx] = inp return _as_tuple(func(*new_inputs))[output_idx] - slow_analytical = _get_analytical_jacobian_forward_ad(new_fn, (tupled_inputs[input_idx],), (outputs[output_idx],))[0][0] - else: - slow_analytical = _get_analytical_jacobian(tupled_inputs, outputs, input_idx, output_idx) + slow_analytical = _get_analytical_jacobian_forward_ad( + new_fn, (tupled_inputs[input_idx],), (outputs[output_idx],) + )[0][0] + else: + slow_analytical = _get_analytical_jacobian( + tupled_inputs, outputs, input_idx, output_idx + ) # Assume jacobians are non-empty and have the same shape slow_max_diff = (slow_numerical - slow_analytical).abs().max() slow_allclose = torch.allclose(slow_analytical, slow_numerical, rtol, atol) - msg = ("\nThe above quantities relating the numerical and analytical jacobians are computed \n" - "in fast mode. See: https://github.com/pytorch/pytorch/issues/53876 for more background \n" - "about fast mode. Below, we recompute numerical and analytical jacobians in slow mode:\n\n" - f"Numerical:\n {slow_numerical}\n" - f"Analytical:\n{slow_analytical}\n\n" - f"The max per-element difference (slow mode) is: {slow_max_diff}.\n") + msg = ( + "\nThe above quantities relating the numerical and analytical jacobians are computed \n" + "in fast mode. See: https://github.com/pytorch/pytorch/issues/53876 for more background \n" + "about fast mode. Below, we recompute numerical and analytical jacobians in slow mode:\n\n" + f"Numerical:\n {slow_numerical}\n" + f"Analytical:\n{slow_analytical}\n\n" + f"The max per-element difference (slow mode) is: {slow_max_diff}.\n" + ) if slow_allclose: # Slow gradcheck would've passed! msg += FAST_FAIL_SLOW_OK_MSG @@ -1366,12 +1793,27 @@ def _make_vectors(inp_tensors, outputs, *, use_forward_ad): else: all_u.append(ur) all_u_dense.append(ur_dense) - all_v = None if use_forward_ad else [_vec_from_tensor(out, g_cpu) for out in outputs] + all_v = ( + None if use_forward_ad else [_vec_from_tensor(out, g_cpu) for out in outputs] + ) return all_v, all_u, all_u_dense -def _check_analytical_numerical_equal(all_analytical, all_numerical, complex_indices, tupled_inputs, outputs, - func, all_v, all_u, rtol, atol, test_imag, *, is_forward_ad=False): +def _check_analytical_numerical_equal( + all_analytical, + all_numerical, + complex_indices, + tupled_inputs, + outputs, + func, + all_v, + all_u, + rtol, + atol, + test_imag, + *, + is_forward_ad=False, +): for i, all_numerical_for_input_i in enumerate(all_numerical): for j, n in enumerate(all_numerical_for_input_i): # Forward AD generates the transpose of what this function expects @@ -1382,13 +1824,33 @@ def _check_analytical_numerical_equal(all_analytical, all_numerical, complex_ind n = n.to(device=a.device) updated_atol = _adjusted_atol(atol, all_u[i], all_v[j] if all_v else None) if not _allclose_with_type_promotion(a, n.to(a.device), rtol, updated_atol): - jacobians_str = _run_slow_mode_and_get_error(func, tupled_inputs, outputs, i, j, rtol, atol, is_forward_ad) - raise GradcheckError(_get_notallclose_msg(a, n, j, i, complex_indices, test_imag, is_forward_ad) + jacobians_str) + jacobians_str = _run_slow_mode_and_get_error( + func, tupled_inputs, outputs, i, j, rtol, atol, is_forward_ad + ) + raise GradcheckError( + _get_notallclose_msg( + a, n, j, i, complex_indices, test_imag, is_forward_ad + ) + + jacobians_str + ) -def _fast_gradcheck(func, func_out, inputs, outputs, eps, rtol, - atol, check_grad_dtypes, nondet_tol, *, use_forward_ad=False, complex_indices=None, test_imag=False, - masked=False): +def _fast_gradcheck( + func, + func_out, + inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + nondet_tol, + *, + use_forward_ad=False, + complex_indices=None, + test_imag=False, + masked=False, +): # See https://github.com/pytorch/pytorch/issues/53876 for details inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs) # Backward mode computes v^T * J (VJP) @@ -1398,25 +1860,58 @@ def _fast_gradcheck(func, func_out, inputs, outputs, eps, rtol, # Forward mode computes J * u (JVP) # Since we already compute JVP through finite difference method, # we don't need v for correctness check here as asserted below - all_v, all_u, all_u_dense = _make_vectors(inp_tensors, outputs, use_forward_ad=use_forward_ad) + all_v, all_u, all_u_dense = _make_vectors( + inp_tensors, outputs, use_forward_ad=use_forward_ad + ) - inputs_numerical, all_u_numerical, all_v_numerical = (inputs, all_u, all_v) if masked else _densify((inputs, all_u, all_v)) + inputs_numerical, all_u_numerical, all_v_numerical = ( + (inputs, all_u, all_v) if masked else _densify((inputs, all_u, all_v)) + ) - numerical_vJu = _get_numerical_vJu(func, inputs_numerical, inp_tensors_idx, func_out, - all_u_numerical, all_v_numerical, eps, is_forward_ad=use_forward_ad) + numerical_vJu = _get_numerical_vJu( + func, + inputs_numerical, + inp_tensors_idx, + func_out, + all_u_numerical, + all_v_numerical, + eps, + is_forward_ad=use_forward_ad, + ) # TODO: replicate https://github.com/pytorch/pytorch/pull/77743 for fast gradcheck as well if use_forward_ad: assert all_v is None - analytical_vJu = _get_analytical_jacobian_forward_ad(func, inputs, _as_tuple(func_out), - all_u=all_u, check_grad_dtypes=check_grad_dtypes) + analytical_vJu = _get_analytical_jacobian_forward_ad( + func, + inputs, + _as_tuple(func_out), + all_u=all_u, + check_grad_dtypes=check_grad_dtypes, + ) else: if not outputs: - _check_no_differentiable_outputs_fast(func, func_out, inputs, inp_tensors_idx, all_u, eps, nondet_tol) + _check_no_differentiable_outputs_fast( + func, func_out, inputs, inp_tensors_idx, all_u, eps, nondet_tol + ) - analytical_vJu = _get_analytical_vJu_backward_mode(inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u_dense) + analytical_vJu = _get_analytical_vJu_backward_mode( + inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u_dense + ) - _check_analytical_numerical_equal(analytical_vJu, numerical_vJu, complex_indices, - inputs, outputs, func, all_v, all_u, rtol, atol, test_imag, is_forward_ad=use_forward_ad) + _check_analytical_numerical_equal( + analytical_vJu, + numerical_vJu, + complex_indices, + inputs, + outputs, + func, + all_v, + all_u, + rtol, + atol, + test_imag, + is_forward_ad=use_forward_ad, + ) return True @@ -1526,18 +2021,24 @@ def gradcheck( check_sparse_nnz = masked else: warnings.warn( - 'Backwards compatibility: check_sparse_nnz is deprecated, it will be removed in a future version of PyTorch.' - f' Use masked={check_sparse_nnz} instead.') + "Backwards compatibility: check_sparse_nnz is deprecated, it will be removed in a future version of PyTorch." + f" Use masked={check_sparse_nnz} instead." + ) if masked is None: masked = check_sparse_nnz elif check_sparse_nnz != masked: - raise ValueError(f"Expected specified check_sparse_nnz (={check_sparse_nnz}) to be equal to masked (={masked}).") - assert check_forward_ad or check_backward_ad, \ - "Expected at least one of check_forward_ad or check_backward_ad to be True" - assert not (check_batched_grad and not check_backward_ad), ( - "Setting check_batched_grad=True requires check_backward_ad to be True") - assert not (check_batched_forward_grad and not check_forward_ad), ( - "Setting check_batched_forward_grad=True requires check_forward_ad to be True") + raise ValueError( + f"Expected specified check_sparse_nnz (={check_sparse_nnz}) to be equal to masked (={masked})." + ) + assert ( + check_forward_ad or check_backward_ad + ), "Expected at least one of check_forward_ad or check_backward_ad to be True" + assert not ( + check_batched_grad and not check_backward_ad + ), "Setting check_batched_grad=True requires check_backward_ad to be True" + assert not ( + check_batched_forward_grad and not check_forward_ad + ), "Setting check_batched_forward_grad=True requires check_forward_ad to be True" args = locals().copy() args.pop("raise_exception") args.pop("check_sparse_nnz") @@ -1550,9 +2051,22 @@ def gradcheck( return _gradcheck_helper(**args) -def _gradcheck_helper(func, inputs, eps, atol, rtol, nondet_tol, check_undefined_grad, - check_grad_dtypes, check_batched_grad, check_batched_forward_grad, check_forward_ad, - check_backward_ad, fast_mode, masked): +def _gradcheck_helper( + func, + inputs, + eps, + atol, + rtol, + nondet_tol, + check_undefined_grad, + check_grad_dtypes, + check_batched_grad, + check_batched_forward_grad, + check_forward_ad, + check_backward_ad, + fast_mode, + masked, +): tupled_inputs = _as_tuple(inputs) _check_inputs(tupled_inputs) @@ -1560,11 +2074,24 @@ def _gradcheck_helper(func, inputs, eps, atol, rtol, nondet_tol, check_undefined outputs = _differentiable_outputs(func_out) _check_outputs(outputs) - gradcheck_fn = functools.partial(_fast_gradcheck if fast_mode else _slow_gradcheck, masked=masked) - _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps, - rtol, atol, check_grad_dtypes, check_forward_ad=check_forward_ad, - check_backward_ad=check_backward_ad, nondet_tol=nondet_tol, - check_undefined_grad=check_undefined_grad) + gradcheck_fn = functools.partial( + _fast_gradcheck if fast_mode else _slow_gradcheck, masked=masked + ) + _gradcheck_real_imag( + gradcheck_fn, + func, + func_out, + tupled_inputs, + outputs, + eps, + rtol, + atol, + check_grad_dtypes, + check_forward_ad=check_forward_ad, + check_backward_ad=check_backward_ad, + nondet_tol=nondet_tol, + check_undefined_grad=check_undefined_grad, + ) if check_batched_forward_grad: _test_batched_grad_forward_ad(func, tupled_inputs) @@ -1657,12 +2184,15 @@ def gradgradcheck( Returns: True if all differences satisfy allclose condition """ - assert check_fwd_over_rev or check_rev_over_rev, \ - "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True" - assert not (check_undefined_grad and not check_rev_over_rev), \ - "Setting check_undefined_grad=True requires check_rev_over_rev to be True" - assert not (check_batched_grad and not check_rev_over_rev), ( - "Setting check_batched_grad=True requires check_rev_over_rev to be True") + assert ( + check_fwd_over_rev or check_rev_over_rev + ), "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True" + assert not ( + check_undefined_grad and not check_rev_over_rev + ), "Setting check_undefined_grad=True requires check_rev_over_rev to be True" + assert not ( + check_batched_grad and not check_rev_over_rev + ), "Setting check_batched_grad=True requires check_rev_over_rev to be True" # TODO: do we want to test this too? # assert not (check_batched_forward_grad and not check_fwd_over_rev), ( # "Setting check_batched_forward_grad=True requires check_fwd_over_rev to be True") @@ -1675,7 +2205,9 @@ def gradgradcheck( tupled_grad_outputs = tuple( torch.testing.make_tensor( x.shape, - dtype=x.dtype if x.is_floating_point() or x.is_complex() else torch.double, + dtype=x.dtype + if x.is_floating_point() or x.is_complex() + else torch.double, device=x.device, low=-1, high=1, @@ -1691,22 +2223,46 @@ def gradgradcheck( # NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs # before running forward mode AD - diff_input_args_indices = {i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad} - diff_grad_output_indices = {i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad} + diff_input_args_indices = { + i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad + } + diff_grad_output_indices = { + i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad + } def new_func(*args): # Restore the requires_grad information - input_args = tuple(x.requires_grad_() if i in diff_input_args_indices else x for i, x in enumerate(args[:-num_outputs])) + input_args = tuple( + x.requires_grad_() if i in diff_input_args_indices else x + for i, x in enumerate(args[:-num_outputs]) + ) outputs = _differentiable_outputs(func(*input_args)) - grad_outputs = tuple(x.requires_grad_() if i in diff_grad_output_indices else x for i, x in enumerate(args[-num_outputs:])) - diff_input_args = tuple(x for i, x in enumerate(input_args) if i in diff_input_args_indices) - grad_inputs = torch.autograd.grad(outputs, diff_input_args, grad_outputs, create_graph=True, - allow_unused=True) + grad_outputs = tuple( + x.requires_grad_() if i in diff_grad_output_indices else x + for i, x in enumerate(args[-num_outputs:]) + ) + diff_input_args = tuple( + x for i, x in enumerate(input_args) if i in diff_input_args_indices + ) + grad_inputs = torch.autograd.grad( + outputs, diff_input_args, grad_outputs, create_graph=True, allow_unused=True + ) grad_inputs = tuple(g for g in grad_inputs if g is not None) return grad_inputs return gradcheck( - new_func, tupled_inputs + tupled_grad_outputs, eps=eps, atol=atol, rtol=rtol, raise_exception=raise_exception, - nondet_tol=nondet_tol, check_undefined_grad=check_undefined_grad, - check_grad_dtypes=check_grad_dtypes, check_batched_grad=check_batched_grad, fast_mode=fast_mode, - check_forward_ad=check_fwd_over_rev, check_backward_ad=check_rev_over_rev, masked=masked) + new_func, + tupled_inputs + tupled_grad_outputs, + eps=eps, + atol=atol, + rtol=rtol, + raise_exception=raise_exception, + nondet_tol=nondet_tol, + check_undefined_grad=check_undefined_grad, + check_grad_dtypes=check_grad_dtypes, + check_batched_grad=check_batched_grad, + fast_mode=fast_mode, + check_forward_ad=check_fwd_over_rev, + check_backward_ad=check_rev_over_rev, + masked=masked, + ) diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 1b6b04c8226c..87d8b03eaf7f 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -1,11 +1,12 @@ -import torch -import contextlib -from typing import Callable, Any, Dict, Tuple, Optional, Sequence, List, Set -from torch.utils.hooks import RemovableHandle -from torch.utils._python_dispatch import TorchDispatchMode -from collections import defaultdict -import weakref import abc +import contextlib +import weakref +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple + +import torch +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.hooks import RemovableHandle __all__ = [ "saved_tensors_hooks", @@ -17,6 +18,7 @@ __all__ = [ "increment_version", ] + class Node(abc.ABC): @abc.abstractmethod def name(self) -> str: @@ -35,7 +37,7 @@ class Node(abc.ABC): @property @abc.abstractmethod - def next_functions(self) -> Tuple[Tuple[Optional['Node'], int], ...]: + def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]: ... @abc.abstractmethod @@ -124,11 +126,13 @@ class Node(abc.ABC): @classmethod def __subclasshook__(cls, C): if cls is Node: - if ((C is not None and C is getattr(torch._C._functions, C.__name__, None)) - or issubclass(C, torch.autograd.function.BackwardCFunction)): + if ( + C is not None and C is getattr(torch._C._functions, C.__name__, None) + ) or issubclass(C, torch.autograd.function.BackwardCFunction): return True return NotImplemented + def increment_version(tensor): """This function can be used to let autograd know that a given Tensor was modified inplace to enable more accurate error checking within the autograd engine. @@ -144,6 +148,7 @@ def increment_version(tensor): """ torch._C._increment_version(tensor) + class saved_tensors_hooks: """Context-manager that sets a pair of pack / unpack hooks for saved tensors. @@ -205,12 +210,19 @@ class saved_tensors_hooks: Only one pair of hooks is allowed at a time. When recursively nesting this context-manager, only the inner-most pair of hooks will be applied. """ - def __init__(self, pack_hook: Callable[[torch.Tensor], Any], unpack_hook: Callable[[Any], torch.Tensor]): + + def __init__( + self, + pack_hook: Callable[[torch.Tensor], Any], + unpack_hook: Callable[[Any], torch.Tensor], + ): self.pack_hook = pack_hook self.unpack_hook = unpack_hook def __enter__(self): - torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook) + torch._C._autograd._push_saved_tensors_default_hooks( + self.pack_hook, self.unpack_hook + ) def __exit__(self, *args: Any): torch._C._autograd._pop_saved_tensors_default_hooks() @@ -258,6 +270,7 @@ class save_on_cpu(saved_tensors_hooks): >>> # all intermediary tensors are released (deleted) after the call to backward """ + def __init__(self, pin_memory=False, device_type="cuda"): device_module = getattr(torch, device_type, torch.cuda) @@ -268,7 +281,8 @@ class save_on_cpu(saved_tensors_hooks): tensor.size(), dtype=tensor.dtype, layout=tensor.layout, - pin_memory=(device_module.is_available() and not tensor.is_sparse)) + pin_memory=(device_module.is_available() and not tensor.is_sparse), + ) packed.copy_(tensor) return (tensor.device, packed) @@ -302,7 +316,9 @@ def disable_saved_tensors_hooks(error_message): """ try: - maybe_prev_message = torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() + maybe_prev_message = ( + torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() + ) torch._C._autograd._saved_tensors_hooks_disable(error_message) yield finally: @@ -313,7 +329,10 @@ def disable_saved_tensors_hooks(error_message): torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message) -def register_multi_grad_hook(tensors: Sequence[torch.Tensor], fn: Callable[[Sequence[Optional[torch.Tensor]]], None]): +def register_multi_grad_hook( + tensors: Sequence[torch.Tensor], + fn: Callable[[Sequence[Optional[torch.Tensor]]], None], +): r"""Registers a multi-grad backward hook. The hook will be called after gradients with respect to every tensor in @@ -388,6 +407,7 @@ def register_multi_grad_hook(tensors: Sequence[torch.Tensor], fn: Callable[[Sequ fn(buffer[id]) del count[id] del buffer[id] + return inner_hook class Handle(RemovableHandle): @@ -428,15 +448,19 @@ def register_multi_grad_hook(tensors: Sequence[torch.Tensor], fn: Callable[[Sequ # - if the clone exists, the tensor must've been modified in-place _allow_mutation_on_saved_tensors_enabled = False + def _get_tid(t) -> Tuple[int, int, int]: return (id(t), t.data_ptr(), t._version) + def _get_sid(t) -> Tuple[int, int]: return (t.data_ptr(), t._version) + class _Handle: pass + class _swap_with_cloned(saved_tensors_hooks): def __init__(self, ctx): def pack_hook(t): @@ -462,7 +486,8 @@ class _swap_with_cloned(saved_tensors_hooks): handle = tup error_msg = ( "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" - "in which the graph was originally recorded.") + "in which the graph was originally recorded." + ) assert _allow_mutation_on_saved_tensors_enabled, error_msg if handle in ctx.cloned: res = ctx.cloned[handle] @@ -473,6 +498,7 @@ class _swap_with_cloned(saved_tensors_hooks): super().__init__(pack_hook, unpack_hook) + class _CloneArgBeforeMutateMode(TorchDispatchMode): def __init__(self, ctx): self.ctx = ctx @@ -509,12 +535,17 @@ class _CloneArgBeforeMutateMode(TorchDispatchMode): rs = func(*args, **kwargs) return rs + class _AllowMutationOnSavedContext: def __init__(self): self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() - self.tid_to_weakhandle: weakref.WeakValueDictionary = weakref.WeakValueDictionary() - self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(set) + self.tid_to_weakhandle: weakref.WeakValueDictionary = ( + weakref.WeakValueDictionary() + ) + self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict( + set + ) def clear(self): self.cloned.clear() @@ -522,6 +553,7 @@ class _AllowMutationOnSavedContext: self.tid_to_weakhandle.clear() self.sid_to_tid.clear() + @contextlib.contextmanager def allow_mutation_on_saved_tensors(): """Context manager under which mutating tensors saved for backward is allowed @@ -560,7 +592,9 @@ def allow_mutation_on_saved_tensors(): with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx): try: if _allow_mutation_on_saved_tensors_enabled: - raise RuntimeError("allow_mutation_on_saved_tensors contexts cannot be nested") + raise RuntimeError( + "allow_mutation_on_saved_tensors contexts cannot be nested" + ) _allow_mutation_on_saved_tensors_enabled = True yield ctx finally: diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 7082b200eab1..c986175cafc6 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -1,12 +1,12 @@ -from typing import Any, Dict, List, Optional from collections import defaultdict +from typing import Any, Dict, List, Optional from warnings import warn import torch import torch.cuda -from torch._C._profiler import _ExperimentalConfig from torch._C import _get_privateuse1_backend_name +from torch._C._profiler import _ExperimentalConfig from torch.autograd import ( _disable_profiler, @@ -33,8 +33,19 @@ from torch.autograd.profiler_util import ( ) from torch.futures import Future -__all__ = ["profile", "record_function", "emit_itt", "emit_nvtx", "load_nvprof", "EnforceUnique", - "parse_nvprof_trace", "KinetoStepTracker", "EventList", "FunctionEvent", "MemRecordsAcc"] +__all__ = [ + "profile", + "record_function", + "emit_itt", + "emit_nvtx", + "load_nvprof", + "EnforceUnique", + "parse_nvprof_trace", + "KinetoStepTracker", + "EventList", + "FunctionEvent", + "MemRecordsAcc", +] try: # Available in Python >= 3.2 @@ -43,7 +54,6 @@ except ImportError: import functools class _ContextDecorator: # type: ignore[no-redef] - def __enter__(self): raise NotImplementedError @@ -58,11 +68,13 @@ except ImportError: return wrapped + def _enable_dynamo_cache_lookup_profiler(enable: bool): from torch._dynamo.eval_frame import ( # type: ignore[attr-defined] clear_profiler_hooks, set_profiler_hooks, ) + """ Registers a hook within dynamo eval_frame.c called before and after the lookup process, which runs guards associated with each cached frame. @@ -77,6 +89,7 @@ def _enable_dynamo_cache_lookup_profiler(enable: bool): def _profiler_end(record): torch.ops.profiler._record_function_exit._RecordFunction(record) + set_profiler_hooks(_profiler_start, _profiler_end) else: clear_profiler_hooks() @@ -166,21 +179,23 @@ class profile: ----------------------------------- --------------- --------------- --------------- """ + def __init__( - self, - enabled=True, - *, - use_cuda=False, - use_device=None, - record_shapes=False, - with_flops=False, - profile_memory=False, - with_stack=False, - with_modules=False, - use_kineto=False, - use_cpu=True, - use_mtia=False, - experimental_config=None): + self, + enabled=True, + *, + use_cuda=False, + use_device=None, + record_shapes=False, + with_flops=False, + profile_memory=False, + with_stack=False, + with_modules=False, + use_kineto=False, + use_cpu=True, + use_mtia=False, + experimental_config=None, + ): self.enabled: bool = enabled if not self.enabled: return @@ -202,8 +217,9 @@ class profile: self.kineto_results: Optional[_ProfilerResult] = None if not self.use_cpu: - assert use_kineto, \ - "Device-only events supported only with Kineto (use_kineto=True)" + assert ( + use_kineto + ), "Device-only events supported only with Kineto (use_kineto=True)" if self.use_cuda and not torch.cuda.is_available(): warn("CUDA is not available, disabling CUDA profiling") @@ -217,21 +233,22 @@ class profile: self.profiler_kind = ProfilerState.KINETO if self.use_cuda: - if (not use_kineto or ProfilerActivity.CUDA not in - _supported_activities()): + if not use_kineto or ProfilerActivity.CUDA not in _supported_activities(): assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True" self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK else: self.kineto_activities.add(ProfilerActivity.CUDA) if self.use_device: - if self.use_device == 'cuda': + if self.use_device == "cuda": # TODO:using 'use_device' instead of 'use_cuda' facilitates access by other devices # and integrate it in subsequent pr. pass elif self.use_device == _get_privateuse1_backend_name(): if not use_kineto: - assert self.use_cpu, "Legacy custombackend profiling requires use_cpu=True" + assert ( + self.use_cpu + ), "Legacy custombackend profiling requires use_cpu=True" self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK else: raise AssertionError( @@ -240,9 +257,9 @@ class profile: else: raise AssertionError(f"{self.use_device} doesn't support profile.") - assert len(self.kineto_activities) > 0, \ - "No activities specified for the profiler" - + assert ( + len(self.kineto_activities) > 0 + ), "No activities specified for the profiler" def config(self): return ProfilerConfig( @@ -252,7 +269,8 @@ class profile: self.with_stack, self.with_flops, self.with_modules, - self.experimental_config) + self.experimental_config, + ) def __enter__(self): if not self.enabled: @@ -284,18 +302,19 @@ class profile: parsed_results, use_cuda=self.use_cuda, profile_memory=self.profile_memory, - with_flops=self.with_flops) + with_flops=self.with_flops, + ) self.function_events._build_tree() return False def __repr__(self): if self.function_events is None: - return '' + return "" return repr(self.function_events) def __str__(self): if self.function_events is None: - return '' + return "" return str(self.function_events) def _check_finish(self): @@ -303,14 +322,14 @@ class profile: raise RuntimeError("Profiler didn't finish running") def table( - self, - sort_by=None, - row_limit=100, - max_src_column_width=75, - max_name_column_width=55, - max_shapes_column_width=80, - header=None, - top_level_events_only=False + self, + sort_by=None, + row_limit=100, + max_src_column_width=75, + max_name_column_width=55, + max_shapes_column_width=80, + header=None, + top_level_events_only=False, ): self._check_finish() assert self.function_events is not None @@ -321,8 +340,9 @@ class profile: max_name_column_width=max_name_column_width, max_shapes_column_width=max_shapes_column_width, header=header, - top_level_events_only=top_level_events_only + top_level_events_only=top_level_events_only, ) + table.__doc__ = EventList.table.__doc__ def export_chrome_trace(self, path): @@ -331,6 +351,7 @@ class profile: self.kineto_results.save(path) # type: ignore[union-attr] else: return self.function_events.export_chrome_trace(path) # type: ignore[union-attr] + export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__ def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): @@ -343,17 +364,19 @@ class profile: self._check_finish() assert self.function_events is not None, "Expected profiling results" return self.function_events.key_averages(group_by_input_shape, group_by_stack_n) + key_averages.__doc__ = EventList.key_averages.__doc__ def total_average(self): self._check_finish() assert self.function_events is not None, "Expected profiling results" return self.function_events.total_average() + total_average.__doc__ = EventList.total_average.__doc__ @property def self_cpu_time_total(self): - """ Returns total time spent on CPU obtained as a sum of + """Returns total time spent on CPU obtained as a sum of all self times across all the events. """ self._check_finish() @@ -364,19 +387,28 @@ class profile: # result.events() has most of the events - PyTorch op-level and device-level events trace_start_us = result.trace_start_us() - mem_records = [[evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME] - oom_records = [evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME] + mem_records = [ + [evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME + ] + oom_records = [ + evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME + ] mem_records_acc = MemRecordsAcc(mem_records) def _cpu_memory_usage(mem_record): - return mem_record.nbytes() if \ - mem_record.device_type() in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP] \ + return ( + mem_record.nbytes() + if mem_record.device_type() + in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP] else 0 + ) def _cuda_memory_usage(mem_record): - return mem_record.nbytes() if \ - mem_record.device_type() in [DeviceType.CUDA, DeviceType.HIP] \ + return ( + mem_record.nbytes() + if mem_record.device_type() in [DeviceType.CUDA, DeviceType.HIP] else 0 + ) # Create and return FunctionEvent list function_events = [] @@ -393,7 +425,9 @@ class profile: cuda_memory_usage = 0 if kineto_event.device_type() == DeviceType.CPU: # find the corresponding memory allocation events - for mem_record in mem_records_acc.in_interval(kineto_event.start_us(), abs_end_us): + for mem_record in mem_records_acc.in_interval( + kineto_event.start_us(), abs_end_us + ): cpu_memory_usage += _cpu_memory_usage(mem_record[0]) cuda_memory_usage += _cuda_memory_usage(mem_record[0]) mem_record[1] = True @@ -412,7 +446,11 @@ class profile: fwd_thread=kineto_event.fwd_thread_id(), input_shapes=kineto_event.shapes(), concrete_inputs=kineto_event.concrete_inputs(), - stack=[entry for entry in kineto_event.stack() if _filter_stack_entry(entry)], + stack=[ + entry + for entry in kineto_event.stack() + if _filter_stack_entry(entry) + ], scope=kineto_event.scope(), cpu_memory_usage=cpu_memory_usage, cuda_memory_usage=cuda_memory_usage, @@ -427,10 +465,7 @@ class profile: # Check if we have CUDA time as a fallback cuda_time = kineto_event.cuda_elapsed_us() if cuda_time > 0: - fe.append_kernel( - fe.name, - fe.device_index, - cuda_time) + fe.append_kernel(fe.name, fe.device_index, cuda_time) fe.is_legacy = True function_events.append(fe) corr_id = kineto_event.linked_correlation_id() @@ -441,21 +476,24 @@ class profile: # associate CUDA kernels and CUDA runtime (CPU) with CPU events for fe in function_events: - if (fe.device_type == DeviceType.CPU and not fe.is_async and - fe.id in cuda_corr_map): + if ( + fe.device_type == DeviceType.CPU + and not fe.is_async + and fe.id in cuda_corr_map + ): for f_evt in cuda_corr_map[fe.id]: if f_evt.device_type == DeviceType.CUDA: fe.append_kernel( f_evt.name, f_evt.device_index, - f_evt.time_range.end - f_evt.time_range.start) + f_evt.time_range.end - f_evt.time_range.start, + ) elif f_evt.device_type == DeviceType.CPU: # make sure that 'thread' of a CPU Kineto (e.g. CUDA Runtime) event is associated # with the 'thread' of the corresponding linked PyTorch event to properly track # parents and children f_evt.thread = fe.thread - def createFunctionEventForMemoryEvents(evt): rel_start_us = evt.start_us() - trace_start_us fe = FunctionEvent( @@ -490,7 +528,9 @@ class profile: fe = createFunctionEventForMemoryEvents(oom_record) function_events.append(fe) - function_events.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end]) + function_events.sort( + key=lambda evt: [evt.time_range.start, -evt.time_range.end] + ) return function_events @@ -530,6 +570,7 @@ class record_function(_ContextDecorator): CUDA time total: 0.000us """ + def __init__(self, name: str, args: Optional[str] = None): self.name: str = name self.args: Optional[str] = args @@ -537,10 +578,14 @@ class record_function(_ContextDecorator): self.run_callbacks_on_exit: bool = True # TODO: TorchScript ignores standard type annotation here # self.record: Optional["torch.classes.profiler._RecordFunction"] = None - self.record = torch.jit.annotate(Optional["torch.classes.profiler._RecordFunction"], None) + self.record = torch.jit.annotate( + Optional["torch.classes.profiler._RecordFunction"], None + ) def __enter__(self): - self.record = torch.ops.profiler._record_function_enter_new(self.name, self.args) + self.record = torch.ops.profiler._record_function_enter_new( + self.name, self.args + ) return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): @@ -593,10 +638,15 @@ class record_function(_ContextDecorator): # See https://github.com/pytorch/pytorch/issues/76410 if not torch.jit.is_scripting(): with torch._C.DisableTorchFunctionSubclass(): - profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction( - record, fut) + profiled_future = ( + torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction( + record, fut + ) + ) else: - profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut) + profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut( + record, fut + ) return profiled_future @@ -636,6 +686,7 @@ class emit_itt: ... model(x) """ + def __init__(self, enabled=True, record_shapes=False): self.enabled = enabled self.entered = False @@ -655,8 +706,9 @@ class emit_itt: False, False, False, - _ExperimentalConfig()), - set() + _ExperimentalConfig(), + ), + set(), ) return self @@ -751,6 +803,7 @@ class emit_nvtx: backward Function object. You may need to make a judgment based on analytic knowledge of what the expected correspondence should be. """ + def __init__(self, enabled=True, record_shapes=False): self.enabled = enabled self.entered = False @@ -771,8 +824,9 @@ class emit_nvtx: False, False, False, - _ExperimentalConfig()), - set() + _ExperimentalConfig(), + ), + set(), ) return self @@ -795,17 +849,19 @@ def load_nvprof(path): class EnforceUnique: """Raises an error if a key is seen more than once.""" + def __init__(self): self.seen = set() def see(self, *key): if key in self.seen: - raise RuntimeError('duplicate key: ' + str(key)) + raise RuntimeError("duplicate key: " + str(key)) self.seen.add(key) def parse_nvprof_trace(path): import sqlite3 + conn = sqlite3.connect(path) conn.row_factory = sqlite3.Row @@ -828,14 +884,16 @@ def parse_nvprof_trace(path): functions_map = {} unique = EnforceUnique() for row in conn.execute(marker_query): - unique.see(row['marker_id']) - evt = FunctionEvent(id=row['marker_id'], - node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure - # that pytorch doesn't crash when creating a FunctionEvent() object - name=strings[row['name']], - start_us=row['start_time'], - end_us=row['end_time'], - thread=0) # TODO: find in sqlite database + unique.see(row["marker_id"]) + evt = FunctionEvent( + id=row["marker_id"], + node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure + # that pytorch doesn't crash when creating a FunctionEvent() object + name=strings[row["name"]], + start_us=row["start_time"], + end_us=row["end_time"], + thread=0, + ) # TODO: find in sqlite database functions.append(evt) functions_map[evt.id] = evt @@ -856,13 +914,13 @@ def parse_nvprof_trace(path): """ unique = EnforceUnique() for row in conn.execute(kernel_query): - unique.see(row['marker_id'], row['runtime_id']) + unique.see(row["marker_id"], row["runtime_id"]) # 211 is cudaKernelLaunch for cuda >= 9.2 - assert (row['cbid'] == 211) - evt = functions_map[row['marker_id']] - evt.append_kernel(row['kernel_name'], - 0, - row['kernel_end'] - row['kernel_start']) + assert row["cbid"] == 211 + evt = functions_map[row["marker_id"]] + evt.append_kernel( + row["kernel_name"], 0, row["kernel_end"] - row["kernel_start"] + ) functions.sort(key=lambda evt: evt.time_range.start) return functions @@ -904,6 +962,7 @@ class KinetoStepTracker: NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer for now. The result could be incorrect increments of the step count. """ + _current_step = -1 _step_dict: Dict[str, int] = defaultdict(int) @@ -930,8 +989,10 @@ class KinetoStepTracker: if new_step > cls._current_step: delta = new_step - cls._current_step if delta > 1: - warn("Profiler step count has increased more than 1 - " - f"current_step = {cls._current_step} step dict = {cls._step_dict}") + warn( + "Profiler step count has increased more than 1 - " + f"current_step = {cls._current_step} step dict = {cls._step_dict}" + ) for _ in range(0, delta): _kineto_step() cls._current_step = new_step diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py index 0f535f91f128..5de0965bc411 100644 --- a/torch/autograd/profiler_legacy.py +++ b/torch/autograd/profiler_legacy.py @@ -1,32 +1,42 @@ -import torch -import torch.cuda -from torch.autograd.profiler_util import ( - EventList, FunctionEvent, MEMORY_EVENT_NAME, - _filter_name, _filter_stack_entry, _rewrite_name -) - -from torch.autograd import ( - DeviceType, ProfilerConfig, ProfilerState, - _disable_profiler_legacy, _enable_profiler_legacy, -) - import itertools from warnings import warn +import torch +import torch.cuda + +from torch.autograd import ( + _disable_profiler_legacy, + _enable_profiler_legacy, + DeviceType, + ProfilerConfig, + ProfilerState, +) +from torch.autograd.profiler_util import ( + _filter_name, + _filter_stack_entry, + _rewrite_name, + EventList, + FunctionEvent, + MEMORY_EVENT_NAME, +) + __all__ = ["profile"] + class profile: """DEPRECATED: use torch.profiler instead""" + def __init__( - self, - enabled=True, - *, - use_cuda=False, - record_shapes=False, - with_flops=False, - profile_memory=False, - with_stack=False, - with_modules=False): + self, + enabled=True, + *, + use_cuda=False, + record_shapes=False, + with_flops=False, + profile_memory=False, + with_stack=False, + with_modules=False, + ): self.enabled: bool = enabled if not self.enabled: return @@ -85,18 +95,19 @@ class profile: parsed_results, use_cuda=self.use_cuda, profile_memory=self.profile_memory, - with_flops=self.with_flops) + with_flops=self.with_flops, + ) self.function_events._build_tree() return False def __repr__(self): if self.function_events is None: - return '' + return "" return repr(self.function_events) def __str__(self): if self.function_events is None: - return '' + return "" return str(self.function_events) def _check_finish(self): @@ -104,14 +115,14 @@ class profile: raise RuntimeError("Profiler didn't finish running") def table( - self, - sort_by=None, - row_limit=100, - max_src_column_width=75, - max_name_column_width=55, - max_shapes_column_width=80, - header=None, - top_level_events_only=False + self, + sort_by=None, + row_limit=100, + max_src_column_width=75, + max_name_column_width=55, + max_shapes_column_width=80, + header=None, + top_level_events_only=False, ): self._check_finish() assert self.function_events is not None @@ -122,14 +133,16 @@ class profile: max_name_column_width=max_name_column_width, max_shapes_column_width=max_shapes_column_width, header=header, - top_level_events_only=top_level_events_only + top_level_events_only=top_level_events_only, ) + table.__doc__ = EventList.table.__doc__ def export_chrome_trace(self, path): self._check_finish() assert self.function_events is not None return self.function_events.export_chrome_trace(path) + export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__ def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): @@ -142,17 +155,19 @@ class profile: self._check_finish() assert self.function_events is not None, "Expected profiling results" return self.function_events.key_averages(group_by_input_shape, group_by_stack_n) + key_averages.__doc__ = EventList.key_averages.__doc__ def total_average(self): self._check_finish() assert self.function_events is not None, "Expected profiling results" return self.function_events.total_average() + total_average.__doc__ = EventList.total_average.__doc__ @property def self_cpu_time_total(self): - """ Returns total time spent on CPU obtained as a sum of + """Returns total time spent on CPU obtained as a sum of all self times across all the events. """ self._check_finish() @@ -176,7 +191,7 @@ def _parse_legacy_records(thread_records): # '__start_profile' is not guaranteed to be first, so we must find it here for record in itertools.chain(*thread_records): name = record.name() - if start_record is None and name == '__start_profile': + if start_record is None and name == "__start_profile": start_record = record assert start_record is not None and not start_record.is_remote() @@ -192,12 +207,11 @@ def _parse_legacy_records(thread_records): prev_record = None for record in thread_record_list: record_key = _get_record_key(record) - if (_filter_name(record.name()) or - record_key in filtered_handles): + if _filter_name(record.name()) or record_key in filtered_handles: filtered_handles.add(record_key) continue - if record.kind() == 'push': + if record.kind() == "push": # workaround to reduce double logging from operator # wrappers and redispatch if prev_record is not None: @@ -213,7 +227,7 @@ def _parse_legacy_records(thread_records): range_starts[record_key] = record cpu_memory_allocs[record_key] = 0 cuda_memory_allocs[record_key] = 0 - elif record.kind() == 'pop': + elif record.kind() == "pop": assert ( record_key in range_starts ), f"""Expected record with key {record_key} to exist in range_starts. @@ -223,9 +237,7 @@ def _parse_legacy_records(thread_records): cpu_memory_usage = cpu_memory_allocs[record_key] cuda_memory_usage = cuda_memory_allocs[record_key] - is_async = start.is_async() or ( - start.thread_id() != record.thread_id() - ) + is_async = start.is_async() or (start.thread_id() != record.thread_id()) is_remote_event = record.is_remote() start_flops = start.flops() @@ -239,7 +251,9 @@ def _parse_legacy_records(thread_records): end_us=start_record.cpu_elapsed_us(record), fwd_thread=start.fwd_thread_id(), input_shapes=start.shapes(), - stack=[entry for entry in start.stack() if _filter_stack_entry(entry)], + stack=[ + entry for entry in start.stack() if _filter_stack_entry(entry) + ], scope=start.scope(), cpu_memory_usage=cpu_memory_usage, cuda_memory_usage=cuda_memory_usage, @@ -254,15 +268,12 @@ def _parse_legacy_records(thread_records): if not is_async and start.has_cuda(): duration = start.cuda_elapsed_us(record) if duration > 0: - fe.append_kernel( - start.name(), - start.device(), - duration) + fe.append_kernel(start.name(), start.device(), duration) functions.append(fe) del range_starts[record_key] del cpu_memory_allocs[record_key] del cuda_memory_allocs[record_key] - elif record.kind() == 'memory_alloc': + elif record.kind() == "memory_alloc": num_open_handles_cpu = len(cpu_memory_allocs) num_open_handles_cuda = len(cuda_memory_allocs) assert num_open_handles_cpu == num_open_handles_cuda diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index b7370fc96ade..37437b4dce99 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1,24 +1,34 @@ +import bisect import itertools -import torch -from torch.autograd import DeviceType +import math from collections import defaultdict, namedtuple from operator import attrgetter -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple -import bisect -import math +import torch +from torch.autograd import DeviceType + +__all__ = [ + "EventList", + "FormattedTimesMixin", + "Interval", + "Kernel", + "FunctionEvent", + "FunctionEventAvg", + "StringTable", + "MemRecordsAcc", +] -__all__ = ["EventList", "FormattedTimesMixin", "Interval", "Kernel", "FunctionEvent", "FunctionEventAvg", - "StringTable", "MemRecordsAcc"] class EventList(list): """A list of Events (for pretty printing)""" + def __init__(self, *args, **kwargs): - use_cuda = kwargs.pop('use_cuda', True) - profile_memory = kwargs.pop('profile_memory', False) - with_flops = kwargs.pop('with_flops', False) + use_cuda = kwargs.pop("use_cuda", True) + profile_memory = kwargs.pop("profile_memory", False) + with_flops = kwargs.pop("with_flops", False) super().__init__(*args, **kwargs) self._use_cuda = use_cuda self._profile_memory = profile_memory @@ -38,9 +48,11 @@ class EventList(list): while True: to_delete = set() for idx in range(len(self)): - if (self[idx].cpu_parent is not None and - self[idx].cpu_parent.name == self[idx].name and - len(self[idx].cpu_parent.cpu_children) == 1): + if ( + self[idx].cpu_parent is not None + and self[idx].cpu_parent.name == self[idx].name + and len(self[idx].cpu_parent.cpu_children) == 1 + ): self[idx].cpu_parent.cpu_children = self[idx].cpu_children self[idx].cpu_parent.kernels = self[idx].kernels # lift kernels up for ch in self[idx].cpu_children: @@ -68,7 +80,11 @@ class EventList(list): # Some events can be async (i.e. start and end on different threads), # since it's generally undefined how to attribute children ranges to # async ranges, we do not use them when calculating nested ranges and stats - sync_events = [evt for evt in self if not evt.is_async and evt.device_type == DeviceType.CPU] + sync_events = [ + evt + for evt in self + if not evt.is_async and evt.device_type == DeviceType.CPU + ] events = sorted( sync_events, key=attrgetter("thread"), @@ -102,8 +118,10 @@ class EventList(list): for event in thread_events_: while len(current_events) > 0: parent = current_events[-1] - if event.time_range.start >= parent.time_range.end or \ - event.time_range.end > parent.time_range.end: + if ( + event.time_range.start >= parent.time_range.end + or event.time_range.end > parent.time_range.end + ): # this can't be a parent current_events.pop() else: @@ -147,14 +165,14 @@ class EventList(list): return sum([event.self_cpu_time_total for event in self]) def table( - self, - sort_by=None, - row_limit=100, - max_src_column_width=75, - max_name_column_width=55, - max_shapes_column_width=80, - header=None, - top_level_events_only=False + self, + sort_by=None, + row_limit=100, + max_src_column_width=75, + max_name_column_width=55, + max_shapes_column_width=80, + header=None, + top_level_events_only=False, ): """Prints an EventList as a nicely formatted table. @@ -183,7 +201,8 @@ class EventList(list): header=header, profile_memory=self._profile_memory, with_flops=self._with_flops, - top_level_events_only=top_level_events_only) + top_level_events_only=top_level_events_only, + ) def export_chrome_trace(self, path): """Exports an EventList as a Chrome tracing tools file. @@ -194,7 +213,8 @@ class EventList(list): path (str): Path where the trace will be written. """ import os - with open(path, 'w') as f: + + with open(path, "w") as f: chrome_events = [] next_id = 0 # Use file IO over using json.dump since JSON dumping is very slow and @@ -222,15 +242,18 @@ class EventList(list): for k in evt.kernels: # 's' and 'f' draw Flow arrows from # the CPU launch to the GPU kernel - f.write('{{"name": "{}", ' - '"ph": "s", ' - '"ts": {}, ' - '"tid": {}, ' - '"pid": "CPU functions", ' - '"id": {}, ' - '"cat": "cpu_to_cuda", ' - '"args": {{}}}}, '.format(evt.trace_name, evt.time_range.start, - evt.thread, next_id)) + f.write( + '{{"name": "{}", ' + '"ph": "s", ' + '"ts": {}, ' + '"tid": {}, ' + '"pid": "CPU functions", ' + '"id": {}, ' + '"cat": "cpu_to_cuda", ' + '"args": {{}}}}, '.format( + evt.trace_name, evt.time_range.start, evt.thread, next_id + ) + ) # Note: use torch.profiler to get device kernel trace next_id += 1 if len(self) > 0: @@ -244,9 +267,12 @@ class EventList(list): def export_stacks(self, path: str, metric: str): if metric not in self.supported_export_stacks_metrics(): - raise ValueError("metric should be one of: " + str(self.supported_export_stacks_metrics())) + raise ValueError( + "metric should be one of: " + + str(self.supported_export_stacks_metrics()) + ) translate_table = str.maketrans(" ;\t\n", "____") - with open(path, 'w') as f: + with open(path, "w") as f: for evt in self: if evt.stack and len(evt.stack) > 0: metric_value = getattr(evt, metric) @@ -277,12 +303,18 @@ class EventList(list): stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg) def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]: - key = [str(event.key), str(event.node_id), str(event.device_type), str(event.is_legacy)] + key = [ + str(event.key), + str(event.node_id), + str(event.device_type), + str(event.is_legacy), + ] if group_by_input_shapes: key.append(str(event.input_shapes)) if group_by_stack_n > 0: key += event.stack[:group_by_stack_n] return tuple(key) + for evt in self: stats[get_key(evt, group_by_input_shapes, group_by_stack_n)].add(evt) @@ -290,7 +322,8 @@ class EventList(list): stats.values(), use_cuda=self._use_cuda, profile_memory=self._profile_memory, - with_flops=self._with_flops) + with_flops=self._with_flops, + ) for evt in avg_list: evt.stack = evt.stack[:group_by_stack_n] if not group_by_input_shapes: @@ -307,7 +340,7 @@ class EventList(list): for evt in self: total_stat += evt total_stat.key = None - total_stat.key = 'Total' + total_stat.key = "Total" return total_stat @@ -316,31 +349,34 @@ def _format_time(time_us): US_IN_SECOND = 1000.0 * 1000.0 US_IN_MS = 1000.0 if time_us >= US_IN_SECOND: - return f'{time_us / US_IN_SECOND:.3f}s' + return f"{time_us / US_IN_SECOND:.3f}s" if time_us >= US_IN_MS: - return f'{time_us / US_IN_MS:.3f}ms' - return f'{time_us:.3f}us' + return f"{time_us / US_IN_MS:.3f}ms" + return f"{time_us:.3f}us" + def _format_time_share(time_us, total_time_us): """Defines how to format time in FunctionEvent""" if total_time_us == 0: assert time_us == 0, f"Expected time_us == 0 but got {time_us}" return "NaN" - return f'{time_us * 100.0 / total_time_us:.2f}%' + return f"{time_us * 100.0 / total_time_us:.2f}%" + def _format_memory(nbytes): """Returns a formatted memory size string""" KB = 1024 MB = 1024 * KB GB = 1024 * MB - if (abs(nbytes) >= GB): - return f'{nbytes * 1.0 / GB:.2f} Gb' - elif (abs(nbytes) >= MB): - return f'{nbytes * 1.0 / MB:.2f} Mb' - elif (abs(nbytes) >= KB): - return f'{nbytes * 1.0 / KB:.2f} Kb' + if abs(nbytes) >= GB: + return f"{nbytes * 1.0 / GB:.2f} Gb" + elif abs(nbytes) >= MB: + return f"{nbytes * 1.0 / MB:.2f} Mb" + elif abs(nbytes) >= KB: + return f"{nbytes * 1.0 / KB:.2f} Kb" else: - return str(nbytes) + ' b' + return str(nbytes) + " b" + def _attr_formatter(name): return property(lambda self: _format_time(getattr(self, name))) @@ -351,12 +387,13 @@ class FormattedTimesMixin: The subclass should define `*_time_total` and `count` attributes. """ - cpu_time_str = _attr_formatter('cpu_time') - cuda_time_str = _attr_formatter('cuda_time') - cpu_time_total_str = _attr_formatter('cpu_time_total') - cuda_time_total_str = _attr_formatter('cuda_time_total') - self_cpu_time_total_str = _attr_formatter('self_cpu_time_total') - self_cuda_time_total_str = _attr_formatter('self_cuda_time_total') + + cpu_time_str = _attr_formatter("cpu_time") + cuda_time_str = _attr_formatter("cuda_time") + cpu_time_total_str = _attr_formatter("cpu_time_total") + cuda_time_total_str = _attr_formatter("cuda_time_total") + self_cpu_time_total_str = _attr_formatter("self_cpu_time_total") + self_cuda_time_total_str = _attr_formatter("self_cuda_time_total") @property def cpu_time(self): @@ -376,16 +413,36 @@ class Interval: return self.end - self.start -Kernel = namedtuple('Kernel', ['name', 'device', 'duration']) +Kernel = namedtuple("Kernel", ["name", "device", "duration"]) class FunctionEvent(FormattedTimesMixin): """Profiling information about a single function.""" + def __init__( - self, id, name, thread, start_us, end_us, fwd_thread=None, input_shapes=None, - stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, is_async=False, - is_remote=False, sequence_nr=-1, node_id=-1, device_type=DeviceType.CPU, device_index=0, - is_legacy=False, flops=None, trace_name=None, concrete_inputs=None): + self, + id, + name, + thread, + start_us, + end_us, + fwd_thread=None, + input_shapes=None, + stack=None, + scope=0, + cpu_memory_usage=0, + cuda_memory_usage=0, + is_async=False, + is_remote=False, + sequence_nr=-1, + node_id=-1, + device_type=DeviceType.CPU, + device_index=0, + is_legacy=False, + flops=None, + trace_name=None, + concrete_inputs=None, + ): self.id: int = id self.node_id: int = node_id self.name: str = name @@ -421,9 +478,9 @@ class FunctionEvent(FormattedTimesMixin): One is supposed to append only direct children to the event to have correct self cpu time being reported. """ - assert(self.device_type == DeviceType.CPU) - assert(isinstance(child, FunctionEvent)) - assert(child.device_type == DeviceType.CPU) + assert self.device_type == DeviceType.CPU + assert isinstance(child, FunctionEvent) + assert child.device_type == DeviceType.CPU self.cpu_children.append(child) def set_cpu_parent(self, parent): @@ -433,9 +490,9 @@ class FunctionEvent(FormattedTimesMixin): the child's range interval is completely inside the parent's. We use this connection to determine the event is from top-level op or not. """ - assert(self.device_type == DeviceType.CPU) - assert(isinstance(parent, FunctionEvent)) - assert(parent.device_type == DeviceType.CPU) + assert self.device_type == DeviceType.CPU + assert isinstance(parent, FunctionEvent) + assert parent.device_type == DeviceType.CPU self.cpu_parent = parent # Note: async events don't have children, are not used when computing 'self' @@ -471,8 +528,9 @@ class FunctionEvent(FormattedTimesMixin): if self.device_type == DeviceType.CPU: if not self.is_legacy: # account for the kernels in the children ops - return (sum(kinfo.duration for kinfo in self.kernels) + - sum(ch.cuda_time_total for ch in self.cpu_children)) + return sum(kinfo.duration for kinfo in self.kernels) + sum( + ch.cuda_time_total for ch in self.cpu_children + ) else: # each legacy cpu events has a single (fake) kernel return sum(kinfo.duration for kinfo in self.kernels) @@ -485,10 +543,11 @@ class FunctionEvent(FormattedTimesMixin): if self.is_async: return 0 if self.device_type == DeviceType.CPU: - return self.cuda_time_total - \ - sum([child.cuda_time_total for child in self.cpu_children]) + return self.cuda_time_total - sum( + [child.cuda_time_total for child in self.cpu_children] + ) else: - assert(self.device_type == DeviceType.CUDA) + assert self.device_type == DeviceType.CUDA return self.cuda_time_total @property @@ -504,9 +563,9 @@ class FunctionEvent(FormattedTimesMixin): def __repr__(self): return ( - ''.format( + "".format( self.id, self.name, self.device_type, @@ -531,6 +590,7 @@ class FunctionEvent(FormattedTimesMixin): class FunctionEventAvg(FormattedTimesMixin): """Used to average stats over multiple FunctionEvent objects.""" + def __init__(self): self.key: Optional[str] = None self.count: int = 0 @@ -593,9 +653,9 @@ class FunctionEventAvg(FormattedTimesMixin): def __repr__(self): return ( - ''.format( + "".format( self.key, self.self_cpu_time_total_str, self.cpu_time_str, @@ -646,9 +706,11 @@ def _filter_stack_entry(entry): ] return all(not (f[0] in entry and f[1] in entry) for f in filtered_entries) + MEMORY_EVENT_NAME = "[memory]" OUT_OF_MEMORY_EVENT_NAME = "[OutOfMemory]" + def _filter_name(name): # ignoring the following utility ops filtered_out_names = [ @@ -663,6 +725,7 @@ def _filter_name(name): ] return name in filtered_out_names + # Demangles and optionally rewrites the provided event name, # with_wildcard - whether to replace certain numbered event names # with a wildcard name to aggregate them together in the profiler table @@ -675,17 +738,19 @@ def _rewrite_name(name, with_wildcard=False): name = "ProfilerStep*" return name + def _build_table( - events, - sort_by=None, - header=None, - row_limit=100, - max_src_column_width=75, - max_name_column_width=55, - max_shapes_column_width=80, - with_flops=False, - profile_memory=False, - top_level_events_only=False): + events, + sort_by=None, + header=None, + row_limit=100, + max_src_column_width=75, + max_name_column_width=55, + max_shapes_column_width=80, + with_flops=False, + profile_memory=False, + top_level_events_only=False, +): """Prints a summary of events (which can be a list of FunctionEvent or FunctionEventAvg).""" if len(events) == 0: return "" @@ -693,12 +758,17 @@ def _build_table( has_cuda_time = any(event.self_cuda_time_total > 0 for event in events) has_cuda_mem = any(event.self_cuda_memory_usage > 0 for event in events) has_input_shapes = any( - (event.input_shapes is not None and len(event.input_shapes) > 0) for event in events) + (event.input_shapes is not None and len(event.input_shapes) > 0) + for event in events + ) if sort_by is not None: - events = EventList(sorted( - events, key=lambda evt: getattr(evt, sort_by), reverse=True - ), use_cuda=has_cuda_time, profile_memory=profile_memory, with_flops=with_flops) + events = EventList( + sorted(events, key=lambda evt: getattr(evt, sort_by), reverse=True), + use_cuda=has_cuda_time, + profile_memory=profile_memory, + with_flops=with_flops, + ) name_column_width = max([len(evt.key) for evt in events]) + 4 if max_name_column_width is not None: @@ -718,42 +788,48 @@ def _build_table( stacks.append(evt.stack) has_stack = len(stacks) > 0 if has_stack: - src_column_width = max([max([len(entry) for entry in stack]) for stack in stacks]) + 4 + src_column_width = ( + max([max([len(entry) for entry in stack]) for stack in stacks]) + 4 + ) if max_src_column_width is not None: src_column_width = min(src_column_width, max_src_column_width) headers = [ - 'Name', - 'Self CPU %', - 'Self CPU', - 'CPU total %', - 'CPU total', - 'CPU time avg', + "Name", + "Self CPU %", + "Self CPU", + "CPU total %", + "CPU total", + "CPU time avg", ] if has_cuda_time: - headers.extend([ - 'Self CUDA', - 'Self CUDA %', - 'CUDA total', - 'CUDA time avg', - ]) + headers.extend( + [ + "Self CUDA", + "Self CUDA %", + "CUDA total", + "CUDA time avg", + ] + ) if profile_memory: - headers.extend([ - 'CPU Mem', - 'Self CPU Mem', - ]) + headers.extend( + [ + "CPU Mem", + "Self CPU Mem", + ] + ) if has_cuda_mem: - headers.extend([ - 'CUDA Mem', - 'Self CUDA Mem', - ]) - headers.append( - '# of Calls' - ) + headers.extend( + [ + "CUDA Mem", + "Self CUDA Mem", + ] + ) + headers.append("# of Calls") # Only append Node ID if any event has a valid (>= 0) Node ID append_node_id = any(evt.node_id != -1 for evt in events) if append_node_id: - headers.append('Node ID') + headers.append("Node ID") # Have to use a list because nonlocal is Py3 only... SPACING_SIZE = 2 @@ -762,19 +838,21 @@ def _build_table( line_length_lst = [-SPACING_SIZE] MAX_STACK_ENTRY = 5 - def add_column(padding, text_dir='>'): - row_format_lst[0] += '{: ' + text_dir + str(padding) + '}' + (' ' * SPACING_SIZE) - header_sep_lst[0] += '-' * padding + (' ' * SPACING_SIZE) + def add_column(padding, text_dir=">"): + row_format_lst[0] += ( + "{: " + text_dir + str(padding) + "}" + (" " * SPACING_SIZE) + ) + header_sep_lst[0] += "-" * padding + (" " * SPACING_SIZE) line_length_lst[0] += padding + SPACING_SIZE def auto_scale_flops(flops): flop_headers = [ - 'FLOPs', - 'KFLOPs', - 'MFLOPs', - 'GFLOPs', - 'TFLOPs', - 'PFLOPs', + "FLOPs", + "KFLOPs", + "MFLOPs", + "GFLOPs", + "TFLOPs", + "PFLOPs", ] assert flops > 0 log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1))) @@ -786,12 +864,12 @@ def _build_table( add_column(DEFAULT_COLUMN_WIDTH) if has_input_shapes: - headers.append('Input Shapes') + headers.append("Input Shapes") add_column(shapes_column_width) if has_stack: - headers.append('Source Location') - add_column(src_column_width, text_dir='<') + headers.append("Source Location") + add_column(src_column_width, text_dir="<") if with_flops: # Auto-scaling of flops header @@ -801,7 +879,7 @@ def _build_table( raw_flops.append(evt.flops) if len(raw_flops) != 0: (flops_scale, flops_header) = auto_scale_flops(min(raw_flops)) - headers.append(f'Total {flops_header}') + headers.append(f"Total {flops_header}") add_column(flops_column_width) else: with_flops = False # can't find any valid flops @@ -816,7 +894,7 @@ def _build_table( def append(s): result.append(s) - result.append('\n') # Yes, newline after the end as well + result.append("\n") # Yes, newline after the end as well sum_self_cpu_time_total = sum([event.self_cpu_time_total for event in events]) sum_self_cuda_time_total = 0 @@ -831,11 +909,11 @@ def _build_table( # Actual printing if header is not None: - append('=' * line_length) + append("=" * line_length) append(header) if top_level_events_only: - append('=' * line_length) - append('This report only display top-level ops statistics') + append("=" * line_length) + append("This report only display top-level ops statistics") append(header_sep) append(row_format.format(*headers)) @@ -859,39 +937,49 @@ def _build_table( event_limit += 1 name = evt.key if max_name_column_width is not None and len(name) >= max_name_column_width - 3: - name = name[:(max_name_column_width - 3)] + "..." + name = name[: (max_name_column_width - 3)] + "..." row_values = [ name, # Self CPU total %, 0 for async events. _format_time_share(evt.self_cpu_time_total, sum_self_cpu_time_total), evt.self_cpu_time_total_str, # Self CPU total # CPU total %, 0 for async events. - _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total) if not evt.is_async else 0, + _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total) + if not evt.is_async + else 0, evt.cpu_time_total_str, # CPU total evt.cpu_time_str, # CPU time avg ] if has_cuda_time: - row_values.extend([ - evt.self_cuda_time_total_str, - # CUDA time total % - _format_time_share(evt.self_cuda_time_total, sum_self_cuda_time_total), - evt.cuda_time_total_str, - evt.cuda_time_str, # Cuda time avg - ]) + row_values.extend( + [ + evt.self_cuda_time_total_str, + # CUDA time total % + _format_time_share( + evt.self_cuda_time_total, sum_self_cuda_time_total + ), + evt.cuda_time_total_str, + evt.cuda_time_str, # Cuda time avg + ] + ) if profile_memory: - row_values.extend([ - # CPU Mem Total - _format_memory(evt.cpu_memory_usage), - # Self CPU Mem Total - _format_memory(evt.self_cpu_memory_usage), - ]) + row_values.extend( + [ + # CPU Mem Total + _format_memory(evt.cpu_memory_usage), + # Self CPU Mem Total + _format_memory(evt.self_cpu_memory_usage), + ] + ) if has_cuda_mem: - row_values.extend([ - # CUDA Mem Total - _format_memory(evt.cuda_memory_usage), - # Self CUDA Mem Total - _format_memory(evt.self_cuda_memory_usage), - ]) + row_values.extend( + [ + # CUDA Mem Total + _format_memory(evt.cuda_memory_usage), + # Self CUDA Mem Total + _format_memory(evt.self_cuda_memory_usage), + ] + ) row_values.append( evt.count, # Number of calls ) @@ -904,7 +992,7 @@ def _build_table( if evt.flops <= 0: row_values.append("--") else: - row_values.append(f'{evt.flops * flops_scale:8.3f}') + row_values.append(f"{evt.flops * flops_scale:8.3f}") if has_stack: src_field = "" if len(evt.stack) > 0: @@ -915,7 +1003,11 @@ def _build_table( if has_stack: empty_headers = [""] * (len(headers) - 1) for entry in evt.stack[1:MAX_STACK_ENTRY]: - append(row_format.format(*(empty_headers + [trim_path(entry, src_column_width)]))) + append( + row_format.format( + *(empty_headers + [trim_path(entry, src_column_width)]) + ) + ) empty_headers.append("") append(row_format.format(*empty_headers)) @@ -923,4 +1015,4 @@ def _build_table( append(f"Self CPU time total: {_format_time(sum_self_cpu_time_total)}") if has_cuda_time: append(f"Self CUDA time total: {_format_time(sum_self_cuda_time_total)}") - return ''.join(result) + return "".join(result) diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index a848a5d16eb5..b5deca9f5ca7 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -1,18 +1,22 @@ -from contextlib import contextmanager import types +from contextlib import contextmanager + # The idea for this parameter is that we forbid bare assignment # to torch.backends..enabled and friends when running our # test suite, where it's very easy to forget to undo the change # later. __allow_nonbracketed_mutation_flag = True + def disable_global_flags(): global __allow_nonbracketed_mutation_flag __allow_nonbracketed_mutation_flag = False + def flags_frozen(): return not __allow_nonbracketed_mutation_flag + @contextmanager def __allow_nonbracketed_mutation(): global __allow_nonbracketed_mutation_flag @@ -23,6 +27,7 @@ def __allow_nonbracketed_mutation(): finally: __allow_nonbracketed_mutation_flag = old + class ContextProp: def __init__(self, getter, setter): self.getter = getter @@ -35,8 +40,12 @@ class ContextProp: if not flags_frozen(): self.setter(val) else: - raise RuntimeError("not allowed to set %s flags " - "after disable_global_flags; please use flags() context manager instead" % obj.__name__) + raise RuntimeError( + "not allowed to set %s flags " + "after disable_global_flags; please use flags() context manager instead" + % obj.__name__ + ) + class PropModule(types.ModuleType): def __init__(self, m, name): @@ -47,11 +56,13 @@ class PropModule(types.ModuleType): return self.m.__getattribute__(attr) -from torch.backends import cpu as cpu -from torch.backends import cuda as cuda -from torch.backends import mps as mps -from torch.backends import cudnn as cudnn -from torch.backends import mkl as mkl -from torch.backends import mkldnn as mkldnn -from torch.backends import openmp as openmp -from torch.backends import quantized as quantized +from torch.backends import ( + cpu as cpu, + cuda as cuda, + cudnn as cudnn, + mkl as mkl, + mkldnn as mkldnn, + mps as mps, + openmp as openmp, + quantized as quantized, +) diff --git a/torch/backends/_coreml/preprocess.py b/torch/backends/_coreml/preprocess.py index f72dae177ed2..f393929bb7c2 100644 --- a/torch/backends/_coreml/preprocess.py +++ b/torch/backends/_coreml/preprocess.py @@ -3,11 +3,12 @@ import json from typing import Dict, Tuple import coremltools as ct # type: ignore[import] -import torch from coremltools.converters.mil.input_types import TensorType # type: ignore[import] from coremltools.converters.mil.mil import types # type: ignore[import] from coremltools.models.neural_network import quantization_utils # type: ignore[import] +import torch + CT_METADATA_VERSION = "com.github.apple.coremltools.version" CT_METADATA_SOURCE = "com.github.apple.coremltools.source" @@ -19,6 +20,7 @@ class ScalarType: Long = 3 Undefined = 4 + # Supported Tensor types in coremltools: # https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/converter.py#L28 torch_to_mil_types = { @@ -34,24 +36,33 @@ class CoreMLComputeUnit: CPUAndGPU = "cpuAndGPU" ALL = "all" + class CoreMLQuantizationMode: LINEAR = "linear" LINEAR_SYMMETRIC = "linear_symmetric" NONE = "none" - def TensorSpec(shape, dtype=ScalarType.Float): return (shape, dtype) -def CompileSpec(inputs, - outputs, - backend=CoreMLComputeUnit.CPU, - allow_low_precision=True, - quantization_mode=CoreMLQuantizationMode.NONE, - mlmodel_export_path=None): - return (inputs, outputs, backend, allow_low_precision, quantization_mode, mlmodel_export_path) +def CompileSpec( + inputs, + outputs, + backend=CoreMLComputeUnit.CPU, + allow_low_precision=True, + quantization_mode=CoreMLQuantizationMode.NONE, + mlmodel_export_path=None, +): + return ( + inputs, + outputs, + backend, + allow_low_precision, + quantization_mode, + mlmodel_export_path, + ) def _check_enumerated_shape(shape): @@ -72,7 +83,14 @@ def _convert_to_mil_type(shape, dtype, name: str): def preprocess(script_module: torch._C.ScriptObject, compile_spec: Dict[str, Tuple]): spec = compile_spec["forward"] - input_specs, output_specs, backend, allow_low_precision, quantization_mode, mlmodel_export_path = spec + ( + input_specs, + output_specs, + backend, + allow_low_precision, + quantization_mode, + mlmodel_export_path, + ) = spec mil_inputs = [] inputs = [] for index, input in enumerate(input_specs): @@ -84,8 +102,10 @@ def preprocess(script_module: torch._C.ScriptObject, compile_spec: Dict[str, Tup model = torch.jit.RecursiveScriptModule._construct(script_module, lambda x: None) mlmodel = ct.convert(model, inputs=mil_inputs) - if(quantization_mode != CoreMLQuantizationMode.NONE): - quant_model_spec = quantization_utils.quantize_weights(mlmodel, nbits=8, quantization_mode=quantization_mode) + if quantization_mode != CoreMLQuantizationMode.NONE: + quant_model_spec = quantization_utils.quantize_weights( + mlmodel, nbits=8, quantization_mode=quantization_mode + ) mlmodel = ct.models.MLModel(quant_model_spec) spec = mlmodel.get_spec() diff --git a/torch/backends/_nnapi/prepare.py b/torch/backends/_nnapi/prepare.py index 1d3d04b8dc08..71b504dbdbfc 100644 --- a/torch/backends/_nnapi/prepare.py +++ b/torch/backends/_nnapi/prepare.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import List, Optional import torch from torch.backends._nnapi.serializer import _NnapiSerializer @@ -48,7 +48,12 @@ class NnapiModule(torch.nn.Module): self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator] self.weights = [w.contiguous() for w in self.weights] comp = torch.classes._nnapi.Compilation() - comp.init2(self.ser_model, self.weights, self.compilation_preference, self.relax_f32_to_f16) + comp.init2( + self.ser_model, + self.weights, + self.compilation_preference, + self.relax_f32_to_f16, + ) self.comp = comp @@ -85,6 +90,7 @@ class NnapiModule(torch.nn.Module): raise Exception("Invalid mem_fmt") return outs + def convert_model_to_nnapi( model, inputs, @@ -94,8 +100,16 @@ def convert_model_to_nnapi( compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED, relax_f32_to_f16=False, ): - (shape_compute_module, ser_model_tensor, used_weights, inp_mem_fmts, out_mem_fmts, - retval_count) = process_for_nnapi(model, inputs, serializer, return_shapes, use_int16_for_qint16) + ( + shape_compute_module, + ser_model_tensor, + used_weights, + inp_mem_fmts, + out_mem_fmts, + retval_count, + ) = process_for_nnapi( + model, inputs, serializer, return_shapes, use_int16_for_qint16 + ) nnapi_model = NnapiModule( shape_compute_module, @@ -104,7 +118,7 @@ def convert_model_to_nnapi( inp_mem_fmts, out_mem_fmts, compilation_preference, - relax_f32_to_f16 + relax_f32_to_f16, ) class NnapiInterfaceWrapper(torch.nn.Module): @@ -115,6 +129,7 @@ def convert_model_to_nnapi( It returns results as either a single tensor or tuple, matching the original module. """ + def __init__(self, mod): super().__init__() self.mod = mod @@ -134,15 +149,26 @@ def convert_model_to_nnapi( ) return wrapper_model -def process_for_nnapi(model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False): + +def process_for_nnapi( + model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False +): model = torch.jit.freeze(model) if isinstance(inputs, torch.Tensor): inputs = [inputs] - serializer = serializer or _NnapiSerializer(config=None, use_int16_for_qint16=use_int16_for_qint16) - (ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines, - retval_count) = serializer.serialize_model(model, inputs, return_shapes) + serializer = serializer or _NnapiSerializer( + config=None, use_int16_for_qint16=use_int16_for_qint16 + ) + ( + ser_model, + used_weights, + inp_mem_fmts, + out_mem_fmts, + shape_compute_lines, + retval_count, + ) = serializer.serialize_model(model, inputs, return_shapes) ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32) # We have to create a new class here every time this function is called @@ -153,13 +179,13 @@ def process_for_nnapi(model, inputs, serializer=None, return_shapes=None, use_in module.prepare will mutate ser_model according to the computed operand shapes, based on the shapes of args. Returns a list of output templates. """ + pass + shape_compute_module = torch.jit.script(ShapeComputeModule()) real_shape_compute_lines = [ "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n", - ] + [ - f" {line}\n" for line in shape_compute_lines - ] + ] + [f" {line}\n" for line in shape_compute_lines] shape_compute_module.define("".join(real_shape_compute_lines)) return ( diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py index b569e14196ab..6ff3f5f71e90 100644 --- a/torch/backends/_nnapi/serializer.py +++ b/torch/backends/_nnapi/serializer.py @@ -1,15 +1,10 @@ -import sys -import enum -import struct import array -import logging +import enum import functools -from typing import ( - Tuple, - NamedTuple, - List, - Optional, -) +import logging +import struct +import sys +from typing import List, NamedTuple, Optional, Tuple import torch @@ -182,6 +177,7 @@ def change_element(tup, index, value): class ConvPoolArgs2d(NamedTuple): """Configuration arguments for a convolution.""" + kernel_h: int kernel_w: int stride_h: int @@ -387,17 +383,23 @@ class _NnapiSerializer: elif dtype == "int16": if self.use_int16_for_qint16: nnapi_dtype = getattr(tensor, "nnapi_dtype", None) - op_codes = (NNAPI_OperandCode.TENSOR_QUANT16_SYMM, NNAPI_OperandCode.TENSOR_QUANT16_ASYMM) + op_codes = ( + NNAPI_OperandCode.TENSOR_QUANT16_SYMM, + NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, + ) if nnapi_dtype in op_codes: op_type = nnapi_dtype scale = tensor.nnapi_scale zero_point = tensor.nnapi_zero_point else: - raise Exception(f"`nnapi_type` needs to be one of {op_codes} for `int16`") + raise Exception( + f"`nnapi_type` needs to be one of {op_codes} for `int16`" + ) else: raise Exception( "`int16` isn't supported. If you're trying to represent NNAPI" - " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`") + " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" + ) else: raise Exception(f"Can't handle input with dtype '{tensor.dtype}'") return Operand( @@ -410,17 +412,23 @@ class _NnapiSerializer: def add_tensor_operand_for_input(self, arg_idx, jitval, tensor): dim_order = ( - DimOrder.CHANNELS_LAST if getattr(tensor, "nnapi_nhwc", False) - else DimOrder.PRESUMED_CONTIGUOUS) + DimOrder.CHANNELS_LAST + if getattr(tensor, "nnapi_nhwc", False) + else DimOrder.PRESUMED_CONTIGUOUS + ) toper = self.torch_tensor_to_operand(tensor, dim_order) operand_id = self.add_tensor_operand(jitval, toper) self.inputs.append(operand_id) for dim, size in enumerate(tensor.shape): if size == 0: - self.compute_operand_shape(operand_id, dim, f"args[{arg_idx}].shape[{dim}]") + self.compute_operand_shape( + operand_id, dim, f"args[{arg_idx}].shape[{dim}]" + ) return operand_id - def add_tensor_operand_for_weight(self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT): + def add_tensor_operand_for_weight( + self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT + ): toper = self.torch_tensor_to_operand(tensor, dim_order) operand_id = len(self.operands) self.operands.append(toper) @@ -429,11 +437,7 @@ class _NnapiSerializer: self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER)) buf_num = len(self.used_weights) offset = 0 - self.value_data.append(struct.pack( - "iii", - buf_num, - offset, - tsize)) + self.value_data.append(struct.pack("iii", buf_num, offset, tsize)) # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor if dim_order == DimOrder.CHANNELS_LAST: tensor = tensor.permute(0, 2, 3, 1) @@ -453,27 +457,25 @@ class _NnapiSerializer: def add_immediate_int_scalar(self, value): return self.add_immediate_operand( - NNAPI_OperandCode.INT32, - struct.pack("i", value), - ()) + NNAPI_OperandCode.INT32, struct.pack("i", value), () + ) def add_immediate_float_scalar(self, value): return self.add_immediate_operand( - NNAPI_OperandCode.FLOAT32, - struct.pack("f", value), - ()) + NNAPI_OperandCode.FLOAT32, struct.pack("f", value), () + ) def add_immediate_bool_scalar(self, value): return self.add_immediate_operand( - NNAPI_OperandCode.BOOL, - b"\x01" if value else b"\x00", - ()) + NNAPI_OperandCode.BOOL, b"\x01" if value else b"\x00", () + ) def add_immediate_int_vector(self, value): return self.add_immediate_operand( NNAPI_OperandCode.TENSOR_INT32, array.array("i", value).tobytes(), - (len(value),)) + (len(value),), + ) def has_operand_for_jitval(self, jitval): return jitval in self.jitval_operand_map @@ -494,7 +496,9 @@ class _NnapiSerializer: LOG.warning("Operand %s has runtime flex shape", oper) return op_id, oper - def get_tensor_operand_or_constant(self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS): + def get_tensor_operand_or_constant( + self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS + ): operand_id = self.jitval_operand_map.get(jitval) if operand_id is None: _, value = self.get_constant_value(jitval, "TensorType") @@ -525,7 +529,8 @@ class _NnapiSerializer: ctype, _ = record if typekind is not None and ctype.kind() != typekind: raise Exception( - f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'") + f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'" + ) return record def operand_to_template_torchscript(self, op_id, oper, shape=None): @@ -545,7 +550,7 @@ class _NnapiSerializer: shape_parts.append(flex_name(op_id, d)) elif s == -1: # Runtime flexible shape - shape_parts.append('0') + shape_parts.append("0") else: raise Exception("Unknown dim value, dimensions should be >= -1") shape_parts.append(",") @@ -561,13 +566,17 @@ class _NnapiSerializer: f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)" f".expand({shape_code}).contiguous()" ) - elif oper.op_type in (NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, NNAPI_OperandCode.TENSOR_QUANT16_SYMM): + elif oper.op_type in ( + NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, + NNAPI_OperandCode.TENSOR_QUANT16_SYMM, + ): if self.use_int16_for_qint16: return f"torch.zeros({shape_code}, dtype=torch.int16)" else: raise Exception( "`int16` isn't supported. If you're trying to represent NNAPI" - " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`") + " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" + ) raise Exception(f"Unsupported output operand type: {oper.op_type}") @@ -575,7 +584,9 @@ class _NnapiSerializer: self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim)) def compute_operand_shape(self, op_id, dim, expr): - self.flexible_shape_computation_lines.append(f"{flex_name(op_id, dim)} = {expr}") + self.flexible_shape_computation_lines.append( + f"{flex_name(op_id, dim)} = {expr}" + ) def transpose_to_nhwc(self, in_id, oper): if oper.shape[2:] != (1, 1): @@ -607,7 +618,8 @@ class _NnapiSerializer: return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper) raise Exception( - f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}") + f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}" + ) def get_size_arg(self, jitval): ctype, value = self.get_constant_value(jitval) @@ -628,9 +640,13 @@ class _NnapiSerializer: assert len(pc) == 11 assert output_padding == [0, 0] - return self.get_conv_pool_args_2d_common(kernel_size, strides, paddings, dilations, group_num) + return self.get_conv_pool_args_2d_common( + kernel_size, strides, paddings, dilations, group_num + ) - def get_conv_pool_args_2d_from_jit(self, kernel_size, stride, padding, dilation=None, group=None): + def get_conv_pool_args_2d_from_jit( + self, kernel_size, stride, padding, dilation=None, group=None + ): strides = self.get_size_arg(stride) paddings = self.get_size_arg(padding) if dilation is None: @@ -641,9 +657,13 @@ class _NnapiSerializer: _, group_num = self.get_constant_value(group, "IntType") else: group_num = None - return self.get_conv_pool_args_2d_common(kernel_size, strides, paddings, dilations, group_num) + return self.get_conv_pool_args_2d_common( + kernel_size, strides, paddings, dilations, group_num + ) - def get_conv_pool_args_2d_common(self, kernel_size, strides, paddings, dilations, group_num): + def get_conv_pool_args_2d_common( + self, kernel_size, strides, paddings, dilations, group_num + ): kernels = list(kernel_size) assert len(kernels) == 2 @@ -655,7 +675,9 @@ class _NnapiSerializer: ph, pw = paddings real_paddings = [ph, ph, pw, pw] - return ConvPoolArgs2d(*(kernels + strides + real_paddings + dilations + [group_num])) + return ConvPoolArgs2d( + *(kernels + strides + real_paddings + dilations + [group_num]) + ) def serialize_model(self, model, inputs, return_shapes=None): self.add_immediate_bool_scalar(False) @@ -667,8 +689,12 @@ class _NnapiSerializer: self_jitval = next(model.graph.inputs()) self.add_constant_value(self_jitval, self_jitval.type(), model) - for arg_idx, (input_value, input_tensor) in enumerate(zip(list(model.graph.inputs())[1:], inputs)): - op_id = self.add_tensor_operand_for_input(arg_idx, input_value, input_tensor) + for arg_idx, (input_value, input_tensor) in enumerate( + zip(list(model.graph.inputs())[1:], inputs) + ): + op_id = self.add_tensor_operand_for_input( + arg_idx, input_value, input_tensor + ) inp_dim_orders.append(self.operands[op_id].dim_order.value) for idx, node in enumerate(model.graph.nodes()): @@ -697,8 +723,8 @@ class _NnapiSerializer: out_dim_orders.append(self.operands[op_id].dim_order.value) shape = return_shapes[i] if return_shapes else None template_return_lines.append( - self.operand_to_template_torchscript( - op_id, self.operands[op_id], shape) + "," + self.operand_to_template_torchscript(op_id, self.operands[op_id], shape) + + "," ) template_return_lines.append("]") @@ -718,7 +744,9 @@ class _NnapiSerializer: serialized_values, serialized_value_data = self.serialize_values() - model.extend(struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands) + model.extend( + struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands + ) model.extend(serialized_values) model.extend(struct.pack("iii", *x) for x in self.operations) @@ -731,13 +759,14 @@ class _NnapiSerializer: assert model_offset % 4 == 0 model_offset = int(model_offset / 4) - for (op_id, (_, dims, dim_order, _, _)) in enumerate(self.operands): + for op_id, (_, dims, dim_order, _, _) in enumerate(self.operands): shape = fix_shape(dims, dim_order) for d, s in enumerate(shape): if s == 0: pt_d = reverse_map_dim(dim_order, d) self.flexible_shape_computation_lines.append( - f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}") + f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}" + ) model_offset += 1 # convert runtime flex shape from -1 to 0 @@ -764,14 +793,16 @@ class _NnapiSerializer: serialized_values = [] serialized_value_data = [] assert len(self.values) == len(self.value_data) - for ((op_index, source_type), data) in zip(self.values, self.value_data): + for (op_index, source_type), data in zip(self.values, self.value_data): source_length = len(data) # Pad with 0 bytes out to a multiple of 4 for alignment. physical_length = ((source_length - 1) | 0x3) + 1 padded_data = data + (b"\0" * (physical_length - source_length)) - serialized_values.append(struct.pack("iii", op_index, source_type, source_length)) + serialized_values.append( + struct.pack("iii", op_index, source_type, source_length) + ) serialized_value_data.append(padded_data) return serialized_values, serialized_value_data @@ -781,86 +812,76 @@ class _NnapiSerializer: return array.array("i", ints).tobytes() ADDER_MAP = { - "prim::GetAttr": lambda self, node: - self.add_getattr(node), - "prim::Constant": lambda self, node: - self.add_constant_node(node), - "prim::ListConstruct": lambda self, node: - self.add_list_construct(node), - "prim::TupleConstruct": lambda self, node: - self.add_tuple_construct(node), - "aten::unsqueeze": lambda self, node: - self.add_unsqueeze(node), - "aten::to": lambda self, node: - self.add_to(node), - "aten::detach": lambda self, node: - self._identity(node), - "aten::reshape": lambda self, node: - self.add_reshape(node), - "aten::flatten": lambda self, node: - self.add_flatten(node), - "aten::slice": lambda self, node: - self.add_slice(node), - "aten::size": lambda self, node: - self.add_size(node), - "aten::cat": lambda self, node: - self.add_cat(node), - "aten::mean": lambda self, node: - self.add_mean(node), - "aten::quantize_per_tensor": lambda self, node: - self.add_quantize(node), - "aten::dequantize": lambda self, node: - self.add_dequantize(node), - "aten::add": lambda self, node: - self.add_add_sub_op(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE), - "aten::sub": lambda self, node: - self.add_add_sub_op(node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE), - "aten::mul": lambda self, node: - self.add_pointwise_simple_binary_broadcast_op(node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE), - "aten::div": lambda self, node: - self.add_pointwise_simple_binary_broadcast_op(node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE), - "aten::relu": lambda self, node: - self.add_pointwise_simple_unary_op(node, NNAPI_OperationCode.RELU), - "aten::sigmoid": lambda self, node: - self.add_pointwise_simple_unary_op(node, NNAPI_OperationCode.LOGISTIC), - "aten::softmax": lambda self, node: - self.add_softmax(node), - "aten::hardtanh": lambda self, node: - self.add_hardtanh(node), - "aten::avg_pool2d": lambda self, node: - self.add_avg_pool2d(node), - "aten::max_pool2d": lambda self, node: - self.add_pool2d_node(node, NNAPI_OperationCode.MAX_POOL_2D), - "aten::adaptive_avg_pool2d": lambda self, node: - self.add_adaptive_avg_pool2d(node), - "aten::upsample_nearest2d": lambda self, node: - self.add_upsample_nearest2d(node), - "aten::prelu": lambda self, node: - self.add_prelu_op(node), - "aten::addmm": lambda self, node: - self.add_addmm(node), - "aten::linear": lambda self, node: - self.add_linear(node), - "aten::_convolution": lambda self, node: - self.add_conv_underscore(node), - "aten::conv2d": lambda self, node: - self.add_conv2d(node), - "aten::log_softmax": lambda self, node: - self.add_log_softmax(node), - "quantized::linear": lambda self, node: - self.add_qlinear(node), - "quantized::conv2d": lambda self, node: - self.add_qconv2d(node, NNAPI_FuseCode.FUSED_NONE), - "quantized::conv2d_relu": lambda self, node: - self.add_qconv2d(node, NNAPI_FuseCode.FUSED_RELU), - "quantized::conv_transpose2d": lambda self, node: - self.add_qconv2d(node, NNAPI_FuseCode.FUSED_NONE, transpose=True), - "quantized::add": lambda self, node: - self.add_qadd(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE), - "quantized::add_relu": lambda self, node: - self.add_qadd(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU), - "quantized::mul": lambda self, node: - self.add_qadd(node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE), + "prim::GetAttr": lambda self, node: self.add_getattr(node), + "prim::Constant": lambda self, node: self.add_constant_node(node), + "prim::ListConstruct": lambda self, node: self.add_list_construct(node), + "prim::TupleConstruct": lambda self, node: self.add_tuple_construct(node), + "aten::unsqueeze": lambda self, node: self.add_unsqueeze(node), + "aten::to": lambda self, node: self.add_to(node), + "aten::detach": lambda self, node: self._identity(node), + "aten::reshape": lambda self, node: self.add_reshape(node), + "aten::flatten": lambda self, node: self.add_flatten(node), + "aten::slice": lambda self, node: self.add_slice(node), + "aten::size": lambda self, node: self.add_size(node), + "aten::cat": lambda self, node: self.add_cat(node), + "aten::mean": lambda self, node: self.add_mean(node), + "aten::quantize_per_tensor": lambda self, node: self.add_quantize(node), + "aten::dequantize": lambda self, node: self.add_dequantize(node), + "aten::add": lambda self, node: self.add_add_sub_op( + node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE + ), + "aten::sub": lambda self, node: self.add_add_sub_op( + node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE + ), + "aten::mul": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( + node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE + ), + "aten::div": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( + node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE + ), + "aten::relu": lambda self, node: self.add_pointwise_simple_unary_op( + node, NNAPI_OperationCode.RELU + ), + "aten::sigmoid": lambda self, node: self.add_pointwise_simple_unary_op( + node, NNAPI_OperationCode.LOGISTIC + ), + "aten::softmax": lambda self, node: self.add_softmax(node), + "aten::hardtanh": lambda self, node: self.add_hardtanh(node), + "aten::avg_pool2d": lambda self, node: self.add_avg_pool2d(node), + "aten::max_pool2d": lambda self, node: self.add_pool2d_node( + node, NNAPI_OperationCode.MAX_POOL_2D + ), + "aten::adaptive_avg_pool2d": lambda self, node: self.add_adaptive_avg_pool2d( + node + ), + "aten::upsample_nearest2d": lambda self, node: self.add_upsample_nearest2d( + node + ), + "aten::prelu": lambda self, node: self.add_prelu_op(node), + "aten::addmm": lambda self, node: self.add_addmm(node), + "aten::linear": lambda self, node: self.add_linear(node), + "aten::_convolution": lambda self, node: self.add_conv_underscore(node), + "aten::conv2d": lambda self, node: self.add_conv2d(node), + "aten::log_softmax": lambda self, node: self.add_log_softmax(node), + "quantized::linear": lambda self, node: self.add_qlinear(node), + "quantized::conv2d": lambda self, node: self.add_qconv2d( + node, NNAPI_FuseCode.FUSED_NONE + ), + "quantized::conv2d_relu": lambda self, node: self.add_qconv2d( + node, NNAPI_FuseCode.FUSED_RELU + ), + "quantized::conv_transpose2d": lambda self, node: self.add_qconv2d( + node, NNAPI_FuseCode.FUSED_NONE, transpose=True + ), + "quantized::add": lambda self, node: self.add_qadd( + node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE + ), + "quantized::add_relu": lambda self, node: self.add_qadd( + node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU + ), + "quantized::mul": lambda self, node: self.add_qadd( + node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE + ), } def add_node(self, node): @@ -918,7 +939,8 @@ class _NnapiSerializer: self.add_tensor_sequence(output, tensors) if const_vals is None and tensors is None: raise Exception( - f"Unable to handle ListConstruct node. Neither all constants nor all tensors. {node!r}") + f"Unable to handle ListConstruct node. Neither all constants nor all tensors. {node!r}" + ) def add_tuple_construct(self, node): assert node.outputsSize() == 1 @@ -969,11 +991,14 @@ class _NnapiSerializer: if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape: raise Exception( - "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1].") + "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1]." + ) # Bit of a hack here. Use a real tensor to infer the output shape. out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape - out_oper = in_oper._replace(shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS) + out_oper = in_oper._replace( + shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS + ) inputs = [None] * 2 inputs[0] = in_id @@ -995,10 +1020,12 @@ class _NnapiSerializer: # channels last with channels == 1 or (height & width both 1) is_trivial_flatten = len(in_oper.shape) == 4 and ( - in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1)) + in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1) + ) if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten: raise Exception( - "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1") + "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1" + ) if start_dim < 0: start_dim += len(in_oper.shape) @@ -1006,29 +1033,31 @@ class _NnapiSerializer: end_dim += len(in_oper.shape) out_shape = ( - in_oper.shape[: start_dim] + - (functools.reduce( - lambda x, y: x * y, in_oper.shape[start_dim: end_dim + 1]),) + - in_oper.shape[end_dim + 1:] + in_oper.shape[:start_dim] + + ( + functools.reduce( + lambda x, y: x * y, in_oper.shape[start_dim : end_dim + 1] + ), + ) + + in_oper.shape[end_dim + 1 :] ) - if any(dim == 0 for dim in in_oper.shape[start_dim: end_dim + 1]): + if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]): raise Exception("Flattening flexible dims is not supported yet") - non_flattened_dims = in_oper.shape[: start_dim] + in_oper.shape[end_dim + 1:] + non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :] if non_flattened_dims.count(0) > 1: raise Exception("Only 1 dim can be flexible") - out_oper = in_oper._replace(shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS) + out_oper = in_oper._replace( + shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS + ) out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) for idx, dim in enumerate(out_shape): if dim == 0: self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0)) - inputs_1 = tuple( - dim if dim != 0 else -1 - for dim in out_shape - ) + inputs_1 = tuple(dim if dim != 0 else -1 for dim in out_shape) inputs = [None] * 2 inputs[0] = in_id inputs[1] = self.add_immediate_int_vector(inputs_1) @@ -1074,24 +1103,34 @@ class _NnapiSerializer: raise Exception("Slice start value should be less than stop value") out_len = (stop_value - start_value) // step_value - out_shape = tuple(out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape)) - out_id = self.add_tensor_operand(node.outputsAt(0), in_oper._replace(shape=out_shape)) + out_shape = tuple( + out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape) + ) + out_id = self.add_tensor_operand( + node.outputsAt(0), in_oper._replace(shape=out_shape) + ) # flex inputs end_mask = 0 for idx, dim in enumerate(out_shape): if dim == 0: self.forward_operand_shape(out_id, idx, in_id, idx) - end_mask |= (1 << idx) + end_mask |= 1 << idx inputs = [None] * 7 inputs[0] = in_id inputs[1] = self.add_immediate_int_vector( - [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))]) + [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))] + ) inputs[2] = self.add_immediate_int_vector( - [stop_value if i == dim_value else dim for i, dim in enumerate(in_oper.shape)]) + [ + stop_value if i == dim_value else dim + for i, dim in enumerate(in_oper.shape) + ] + ) inputs[3] = self.add_immediate_int_vector( - [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))]) + [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))] + ) inputs[4] = self.add_immediate_int_scalar(0) # begin mask inputs[5] = self.add_immediate_int_scalar(end_mask) inputs[6] = self.add_immediate_int_scalar(0) # shrink axis mas @@ -1129,14 +1168,18 @@ class _NnapiSerializer: out_oper = in_oper._replace(shape=out_shape) assert in_oper.op_type == out_oper.op_type assert in_oper.dim_order == out_oper.dim_order - assert change_element(in_oper.shape, dim, -1) == change_element(out_oper.shape, dim, -1) + assert change_element(in_oper.shape, dim, -1) == change_element( + out_oper.shape, dim, -1 + ) # TODO: Possibly check scale and zero point. in_ids.append(in_id) # TODO: Possibly support variable-sized inputs. out_dim_size += in_oper.shape[dim] assert out_oper is not None - out_oper = out_oper._replace(shape=change_element(out_oper.shape, dim, out_dim_size)) + out_oper = out_oper._replace( + shape=change_element(out_oper.shape, dim, out_dim_size) + ) if in_oper.dim_order == DimOrder.CHANNELS_LAST: assert len(out_oper.shape) == 4 @@ -1217,14 +1260,16 @@ class _NnapiSerializer: if in_oper.dim_order != DimOrder.CHANNELS_LAST: raise Exception( "Most hardware backends prefer NHWC quantized tensors. " - "Try setting `t.nnapi_nhwc = True` on your tensor inputs. ") + "Try setting `t.nnapi_nhwc = True` on your tensor inputs. " + ) _, scale = self.get_constant_value(node.inputsAt(1), "FloatType") _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType") _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType") if scalar_type != TorchScalarTypes.QUINT8.value: raise Exception( "PyTorch NNAPI export only supports quantized tensors " - "with the quint8 dtype.") + "with the quint8 dtype." + ) op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM out_oper = in_oper._replace( @@ -1297,16 +1342,21 @@ class _NnapiSerializer: if self.has_operand_for_jitval(node.inputsAt(0)): in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) - in1_id, in1_oper = self.get_tensor_operand_or_constant(node.inputsAt(1), in0_oper.dim_order) + in1_id, in1_oper = self.get_tensor_operand_or_constant( + node.inputsAt(1), in0_oper.dim_order + ) elif self.has_operand_for_jitval(node.inputsAt(1)): in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1)) - in0_id, in0_oper = self.get_tensor_operand_or_constant(node.inputsAt(0), in1_oper.dim_order) + in0_id, in0_oper = self.get_tensor_operand_or_constant( + node.inputsAt(0), in1_oper.dim_order + ) else: raise Exception(f"Can't do a NNAPI binary op: {opcode} on two constants") assert in0_oper.op_type == in1_oper.op_type in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast( - in0_id, in0_oper, in1_id, in1_oper) + in0_id, in0_oper, in1_id, in1_oper + ) # NOTE: PyTorch and NNAPI have the same broadcast semantics. out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape) out_oper = in0_oper._replace(shape=out_shape) @@ -1370,7 +1420,9 @@ class _NnapiSerializer: inputs = [None] * 3 inputs[0] = in_id - inputs[1] = self.add_immediate_float_scalar(1.0) # positive scaling factor of exponent, beta + inputs[1] = self.add_immediate_float_scalar( + 1.0 + ) # positive scaling factor of exponent, beta inputs[2] = self.add_immediate_int_scalar(softmax_dim) outputs = [None] * 1 @@ -1388,7 +1440,7 @@ class _NnapiSerializer: op_map = { (-1, 1): NNAPI_OperationCode.RELU1, - ( 0, 6): NNAPI_OperationCode.RELU6, # noqa: E201 + (0, 6): NNAPI_OperationCode.RELU6, # noqa: E201 } opcode = op_map.get((min_val, max_val)) @@ -1417,7 +1469,9 @@ class _NnapiSerializer: if w_oper.shape[0] > 1: if in_oper.use_nchw(): # TODO: Support this by adding trailing 1 dims. - raise Exception("Per-channel PReLU only supports channels_last right now.") + raise Exception( + "Per-channel PReLU only supports channels_last right now." + ) out_id = self.add_tensor_operand(node.outputsAt(0), in_oper) for dim, size in enumerate(in_oper.shape): @@ -1446,14 +1500,18 @@ class _NnapiSerializer: # TODO: Validate ceil_mode semantics. - args = self.get_conv_pool_args_2d_from_jit(self.get_size_arg(kernel), stride, padding, dilation) + args = self.get_conv_pool_args_2d_from_jit( + self.get_size_arg(kernel), stride, padding, dilation + ) if args.dilation_h != 1 or args.dilation_w != 1: raise Exception("NNAPI does not support dilated pooling.") image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image) assert len(image_oper.shape) == 4 - out_shape = get_conv_pool_shape(image_oper.shape, args, image_oper.shape[1], False) + out_shape = get_conv_pool_shape( + image_oper.shape, args, image_oper.shape[1], False + ) use_nchw = image_oper.use_nchw() inputs = [None] * 11 @@ -1470,26 +1528,42 @@ class _NnapiSerializer: inputs[10] = self.add_immediate_bool_scalar(use_nchw) outputs = [None] * 1 - outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape)) + outputs[0] = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) self.add_operation(opcode, inputs, outputs) def add_avg_pool2d(self, node): assert node.inputsSize() == 7 assert node.outputsSize() == 1 - image, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override = node.inputs() + ( + image, + kernel, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) = node.inputs() _, count_include_pad_value = self.get_constant_value(count_include_pad) _, divisor_override_value = self.get_constant_value(divisor_override) if not count_include_pad_value or divisor_override_value: - raise Exception("NNAPI doesn't support count_include_pad=False or divisor_override") + raise Exception( + "NNAPI doesn't support count_include_pad=False or divisor_override" + ) - args = self.get_conv_pool_args_2d_from_jit(self.get_size_arg(kernel), stride, padding) + args = self.get_conv_pool_args_2d_from_jit( + self.get_size_arg(kernel), stride, padding + ) image_id, image_oper = self.get_tensor_operand_by_jitval(image) assert len(image_oper.shape) == 4 - out_shape = get_conv_pool_shape(image_oper.shape, args, image_oper.shape[1], False) + out_shape = get_conv_pool_shape( + image_oper.shape, args, image_oper.shape[1], False + ) use_nchw = image_oper.use_nchw() inputs = [None] * 11 @@ -1506,7 +1580,9 @@ class _NnapiSerializer: inputs[10] = self.add_immediate_bool_scalar(use_nchw) outputs = [None] * 1 - out_id = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape)) + out_id = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) self._handle_conv_pool_flexible_input(out_id, image, args, False) outputs[0] = out_id @@ -1516,14 +1592,18 @@ class _NnapiSerializer: assert node.inputsSize() == 2 assert node.outputsSize() == 1 - image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size( + node.inputsAt(0) + ) assert len(image_oper.shape) == 4 size_ctype, size_arg = self.get_constant_value(node.inputsAt(1)) assert size_ctype.kind() == "ListType" assert size_ctype.getElementType().kind() == "IntType" if size_arg != [1, 1]: - raise Exception("NNAPI only supports adaptive_avg_pool2d with output size (1, 1).") + raise Exception( + "NNAPI only supports adaptive_avg_pool2d with output size (1, 1)." + ) out_shape = image_oper.shape[0:2] + tuple(size_arg) use_nchw = image_oper.use_nchw() @@ -1542,7 +1622,9 @@ class _NnapiSerializer: inputs[10] = self.add_immediate_bool_scalar(use_nchw) outputs = [None] * 1 - outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape)) + outputs[0] = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) @@ -1610,18 +1692,24 @@ class _NnapiSerializer: out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w) use_nchw = image_oper.use_nchw() - out_id = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape)) + out_id = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) if image_oper.shape[0] == 0 or image_oper.shape[1] == 0: raise Exception("Flexible batch or channels not supported") # Handle variable input size - for dim in (2, 3): # h, w indices + for dim in (2, 3): # h, w indices if image_oper.shape[dim] == 0: if size_ctype.kind() != "NoneType": self.compute_operand_shape(out_id, dim, size_arg[dim - 2]) elif scale_ctype.kind() != "NoneType": - self.compute_operand_shape(out_id, dim, f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})") + self.compute_operand_shape( + out_id, + dim, + f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})", + ) else: raise Exception("Size and scale cannot both be None.") @@ -1645,7 +1733,9 @@ class _NnapiSerializer: scale_ctype, scale_value = self.get_constant_value(jitval) assert scale_ctype.kind() in ("IntType", "FloatType") if scale_value != 1: - raise Exception("NNAPI Fully-Connected does not support alpha and beta.") + raise Exception( + "NNAPI Fully-Connected does not support alpha and beta." + ) self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias) @@ -1656,7 +1746,9 @@ class _NnapiSerializer: self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias) - def add_addmm_or_linear(self, node, transpose_weight, jit_input, jit_weight, jit_bias): + def add_addmm_or_linear( + self, node, transpose_weight, jit_input, jit_weight, jit_bias + ): input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input) bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias) @@ -1674,7 +1766,9 @@ class _NnapiSerializer: weight_oper = self.operands[weight_id] out_shape = (input_oper.shape[0], weight_oper.shape[0]) - out_id = self.add_tensor_operand(node.outputsAt(0), input_oper._replace(shape=out_shape)) + out_id = self.add_tensor_operand( + node.outputsAt(0), input_oper._replace(shape=out_shape) + ) if input_oper.shape[0] == 0: self.forward_operand_shape(out_id, 0, input_id, 0) @@ -1724,7 +1818,8 @@ class _NnapiSerializer: unsigned_weight = torch._make_per_tensor_quantized_tensor( (raw_weight.int_repr().int() + 128).to(torch.uint8), scale=raw_weight.q_scale(), - zero_point=raw_weight.q_zero_point() + 128) + zero_point=raw_weight.q_zero_point() + 128, + ) weight_scale = unsigned_weight.q_scale() bias_scale = input_oper.scale * weight_scale int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) @@ -1736,7 +1831,8 @@ class _NnapiSerializer: raise Exception( "Quantized convolution multiplier is greater than 1. " "This is supported by NNAPI, but not by most hardware backends. " - "Try training a model without quantization-aware training. ") + "Try training a model without quantization-aware training. " + ) # TODO: Transform at load time to share weights with CPU model. nnapi_weight_tensor = unsigned_weight.contiguous() @@ -1765,7 +1861,9 @@ class _NnapiSerializer: ctype, value = self.get_constant_value(jit_bias) if ctype.kind() == "NoneType": bias_idx = 1 if transpose else 0 - nnapi_bias_tensor = torch.zeros(weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype) + nnapi_bias_tensor = torch.zeros( + weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype + ) bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor) bias_oper = self.operands[bias_id] return bias_id, bias_oper @@ -1789,7 +1887,8 @@ class _NnapiSerializer: _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor) args = self.get_conv_pool_args_2d_from_jit( - weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups) + weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups + ) return self.add_conv2d_common( node.outputsAt(0), @@ -1823,12 +1922,12 @@ class _NnapiSerializer: _, ) = node.inputs() - _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") _, transpose = self.get_constant_value(jit_transpose) bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose) args = self.get_conv_pool_args_2d_from_jit( - weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups) + weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups + ) return self.add_conv2d_common( node.outputsAt(0), @@ -1846,11 +1945,7 @@ class _NnapiSerializer: assert node.inputsSize() == 3 assert node.outputsSize() == 1 - ( - jit_input, - jit_dim, - jit_half_to_float - ) = node.inputs() + (jit_input, jit_dim, jit_half_to_float) = node.inputs() input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) _, dim = self.get_constant_value(jit_dim, "IntType") @@ -1863,10 +1958,11 @@ class _NnapiSerializer: inputs[2] = self.add_immediate_int_scalar(dim) outputs = [None] * 1 - outputs[0] = self.add_tensor_operand(node.outputsAt(0), input_oper._replace(shape=out_shape)) + outputs[0] = self.add_tensor_operand( + node.outputsAt(0), input_oper._replace(shape=out_shape) + ) self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs) - def add_qconv2d(self, node, fuse_code, transpose=False): assert node.inputsSize() == 4 assert node.outputsSize() == 1 @@ -1889,9 +1985,11 @@ class _NnapiSerializer: ) = packed_weight.__getstate__()[0] assert pack_version == "2" packed_config, raw_weight = tensors - raw_bias, = opt_tensors + (raw_bias,) = opt_tensors assert raw_bias is not None - args = self.get_conv_pool_args_2d_from_pack(raw_weight.shape[2:4], packed_config) + args = self.get_conv_pool_args_2d_from_pack( + raw_weight.shape[2:4], packed_config + ) assert raw_weight.qscheme() == torch.per_tensor_affine if raw_weight.dtype == torch.quint8: @@ -1901,7 +1999,8 @@ class _NnapiSerializer: unsigned_weight = torch._make_per_tensor_quantized_tensor( (raw_weight.int_repr().int() + 128).to(torch.uint8), scale=raw_weight.q_scale(), - zero_point=raw_weight.q_zero_point() + 128) + zero_point=raw_weight.q_zero_point() + 128, + ) weight_scale = unsigned_weight.q_scale() _, image_oper = self.get_tensor_operand_by_jitval(jit_image) bias_scale = image_oper.scale * weight_scale @@ -1914,7 +2013,8 @@ class _NnapiSerializer: raise Exception( "Quantized convolution multiplier is greater than 1. " "This is supported by NNAPI, but not by most hardware backends. " - "Try training a model without quantization-aware training. ") + "Try training a model without quantization-aware training. " + ) return self.add_conv2d_common( node.outputsAt(0), @@ -1929,16 +2029,17 @@ class _NnapiSerializer: ) def add_conv2d_common( - self, - jit_out, - out_scale, - out_zero_point, - jit_image, - weight_tensor, - bias_id, - args, - transpose, - fuse_code): + self, + jit_out, + out_scale, + out_zero_point, + jit_image, + weight_tensor, + bias_id, + args, + transpose, + fuse_code, + ): image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) in_c = image_oper.shape[1] @@ -1972,8 +2073,7 @@ class _NnapiSerializer: assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale) assert bias_oper.zero_point == 0 else: - raise Exception( - f"Unsupported input type for conv2d: {image_oper.op_type}") + raise Exception(f"Unsupported input type for conv2d: {image_oper.op_type}") assert len(image_oper.shape) == 4 assert len(weight_oper.shape) == 4 @@ -2051,30 +2151,32 @@ class _NnapiSerializer: self.compute_operand_shape( out_id, 2, - f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}" + f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}", ) if in_w == 0: self.compute_operand_shape( out_id, 3, - f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}" + f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}", ) else: if in_h == 0: self.compute_operand_shape( out_id, 2, - f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1" + f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1", ) if in_w == 0: self.compute_operand_shape( out_id, 3, - f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1" + f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1", ) -def serialize_model(module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False): +def serialize_model( + module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False +): """Convert to NNAPI and serialize torchscript module: Parameters: module: Torchscript module to convert @@ -2086,4 +2188,6 @@ def serialize_model(module, inputs, *, config=None, return_shapes=None, use_int1 use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values """ - return _NnapiSerializer(config, use_int16_for_qint16).serialize_model(module, inputs, return_shapes) + return _NnapiSerializer(config, use_int16_for_qint16).serialize_model( + module, inputs, return_shapes + ) diff --git a/torch/backends/cpu/__init__.py b/torch/backends/cpu/__init__.py index 22524e3576c7..a3089bdb3366 100644 --- a/torch/backends/cpu/__init__.py +++ b/torch/backends/cpu/__init__.py @@ -1,6 +1,8 @@ import torch -__all__ = ["get_cpu_capability", ] +__all__ = [ + "get_cpu_capability", +] def get_cpu_capability() -> str: diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 5a7536624901..320fa01e8f31 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -1,14 +1,30 @@ -import sys -import torch import contextlib +import sys from enum import IntEnum from typing import Union -__all__ = ["is_built", "cuFFTPlanCacheAttrContextProp", "cuFFTPlanCache", "cuFFTPlanCacheManager", - "cuBLASModule", "preferred_linalg_library", "cufft_plan_cache", "matmul", "SDPBackend", "enable_flash_sdp", - "flash_sdp_enabled", "enable_mem_efficient_sdp", "mem_efficient_sdp_enabled", - "math_sdp_enabled", "enable_math_sdp", "sdp_kernel"] +import torch + +__all__ = [ + "is_built", + "cuFFTPlanCacheAttrContextProp", + "cuFFTPlanCache", + "cuFFTPlanCacheManager", + "cuBLASModule", + "preferred_linalg_library", + "cufft_plan_cache", + "matmul", + "SDPBackend", + "enable_flash_sdp", + "flash_sdp_enabled", + "enable_mem_efficient_sdp", + "mem_efficient_sdp_enabled", + "math_sdp_enabled", + "enable_math_sdp", + "sdp_kernel", +] + def is_built(): r"""Returns whether PyTorch is built with CUDA support. Note that this @@ -40,16 +56,19 @@ class cuFFTPlanCache: attributes `size` and `max_size`, and method `clear`, can fetch and/ or change properties of the C++ cuFFT plan cache. """ + def __init__(self, device_index): self.device_index = device_index size = cuFFTPlanCacheAttrContextProp( torch._cufft_get_plan_cache_size, - '.size is a read-only property showing the number of plans currently in the ' - 'cache. To change the cache capacity, set cufft_plan_cache.max_size.') + ".size is a read-only property showing the number of plans currently in the " + "cache. To change the cache capacity, set cufft_plan_cache.max_size.", + ) - max_size = cuFFTPlanCacheAttrContextProp(torch._cufft_get_plan_cache_max_size, - torch._cufft_set_plan_cache_max_size) + max_size = cuFFTPlanCacheAttrContextProp( + torch._cufft_get_plan_cache_max_size, torch._cufft_set_plan_cache_max_size + ) def clear(self): return torch._cufft_clear_plan_cache(self.device_index) @@ -75,10 +94,15 @@ class cuFFTPlanCacheManager: index = torch.cuda._utils._get_device_index(device) if index < 0 or index >= torch.cuda.device_count(): raise RuntimeError( - ("cufft_plan_cache: expected 0 <= device index < {}, but got " - "device with index {}").format(torch.cuda.device_count(), index)) + ( + "cufft_plan_cache: expected 0 <= device index < {}, but got " + "device with index {}" + ).format(torch.cuda.device_count(), index) + ) if len(self.caches) == 0: - self.caches.extend(cuFFTPlanCache(index) for index in range(torch.cuda.device_count())) + self.caches.extend( + cuFFTPlanCache(index) for index in range(torch.cuda.device_count()) + ) return self.caches[index] def __getattr__(self, name): @@ -110,15 +134,19 @@ class cuBLASModule: return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value) raise AssertionError("Unknown attribute " + name) -_LinalgBackends = { - 'default': torch._C._LinalgBackend.Default, - 'cusolver': torch._C._LinalgBackend.Cusolver, - 'magma': torch._C._LinalgBackend.Magma, -} -_LinalgBackends_str = ', '.join(_LinalgBackends.keys()) -def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend] = None) -> torch._C._LinalgBackend: - r''' +_LinalgBackends = { + "default": torch._C._LinalgBackend.Default, + "cusolver": torch._C._LinalgBackend.Cusolver, + "magma": torch._C._LinalgBackend.Magma, +} +_LinalgBackends_str = ", ".join(_LinalgBackends.keys()) + + +def preferred_linalg_library( + backend: Union[None, str, torch._C._LinalgBackend] = None +) -> torch._C._LinalgBackend: + r""" .. warning:: This flag is experimental and subject to change. When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries, @@ -152,14 +180,15 @@ def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend] * :func:`torch.linalg.eighvals` * :func:`torch.linalg.svd` * :func:`torch.linalg.svdvals` - ''' + """ if backend is None: pass elif isinstance(backend, str): if backend not in _LinalgBackends: - raise RuntimeError("Unknown input value. " - f"Choose from: {_LinalgBackends_str}.") + raise RuntimeError( + "Unknown input value. " f"Choose from: {_LinalgBackends_str}." + ) torch._C._set_linalg_preferred_backend(_LinalgBackends[backend]) elif isinstance(backend, torch._C._LinalgBackend): torch._C._set_linalg_preferred_backend(backend) @@ -200,6 +229,7 @@ def enable_flash_sdp(enabled: bool): """ torch._C._set_sdp_use_flash(enabled) + def mem_efficient_sdp_enabled(): r""" .. warning:: This flag is beta and subject to change. @@ -217,6 +247,7 @@ def enable_mem_efficient_sdp(enabled: bool): """ torch._C._set_sdp_use_mem_efficient(enabled) + def math_sdp_enabled(): r""" .. warning:: This flag is beta and subject to change. @@ -236,7 +267,11 @@ def enable_math_sdp(enabled: bool): @contextlib.contextmanager -def sdp_kernel(enable_flash: bool = True, enable_math: bool = True, enable_mem_efficient: bool = True): +def sdp_kernel( + enable_flash: bool = True, + enable_math: bool = True, + enable_mem_efficient: bool = True, +): r""" .. warning:: This flag is beta and subject to change. @@ -250,11 +285,12 @@ def sdp_kernel(enable_flash: bool = True, enable_math: bool = True, enable_mem_e enable_flash_sdp(enable_flash) enable_mem_efficient_sdp(enable_mem_efficient) enable_math_sdp(enable_math) - yield{} + yield {} finally: enable_flash_sdp(previous_flash) enable_mem_efficient_sdp(previous_mem_efficient) enable_math_sdp(previous_math) + cufft_plan_cache = cuFFTPlanCacheManager() matmul = cuBLASModule() diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index 98e095e294c2..7518b23faa23 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -1,9 +1,10 @@ -import sys import os -import torch +import sys import warnings from contextlib import contextmanager -from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation + +import torch +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule try: from torch._C import _cudnn @@ -19,6 +20,7 @@ except ImportError: __cudnn_version = None if _cudnn is not None: + def _init(): global __cudnn_version if __cudnn_version is None: @@ -37,30 +39,40 @@ if _cudnn is not None: else: cudnn_compatible = runtime_minor >= compile_minor if not cudnn_compatible: - if os.environ.get('PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK', '0') == '1': + if os.environ.get("PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK", "0") == "1": return True - base_error_msg = (f'cuDNN version incompatibility: ' - f'PyTorch was compiled against {compile_version} ' - f'but found runtime version {runtime_version}. ' - f'PyTorch already comes bundled with cuDNN. ' - f'One option to resolving this error is to ensure PyTorch ' - f'can find the bundled cuDNN.') + base_error_msg = ( + f"cuDNN version incompatibility: " + f"PyTorch was compiled against {compile_version} " + f"but found runtime version {runtime_version}. " + f"PyTorch already comes bundled with cuDNN. " + f"One option to resolving this error is to ensure PyTorch " + f"can find the bundled cuDNN." + ) - if 'LD_LIBRARY_PATH' in os.environ: - ld_library_path = os.environ.get('LD_LIBRARY_PATH', '') - if any(substring in ld_library_path for substring in ['cuda', 'cudnn']): - raise RuntimeError(f'{base_error_msg}' - f'Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn' - f'Please either remove it from the path or install cudnn {compile_version}') + if "LD_LIBRARY_PATH" in os.environ: + ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + if any( + substring in ld_library_path for substring in ["cuda", "cudnn"] + ): + raise RuntimeError( + f"{base_error_msg}" + f"Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn" + f"Please either remove it from the path or install cudnn {compile_version}" + ) else: - raise RuntimeError(f'{base_error_msg}' - f'one possibility is that there is a ' - f'conflicting cuDNN in LD_LIBRARY_PATH.') + raise RuntimeError( + f"{base_error_msg}" + f"one possibility is that there is a " + f"conflicting cuDNN in LD_LIBRARY_PATH." + ) else: raise RuntimeError(base_error_msg) return True + else: + def _init(): return False @@ -87,29 +99,40 @@ def is_available(): def is_acceptable(tensor): if not torch._C._get_cudnn_enabled(): return False - if tensor.device.type != 'cuda' or tensor.dtype not in CUDNN_TENSOR_DTYPES: + if tensor.device.type != "cuda" or tensor.dtype not in CUDNN_TENSOR_DTYPES: return False if not is_available(): warnings.warn( "PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild " - "PyTorch making sure the library is visible to the build system.") + "PyTorch making sure the library is visible to the build system." + ) return False if not _init(): - warnings.warn('cuDNN/MIOpen library not found. Check your {libpath}'.format( - libpath={ - 'darwin': 'DYLD_LIBRARY_PATH', - 'win32': 'PATH' - }.get(sys.platform, 'LD_LIBRARY_PATH'))) + warnings.warn( + "cuDNN/MIOpen library not found. Check your {libpath}".format( + libpath={"darwin": "DYLD_LIBRARY_PATH", "win32": "PATH"}.get( + sys.platform, "LD_LIBRARY_PATH" + ) + ) + ) return False return True -def set_flags(_enabled=None, _benchmark=None, _benchmark_limit=None, _deterministic=None, _allow_tf32=None): - orig_flags = (torch._C._get_cudnn_enabled(), - torch._C._get_cudnn_benchmark(), - None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(), - torch._C._get_cudnn_deterministic(), - torch._C._get_cudnn_allow_tf32()) +def set_flags( + _enabled=None, + _benchmark=None, + _benchmark_limit=None, + _deterministic=None, + _allow_tf32=None, +): + orig_flags = ( + torch._C._get_cudnn_enabled(), + torch._C._get_cudnn_benchmark(), + None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(), + torch._C._get_cudnn_deterministic(), + torch._C._get_cudnn_allow_tf32(), + ) if _enabled is not None: torch._C._set_cudnn_enabled(_enabled) if _benchmark is not None: @@ -124,9 +147,17 @@ def set_flags(_enabled=None, _benchmark=None, _benchmark_limit=None, _determinis @contextmanager -def flags(enabled=False, benchmark=False, benchmark_limit=10, deterministic=False, allow_tf32=True): +def flags( + enabled=False, + benchmark=False, + benchmark_limit=10, + deterministic=False, + allow_tf32=True, +): with __allow_nonbracketed_mutation(): - orig_flags = set_flags(enabled, benchmark, benchmark_limit, deterministic, allow_tf32) + orig_flags = set_flags( + enabled, benchmark, benchmark_limit, deterministic, allow_tf32 + ) try: yield finally: @@ -139,17 +170,28 @@ def flags(enabled=False, benchmark=False, benchmark_limit=10, deterministic=Fals # # torch.backends..enabled = True + class CudnnModule(PropModule): def __init__(self, m, name): super().__init__(m, name) enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) - deterministic = ContextProp(torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic) - benchmark = ContextProp(torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark) + deterministic = ContextProp( + torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic + ) + benchmark = ContextProp( + torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark + ) benchmark_limit = None if is_available(): - benchmark_limit = ContextProp(torch._C._cuda_get_cudnn_benchmark_limit, torch._C._cuda_set_cudnn_benchmark_limit) - allow_tf32 = ContextProp(torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32) + benchmark_limit = ContextProp( + torch._C._cuda_get_cudnn_benchmark_limit, + torch._C._cuda_set_cudnn_benchmark_limit, + ) + allow_tf32 = ContextProp( + torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32 + ) + # This is the sys.modules replacement trick, see # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py index 3c5740622f28..5ce166c8b28a 100644 --- a/torch/backends/cudnn/rnn.py +++ b/torch/backends/cudnn/rnn.py @@ -9,13 +9,13 @@ except ImportError: def get_cudnn_mode(mode): - if mode == 'RNN_RELU': + if mode == "RNN_RELU": return int(_cudnn.RNNMode.rnn_relu) - elif mode == 'RNN_TANH': + elif mode == "RNN_TANH": return int(_cudnn.RNNMode.rnn_tanh) - elif mode == 'LSTM': + elif mode == "LSTM": return int(_cudnn.RNNMode.lstm) - elif mode == 'GRU': + elif mode == "GRU": return int(_cudnn.RNNMode.gru) else: raise Exception(f"Unknown mode: {mode}") @@ -25,7 +25,6 @@ def get_cudnn_mode(mode): # dropout state for even better reproducibility), but it is kept for backwards # compatibility for old models. class Unserializable: - def __init__(self, inner): self.inner = inner @@ -42,17 +41,22 @@ class Unserializable: def init_dropout_state(dropout, train, dropout_seed, dropout_state): - dropout_desc_name = 'desc_' + str(torch.cuda.current_device()) + dropout_desc_name = "desc_" + str(torch.cuda.current_device()) dropout_p = dropout if train else 0 - if (dropout_desc_name not in dropout_state) or (dropout_state[dropout_desc_name].get() is None): + if (dropout_desc_name not in dropout_state) or ( + dropout_state[dropout_desc_name].get() is None + ): if dropout_p == 0: dropout_state[dropout_desc_name] = Unserializable(None) else: - dropout_state[dropout_desc_name] = Unserializable(torch._cudnn_init_dropout_state( # type: ignore[call-arg] - dropout_p, - train, - dropout_seed, - self_ty=torch.uint8, - device=torch.device('cuda'))) + dropout_state[dropout_desc_name] = Unserializable( + torch._cudnn_init_dropout_state( # type: ignore[call-arg] + dropout_p, + train, + dropout_seed, + self_ty=torch.uint8, + device=torch.device("cuda"), + ) + ) dropout_ts = dropout_state[dropout_desc_name].get() return dropout_ts diff --git a/torch/backends/mkl/__init__.py b/torch/backends/mkl/__init__.py index 22cad6db2203..af618044f026 100644 --- a/torch/backends/mkl/__init__.py +++ b/torch/backends/mkl/__init__.py @@ -1,11 +1,15 @@ import torch + def is_available(): r"""Returns whether PyTorch is built with MKL support.""" return torch._C.has_mkl + VERBOSE_OFF = 0 VERBOSE_ON = 1 + + class verbose: """ On-demand oneMKL verbosing functionality @@ -41,7 +45,9 @@ class verbose: if self.enable == VERBOSE_OFF: return st = torch._C._verbose.mkl_set_verbose(self.enable) - assert st, "Failed to set MKL into verbose mode. Please consider to disable this verbose scope." + assert ( + st + ), "Failed to set MKL into verbose mode. Please consider to disable this verbose scope." return self def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index 45275468daad..cee7373b41b4 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -1,15 +1,20 @@ import sys -import torch from contextlib import contextmanager -from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation + +import torch +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule + def is_available(): r"""Returns whether PyTorch is built with MKL-DNN support.""" return torch._C._has_mkldnn + VERBOSE_OFF = 0 VERBOSE_ON = 1 VERBOSE_ON_CREATION = 2 + + class verbose: """ On-demand oneDNN (former MKL-DNN) verbosing functionality @@ -46,18 +51,22 @@ class verbose: if self.level == VERBOSE_OFF: return st = torch._C._verbose.mkldnn_set_verbose(self.level) - assert st, "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope." + assert ( + st + ), "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope." return self def __exit__(self, exc_type, exc_val, exc_tb): torch._C._verbose.mkldnn_set_verbose(VERBOSE_OFF) return False + def set_flags(_enabled): orig_flags = (torch._C._get_mkldnn_enabled(),) torch._C._set_mkldnn_enabled(_enabled) return orig_flags + @contextmanager def flags(enabled=False): with __allow_nonbracketed_mutation(): @@ -68,12 +77,14 @@ def flags(enabled=False): with __allow_nonbracketed_mutation(): set_flags(orig_flags[0]) + class MkldnnModule(PropModule): def __init__(self, m, name): super().__init__(m, name) enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled) + # Cool stuff from torch/backends/cudnn/__init__.py and # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__) diff --git a/torch/backends/mps/__init__.py b/torch/backends/mps/__init__.py index 97370e8ad65e..33129276cecd 100644 --- a/torch/backends/mps/__init__.py +++ b/torch/backends/mps/__init__.py @@ -1,5 +1,6 @@ -import torch from functools import lru_cache as _lru_cache + +import torch from ...library import Library as _Library __all__ = ["is_built", "is_available", "is_macos13_or_newer"] @@ -26,13 +27,18 @@ def is_macos13_or_newer(minor: int = 0) -> bool: _lib = None + + def _init(): r"""Register prims as implementation of var_mean and group_norm""" global _lib if is_built() is False or _lib is not None: return - from ..._refs import var_mean as _var_mean, native_group_norm as _native_group_norm - from ..._decomp.decompositions import native_group_norm_backward as _native_group_norm_backward + from ..._decomp.decompositions import ( + native_group_norm_backward as _native_group_norm_backward, + ) + from ..._refs import native_group_norm as _native_group_norm, var_mean as _var_mean + _lib = _Library("aten", "IMPL") _lib.impl("var_mean.correction", _var_mean, "MPS") _lib.impl("native_group_norm", _native_group_norm, "MPS") diff --git a/torch/backends/opt_einsum/__init__.py b/torch/backends/opt_einsum/__init__.py index ab8c4bc193b9..e71768d3a82d 100644 --- a/torch/backends/opt_einsum/__init__.py +++ b/torch/backends/opt_einsum/__init__.py @@ -1,9 +1,10 @@ -from typing import Any -import warnings import sys -from functools import lru_cache as _lru_cache +import warnings from contextlib import contextmanager -from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation +from functools import lru_cache as _lru_cache +from typing import Any + +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule try: import opt_einsum as _opt_einsum # type: ignore[import] @@ -24,10 +25,12 @@ def get_opt_einsum() -> Any: def _set_enabled(_enabled: bool) -> None: if not is_available() and _enabled: - raise ValueError(f'opt_einsum is not available, so setting `enabled` to {_enabled} will not reap ' - 'the benefits of calculating an optimal path for einsum. torch.einsum will ' - 'fall back to contracting from left to right. To enable this optimal path ' - 'calculation, please install opt-einsum.') + raise ValueError( + f"opt_einsum is not available, so setting `enabled` to {_enabled} will not reap " + "the benefits of calculating an optimal path for einsum. torch.einsum will " + "fall back to contracting from left to right. To enable this optimal path " + "calculation, please install opt-einsum." + ) global enabled enabled = _enabled @@ -38,15 +41,21 @@ def _get_enabled() -> bool: def _set_strategy(_strategy: str) -> None: if not is_available(): - raise ValueError(f'opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. ' - 'torch.einsum will bypass path calculation and simply contract from left to right. ' - 'Please install opt_einsum or unset `strategy`.') + raise ValueError( + f"opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. " + "torch.einsum will bypass path calculation and simply contract from left to right. " + "Please install opt_einsum or unset `strategy`." + ) if not enabled: - raise ValueError(f'opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. ' - 'torch.einsum will bypass path calculation and simply contract from left to right. ' - 'Please set `enabled` to `True` as well or unset `strategy`.') - if _strategy not in ['auto', 'greedy', 'optimal']: - raise ValueError(f'`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}') + raise ValueError( + f"opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. " + "torch.einsum will bypass path calculation and simply contract from left to right. " + "Please set `enabled` to `True` as well or unset `strategy`." + ) + if _strategy not in ["auto", "greedy", "optimal"]: + raise ValueError( + f"`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}" + ) global strategy strategy = _strategy @@ -80,6 +89,7 @@ def flags(enabled=None, strategy=None): # # torch.backends.opt_einsum.enabled = True + class OptEinsumModule(PropModule): def __init__(self, m, name): super().__init__(m, name) @@ -91,9 +101,10 @@ class OptEinsumModule(PropModule): if is_available(): strategy = ContextProp(_get_strategy, _set_strategy) + # This is the sys.modules replacement trick, see # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__) enabled = True if is_available() else False -strategy = 'auto' if is_available() else None +strategy = "auto" if is_available() else None diff --git a/torch/backends/quantized/__init__.py b/torch/backends/quantized/__init__.py index 70b7db458f5b..85009753e0ae 100644 --- a/torch/backends/quantized/__init__.py +++ b/torch/backends/quantized/__init__.py @@ -1,29 +1,33 @@ import sys -import torch import types from typing import List +import torch + + # This function should correspond to the enums present in c10/core/QEngine.h def _get_qengine_id(qengine: str) -> int: - if qengine == 'none' or qengine == '' or qengine is None: + if qengine == "none" or qengine == "" or qengine is None: ret = 0 - elif qengine == 'fbgemm': + elif qengine == "fbgemm": ret = 1 - elif qengine == 'qnnpack': + elif qengine == "qnnpack": ret = 2 - elif qengine == 'onednn': + elif qengine == "onednn": ret = 3 - elif qengine == 'x86': + elif qengine == "x86": ret = 4 else: ret = -1 raise RuntimeError(f"{qengine} is not a valid value for quantized engine") return ret + # This function should correspond to the enums present in c10/core/QEngine.h def _get_qengine_str(qengine: int) -> str: - all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack', 3 : 'onednn', 4 : 'x86'} - return all_engines.get(qengine, '*undefined') + all_engines = {0: "none", 1: "fbgemm", 2: "qnnpack", 3: "onednn", 4: "x86"} + return all_engines.get(qengine, "*undefined") + class _QEngineProp: def __get__(self, obj, objtype) -> str: @@ -32,6 +36,7 @@ class _QEngineProp: def __set__(self, obj, val: str) -> None: torch._C._set_qengine(_get_qengine_id(val)) + class _SupportedQEnginesProp: def __get__(self, obj, objtype) -> List[str]: qengines = torch._C._supported_qengines() @@ -40,6 +45,7 @@ class _SupportedQEnginesProp: def __set__(self, obj, val) -> None: raise RuntimeError("Assignment not supported") + class QuantizedEngine(types.ModuleType): def __init__(self, m, name): super().__init__(name) @@ -51,6 +57,7 @@ class QuantizedEngine(types.ModuleType): engine = _QEngineProp() supported_engines = _SupportedQEnginesProp() + # This is the sys.modules replacement trick, see # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__) diff --git a/torch/backends/xeon/run_cpu.py b/torch/backends/xeon/run_cpu.py index a0764cb8339a..accfea49873b 100644 --- a/torch/backends/xeon/run_cpu.py +++ b/torch/backends/xeon/run_cpu.py @@ -121,29 +121,30 @@ Memory allocator """ -import sys -import platform -import subprocess -import os -from os.path import expanduser -import re import glob -from argparse import ArgumentParser, REMAINDER -from argparse import RawTextHelpFormatter import logging -from torch.distributed.elastic.multiprocessing import Std, start_processes -from typing import List, Dict +import os +import platform +import re +import subprocess +import sys +from argparse import ArgumentParser, RawTextHelpFormatter, REMAINDER +from os.path import expanduser +from typing import Dict, List + +from torch.distributed.elastic.multiprocessing import start_processes, Std format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" logging.basicConfig(level=logging.INFO, format=format_str) logger = logging.getLogger(__name__) + class _CPUinfo: """ Get CPU information, such as cores list and NUMA information. """ - def __init__(self, test_input=""): + def __init__(self, test_input=""): self.cpuinfo = [] if platform.system() in ["Windows", "Darwin"]: raise RuntimeError(f"{platform.system()} is not supported!!!") @@ -159,7 +160,9 @@ class _CPUinfo: # ... if test_input == "": lscpu_cmd = ["lscpu", "--parse=CPU,Core,Socket,Node"] - lscpu_info = subprocess.check_output(lscpu_cmd, universal_newlines=True).split("\n") + lscpu_info = subprocess.check_output( + lscpu_cmd, universal_newlines=True + ).split("\n") else: lscpu_info = test_input.split("\n") @@ -174,9 +177,9 @@ class _CPUinfo: # logical cores := cPU column in lscpu output self.node_nums = int(max([line[3] for line in self.cpuinfo])) + 1 self.node_physical_cores: List[List[int]] = [] # node_id is index - self.node_logical_cores: List[List[int]] = [] # node_id is index + self.node_logical_cores: List[List[int]] = [] # node_id is index self.physical_core_node_map = {} # physical core to numa node id - self.logical_core_node_map = {} # logical core to numa node id + self.logical_core_node_map = {} # logical core to numa node id for node_id in range(self.node_nums): cur_node_physical_core = [] @@ -200,12 +203,16 @@ class _CPUinfo: def get_node_physical_cores(self, node_id): if node_id < 0 or node_id > self.node_nums - 1: - raise ValueError(f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}") + raise ValueError( + f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}" + ) return self.node_physical_cores[node_id] def get_node_logical_cores(self, node_id): if node_id < 0 or node_id > self.node_nums - 1: - raise ValueError(f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}") + raise ValueError( + f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}" + ) return self.node_logical_cores[node_id] def get_all_physical_cores(self): @@ -232,16 +239,23 @@ class _CPUinfo: if numa_id not in numa_ids: numa_ids.append(numa_id) if len(numa_ids) > 1: - logger.warning("Numa Aware: cores:%s on different NUMA nodes:%s. To avoid \ + logger.warning( + "Numa Aware: cores:%s on different NUMA nodes:%s. To avoid \ this behavior, please use --ncores-per-instance knob to make sure number of cores is divisible by --ncores-per-\ -instance. Alternatively, please use --skip-cross-node-cores knob.", str(core_list), str(numa_ids)) +instance. Alternatively, please use --skip-cross-node-cores knob.", + str(core_list), + str(numa_ids), + ) if len(numa_ids) == 0: - raise RuntimeError("invalid number of NUMA nodes; please make sure numa_ids >= 1") + raise RuntimeError( + "invalid number of NUMA nodes; please make sure numa_ids >= 1" + ) return numa_ids + class _Launcher: r""" - Class for launcher + Class for launcher """ msg_lib_notfound = f"Unable to find the {{0}} library file lib{{1}}.so in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib \ @@ -261,8 +275,13 @@ or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib6 if "VIRTUAL_ENV" in os.environ: library_paths.append(f"{os.environ['VIRTUAL_ENV']}/lib") - library_paths += [f"{expanduser('~')}/.local/lib", "/usr/local/lib", - "/usr/local/lib64", "/usr/lib", "/usr/lib64"] + library_paths += [ + f"{expanduser('~')}/.local/lib", + "/usr/local/lib", + "/usr/local/lib64", + "/usr/lib", + "/usr/lib64", + ] lib_find = False lib_set = False @@ -276,37 +295,46 @@ or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib6 matches = glob.glob(library_file) if len(matches) > 0: ld_preloads = [f"{matches[0]}", os.getenv("LD_PRELOAD", "")] - os.environ["LD_PRELOAD"] = os.pathsep.join([p.strip(os.pathsep) for p in ld_preloads if p]) + os.environ["LD_PRELOAD"] = os.pathsep.join( + [p.strip(os.pathsep) for p in ld_preloads if p] + ) lib_find = True break return lib_set or lib_find - def is_numactl_available(self): numactl_available = False try: cmd = ["numactl", "-C", "0", "-m", "0", "hostname"] - r = subprocess.run(cmd, env=os.environ, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + r = subprocess.run( + cmd, + env=os.environ, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) if r.returncode == 0: numactl_available = True except Exception: pass return numactl_available - - def set_memory_allocator(self, enable_tcmalloc=True, enable_jemalloc=False, use_default_allocator=False): + def set_memory_allocator( + self, enable_tcmalloc=True, enable_jemalloc=False, use_default_allocator=False + ): """ Enable TCMalloc/JeMalloc with LD_PRELOAD and set configuration for JeMalloc. By default, PTMalloc will be used for PyTorch, but TCMalloc and JeMalloc can get better memory reuse and reduce page fault to improve performance. """ if enable_tcmalloc and enable_jemalloc: - raise RuntimeError("Unable to enable TCMalloc and JEMalloc at the same time.") + raise RuntimeError( + "Unable to enable TCMalloc and JEMalloc at the same time." + ) if enable_tcmalloc: find_tc = self.add_lib_preload(lib_type="tcmalloc") if not find_tc: - msg = f"{self.msg_lib_notfound} you can use \"conda install -c conda-forge gperftools\" to install {{0}}" + msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge gperftools" to install {{0}}' logger.warning(msg.format("TCmalloc", "tcmalloc")) # noqa: G001 else: logger.info("Use TCMalloc memory allocator") @@ -314,11 +342,14 @@ or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib6 elif enable_jemalloc: find_je = self.add_lib_preload(lib_type="jemalloc") if not find_je: - msg = f"{self.msg_lib_notfound} you can use \"conda install -c conda-forge jemalloc\" to install {{0}}" + msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge jemalloc" to install {{0}}' logger.warning(msg.format("Jemalloc", "jemalloc")) # noqa: G001 else: logger.info("Use JeMalloc memory allocator") - self.set_env("MALLOC_CONF", "oversize_threshold:1,background_thread:true,metadata_thp:auto") + self.set_env( + "MALLOC_CONF", + "oversize_threshold:1,background_thread:true,metadata_thp:auto", + ) elif use_default_allocator: pass @@ -332,10 +363,13 @@ or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib6 if find_je: logger.info("Use JeMalloc memory allocator") return - logger.warning("""Neither TCMalloc nor JeMalloc is found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib + logger.warning( + """Neither TCMalloc nor JeMalloc is found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or %s/.local/lib/ so the LD_PRELOAD environment variable will not be set. - This may drop the performance""", expanduser("~")) + This may drop the performance""", + expanduser("~"), + ) def log_env_var(self, env_var_name=""): if env_var_name in os.environ: @@ -347,30 +381,40 @@ or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib6 if env_name not in os.environ: os.environ[env_name] = env_value elif os.environ[env_name] != env_value: - logger.warning("Overriding value with the one set in environment variable: %s. \ -Value applied: %s. Value ignored: %s", env_name, os.environ[env_name], env_value) + logger.warning( + "Overriding value with the one set in environment variable: %s. \ +Value applied: %s. Value ignored: %s", + env_name, + os.environ[env_name], + env_value, + ) self.log_env_var(env_name) # set_kmp_affinity is used to control whether to set KMP_AFFINITY or not. # In scenario that use all cores on all nodes, including logical cores, setting KMP_AFFINITY disables logical cores. # In this case, KMP_AFFINITY should not be set. - def set_multi_thread_and_allocator(self, ncores_per_instance, - disable_iomp=False, - set_kmp_affinity=True, - enable_tcmalloc=True, - enable_jemalloc=False, - use_default_allocator=False): + def set_multi_thread_and_allocator( + self, + ncores_per_instance, + disable_iomp=False, + set_kmp_affinity=True, + enable_tcmalloc=True, + enable_jemalloc=False, + use_default_allocator=False, + ): """ Set multi-thread configuration and enable Intel openMP and TCMalloc/JeMalloc. By default, GNU openMP and PTMalloc are used in PyTorch. but Intel openMP and TCMalloc/JeMalloc are better alternatives to get performance benefit. """ - self.set_memory_allocator(enable_tcmalloc, enable_jemalloc, use_default_allocator) + self.set_memory_allocator( + enable_tcmalloc, enable_jemalloc, use_default_allocator + ) self.set_env("OMP_NUM_THREADS", str(ncores_per_instance)) if not disable_iomp: find_iomp = self.add_lib_preload(lib_type="iomp5") if not find_iomp: - msg = f"{self.msg_lib_notfound} you can use \"conda install mkl\" to install {{0}}" + msg = f'{self.msg_lib_notfound} you can use "conda install mkl" to install {{0}}' logger.warning(msg.format("iomp", "iomp5")) # noqa: G001 else: logger.info("Using Intel OpenMP") @@ -382,6 +426,7 @@ Value applied: %s. Value ignored: %s", env_name, os.environ[env_name], env_value r""" Launcher for single instance and multi-instance """ + def launch(self, args): cores = [] set_kmp_affinity = True @@ -389,10 +434,19 @@ Value applied: %s. Value ignored: %s", env_name, os.environ[env_name], env_value if args.core_list: # user specify what cores will be used by params cores = [int(x) for x in args.core_list.split(",")] if args.ncores_per_instance == -1: - raise RuntimeError("please specify the \"--ncores-per-instance\" if you have pass the --core-list params") - elif args.ninstances > 1 and args.ncores_per_instance * args.ninstances < len(cores): - logger.warning("only first %s cores will be used, \ -but you specify %s cores in core_list", args.ncores_per_instance * args.ninstances, len(cores)) + raise RuntimeError( + 'please specify the "--ncores-per-instance" if you have pass the --core-list params' + ) + elif ( + args.ninstances > 1 + and args.ncores_per_instance * args.ninstances < len(cores) + ): + logger.warning( + "only first %s cores will be used, \ +but you specify %s cores in core_list", + args.ncores_per_instance * args.ninstances, + len(cores), + ) else: args.ninstances = len(cores) // args.ncores_per_instance @@ -410,15 +464,25 @@ but you specify %s cores in core_list", args.ncores_per_instance * args.ninstanc cores = self.cpuinfo.get_node_physical_cores(args.node_id) else: cores = self.cpuinfo.get_all_physical_cores() - if not args.multi_instance and args.ninstances == -1 and args.ncores_per_instance == -1: + if ( + not args.multi_instance + and args.ninstances == -1 + and args.ncores_per_instance == -1 + ): args.ninstances = 1 args.ncores_per_instance = len(cores) - elif args.multi_instance and args.ninstances == -1 and args.ncores_per_instance == -1: + elif ( + args.multi_instance + and args.ninstances == -1 + and args.ncores_per_instance == -1 + ): args.throughput_mode = True elif args.ncores_per_instance == -1 and args.ninstances != -1: if args.ninstances > len(cores): - raise RuntimeError(f"there are {len(cores)} total cores but you specify {args.ninstances} ninstances; \ -please make sure ninstances <= total_cores)") + raise RuntimeError( + f"there are {len(cores)} total cores but you specify {args.ninstances} ninstances; \ +please make sure ninstances <= total_cores)" + ) else: args.ncores_per_instance = len(cores) // args.ninstances elif args.ncores_per_instance != -1 and args.ninstances == -1: @@ -429,70 +493,99 @@ please make sure ninstances <= total_cores)") num_leftover_cores = ncore_per_node % args.ncores_per_instance if args.ncores_per_instance > ncore_per_node: # too many ncores_per_instance to skip cross-node cores - logger.warning("there are %s core(s) per socket, but you specify %s ncores_per_instance and \ + logger.warning( + "there are %s core(s) per socket, but you specify %s ncores_per_instance and \ skip_cross_node_cores. Please make sure --ncores-per-instance < core(s) per \ -socket", ncore_per_node, args.ncores_per_instance) +socket", + ncore_per_node, + args.ncores_per_instance, + ) exit(-1) elif num_leftover_cores == 0: # aren't any cross-node cores - logger.info('--skip-cross-node-cores is set, but there are no cross-node cores.') + logger.info( + "--skip-cross-node-cores is set, but there are no cross-node cores." + ) args.ninstances = len(cores) // args.ncores_per_instance else: # skip cross-node cores if args.ninstances != -1: - logger.warning('--skip-cross-node-cores is exclusive to --ninstances. --ninstances \ -won\'t take effect even if it is set explicitly.') + logger.warning( + "--skip-cross-node-cores is exclusive to --ninstances. --ninstances \ +won't take effect even if it is set explicitly." + ) i = 1 leftover_cores = set() while ncore_per_node * i <= len(cores): - leftover_cores.update(cores[ncore_per_node * i - num_leftover_cores : ncore_per_node * i]) + leftover_cores.update( + cores[ + ncore_per_node * i + - num_leftover_cores : ncore_per_node * i + ] + ) i += 1 cores = list(set(cores) - leftover_cores) assert len(cores) % args.ncores_per_instance == 0 args.ninstances = len(cores) // args.ncores_per_instance else: if args.ninstances * args.ncores_per_instance > len(cores): - raise RuntimeError("Please make sure ninstances * ncores_per_instance <= total_cores") + raise RuntimeError( + "Please make sure ninstances * ncores_per_instance <= total_cores" + ) if args.latency_mode: - logger.warning("--latency-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \ ---use-logical-core. They won't take effect even they are set explicitly.") + logger.warning( + "--latency-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \ +--use-logical-core. They won't take effect even they are set explicitly." + ) args.ncores_per_instance = 4 cores = self.cpuinfo.get_all_physical_cores() args.ninstances = len(cores) // args.ncores_per_instance if args.throughput_mode: - logger.warning("--throughput-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \ ---use-logical-core. They won't take effect even they are set explicitly.") + logger.warning( + "--throughput-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \ +--use-logical-core. They won't take effect even they are set explicitly." + ) args.ninstances = self.cpuinfo.node_nums cores = self.cpuinfo.get_all_physical_cores() args.ncores_per_instance = len(cores) // args.ninstances if args.ninstances > 1 and args.rank != -1: - logger.info("assigning %s cores for instance %s", args.ncores_per_instance, args.rank) + logger.info( + "assigning %s cores for instance %s", + args.ncores_per_instance, + args.rank, + ) if not args.disable_numactl: numactl_available = self.is_numactl_available() if not numactl_available: if not args.disable_taskset: - logger.warning("Core binding with numactl is not available. Disabling numactl and using taskset instead. \ - This may affect performance in multi-socket system; please use numactl if memory binding is needed.") + logger.warning( + "Core binding with numactl is not available. Disabling numactl and using taskset instead. \ + This may affect performance in multi-socket system; please use numactl if memory binding is needed." + ) args.disable_numactl = True enable_taskset = True else: - logger.warning("Core binding with numactl is not available, and --disable_taskset is set. \ - Please unset --disable_taskset to use taskset instead of numactl.") + logger.warning( + "Core binding with numactl is not available, and --disable_taskset is set. \ + Please unset --disable_taskset to use taskset instead of numactl." + ) exit(-1) if not args.disable_taskset: enable_taskset = True - self.set_multi_thread_and_allocator(args.ncores_per_instance, - args.disable_iomp, - set_kmp_affinity, - args.enable_tcmalloc, - args.enable_jemalloc, - args.use_default_allocator) + self.set_multi_thread_and_allocator( + args.ncores_per_instance, + args.disable_iomp, + set_kmp_affinity, + args.enable_tcmalloc, + args.enable_jemalloc, + args.use_default_allocator, + ) entrypoint = "" launch_args = {} launch_envs: Dict[int, Dict] = {} @@ -506,11 +599,20 @@ won\'t take effect even if it is set explicitly.') elif enable_taskset: cmd = ["taskset"] cores = sorted(cores) - if args.rank == -1: # sequentially assign ncores_per_instance to ninstances - core_list = cores[i * args.ncores_per_instance : (i + 1) * args.ncores_per_instance] + if ( + args.rank == -1 + ): # sequentially assign ncores_per_instance to ninstances + core_list = cores[ + i + * args.ncores_per_instance : (i + 1) + * args.ncores_per_instance + ] else: # assign ncores_per_instance from rank - core_list = cores[args.rank * args.ncores_per_instance - : (args.rank + 1) * args.ncores_per_instance] + core_list = cores[ + args.rank + * args.ncores_per_instance : (args.rank + 1) + * args.ncores_per_instance + ] core_ranges: List[Dict] = [] for core in core_list: @@ -528,7 +630,12 @@ won\'t take effect even if it is set explicitly.') cur_process_cores = cur_process_cores[:-1] if not args.disable_numactl: numa_params = f"-C {cur_process_cores} " - numa_ids = ",".join([str(numa_id) for numa_id in self.cpuinfo.numa_aware_check(core_list)]) + numa_ids = ",".join( + [ + str(numa_id) + for numa_id in self.cpuinfo.numa_aware_check(core_list) + ] + ) numa_params += f"-m {numa_ids}" cmd.extend(numa_params.split()) elif enable_taskset: @@ -554,94 +661,206 @@ won\'t take effect even if it is set explicitly.') if args.rank != -1: # launches single instance, rank, only break - ctx = start_processes(name=args.log_file_prefix, - entrypoint=entrypoint, - args=launch_args, - envs=launch_envs, - log_dir=args.log_path, - tee=launch_tee) + ctx = start_processes( + name=args.log_file_prefix, + entrypoint=entrypoint, + args=launch_args, + envs=launch_envs, + log_dir=args.log_path, + tee=launch_tee, + ) ctx.wait() def _add_memory_allocator_params(parser): - group = parser.add_argument_group("Memory Allocator Parameters") # allocator control - group.add_argument("--enable-tcmalloc", "--enable_tcmalloc", action="store_true", default=False, - help="Enable tcmalloc allocator") - group.add_argument("--enable-jemalloc", "--enable_jemalloc", action="store_true", default=False, - help="Enable jemalloc allocator") - group.add_argument("--use-default-allocator", "--use_default_allocator", action="store_true", default=False, - help="Use default memory allocator") + group.add_argument( + "--enable-tcmalloc", + "--enable_tcmalloc", + action="store_true", + default=False, + help="Enable tcmalloc allocator", + ) + group.add_argument( + "--enable-jemalloc", + "--enable_jemalloc", + action="store_true", + default=False, + help="Enable jemalloc allocator", + ) + group.add_argument( + "--use-default-allocator", + "--use_default_allocator", + action="store_true", + default=False, + help="Use default memory allocator", + ) + def _add_multi_instance_params(parser): - group = parser.add_argument_group("Multi-instance Parameters") # multi-instance control - group.add_argument("--ncores-per-instance", "--ncores_per_instance", metavar="\b", default=-1, type=int, - help="Cores per instance") - group.add_argument("--ninstances", metavar="\b", default=-1, type=int, - help="For multi-instance, you should give the cores number you used for per instance.") - group.add_argument("--skip-cross-node-cores", "--skip_cross_node_cores", action='store_true', default=False, - help="If specified --ncores-per-instance, skips cross-node cores.") - group.add_argument("--rank", metavar="\b", default="-1", type=int, - help="Specify instance index to assign ncores_per_instance for rank; \ + group.add_argument( + "--ncores-per-instance", + "--ncores_per_instance", + metavar="\b", + default=-1, + type=int, + help="Cores per instance", + ) + group.add_argument( + "--ninstances", + metavar="\b", + default=-1, + type=int, + help="For multi-instance, you should give the cores number you used for per instance.", + ) + group.add_argument( + "--skip-cross-node-cores", + "--skip_cross_node_cores", + action="store_true", + default=False, + help="If specified --ncores-per-instance, skips cross-node cores.", + ) + group.add_argument( + "--rank", + metavar="\b", + default="-1", + type=int, + help="Specify instance index to assign ncores_per_instance for rank; \ otherwise ncores_per_instance will be assigned sequentially to ninstances. Please refer to \ -https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md") - group.add_argument("--latency-mode", "--latency_mode", action="store_true", default=False, - help="By default 4 core per instance and use all physical cores") - group.add_argument("--throughput-mode", "--throughput_mode", action="store_true", default=False, - help="By default one instance per node and use all physical cores") - group.add_argument("--node-id", "--node_id", metavar="\b", default=-1, type=int, - help="node id for multi-instance, by default all nodes will be used") - group.add_argument("--use-logical-core", "--use_logical_core", action="store_true", default=False, - help="Whether only use physical cores") - group.add_argument("--disable-numactl", "--disable_numactl", action="store_true", default=False, - help="Disable numactl") - group.add_argument("--disable-taskset", "--disable_taskset", action="store_true", default=False, - help="Disable taskset") - group.add_argument("--core-list", "--core_list", metavar="\b", default=None, type=str, - help="Specify the core list as \"core_id, core_id, ....\", otherwise, all the cores will be used.") - group.add_argument("--log-path", "--log_path", metavar="\b", default="", type=str, - help="The log file directory. Default path is "", which means disable logging to files.") - group.add_argument("--log-file-prefix", "--log_file_prefix", metavar="\b", default="run", type=str, - help="log file prefix") +https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md", + ) + group.add_argument( + "--latency-mode", + "--latency_mode", + action="store_true", + default=False, + help="By default 4 core per instance and use all physical cores", + ) + group.add_argument( + "--throughput-mode", + "--throughput_mode", + action="store_true", + default=False, + help="By default one instance per node and use all physical cores", + ) + group.add_argument( + "--node-id", + "--node_id", + metavar="\b", + default=-1, + type=int, + help="node id for multi-instance, by default all nodes will be used", + ) + group.add_argument( + "--use-logical-core", + "--use_logical_core", + action="store_true", + default=False, + help="Whether only use physical cores", + ) + group.add_argument( + "--disable-numactl", + "--disable_numactl", + action="store_true", + default=False, + help="Disable numactl", + ) + group.add_argument( + "--disable-taskset", + "--disable_taskset", + action="store_true", + default=False, + help="Disable taskset", + ) + group.add_argument( + "--core-list", + "--core_list", + metavar="\b", + default=None, + type=str, + help='Specify the core list as "core_id, core_id, ....", otherwise, all the cores will be used.', + ) + group.add_argument( + "--log-path", + "--log_path", + metavar="\b", + default="", + type=str, + help="The log file directory. Default path is " + ", which means disable logging to files.", + ) + group.add_argument( + "--log-file-prefix", + "--log_file_prefix", + metavar="\b", + default="run", + type=str, + help="log file prefix", + ) + def _add_kmp_iomp_params(parser): - group = parser.add_argument_group("IOMP Parameters") - group.add_argument("--disable-iomp", "--disable_iomp", action="store_true", default=False, - help="By default, we use Intel OpenMP and libiomp5.so will be add to LD_PRELOAD") + group.add_argument( + "--disable-iomp", + "--disable_iomp", + action="store_true", + default=False, + help="By default, we use Intel OpenMP and libiomp5.so will be add to LD_PRELOAD", + ) + def create_args(parser=None): """ Helper function parsing the command line options @retval ArgumentParser """ - parser.add_argument("--multi-instance", "--multi_instance", action="store_true", default=False, - help="Enable multi-instance, by default one instance per node") + parser.add_argument( + "--multi-instance", + "--multi_instance", + action="store_true", + default=False, + help="Enable multi-instance, by default one instance per node", + ) - parser.add_argument("-m", "--module", default=False, action="store_true", - help="Changes each process to interpret the launch script " - "as a python module, executing with the same behavior as" - "\"python -m\".") + parser.add_argument( + "-m", + "--module", + default=False, + action="store_true", + help="Changes each process to interpret the launch script " + "as a python module, executing with the same behavior as" + '"python -m".', + ) - parser.add_argument("--no-python", "--no_python", default=False, action="store_true", - help="Do not prepend the --program script with \"python\" - just exec " - "it directly. Useful when the script is not a Python script.") + parser.add_argument( + "--no-python", + "--no_python", + default=False, + action="store_true", + help='Do not prepend the --program script with "python" - just exec ' + "it directly. Useful when the script is not a Python script.", + ) _add_memory_allocator_params(parser) _add_kmp_iomp_params(parser) _add_multi_instance_params(parser) # positional - parser.add_argument("program", type=str, - help="The full path to the program/script to be launched. " - "followed by all the arguments for the script") + parser.add_argument( + "program", + type=str, + help="The full path to the program/script to be launched. " + "followed by all the arguments for the script", + ) # rest from the training program parser.add_argument("program_args", nargs=REMAINDER) + def main(args): env_before = set(os.environ.keys()) if platform.system() in ["Windows", "Darwin"]: @@ -653,10 +872,14 @@ def main(args): args.log_path = os.devnull if args.latency_mode and args.throughput_mode: - raise RuntimeError("Either args.latency_mode or args.throughput_mode should be set") + raise RuntimeError( + "Either args.latency_mode or args.throughput_mode should be set" + ) if not args.no_python and not args.program.endswith(".py"): - raise RuntimeError("For non Python script, you should use \"--no-python\" parameter.") + raise RuntimeError( + 'For non Python script, you should use "--no-python" parameter.' + ) # Verify LD_PRELOAD if "LD_PRELOAD" in os.environ: @@ -678,22 +901,25 @@ def main(args): for x in sorted(set(os.environ.keys()) - env_before): logger.debug("%s=%s", x, os.environ[x]) + if __name__ == "__main__": - parser = ArgumentParser(description="This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable " - "Processors with optimal configurations. Single instance inference, " - "multi-instance inference are enable. To get the peak performance on Intel(R) " - "Xeon(R) Scalable Processors, the script optimizes the configuration " - "of thread and memory management. For thread management, the script configures thread " - "affinity and the preload of Intel OMP library. For memory management, it configures " - "NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc) " - "\n################################# Basic usage ############################# \n" - "\n 1. single instance\n" - "\n >>> python -m torch.backends.xeon.run_cpu python_script args \n" - "\n2. multi-instance \n" - "\n >>> python -m torch.backends.xeon.run_cpu --ninstances xxx " - "--ncores-per-instance xx python_script args\n" - "\n############################################################################# \n", - formatter_class=RawTextHelpFormatter) + parser = ArgumentParser( + description="This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable " + "Processors with optimal configurations. Single instance inference, " + "multi-instance inference are enable. To get the peak performance on Intel(R) " + "Xeon(R) Scalable Processors, the script optimizes the configuration " + "of thread and memory management. For thread management, the script configures thread " + "affinity and the preload of Intel OMP library. For memory management, it configures " + "NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc) " + "\n################################# Basic usage ############################# \n" + "\n 1. single instance\n" + "\n >>> python -m torch.backends.xeon.run_cpu python_script args \n" + "\n2. multi-instance \n" + "\n >>> python -m torch.backends.xeon.run_cpu --ninstances xxx " + "--ncores-per-instance xx python_script args\n" + "\n############################################################################# \n", + formatter_class=RawTextHelpFormatter, + ) create_args(parser) args = parser.parse_args() main(args) diff --git a/torch/backends/xnnpack/__init__.py b/torch/backends/xnnpack/__init__.py index 17c7f15b355b..c26dc11deb47 100644 --- a/torch/backends/xnnpack/__init__.py +++ b/torch/backends/xnnpack/__init__.py @@ -1,7 +1,9 @@ import sys -import torch import types +import torch + + class _XNNPACKEnabled: def __get__(self, obj, objtype): return torch._C._is_xnnpack_enabled() @@ -9,6 +11,7 @@ class _XNNPACKEnabled: def __set__(self, obj, val): raise RuntimeError("Assignment not supported") + class XNNPACKEngine(types.ModuleType): def __init__(self, m, name): super().__init__(name) @@ -19,6 +22,7 @@ class XNNPACKEngine(types.ModuleType): enabled = _XNNPACKEnabled() + # This is the sys.modules replacement trick, see # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 sys.modules[__name__] = XNNPACKEngine(sys.modules[__name__], __name__) diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index 8845deaab572..86c788b695ab 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -3,11 +3,12 @@ This package implements abstractions found in ``torch.cuda`` to facilitate writing device-agnostic code. """ -from typing import Any, Optional, Union from contextlib import AbstractContextManager -from . import amp -from .. import device as _device +from typing import Any, Optional, Union + import torch +from .. import device as _device +from . import amp __all__ = [ "is_available", @@ -16,11 +17,12 @@ __all__ = [ "stream", "device_count", "Stream", - "StreamContext" + "StreamContext", ] _device_t = Union[_device, str, int, None] + def _is_cpu_support_vnni() -> bool: r"""Returns a bool indicating if CPU supports VNNI.""" return torch._C._cpu._is_cpu_support_vnni() @@ -34,6 +36,7 @@ def is_available() -> bool: """ return True + def synchronize(device: _device_t = None) -> None: r"""Waits for all kernels in all streams on the CPU device to complete. @@ -44,15 +47,19 @@ def synchronize(device: _device_t = None) -> None: """ pass + class Stream: """ N.B. This class only exists to facilitate device-agnostic code """ + pass + _default_cpu_stream = Stream() _current_stream = _default_cpu_stream + def current_stream(device: _device_t = None) -> Stream: r"""Returns the currently selected :class:`Stream` for a given device. @@ -64,13 +71,14 @@ def current_stream(device: _device_t = None) -> Stream: """ return _current_stream + class StreamContext(AbstractContextManager): r"""Context-manager that selects a given stream. N.B. This class only exists to facilitate device-agnostic code """ - cur_stream : Optional[Stream] + cur_stream: Optional[Stream] def __init__(self, stream): self.stream = stream @@ -93,6 +101,7 @@ class StreamContext(AbstractContextManager): global _current_stream _current_stream = self.prev_stream + def stream(stream: Stream) -> AbstractContextManager: r"""Wrapper around the Context-manager StreamContext that selects a given stream. @@ -101,6 +110,7 @@ def stream(stream: Stream) -> AbstractContextManager: """ return StreamContext(stream) + def device_count() -> int: r"""Returns number of CPU devices (not cores). Always 1. diff --git a/torch/cpu/amp/autocast_mode.py b/torch/cpu/amp/autocast_mode.py index 0909a0bcd556..a29a96891722 100644 --- a/torch/cpu/amp/autocast_mode.py +++ b/torch/cpu/amp/autocast_mode.py @@ -1,20 +1,30 @@ -import torch from typing import Any +import torch + __all__ = ["autocast"] + class autocast(torch.amp.autocast_mode.autocast): r""" See :class:`torch.autocast`. ``torch.cpu.amp.autocast(args...)`` is equivalent to ``torch.autocast("cpu", args...)`` """ - def __init__(self, enabled : bool = True, dtype : torch.dtype = torch.bfloat16, cache_enabled : bool = True): + + def __init__( + self, + enabled: bool = True, + dtype: torch.dtype = torch.bfloat16, + cache_enabled: bool = True, + ): if torch._jit_internal.is_scripting(): self._enabled = enabled self.device = "cpu" self.fast_dtype = dtype return - super().__init__("cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) + super().__init__( + "cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled + ) def __enter__(self): if torch._jit_internal.is_scripting(): diff --git a/torch/csrc/jit/tensorexpr/codegen_external.py b/torch/csrc/jit/tensorexpr/codegen_external.py index 120520b139cd..bc69b05162f6 100644 --- a/torch/csrc/jit/tensorexpr/codegen_external.py +++ b/torch/csrc/jit/tensorexpr/codegen_external.py @@ -1,15 +1,19 @@ #!/usr/bin/env python3 import argparse -from torchgen.gen import parse_native_yaml, FileManager + import torchgen.model as model +from torchgen.gen import FileManager, parse_native_yaml + def num_leading_spaces(line: str) -> int: return len(line) - len(line.lstrip()) + + def deindent(code: str) -> str: - lines = code.split('\n') + lines = code.split("\n") min_leading_spaces = min(map(num_leading_spaces, lines)) lines = [line[min_leading_spaces:] for line in lines] - return '\n'.join(lines) + return "\n".join(lines) def gen_external(native_functions_path, tags_path, external_path): @@ -29,11 +33,21 @@ def gen_external(native_functions_path, tags_path, external_path): continue # Doesn't currently support kwarg arguments - if len(args.pre_tensor_options_kwarg_only) > 0 or len(args.post_tensor_options_kwarg_only) > 0: + if ( + len(args.pre_tensor_options_kwarg_only) > 0 + or len(args.post_tensor_options_kwarg_only) > 0 + ): continue self_arg = [args.self_arg.argument] if args.self_arg is not None else [] - args = list(args.pre_self_positional) + self_arg + list(args.post_self_positional) - tensor_args = [arg for arg in args if isinstance(arg.type, model.BaseType) and arg.type.name == model.BaseTy.Tensor] + args = ( + list(args.pre_self_positional) + self_arg + list(args.post_self_positional) + ) + tensor_args = [ + arg + for arg in args + if isinstance(arg.type, model.BaseType) + and arg.type.name == model.BaseTy.Tensor + ] if len(tensor_args) != len(args): continue @@ -44,7 +58,7 @@ def gen_external(native_functions_path, tags_path, external_path): s = f"const at::Tensor& {arg.name} = tensors[{idx + 1}];" tensor_decls.append(s) arg_names[idx] = arg.name - nl = '\n' + nl = "\n" # print(tensor_decls, name, arg_names) func_decl = f"""\ @@ -72,27 +86,39 @@ const static RegisterNNCExternalFunction nnc_{name}( nnc_aten_{name});""" func_decls.append(func_decl) func_registrations.append(func_registration) - fm = FileManager(install_dir='.', template_dir='.', dry_run=False) - fm.write_with_template('external_functions_codegen.cpp', external_path, - lambda: {'external_registrations': func_registrations, 'external_functions': func_decls}) + fm = FileManager(install_dir=".", template_dir=".", dry_run=False) + fm.write_with_template( + "external_functions_codegen.cpp", + external_path, + lambda: { + "external_registrations": func_registrations, + "external_functions": func_decls, + }, + ) def main() -> None: - parser = argparse.ArgumentParser( - description='Generate annotated_fn_args script') - parser.add_argument('--native-functions', - '--native_functions', - help='path to native_functions.yaml', - default='../../../../aten/src/ATen/native/native_functions.yaml') - parser.add_argument('--tags', - help='path to tags.yaml', - default='../../../../aten/src/ATen/native/tags.yaml') - parser.add_argument('--template-path', - '--template_path', - help='path to external_functions_codegen_template.cpp', - default='../../../../tools/jit/templates/external_functions_codegen_template.cpp') + parser = argparse.ArgumentParser(description="Generate annotated_fn_args script") + parser.add_argument( + "--native-functions", + "--native_functions", + help="path to native_functions.yaml", + default="../../../../aten/src/ATen/native/native_functions.yaml", + ) + parser.add_argument( + "--tags", + help="path to tags.yaml", + default="../../../../aten/src/ATen/native/tags.yaml", + ) + parser.add_argument( + "--template-path", + "--template_path", + help="path to external_functions_codegen_template.cpp", + default="../../../../tools/jit/templates/external_functions_codegen_template.cpp", + ) args = parser.parse_args() gen_external(args.native_functions, args.tags, args.template_path) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/torch/csrc/jit/tensorexpr/scripts/bisect.py b/torch/csrc/jit/tensorexpr/scripts/bisect.py index c2549b4b38e1..4a99d66d10e4 100644 --- a/torch/csrc/jit/tensorexpr/scripts/bisect.py +++ b/torch/csrc/jit/tensorexpr/scripts/bisect.py @@ -1,4 +1,5 @@ import subprocess + import click diff --git a/torch/csrc/lazy/test_mnist.py b/torch/csrc/lazy/test_mnist.py index e5c0ecb12c77..e0ff82eed50a 100644 --- a/torch/csrc/lazy/test_mnist.py +++ b/torch/csrc/lazy/test_mnist.py @@ -1,13 +1,15 @@ +import os + import torch +import torch._lazy +import torch._lazy.metrics +import torch._lazy.ts_backend import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -import os -from torchvision import datasets, transforms from torch.optim.lr_scheduler import StepLR -import torch._lazy -import torch._lazy.ts_backend -import torch._lazy.metrics +from torchvision import datasets, transforms + torch._lazy.ts_backend.init() @@ -49,33 +51,39 @@ def train(log_interval, model, device, train_loader, optimizer, epoch): torch._lazy.mark_step() if batch_idx % log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) -if __name__ == '__main__': +if __name__ == "__main__": bsz = 64 - device = 'lazy' + device = "lazy" epochs = 14 log_interval = 10 lr = 1 gamma = 0.7 - train_kwargs = {'batch_size': bsz} + train_kwargs = {"batch_size": bsz} # if we want to use CUDA if "LTC_TS_CUDA" in os.environ: - cuda_kwargs = {'num_workers': 1, - 'pin_memory': True, - 'shuffle': True, - 'batch_size': bsz} + cuda_kwargs = { + "num_workers": 1, + "pin_memory": True, + "shuffle": True, + "batch_size": bsz, + } train_kwargs.update(cuda_kwargs) - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - dataset1 = datasets.MNIST('./data', train=True, download=True, - transform=transform) + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + dataset1 = datasets.MNIST("./data", train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) model = Net().to(device) optimizer = optim.Adadelta(model.parameters(), lr=lr) diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 98e0cb54e9f2..05820e321a95 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -9,23 +9,29 @@ It is lazily initialized, so you can always import it, and use """ import contextlib +import importlib import os import sys -import importlib -import torch -from torch.types import Device +import threading import traceback import warnings -import threading from functools import lru_cache -from typing import Any, List, Optional, Tuple, Union, cast -from ._utils import _get_device_index, _dummy_type -from .._utils import classproperty -from .graphs import CUDAGraph, graph_pool_handle, graph, \ - make_graphed_callables, is_current_stream_capturing -from .streams import ExternalStream, Stream, Event -from .. import device as _device +from typing import Any, cast, List, Optional, Tuple, Union + +import torch import torch._C +from torch.types import Device +from .. import device as _device +from .._utils import classproperty +from ._utils import _dummy_type, _get_device_index +from .graphs import ( + CUDAGraph, + graph, + graph_pool_handle, + is_current_stream_capturing, + make_graphed_callables, +) +from .streams import Event, ExternalStream, Stream try: from torch._C import _cudart # type: ignore[attr-defined] @@ -43,12 +49,12 @@ _HAS_PYNVML = False _PYNVML_ERR = None try: import pynvml # type: ignore[import] + _HAS_PYNVML = True except ImportError as err: _PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later - class _LazySeedTracker: # Since seeding is memory-less, only track the latest seed. # Note: `manual_seed_all` followed by `manual_seed` overwrites @@ -76,22 +82,25 @@ class _LazySeedTracker: _lazy_seed_tracker = _LazySeedTracker() # Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA -if hasattr(torch._C, '_CudaDeviceProperties'): +if hasattr(torch._C, "_CudaDeviceProperties"): _CudaDeviceProperties = torch._C._CudaDeviceProperties else: - _CudaDeviceProperties = _dummy_type('_CudaDeviceProperties') # type: ignore[assignment, misc] + _CudaDeviceProperties = _dummy_type("_CudaDeviceProperties") # type: ignore[assignment, misc] -if hasattr(torch._C, '_cuda_exchangeDevice'): +if hasattr(torch._C, "_cuda_exchangeDevice"): _exchange_device = torch._C._cuda_exchangeDevice else: + def _exchange_device(device: int) -> int: if device < 0: return -1 raise RuntimeError("PyTorch was compiled without CUDA support") -if hasattr(torch._C, '_cuda_maybeExchangeDevice'): + +if hasattr(torch._C, "_cuda_maybeExchangeDevice"): _maybe_exchange_device = torch._C._cuda_maybeExchangeDevice else: + def _maybe_exchange_device(device: int) -> int: if device < 0: return -1 @@ -103,12 +112,15 @@ has_magma: bool = False has_half: bool = False default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment] + def _is_compiled() -> bool: r"""Returns true if compile with CUDA support.""" - return hasattr(torch._C, '_cuda_getDeviceCount') + return hasattr(torch._C, "_cuda_getDeviceCount") + def _nvml_based_avail() -> bool: - return os.getenv('PYTORCH_NVML_BASED_CUDA_CHECK') == '1' + return os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1" + def is_available() -> bool: r"""Returns a bool indicating if CUDA is currently available.""" @@ -135,10 +147,14 @@ def is_bf16_supported(): cu_vers = torch.version.cuda if cu_vers is not None: - cuda_maj_decide = int(cu_vers.split('.')[0]) >= 11 + cuda_maj_decide = int(cu_vers.split(".")[0]) >= 11 else: cuda_maj_decide = False - return torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8 and cuda_maj_decide + return ( + torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8 + and cuda_maj_decide + ) + def _sleep(cycles): torch._C._cuda_sleep(cycles) @@ -166,9 +182,16 @@ def _check_capability(): minor = capability[1] name = get_device_name(d) current_arch = major * 10 + minor - min_arch = min((int(arch.split("_")[1]) for arch in torch.cuda.get_arch_list()), default=35) + min_arch = min( + (int(arch.split("_")[1]) for arch in torch.cuda.get_arch_list()), + default=35, + ) if current_arch < min_arch: - warnings.warn(old_gpu_warn % (d, name, major, minor, min_arch // 10, min_arch % 10)) + warnings.warn( + old_gpu_warn + % (d, name, major, minor, min_arch // 10, min_arch % 10) + ) + def _check_cubins(): incompatible_device_warn = """ @@ -181,7 +204,7 @@ If you want to use the {} GPU with PyTorch, please check the instructions at htt arch_list = get_arch_list() if len(arch_list) == 0: return - supported_sm = [int(arch.split('_')[1]) for arch in arch_list if 'sm_' in arch] + supported_sm = [int(arch.split("_")[1]) for arch in arch_list if "sm_" in arch] for idx in range(device_count()): cap_major, cap_minor = get_device_capability(idx) # NVIDIA GPU compute architectures are backward compatible within major version @@ -189,7 +212,11 @@ If you want to use the {} GPU with PyTorch, please check the instructions at htt if not supported: device_name = get_device_name(idx) capability = cap_major * 10 + cap_minor - warnings.warn(incompatible_device_warn.format(device_name, capability, " ".join(arch_list), device_name)) + warnings.warn( + incompatible_device_warn.format( + device_name, capability, " ".join(arch_list), device_name + ) + ) def is_initialized(): @@ -213,6 +240,7 @@ def _lazy_call(callable, **kwargs): # Don't store the actual traceback to avoid memory cycle _queued_calls.append((callable, traceback.format_stack())) + _lazy_call(_check_capability) _lazy_call(_check_cubins) @@ -220,8 +248,10 @@ _lazy_call(_check_cubins) class DeferredCudaCallError(Exception): pass + OutOfMemoryError = torch._C._OutOfMemoryError + def init(): r"""Initialize PyTorch's CUDA state. You may need to call this explicitly if you are interacting with PyTorch via @@ -237,7 +267,7 @@ def init(): def _lazy_init(): global _initialized, _queued_calls - if is_initialized() or hasattr(_tls, 'is_initializing'): + if is_initialized() or hasattr(_tls, "is_initializing"): return with _initialization_lock: # We be double-checked locking, boys! This is OK because @@ -253,16 +283,18 @@ def _lazy_init(): if _is_in_bad_fork(): raise RuntimeError( "Cannot re-initialize CUDA in forked subprocess. To use CUDA with " - "multiprocessing, you must use the 'spawn' start method") - if not hasattr(torch._C, '_cuda_getDeviceCount'): + "multiprocessing, you must use the 'spawn' start method" + ) + if not hasattr(torch._C, "_cuda_getDeviceCount"): raise AssertionError("Torch not compiled with CUDA enabled") if _cudart is None: raise AssertionError( - "libcudart functions unavailable. It looks like you have a broken build?") + "libcudart functions unavailable. It looks like you have a broken build?" + ) # This function throws if there's a driver initialization error, no GPUs # are found or any other error occurs - if 'CUDA_MODULE_LOADING' not in os.environ: - os.environ['CUDA_MODULE_LOADING'] = 'LAZY' + if "CUDA_MODULE_LOADING" not in os.environ: + os.environ["CUDA_MODULE_LOADING"] = "LAZY" torch._C._cuda_init() # Some of the queued calls may reentrantly call _lazy_init(); # we need to just return without initializing in that case. @@ -278,11 +310,13 @@ def _lazy_init(): try: queued_call() except Exception as e: - msg = (f"CUDA call failed lazily at initialization with error: {str(e)}\n\n" - f"CUDA call was originally invoked at:\n\n{orig_traceback}") + msg = ( + f"CUDA call failed lazily at initialization with error: {str(e)}\n\n" + f"CUDA call was originally invoked at:\n\n{orig_traceback}" + ) raise DeferredCudaCallError(msg) from e finally: - delattr(_tls, 'is_initializing') + delattr(_tls, "is_initializing") _initialized = True @@ -295,10 +329,11 @@ class cudaStatus: SUCCESS: int = 0 ERROR_NOT_READY: int = 34 + class CudaError(RuntimeError): def __init__(self, code: int) -> None: msg = _cudart.cudaGetErrorString(_cudart.cudaError(code)) - super().__init__(f'{msg} ({code})') + super().__init__(f"{msg} ({code})") def check_error(res: int) -> None: @@ -417,9 +452,9 @@ def get_device_properties(device: _device_t) -> _CudaDeviceProperties: raise AssertionError("Invalid device id") return _get_device_properties(device) # type: ignore[name-defined] + def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool: - r"""Checks if peer access between two devices is possible. - """ + r"""Checks if peer access between two devices is possible.""" _lazy_init() device = _get_device_index(device, optional=True) peer_device = _get_device_index(peer_device) @@ -441,17 +476,21 @@ class StreamContext: ``None``. .. note:: Streams are per-device. """ - cur_stream : Optional['torch.cuda.Stream'] + cur_stream: Optional["torch.cuda.Stream"] - def __init__(self, stream: Optional['torch.cuda.Stream']): + def __init__(self, stream: Optional["torch.cuda.Stream"]): self.stream = stream self.idx = _get_device_index(None, True) if not torch.jit.is_scripting(): if self.idx is None: self.idx = -1 - self.src_prev_stream = None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) - self.dst_prev_stream = None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) + self.src_prev_stream = ( + None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) + ) + self.dst_prev_stream = ( + None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) + ) def __enter__(self): # Local cur_stream variable for type refinement @@ -481,7 +520,8 @@ class StreamContext: torch.cuda.set_stream(self.dst_prev_stream) # type: ignore[arg-type] torch.cuda.set_stream(self.src_prev_stream) # type: ignore[arg-type] -def stream(stream: Optional['torch.cuda.Stream']) -> StreamContext: + +def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext: r"""Wrapper around the Context-manager StreamContext that selects a given stream. @@ -493,6 +533,7 @@ def stream(stream: Optional['torch.cuda.Stream']) -> StreamContext: """ return StreamContext(stream) + def set_stream(stream: Stream): r"""Sets the current stream.This is a wrapper API to set the stream. Usage of this function is discouraged in favor of the ``stream`` @@ -504,7 +545,11 @@ def set_stream(stream: Stream): """ if stream is None: return - torch._C._cuda_setStream(stream_id=stream.stream_id, device_index=stream.device_index, device_type=stream.device_type) + torch._C._cuda_setStream( + stream_id=stream.stream_id, + device_index=stream.device_index, + device_type=stream.device_type, + ) def _parse_visible_devices() -> Union[List[int], List[str]]: @@ -518,7 +563,7 @@ def _parse_visible_devices() -> Union[List[int], List[str]]: if not s: return -1 for idx, c in enumerate(s): - if not (c.isdigit() or (idx == 0 and c in '+-')): + if not (c.isdigit() or (idx == 0 and c in "+-")): break if idx + 1 == len(s): idx += 1 @@ -558,7 +603,8 @@ def _parse_visible_devices() -> Union[List[int], List[str]]: def _raw_device_count_nvml() -> int: """Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed.""" - from ctypes import CDLL, c_int, byref + from ctypes import byref, c_int, CDLL + nvml_h = CDLL("libnvidia-ml.so.1") rc = nvml_h.nvmlInit() if rc != 0: @@ -576,7 +622,8 @@ def _raw_device_count_nvml() -> int: def _raw_device_uuid_nvml() -> Optional[List[str]]: """Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed.""" - from ctypes import CDLL, c_int, c_void_p, create_string_buffer, byref + from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer + nvml_h = CDLL("libnvidia-ml.so.1") rc = nvml_h.nvmlInit() if rc != 0: @@ -600,7 +647,7 @@ def _raw_device_uuid_nvml() -> Optional[List[str]]: if rc != 0: warnings.warn("Can't get device UUID") return None - uuids.append(buf.raw.decode("ascii").strip('\0')) + uuids.append(buf.raw.decode("ascii").strip("\0")) del nvml_h return uuids @@ -608,6 +655,7 @@ def _raw_device_uuid_nvml() -> Optional[List[str]]: def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]: """Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs""" + def uuid_to_orinal(candidate: str, uuids: List[str]) -> int: best_match = -1 for idx, uuid in enumerate(uuids): @@ -646,7 +694,9 @@ def _device_count_nvml() -> int: uuids = _raw_device_uuid_nvml() if uuids is None: return -1 - visible_devices = _transform_uuid_to_ordinals(cast(List[str], visible_devices), uuids) + visible_devices = _transform_uuid_to_ordinals( + cast(List[str], visible_devices), uuids + ) else: raw_cnt = _raw_device_count_nvml() if raw_cnt <= 0: @@ -661,6 +711,7 @@ def _device_count_nvml() -> int: return -1 return len(visible_devices) + def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int: r"""Returns the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account.""" idx = _get_device_index(device, optional=True) @@ -669,12 +720,17 @@ def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int: uuids = _raw_device_uuid_nvml() if uuids is None: raise RuntimeError("Can't get device UUIDs") - visible_devices = _transform_uuid_to_ordinals(cast(List[str], visible_devices), uuids) + visible_devices = _transform_uuid_to_ordinals( + cast(List[str], visible_devices), uuids + ) idx_map = dict(enumerate(cast(List[int], visible_devices))) if idx not in idx_map: - raise RuntimeError(f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})") + raise RuntimeError( + f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})" + ) return idx_map[idx] + @lru_cache(maxsize=1) def device_count() -> int: r"""Returns the number of GPUs available.""" @@ -683,6 +739,7 @@ def device_count() -> int: nvml_count = _device_count_nvml() return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count + def get_arch_list() -> List[str]: r"""Returns list CUDA architectures this library was compiled for.""" if not is_available(): @@ -692,14 +749,19 @@ def get_arch_list() -> List[str]: return [] return arch_flags.split() + def get_gencode_flags() -> str: r"""Returns NVCC gencode flags this library was compiled with.""" arch_list = get_arch_list() if len(arch_list) == 0: return "" arch_list_ = [arch.split("_") for arch in arch_list] - return " ".join([f"-gencode compute=compute_{arch},code={kind}_{arch}" for (kind, arch) in arch_list_]) - + return " ".join( + [ + f"-gencode compute=compute_{arch},code={kind}_{arch}" + for (kind, arch) in arch_list_ + ] + ) def current_device() -> int: @@ -745,8 +807,11 @@ def current_stream(device: Optional[_device_t] = None) -> Stream: """ _lazy_init() streamdata = torch._C._cuda_getCurrentStream( - _get_device_index(device, optional=True)) - return Stream(stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]) + _get_device_index(device, optional=True) + ) + return Stream( + stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] + ) def default_stream(device: Optional[_device_t] = None) -> Stream: @@ -760,8 +825,11 @@ def default_stream(device: Optional[_device_t] = None) -> Stream: """ _lazy_init() streamdata = torch._C._cuda_getDefaultStream( - _get_device_index(device, optional=True)) - return Stream(stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]) + _get_device_index(device, optional=True) + ) + return Stream( + stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] + ) def current_blas_handle(): @@ -769,6 +837,7 @@ def current_blas_handle(): _lazy_init() return torch._C._cuda_getCurrentBlasHandle() + def set_sync_debug_mode(debug_mode: Union[int, str]) -> None: r"""Sets the debug mode for cuda synchronizing operations. @@ -790,10 +859,13 @@ def set_sync_debug_mode(debug_mode: Union[int, str]) -> None: elif debug_mode == "error": debug_mode = 2 else: - raise RuntimeError("invalid value of debug_mode, expected one of `default`, `warn`, `error`") + raise RuntimeError( + "invalid value of debug_mode, expected one of `default`, `warn`, `error`" + ) torch._C._cuda_set_sync_debug_mode(debug_mode) + def get_sync_debug_mode() -> int: r"""Returns current value of debug mode for cuda synchronizing operations.""" @@ -803,8 +875,11 @@ def get_sync_debug_mode() -> int: def _get_pynvml_handler(device: Optional[Union[Device, int]] = None): if not _HAS_PYNVML: - raise ModuleNotFoundError("pynvml does not seem to be installed or it can't be imported.") from _PYNVML_ERR + raise ModuleNotFoundError( + "pynvml does not seem to be installed or it can't be imported." + ) from _PYNVML_ERR from pynvml import NVMLError_DriverNotLoaded + try: pynvml.nvmlInit() except NVMLError_DriverNotLoaded as e: @@ -814,6 +889,7 @@ def _get_pynvml_handler(device: Optional[Union[Device, int]] = None): handle = pynvml.nvmlDeviceGetHandleByIndex(device) return handle + def memory_usage(device: Optional[Union[Device, int]] = None) -> int: r"""Returns the percent of time over the past sample period during which global (device) memory was being read or written. as given by `nvidia-smi`. @@ -851,6 +927,7 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int: handle = pynvml.nvmlDeviceGetHandleByIndex(device) return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu + def temperature(device: Optional[Union[Device, int]] = None) -> int: r"""Returns the average temperature of the GPU sensor in Degrees C (Centigrades) over the past sample period as given by `nvidia-smi`. @@ -867,6 +944,7 @@ def temperature(device: Optional[Union[Device, int]] = None) -> int: # 0 refers to the temperature sensor for the GPU die. return pynvml.nvmlDeviceGetTemperature(handle, 0) + def power_draw(device: Optional[Union[Device, int]] = None) -> int: r"""Returns the average power draw of the GPU sensor in mW (MilliWatts) over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices. @@ -882,6 +960,7 @@ def power_draw(device: Optional[Union[Device, int]] = None) -> int: handle = _get_pynvml_handler(device) return pynvml.nvmlDeviceGetPowerUsage(handle) + def clock_rate(device: Optional[Union[Device, int]] = None) -> int: r"""Returns the clock speed of the GPU SM in Hz Hertz over the past sample period as given by `nvidia-smi`. @@ -897,8 +976,6 @@ def clock_rate(device: Optional[Union[Device, int]] = None) -> int: return pynvml.nvmlDeviceGetClockInfo(handle, 1) - - def _get_device(device: Union[int, str, torch.device]) -> torch.device: r"""Return the torch.device type object from the passed in device. @@ -908,7 +985,7 @@ def _get_device(device: Union[int, str, torch.device]) -> torch.device: if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): - device = torch.device('cuda', device) + device = torch.device("cuda", device) return device @@ -925,7 +1002,9 @@ def _get_generator(device: torch.device) -> torch._C.Generator: return torch.cuda.default_generators[idx] -def _set_rng_state_offset(offset: int, device: Union[int, str, torch.device] = 'cuda') -> None: +def _set_rng_state_offset( + offset: int, device: Union[int, str, torch.device] = "cuda" +) -> None: r"""Sets the random number generator state offset of the specified GPU. Args: @@ -941,7 +1020,8 @@ def _set_rng_state_offset(offset: int, device: Union[int, str, torch.device] = ' _lazy_call(cb) -def _get_rng_state_offset(device: Union[int, str, torch.device] = 'cuda') -> int: + +def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int: r"""Returns the random number generator state offset of the specified GPU. Args: @@ -966,6 +1046,7 @@ from .random import * # noqa: F403 # Define Storage and Tensor classes ################################################################################ + @staticmethod # type: ignore[misc] def _lazy_new(cls, *args, **kwargs): _lazy_init() @@ -987,21 +1068,24 @@ class _CudaBase: __new__ = _lazy_new + from torch.storage import _LegacyStorage, _warn_typed_storage_removal + class _CudaLegacyStorage(_LegacyStorage): @classmethod def from_buffer(cls, *args, **kwargs): _warn_typed_storage_removal() - raise RuntimeError('from_buffer: Not available for CUDA storage') + raise RuntimeError("from_buffer: Not available for CUDA storage") @classmethod def _new_with_weak_ptr(cls, *args, **kwargs): - raise RuntimeError('_new_with_weak_ptr: Not available for CUDA storage') + raise RuntimeError("_new_with_weak_ptr: Not available for CUDA storage") @classmethod def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None): - raise RuntimeError('_new_shared_filename: Not available for CUDA storage') + raise RuntimeError("_new_shared_filename: Not available for CUDA storage") + class ByteStorage(_CudaLegacyStorage): @classproperty @@ -1013,6 +1097,7 @@ class ByteStorage(_CudaLegacyStorage): def _dtype(self): return torch.uint8 + class DoubleStorage(_CudaLegacyStorage): @classproperty def dtype(self): @@ -1023,6 +1108,7 @@ class DoubleStorage(_CudaLegacyStorage): def _dtype(self): return torch.double + class FloatStorage(_CudaLegacyStorage): @classproperty def dtype(self): @@ -1033,6 +1119,7 @@ class FloatStorage(_CudaLegacyStorage): def _dtype(self): return torch.float + class HalfStorage(_CudaLegacyStorage): @classproperty def dtype(self): @@ -1043,6 +1130,7 @@ class HalfStorage(_CudaLegacyStorage): def _dtype(self): return torch.half + class LongStorage(_CudaLegacyStorage): @classproperty def dtype(self): @@ -1053,6 +1141,7 @@ class LongStorage(_CudaLegacyStorage): def _dtype(self): return torch.long + class IntStorage(_CudaLegacyStorage): @classproperty def dtype(self): @@ -1063,6 +1152,7 @@ class IntStorage(_CudaLegacyStorage): def _dtype(self): return torch.int + class ShortStorage(_CudaLegacyStorage): @classproperty def dtype(self): @@ -1073,6 +1163,7 @@ class ShortStorage(_CudaLegacyStorage): def _dtype(self): return torch.short + class CharStorage(_CudaLegacyStorage): @classproperty def dtype(self): @@ -1083,6 +1174,7 @@ class CharStorage(_CudaLegacyStorage): def _dtype(self): return torch.int8 + class BoolStorage(_CudaLegacyStorage): @classproperty def dtype(self): @@ -1093,6 +1185,7 @@ class BoolStorage(_CudaLegacyStorage): def _dtype(self): return torch.bool + class BFloat16Storage(_CudaLegacyStorage): @classproperty def dtype(self): @@ -1103,6 +1196,7 @@ class BFloat16Storage(_CudaLegacyStorage): def _dtype(self): return torch.bfloat16 + class ComplexDoubleStorage(_CudaLegacyStorage): @classproperty def dtype(self): @@ -1113,6 +1207,7 @@ class ComplexDoubleStorage(_CudaLegacyStorage): def _dtype(self): return torch.cdouble + class ComplexFloatStorage(_CudaLegacyStorage): @classproperty def dtype(self): @@ -1123,6 +1218,7 @@ class ComplexFloatStorage(_CudaLegacyStorage): def _dtype(self): return torch.cfloat + del _LegacyStorage del _CudaLegacyStorage @@ -1141,8 +1237,7 @@ torch._storage_classes.add(ComplexFloatStorage) class _WrappedTritonKernel: - """ Just a simple wrapper to store some metadata for testing purposes. - """ + """Just a simple wrapper to store some metadata for testing purposes.""" def __init__(self, kernel): self.kernel = kernel @@ -1161,6 +1256,7 @@ def _register_triton_kernels(): @_WrappedTritonKernel def kernel_impl(*args, **kwargs): from torch.sparse._triton_ops import bsr_dense_mm + return bsr_dense_mm(*args, skip_checks=True, **kwargs) has_triton = importlib.util.find_spec("triton") is not None @@ -1169,44 +1265,126 @@ def _register_triton_kernels(): "_triton_bsr_dense_mm_out", "_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)", kernel_impl, - "SparseCsrCUDA" + "SparseCsrCUDA", ) _lazy_call(_register_triton_kernels) -from . import sparse -from . import profiler -from . import nvtx -from . import amp -from . import jiterator +from . import amp, jiterator, nvtx, profiler, sparse __all__ = [ # Typed storage and tensors - 'BFloat16Storage', 'BFloat16Tensor', - 'BoolStorage', 'BoolTensor', - 'ByteStorage', 'ByteTensor', - 'CharStorage', 'CharTensor', - 'ComplexDoubleStorage', 'ComplexFloatStorage', - 'DoubleStorage', 'DoubleTensor', - 'FloatStorage', 'FloatTensor', - 'HalfStorage', 'HalfTensor', - 'IntStorage', 'IntTensor', - 'LongStorage', 'LongTensor', - 'ShortStorage', 'ShortTensor', - 'CUDAGraph', 'CudaError', 'DeferredCudaCallError', 'Event', 'ExternalStream', 'OutOfMemoryError', - 'Stream', 'StreamContext', 'amp', 'caching_allocator_alloc', 'caching_allocator_delete', 'can_device_access_peer', - 'check_error', 'cudaStatus', 'cudart', 'current_blas_handle', 'current_device', 'current_stream', 'default_generators', - 'default_stream', 'device', 'device_count', 'device_of', 'empty_cache', 'get_allocator_backend', 'CUDAPluggableAllocator', - 'change_current_allocator', 'get_arch_list', 'get_device_capability', 'get_device_name', 'get_device_properties', - 'get_gencode_flags', 'get_rng_state', 'get_rng_state_all', 'get_sync_debug_mode', 'graph', 'graph_pool_handle', 'graphs', - 'has_half', 'has_magma', 'init', 'initial_seed', 'ipc_collect', 'is_available', 'is_bf16_supported', - 'is_current_stream_capturing', 'is_initialized', 'jiterator', 'list_gpu_processes', 'make_graphed_callables', - 'manual_seed', 'manual_seed_all', 'max_memory_allocated', 'max_memory_cached', 'max_memory_reserved', - 'mem_get_info', 'memory', 'memory_allocated', 'memory_cached', 'memory_reserved', 'memory_snapshot', - 'memory_stats', 'memory_stats_as_nested_dict', 'memory_summary', 'memory_usage', 'temperature', 'power_draw', - 'clock_rate', 'nccl', 'nvtx', 'profiler', 'random', 'reset_accumulated_memory_stats', 'reset_max_memory_allocated', - 'reset_max_memory_cached', 'reset_peak_memory_stats', 'seed', 'seed_all', 'set_device', 'set_per_process_memory_fraction', - 'set_rng_state', 'set_rng_state_all', 'set_stream', 'set_sync_debug_mode', 'sparse', 'stream', 'streams', - 'synchronize', 'utilization'] + "BFloat16Storage", + "BFloat16Tensor", + "BoolStorage", + "BoolTensor", + "ByteStorage", + "ByteTensor", + "CharStorage", + "CharTensor", + "ComplexDoubleStorage", + "ComplexFloatStorage", + "DoubleStorage", + "DoubleTensor", + "FloatStorage", + "FloatTensor", + "HalfStorage", + "HalfTensor", + "IntStorage", + "IntTensor", + "LongStorage", + "LongTensor", + "ShortStorage", + "ShortTensor", + "CUDAGraph", + "CudaError", + "DeferredCudaCallError", + "Event", + "ExternalStream", + "OutOfMemoryError", + "Stream", + "StreamContext", + "amp", + "caching_allocator_alloc", + "caching_allocator_delete", + "can_device_access_peer", + "check_error", + "cudaStatus", + "cudart", + "current_blas_handle", + "current_device", + "current_stream", + "default_generators", + "default_stream", + "device", + "device_count", + "device_of", + "empty_cache", + "get_allocator_backend", + "CUDAPluggableAllocator", + "change_current_allocator", + "get_arch_list", + "get_device_capability", + "get_device_name", + "get_device_properties", + "get_gencode_flags", + "get_rng_state", + "get_rng_state_all", + "get_sync_debug_mode", + "graph", + "graph_pool_handle", + "graphs", + "has_half", + "has_magma", + "init", + "initial_seed", + "ipc_collect", + "is_available", + "is_bf16_supported", + "is_current_stream_capturing", + "is_initialized", + "jiterator", + "list_gpu_processes", + "make_graphed_callables", + "manual_seed", + "manual_seed_all", + "max_memory_allocated", + "max_memory_cached", + "max_memory_reserved", + "mem_get_info", + "memory", + "memory_allocated", + "memory_cached", + "memory_reserved", + "memory_snapshot", + "memory_stats", + "memory_stats_as_nested_dict", + "memory_summary", + "memory_usage", + "temperature", + "power_draw", + "clock_rate", + "nccl", + "nvtx", + "profiler", + "random", + "reset_accumulated_memory_stats", + "reset_max_memory_allocated", + "reset_max_memory_cached", + "reset_peak_memory_stats", + "seed", + "seed_all", + "set_device", + "set_per_process_memory_fraction", + "set_rng_state", + "set_rng_state_all", + "set_stream", + "set_sync_debug_mode", + "sparse", + "stream", + "streams", + "synchronize", + "utilization", +] diff --git a/torch/cuda/_sanitizer.py b/torch/cuda/_sanitizer.py index 54e5b1a6c018..647dde280a5a 100644 --- a/torch/cuda/_sanitizer.py +++ b/torch/cuda/_sanitizer.py @@ -168,7 +168,7 @@ class _TensorsAccessed: "Found tensor with pointer: %s, but no matching tensor " "allocation in the trace. Backfilling the trace now. " "Perhaps the sanitizer was enabled after some torch operations?", - data_ptr + data_ptr, ) self.create_tensor(data_ptr, None) @@ -179,7 +179,7 @@ class _TensorsAccessed: "pointer: %s. Assuming the trace for tensor deallocation " "wasn't caught and backfilling it now. " "Perhaps the sanitizer was enabled after some torch operations?", - data_ptr + data_ptr, ) self.delete_tensor(data_ptr) @@ -226,7 +226,7 @@ class StreamSynchronizations: "Found Stream with id: %s, but no matching stream " "creation in the trace. Backfilling the trace now. " "Perhaps the sanitizer was enabled after some torch operations?", - stream + stream, ) self.create_stream(stream) @@ -236,7 +236,7 @@ class StreamSynchronizations: "Found Event with id: %s, but no matching event " "creation in the trace. Backfilling the trace now. " "Perhaps the sanitizer was enabled after some torch operations?", - event + event, ) self.create_event(event) @@ -247,7 +247,7 @@ class StreamSynchronizations: "id: %s. Assuming the trace for event deletion wasn't caught " "and backfilling it now. " "Perhaps the sanitizer was enabled after some torch operations?", - event + event, ) self.delete_event(event) @@ -257,7 +257,7 @@ class StreamSynchronizations: "Found duplicate Stream creation in the trace for Stream with " "id: %s. PyTorch Streams are only created once, so this " "trace entry is ignored.", - stream + stream, ) else: self.host_sync_state[stream] = 0 diff --git a/torch/cuda/_utils.py b/torch/cuda/_utils.py index 8a40767c662b..5cdb3f877089 100644 --- a/torch/cuda/_utils.py +++ b/torch/cuda/_utils.py @@ -1,11 +1,14 @@ -import torch from typing import Any + +import torch + # The _get_device_index has been moved to torch.utils._get_device_index from torch._utils import _get_device_index as _torch_get_device_index -def _get_device_index(device: Any, optional: bool = False, - allow_cpu: bool = False) -> int: +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: r"""Gets the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. @@ -26,10 +29,10 @@ def _get_device_index(device: Any, optional: bool = False, device = torch.device(device) if isinstance(device, torch.device): if allow_cpu: - if device.type not in ['cuda', 'cpu']: - raise ValueError(f'Expected a cuda or cpu device, but got: {device}') - elif device.type != 'cuda': - raise ValueError(f'Expected a cuda device, but got: {device}') + if device.type not in ["cuda", "cpu"]: + raise ValueError(f"Expected a cuda or cpu device, but got: {device}") + elif device.type != "cuda": + raise ValueError(f"Expected a cuda device, but got: {device}") if not torch.jit.is_scripting(): if isinstance(device, torch.cuda.device): return device.idx @@ -43,7 +46,10 @@ def _dummy_type(name: str) -> type: class_name = obj.__class__.__name__ else: class_name = obj.__name__ - raise RuntimeError( - f"Tried to instantiate dummy base class {class_name}") + raise RuntimeError(f"Tried to instantiate dummy base class {class_name}") + return err_fn - return type(name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}) + + return type( + name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)} + ) diff --git a/torch/cuda/amp/__init__.py b/torch/cuda/amp/__init__.py index 1c0ecd088765..d6b7868b1250 100644 --- a/torch/cuda/amp/__init__.py +++ b/torch/cuda/amp/__init__.py @@ -1,2 +1,2 @@ -from .autocast_mode import autocast, custom_fwd, custom_bwd # noqa: F401 +from .autocast_mode import autocast, custom_bwd, custom_fwd # noqa: F401 from .grad_scaler import GradScaler # noqa: F401 diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index 566b9ff76ebf..6a03345f8b88 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -1,8 +1,11 @@ -import torch -import functools import collections +import functools + +import torch + try: import numpy as np + HAS_NUMPY = True except ModuleNotFoundError: np = None # type: ignore[assignment] @@ -10,19 +13,27 @@ from typing import Any __all__ = ["autocast", "custom_fwd", "custom_bwd"] + class autocast(torch.amp.autocast_mode.autocast): r""" See :class:`torch.autocast`. ``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)`` """ - def __init__(self, enabled : bool = True, dtype : torch.dtype = torch.float16, cache_enabled : bool = True): + def __init__( + self, + enabled: bool = True, + dtype: torch.dtype = torch.float16, + cache_enabled: bool = True, + ): if torch._jit_internal.is_scripting(): self._enabled = enabled self.device = "cuda" self.fast_dtype = dtype return - super().__init__("cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) + super().__init__( + "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled + ) def __enter__(self): if torch._jit_internal.is_scripting(): @@ -45,7 +56,11 @@ class autocast(torch.amp.autocast_mode.autocast): # may be falsely detected as "Iterables." def _cast(value, dtype): if isinstance(value, torch.Tensor): - is_eligible = (value.is_floating_point() and value.is_cuda and (value.dtype is not torch.float64)) + is_eligible = ( + value.is_floating_point() + and value.is_cuda + and (value.dtype is not torch.float64) + ) return value.to(dtype) if is_eligible else value elif isinstance(value, (str, bytes)): return value @@ -104,6 +119,7 @@ def custom_fwd(fwd=None, *, cast_inputs=None): return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs)) else: return fwd(*args, **kwargs) + return decorate_fwd @@ -117,8 +133,10 @@ def custom_bwd(bwd): Ensures that ``backward`` executes with the same autocast state as ``forward``. See the :ref:`example page` for more detail. """ + @functools.wraps(bwd) def decorate_bwd(*args, **kwargs): with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype): return bwd(*args, **kwargs) + return decorate_bwd diff --git a/torch/cuda/amp/common.py b/torch/cuda/amp/common.py index d0c1e3c04d1c..c4e8c1cc99b0 100644 --- a/torch/cuda/amp/common.py +++ b/torch/cuda/amp/common.py @@ -1,7 +1,9 @@ -import torch from importlib.util import find_spec +import torch + __all__ = ["amp_definitely_not_available"] + def amp_definitely_not_available(): - return not (torch.cuda.is_available() or find_spec('torch_xla')) + return not (torch.cuda.is_available() or find_spec("torch_xla")) diff --git a/torch/cuda/amp/grad_scaler.py b/torch/cuda/amp/grad_scaler.py index c3a05db71d53..78b41cf47095 100644 --- a/torch/cuda/amp/grad_scaler.py +++ b/torch/cuda/amp/grad_scaler.py @@ -1,8 +1,8 @@ -from collections import defaultdict, abc -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, cast import inspect import warnings +from collections import abc, defaultdict +from enum import Enum +from typing import Any, cast, Dict, List, Optional, Tuple import torch from .common import amp_definitely_not_available @@ -10,12 +10,14 @@ from .common import amp_definitely_not_available __all__ = ["OptState", "GradScaler"] + class _MultiDeviceReplicator: """ Lazily serves copies of a tensor to requested devices. Copies are cached per-device. """ + def __init__(self, master_tensor: torch.Tensor) -> None: - assert master_tensor.is_cuda or master_tensor.device.type == 'xla' + assert master_tensor.is_cuda or master_tensor.device.type == "xla" self.master = master_tensor self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} @@ -110,14 +112,19 @@ class GradScaler: invokes the underlying ``optimizer.step()``, and other methods become no-ops. Default: ``True`` """ - def __init__(self, - init_scale=2.**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True): + + def __init__( + self, + init_scale=2.0**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True, + ): if enabled and amp_definitely_not_available(): - warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") + warnings.warn( + "torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling." + ) self._enabled = False else: self._enabled = enabled @@ -137,16 +144,24 @@ class GradScaler: self._growth_tracker = None self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]: + def _check_scale_growth_tracker( + self, funcname + ) -> Tuple[torch.Tensor, torch.Tensor]: fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." - assert self._scale is not None, f"Attempted {funcname} but _scale is None. " + fix - assert self._growth_tracker is not None, f"Attempted {funcname} but _growth_tracker is None. " + fix + assert self._scale is not None, ( + f"Attempted {funcname} but _scale is None. " + fix + ) + assert self._growth_tracker is not None, ( + f"Attempted {funcname} but _growth_tracker is None. " + fix + ) return (self._scale, self._growth_tracker) def _lazy_init_scale_growth_tracker(self, dev): assert self._growth_tracker is None, "_growth_tracker initialized before _scale" self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev) - self._growth_tracker = torch.full((), self._init_growth_tracker, dtype=torch.int32, device=dev) + self._growth_tracker = torch.full( + (), self._init_growth_tracker, dtype=torch.int32, device=dev + ) def scale(self, outputs): """ @@ -163,18 +178,20 @@ class GradScaler: # Short-circuit for the common case. if isinstance(outputs, torch.Tensor): - assert outputs.is_cuda or outputs.device.type == 'xla' + assert outputs.is_cuda or outputs.device.type == "xla" if self._scale is None: self._lazy_init_scale_growth_tracker(outputs.device) assert self._scale is not None return outputs * self._scale.to(device=outputs.device, non_blocking=True) # Invoke the more complex machinery only if we're treating multiple outputs. - stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale + stash: List[ + _MultiDeviceReplicator + ] = [] # holds a reference that can be overwritten by apply_scale def apply_scale(val): if isinstance(val, torch.Tensor): - assert val.is_cuda or val.device.type == 'xla' + assert val.is_cuda or val.device.type == "xla" if len(stash) == 0: if self._scale is None: self._lazy_init_scale_growth_tracker(val.device) @@ -222,13 +239,17 @@ class GradScaler: to_unscale = param.grad # TODO: is there a way to split by device and dtype without appending in the inner loop? - per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale) + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) for device, per_dtype_grads in per_device_and_dtype_grads.items(): for grads in per_dtype_grads.values(): - torch._amp_foreach_non_finite_check_and_unscale_(grads, - per_device_found_inf.get(device), - per_device_inv_scale.get(device)) + torch._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device), + ) return per_device_found_inf._per_device_tensors @@ -272,7 +293,9 @@ class GradScaler: optimizer_state = self._per_optimizer_states[id(optimizer)] if optimizer_state["stage"] is OptState.UNSCALED: - raise RuntimeError("unscale_() has already been called on this optimizer since the last update().") + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) elif optimizer_state["stage"] is OptState.STEPPED: raise RuntimeError("unscale_() is being called after step().") @@ -281,7 +304,9 @@ class GradScaler: inv_scale = self._scale.double().reciprocal().float() found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device) - optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, False + ) optimizer_state["stage"] = OptState.UNSCALED def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): @@ -311,22 +336,29 @@ class GradScaler: .. warning:: Closure use is not currently supported. """ - if (not self._enabled): + if not self._enabled: return optimizer.step(*args, **kwargs) if "closure" in kwargs: - raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") + raise RuntimeError( + "Closure use is not currently supported if GradScaler is enabled." + ) self._check_scale_growth_tracker("step") optimizer_state = self._per_optimizer_states[id(optimizer)] if optimizer_state["stage"] is OptState.STEPPED: - raise RuntimeError("step() has already been called since the last update().") + raise RuntimeError( + "step() has already been called since the last update()." + ) retval = None - if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + if ( + hasattr(optimizer, "_step_supports_amp_scaling") + and optimizer._step_supports_amp_scaling + ): # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. # The contract with custom optimizers is that their step() should accept an additional, # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: @@ -339,13 +371,16 @@ class GradScaler: # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, # while the method is expected to be called by users side, i.e. their optimizers. kwargs_ = kwargs - has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters + has_grad_scaler_kwarg = ( + "grad_scaler" in inspect.signature(optimizer.step).parameters + ) if has_grad_scaler_kwarg: warnings.warn( "GradScaler is going to stop passing itself as a keyword argument to the passed " "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", - FutureWarning) + FutureWarning, + ) kwargs_.update({"grad_scaler": self}) else: if optimizer_state["stage"] is OptState.READY: @@ -353,11 +388,16 @@ class GradScaler: scaler = self._get_scale_async() found_inf = cast( torch.Tensor, - sum([ - t.to(scaler.device, non_blocking=True) for t in optimizer_state["found_inf_per_device"].values() - ]) + sum( + [ + t.to(scaler.device, non_blocking=True) + for t in optimizer_state["found_inf_per_device"].values() + ] + ), + ) + optimizer.grad_scale = ( + None if optimizer_state["stage"] == OptState.UNSCALED else scaler ) - optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler optimizer.found_inf = found_inf retval = optimizer.step(*args, **kwargs_) optimizer_state["stage"] = OptState.STEPPED @@ -369,7 +409,9 @@ class GradScaler: if optimizer_state["stage"] is OptState.READY: self.unscale_(optimizer) - assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer." + assert ( + len(optimizer_state["found_inf_per_device"]) > 0 + ), "No inf checks were recorded for this optimizer." retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) @@ -421,9 +463,11 @@ class GradScaler: else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [found_inf.to(device=_scale.device, non_blocking=True) - for state in self._per_optimizer_states.values() - for found_inf in state["found_inf_per_device"].values()] + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] assert len(found_infs) > 0, "No inf checks were recorded prior to update." @@ -432,12 +476,14 @@ class GradScaler: for i in range(1, len(found_infs)): found_inf_combined += found_infs[i] - torch._amp_update_scale_(_scale, - _growth_tracker, - found_inf_combined, - self._growth_factor, - self._backoff_factor, - self._growth_interval) + torch._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) @@ -453,7 +499,11 @@ class GradScaler: :meth:`get_scale` incurs a CPU-GPU sync. """ if self._enabled: - return self._init_scale if self._scale is None else self._get_scale_async().item() + return ( + self._init_scale + if self._scale is None + else self._get_scale_async().item() + ) else: return 1.0 @@ -498,7 +548,11 @@ class GradScaler: def _get_growth_tracker(self): if self._enabled: - return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() + return ( + self._init_growth_tracker + if self._growth_tracker is None + else self._growth_tracker.item() + ) else: return 0 @@ -524,11 +578,17 @@ class GradScaler: If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` should be called after :meth:`update`. """ - return {"scale": self.get_scale(), + return ( + { + "scale": self.get_scale(), "growth_factor": self._growth_factor, "backoff_factor": self._backoff_factor, "growth_interval": self._growth_interval, - "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} + "_growth_tracker": self._get_growth_tracker(), + } + if self._enabled + else {} + ) def load_state_dict(self, state_dict): r""" @@ -541,8 +601,10 @@ class GradScaler: return if len(state_dict) == 0: - raise RuntimeError("The source state dict is empty, possibly because it was saved " - "from a disabled instance of GradScaler.") + raise RuntimeError( + "The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler." + ) self._init_scale = state_dict["scale"] if self._scale is not None: @@ -557,15 +619,17 @@ class GradScaler: def __getstate__(self): state = self.__dict__.copy() if self._enabled: - assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ - "of an iteration, or at the end after scaler.update()." + assert len(self._per_optimizer_states) == 0, ( + "A GradScaler instance may only be pickled at the beginning " + "of an iteration, or at the end after scaler.update()." + ) # Pickling _scale and _growth_tracker Tensors directly triggers # "warnings.warn("pickle support for Storage will be removed in 1.5..." # so instead, we set the unpickled instance up to reinitialize them lazily. - state['_init_scale'] = self.get_scale() - state['_init_growth_tracker'] = self._get_growth_tracker() - state['_scale'] = None - state['_growth_tracker'] = None + state["_init_scale"] = self.get_scale() + state["_init_growth_tracker"] = self._get_growth_tracker() + state["_scale"] = None + state["_growth_tracker"] = None return state def __setstate__(self, state): @@ -577,8 +641,9 @@ class GradScaler: dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device) found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device) - self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ - self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + self._per_optimizer_states[id(optimizer)][ + "found_inf_per_device" + ] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/torch/cuda/comm.py b/torch/cuda/comm.py index 557ffb0c0de4..2ea23c2072d8 100644 --- a/torch/cuda/comm.py +++ b/torch/cuda/comm.py @@ -1,5 +1,18 @@ # The functions here have been moved to torch.nn.parallel.comm -from torch.nn.parallel.comm import broadcast, broadcast_coalesced, reduce_add, \ - reduce_add_coalesced, scatter, gather +from torch.nn.parallel.comm import ( + broadcast, + broadcast_coalesced, + gather, + reduce_add, + reduce_add_coalesced, + scatter, +) -__all__ = ['broadcast', 'broadcast_coalesced', 'reduce_add', 'reduce_add_coalesced', 'scatter', 'gather'] +__all__ = [ + "broadcast", + "broadcast_coalesced", + "reduce_add", + "reduce_add_coalesced", + "scatter", + "gather", +] diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index 865a07dc2709..144959f5340d 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -1,19 +1,26 @@ import gc + import torch +from torch.utils._pytree import ( + tree_flatten as _tree_flatten, + tree_unflatten as _tree_unflatten, +) from ._utils import _dummy_type -from torch.utils._pytree import tree_flatten as _tree_flatten -from torch.utils._pytree import tree_unflatten as _tree_unflatten -if not hasattr(torch._C, '_CudaStreamBase'): +if not hasattr(torch._C, "_CudaStreamBase"): # Define dummy base classes - torch._C.__dict__['_CUDAGraph'] = _dummy_type('_CUDAGraph') - torch._C.__dict__['_graph_pool_handle'] = _dummy_type('_graph_pool_handle') - torch._C.__dict__['_cuda_isCurrentStreamCapturing'] = _dummy_type('_cuda_isCurrentStreamCapturing') + torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph") + torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle") + torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type( + "_cuda_isCurrentStreamCapturing" + ) -from torch._C import _CUDAGraph # noqa: F401 -from torch._C import _graph_pool_handle -from torch._C import _cuda_isCurrentStreamCapturing +from torch._C import ( # noqa: F401 + _cuda_isCurrentStreamCapturing, + _CUDAGraph, + _graph_pool_handle, +) def is_current_stream_capturing(): @@ -24,6 +31,7 @@ def is_current_stream_capturing(): """ return _cuda_isCurrentStreamCapturing() + # Python shim helps Sphinx process docstrings more reliably. def graph_pool_handle(): r""" @@ -44,6 +52,7 @@ class CUDAGraph(torch._C._CUDAGraph): .. warning:: This API is in beta and may change in future releases. """ + def __new__(cls): return super().__new__(cls) @@ -140,10 +149,7 @@ class graph: """ default_capture_stream = None - def __init__(self, - cuda_graph, - pool=None, - stream=None): + def __init__(self, cuda_graph, pool=None, stream=None): # Lazy-init of default_capture_stream helps avoid circular-import errors. # Not thread safe, but graphs already have the general (explicitly documented) # restriction that only one capture may be underway at a time in the process. @@ -151,7 +157,9 @@ class graph: self.__class__.default_capture_stream = torch.cuda.Stream() self.pool = () if pool is None else (pool,) - self.capture_stream = stream if stream is not None else self.__class__.default_capture_stream + self.capture_stream = ( + stream if stream is not None else self.__class__.default_capture_stream + ) assert self.capture_stream is not None self.stream_ctx = torch.cuda.stream(self.capture_stream) self.cuda_graph = cuda_graph @@ -168,14 +176,15 @@ class graph: self.cuda_graph.capture_begin(*self.pool) - def __exit__(self, exc_type, exc_value, traceback): self.cuda_graph.capture_end() self.stream_ctx.__exit__(exc_type, exc_value, traceback) # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() -def make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unused_input=False): +def make_graphed_callables( + callables, sample_args, num_warmup_iters=3, allow_unused_input=False +): r""" Accepts callables (functions or :class:`nn.Module`\ s) and returns graphed versions. @@ -243,7 +252,9 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unu caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`. """ if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): - raise RuntimeError("make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`.") + raise RuntimeError( + "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`." + ) just_one_callable = False @@ -256,25 +267,37 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unu for c, args in zip(callables, sample_args): if isinstance(c, torch.nn.Module): - assert len(c._backward_hooks) == 0 and len(c._forward_hooks) == 0 and len(c._forward_pre_hooks) == 0, \ - "Modules must not have hooks registered at the time they are passed. However, registering hooks " + \ - "on modules after passing them through make_graphed_callables is allowed." - assert all(b.requires_grad is False for b in c.buffers()), "In any :class:`~torch.nn.Module` passed to " + \ - ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " + \ - "``requires_grad=False``." + assert ( + len(c._backward_hooks) == 0 + and len(c._forward_hooks) == 0 + and len(c._forward_pre_hooks) == 0 + ), ( + "Modules must not have hooks registered at the time they are passed. However, registering hooks " + + "on modules after passing them through make_graphed_callables is allowed." + ) + assert all(b.requires_grad is False for b in c.buffers()), ( + "In any :class:`~torch.nn.Module` passed to " + + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " + + "``requires_grad=False``." + ) flatten_arg, _ = _tree_flatten(args) flatten_sample_args.append(tuple(flatten_arg)) - assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), "In the beta API, sample_args " + \ - "for each callable must contain only Tensors. Other types are not allowed." - + assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( + "In the beta API, sample_args " + + "for each callable must contain only Tensors. Other types are not allowed." + ) # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly # passes to forward (ie, its sample_args) AND the module's parameter attributes. per_callable_len_user_args = [len(args) for args in flatten_sample_args] - per_callable_module_params = [tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () - for c in callables] - per_callable_static_input_surfaces = [flatten_sample_args[i] + per_callable_module_params[i] - for i in range(len(callables))] + per_callable_module_params = [ + tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () + for c in callables + ] + per_callable_static_input_surfaces = [ + flatten_sample_args[i] + per_callable_module_params[i] + for i in range(len(callables)) + ] fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] @@ -286,16 +309,20 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unu # from ending up in any captures. torch.cuda.synchronize() with torch.cuda.stream(torch.cuda.Stream()): - for func, args, static_input_surface in zip(callables, - sample_args, - per_callable_static_input_surfaces): + for func, args, static_input_surface in zip( + callables, sample_args, per_callable_static_input_surfaces + ): for _ in range(num_warmup_iters): outputs, _ = _tree_flatten(func(*args)) - grad_inputs = torch.autograd.grad(outputs=tuple(o for o in outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), - only_inputs=True, - allow_unused=allow_unused_input) + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple( + torch.empty_like(o) for o in outputs if o.requires_grad + ), + only_inputs=True, + allow_unused=allow_unused_input, + ) del outputs, grad_inputs torch.cuda.synchronize() @@ -306,9 +333,7 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unu # Capture forward graphs per_callable_static_outputs = [] per_callable_output_unflatten_spec = [] - for func, args, fwd_graph in zip(callables, - sample_args, - fwd_graphs): + for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): with torch.cuda.graph(fwd_graph, pool=mempool): outputs = func(*args) @@ -316,26 +341,29 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unu per_callable_static_outputs.append(tuple(flatten_outputs)) per_callable_output_unflatten_spec.append(spec) - # Capture backward graphs in reverse order per_callable_static_grad_outputs = [] per_callable_static_grad_inputs = [] - for static_input_surface, static_outputs, bwd_graph, module_params in \ - zip(reversed(per_callable_static_input_surfaces), - reversed(per_callable_static_outputs), - reversed(bwd_graphs), - reversed(per_callable_module_params)): - + for static_input_surface, static_outputs, bwd_graph, module_params in zip( + reversed(per_callable_static_input_surfaces), + reversed(per_callable_static_outputs), + reversed(bwd_graphs), + reversed(per_callable_module_params), + ): # For now, assumes all static_outputs require grad # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad." - static_grad_outputs = tuple(torch.empty_like(o) if o.requires_grad else None for o in static_outputs) + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None for o in static_outputs + ) with torch.cuda.graph(bwd_graph, pool=mempool): - grad_inputs = torch.autograd.grad(outputs=tuple(o for o in static_outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(o for o in static_grad_outputs if o is not None), - only_inputs=True, - allow_unused=allow_unused_input) + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in static_outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad. @@ -358,15 +386,17 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unu per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs)) # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. - def make_graphed_autograd_function(fwd_graph, - bwd_graph, - module_params, - len_user_args, - output_unflatten_spec, - static_input_surface, - static_outputs, - static_grad_outputs, - static_grad_inputs): + def make_graphed_autograd_function( + fwd_graph, + bwd_graph, + module_params, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + static_grad_outputs, + static_grad_inputs, + ): class Graphed(torch.autograd.Function): @staticmethod def forward(ctx, *inputs): @@ -392,7 +422,9 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unu # Input args that didn't require grad expect a None gradient. assert isinstance(static_grad_inputs, tuple) - return tuple(b.detach() if b is not None else b for b in static_grad_inputs) + return tuple( + b.detach() if b is not None else b for b in static_grad_inputs + ) def functionalized(*user_args): # Runs the autograd function with inputs == all inputs to the graph that might require grad @@ -407,17 +439,20 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unu # Put together the final graphed callables ret = [] for i, func in enumerate(callables): - graphed = make_graphed_autograd_function(fwd_graphs[i], - bwd_graphs[i], - per_callable_module_params[i], - per_callable_len_user_args[i], - per_callable_output_unflatten_spec[i], - per_callable_static_input_surfaces[i], - per_callable_static_outputs[i], - per_callable_static_grad_outputs[i], - per_callable_static_grad_inputs[i]) + graphed = make_graphed_autograd_function( + fwd_graphs[i], + bwd_graphs[i], + per_callable_module_params[i], + per_callable_len_user_args[i], + per_callable_output_unflatten_spec[i], + per_callable_static_input_surfaces[i], + per_callable_static_outputs[i], + per_callable_static_grad_outputs[i], + per_callable_static_grad_inputs[i], + ) if isinstance(func, torch.nn.Module): + def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): def new_fwd(*user_args): # If the module's training-or-eval state matches what we graphed, @@ -426,7 +461,9 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unu return graphed(*user_args) else: return orig_fwd(*user_args) + return new_fwd + func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment] ret.append(func) else: diff --git a/torch/cuda/jiterator.py b/torch/cuda/jiterator.py index 7a9b75cc0685..df72b57e9888 100644 --- a/torch/cuda/jiterator.py +++ b/torch/cuda/jiterator.py @@ -1,10 +1,10 @@ -import torch -from torch import Tensor +import re from typing import Callable, List -import re +import torch +from torch import Tensor -__all__ : List[str] = [] +__all__: List[str] = [] class _CodeParser: @@ -17,20 +17,30 @@ class _CodeParser: function_params = r"(?P\(.+\))" function_body = r"(?P\{.+\})" - pattern = \ - optional_ws \ - + "template" \ - + optional_ws + template_params \ - + optional_ws + return_type \ - + required_ws + function_name \ - + optional_ws + function_params \ - + optional_ws + function_body \ + pattern = ( + optional_ws + + "template" + optional_ws + + template_params + + optional_ws + + return_type + + required_ws + + function_name + + optional_ws + + function_params + + optional_ws + + function_body + + optional_ws + ) - result = re.match(pattern, code_string, re.DOTALL) # DOTALL for matching multiline + result = re.match( + pattern, code_string, re.DOTALL + ) # DOTALL for matching multiline if result is None: - raise Exception(f"Couldn't parse code, please check correctness:\n {code_string}") + raise Exception( + f"Couldn't parse code, please check correctness:\n {code_string}" + ) self.template_params = result["template_params"] self.return_type = result["return_type"] @@ -40,10 +50,14 @@ class _CodeParser: class _JittedFunction: - def __init__(self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs): + def __init__( + self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs + ): self.code_string = code_string - assert return_by_ref or num_outputs == 1, "Return by value only works for single output. " + assert ( + return_by_ref or num_outputs == 1 + ), "Return by value only works for single output. " self.return_by_ref = return_by_ref self.num_outputs = num_outputs @@ -56,7 +70,9 @@ class _JittedFunction: def __call__(self, *tensors: Tensor, **kwargs): # Jiterator follow torch.cuda's lazy initialization behavior # Defer checking cuda's availability at the function invocation time - assert self.is_cuda_available, "Jiterator is only supported on CUDA and ROCm GPUs, none are available." + assert ( + self.is_cuda_available + ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available." assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs." @@ -73,7 +89,8 @@ class _JittedFunction: self.return_by_ref, self.num_outputs, tensors, - expanded_kwargs) + expanded_kwargs, + ) def _create_jit_fn(code_string: str, **kwargs) -> Callable: @@ -138,7 +155,9 @@ def _create_jit_fn(code_string: str, **kwargs) -> Callable: return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs) -def _create_multi_output_jit_fn(code_string: str, num_outputs: int, **kwargs) -> Callable: +def _create_multi_output_jit_fn( + code_string: str, num_outputs: int, **kwargs +) -> Callable: """ Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs. @@ -163,4 +182,6 @@ def _create_multi_output_jit_fn(code_string: str, num_outputs: int, **kwargs) -> This API only supports up to 8 inputs and 8 outputs """ - return _JittedFunction(code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs) + return _JittedFunction( + code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs + ) diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 04f4096c2dea..d5be5a18b68c 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -1,33 +1,57 @@ import collections import contextlib import ctypes -import warnings +import os import pickle import sys -import os +import warnings -from typing import Any, Dict, Union, Tuple, Optional +from typing import Any, Dict, Optional, Tuple, Union import torch -from . import is_initialized, _get_device_index, _lazy_init, _get_nvml_device_index -from ._utils import _dummy_type - -from ._memory_viz import segments as _segments, memory as _memory, segment_plot, trace_plot - -from torch.types import Device from torch import _C -__all__ = ["caching_allocator_alloc", "caching_allocator_delete", "set_per_process_memory_fraction", - "empty_cache", "memory_stats", "memory_stats_as_nested_dict", "reset_accumulated_memory_stats", - "reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached", - "memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved", - "memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", "list_gpu_processes", - "mem_get_info", "get_allocator_backend", "CUDAPluggableAllocator", "change_current_allocator"] +from torch.types import Device +from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized + +from ._memory_viz import ( + memory as _memory, + segment_plot, + segments as _segments, + trace_plot, +) +from ._utils import _dummy_type + +__all__ = [ + "caching_allocator_alloc", + "caching_allocator_delete", + "set_per_process_memory_fraction", + "empty_cache", + "memory_stats", + "memory_stats_as_nested_dict", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", + "reset_max_memory_allocated", + "reset_max_memory_cached", + "memory_allocated", + "max_memory_allocated", + "memory_reserved", + "max_memory_reserved", + "memory_cached", + "max_memory_cached", + "memory_snapshot", + "memory_summary", + "list_gpu_processes", + "mem_get_info", + "get_allocator_backend", + "CUDAPluggableAllocator", + "change_current_allocator", +] -if not hasattr(torch._C, '_cuda_CUDAAllocator'): +if not hasattr(torch._C, "_cuda_CUDAAllocator"): # Define dummy base classes - torch._C.__dict__['_cuda_CUDAAllocator'] = _dummy_type('_cuda_CUDAAllocator') + torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator") def _host_allocator(): @@ -71,9 +95,11 @@ def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None if isinstance(stream, torch.cuda.streams.Stream): stream = stream.cuda_stream if not isinstance(stream, int): - raise TypeError('Invalid type for stream argument, must be ' - '`torch.cuda.Stream` or `int` representing a pointer ' - 'to a existing stream') + raise TypeError( + "Invalid type for stream argument, must be " + "`torch.cuda.Stream` or `int` representing a pointer " + "to a existing stream" + ) with torch.cuda.device(device): return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream) @@ -95,7 +121,9 @@ def caching_allocator_delete(mem_ptr): torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr) -def set_per_process_memory_fraction(fraction, device: Union[Device, int] = None) -> None: +def set_per_process_memory_fraction( + fraction, device: Union[Device, int] = None +) -> None: r"""Set memory fraction for a process. The fraction is used to limit an caching allocator to allocated memory on a CUDA device. The allowed value equals the total visible memory multiplied fraction. @@ -114,9 +142,9 @@ def set_per_process_memory_fraction(fraction, device: Union[Device, int] = None) device = torch.cuda.current_device() device = _get_device_index(device) if not isinstance(fraction, float): - raise TypeError('Invalid type for fraction argument, must be `float`') + raise TypeError("Invalid type for fraction argument, must be `float`") if fraction < 0 or fraction > 1: - raise ValueError(f'Invalid fraction value: {fraction}. Allowed range: 0~1') + raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1") torch._C._cuda_setMemoryFraction(fraction, device) @@ -306,7 +334,8 @@ def reset_max_memory_allocated(device: Union[Device, int] = None) -> None: warnings.warn( "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, " "which resets /all/ peak memory stats.", - FutureWarning) + FutureWarning, + ) return reset_peak_memory_stats(device=device) @@ -332,7 +361,8 @@ def reset_max_memory_cached(device: Union[Device, int] = None) -> None: warnings.warn( "torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, " "which resets /all/ peak memory stats.", - FutureWarning) + FutureWarning, + ) return reset_peak_memory_stats(device=device) @@ -418,7 +448,8 @@ def memory_cached(device: Union[Device, int] = None) -> int: r"""Deprecated; see :func:`~torch.cuda.memory_reserved`.""" warnings.warn( "torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved", - FutureWarning) + FutureWarning, + ) return memory_reserved(device=device) @@ -426,7 +457,8 @@ def max_memory_cached(device: Union[Device, int] = None) -> int: r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`.""" warnings.warn( "torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved", - FutureWarning) + FutureWarning, + ) return max_memory_reserved(device=device) @@ -440,7 +472,7 @@ def memory_snapshot(): See :ref:`cuda-memory-management` for more details about GPU memory management. """ - return torch._C._cuda_memorySnapshot()['segments'] + return torch._C._cuda_memorySnapshot()["segments"] def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str: @@ -502,9 +534,13 @@ def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) lines.append("=" * 75) lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ") lines.append("-" * 75) - lines.append(" {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d} ") + lines.append( + " {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d} " + ) lines.append("=" * 75) - lines.append(" Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed ") + lines.append( + " Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed " + ) for metric_key, metric_name, formatter in metrics_to_display: lines.append("-" * 75) @@ -513,7 +549,12 @@ def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) submetrics.append(("large_pool", " from large pool")) submetrics.append(("small_pool", " from small pool")) - current_prefval, peak_prefval, allocated_prefval, freed_prefval = None, None, None, None + current_prefval, peak_prefval, allocated_prefval, freed_prefval = ( + None, + None, + None, + None, + ) for submetric_key, submetric_name in submetrics: prefix = metric_key + "." + submetric_key + "." @@ -529,12 +570,14 @@ def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) allocated_prefval = allocated freed_prefval = freed - lines.append(" {:<21} | {} | {} | {} | {} ".format( - submetric_name, - formatter(current, current_prefval), - formatter(peak, peak_prefval), - formatter(allocated, allocated_prefval), - formatter(freed, freed_prefval)), + lines.append( + " {:<21} | {} | {} | {} | {} ".format( + submetric_name, + formatter(current, current_prefval), + formatter(peak, peak_prefval), + formatter(allocated, allocated_prefval), + formatter(freed, freed_prefval), + ), ) metrics_to_display = [ @@ -552,12 +595,14 @@ def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) allocated = stats[prefix + "allocated"] freed = stats[prefix + "freed"] - lines.append(" {:<21} | {} | {} | {} | {} ".format( - metric_name, - formatter(current, current), - formatter(peak, peak), - formatter(allocated, allocated), - formatter(freed, freed)), + lines.append( + " {:<21} | {} | {} | {} | {} ".format( + metric_name, + formatter(current, current), + formatter(peak, peak), + formatter(allocated, allocated), + formatter(freed, freed), + ), ) lines.append("=" * 75) @@ -584,12 +629,13 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str: try: import pynvml # type: ignore[import] except ModuleNotFoundError: - return("pynvml module not found, please install pynvml") + return "pynvml module not found, please install pynvml" from pynvml import NVMLError_DriverNotLoaded + try: pynvml.nvmlInit() except NVMLError_DriverNotLoaded: - return ("cuda driver can't be loaded, is cuda enabled?") + return "cuda driver can't be loaded, is cuda enabled?" device = _get_nvml_device_index(device) handle = pynvml.nvmlDeviceGetHandleByIndex(device) procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) @@ -602,6 +648,7 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str: lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory") return "\n".join(lines) + def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]: r"""Returns the global free and total GPU memory for a given device using cudaMemGetInfo. @@ -620,14 +667,24 @@ def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]: device = _get_device_index(device) return torch.cuda.cudart().cudaMemGetInfo(device) -def _record_memory_history_legacy(enabled: bool, record_context=True, - trace_alloc_max_entries=1, - trace_alloc_record_context=False, device: Union[Device, int] = None, - record_context_cpp=False): +def _record_memory_history_legacy( + enabled: bool, + record_context=True, + trace_alloc_max_entries=1, + trace_alloc_record_context=False, + device: Union[Device, int] = None, + record_context_cpp=False, +): with torch.cuda.device(device): - _C._cuda_recordMemoryHistory(enabled, record_context, record_context_cpp, - trace_alloc_max_entries, trace_alloc_record_context) + _C._cuda_recordMemoryHistory( + enabled, + record_context, + record_context_cpp, + trace_alloc_max_entries, + trace_alloc_record_context, + ) + def _record_memory_history(enabled="all", *args, **kwargs): """Enables recording of stack traces associated with memory @@ -674,11 +731,14 @@ def _record_memory_history(enabled="all", *args, **kwargs): else: return _record_memory_history_impl(enabled, *args, **kwargs) -def _record_memory_history_impl(enabled: Optional[str] = "all", - context: Optional[str] = "all", - stacks: str = "all", - max_entries: int = sys.maxsize, - device: Union[Device, int] = None): + +def _record_memory_history_impl( + enabled: Optional[str] = "all", + context: Optional[str] = "all", + stacks: str = "all", + max_entries: int = sys.maxsize, + device: Union[Device, int] = None, +): if enabled not in ["state", "all", None]: raise TypeError("expected state to be 'state', 'all', or None") if context not in ["state", "all", None]: @@ -692,39 +752,49 @@ def _record_memory_history_impl(enabled: Optional[str] = "all", trace_alloc_record_context = context == "all" record_context_cpp = stacks == "all" with torch.cuda.device(device): - _C._cuda_recordMemoryHistory(enabled_, record_context, record_context_cpp, - trace_alloc_max_entries, trace_alloc_record_context) + _C._cuda_recordMemoryHistory( + enabled_, + record_context, + record_context_cpp, + trace_alloc_max_entries, + trace_alloc_record_context, + ) + def _snapshot(device: Union[Device, int] = None): with torch.cuda.device(device): return _C._cuda_memorySnapshot() -def _dump_snapshot(filename='snapshot_dump', device: Union[Device, int] = None): + +def _dump_snapshot(filename="snapshot_dump", device: Union[Device, int] = None): os.makedirs(filename, exist_ok=True) s = _snapshot(device) - with open(f'{filename}/snapshot.pickle', 'wb') as f: + with open(f"{filename}/snapshot.pickle", "wb") as f: pickle.dump(s, f) - with open(f'{filename}/trace_plot.html', 'w') as f: + with open(f"{filename}/trace_plot.html", "w") as f: f.write(trace_plot(s)) - with open(f'{filename}/segment_plot.html', 'w') as f: + with open(f"{filename}/segment_plot.html", "w") as f: f.write(segment_plot(s)) -def _save_segment_usage(filename='output.svg', snapshot=None): +def _save_segment_usage(filename="output.svg", snapshot=None): if snapshot is None: snapshot = _snapshot() - with open(filename, 'w') as f: + with open(filename, "w") as f: f.write(_segments(snapshot)) -def _save_memory_usage(filename='output.svg', snapshot=None): + +def _save_memory_usage(filename="output.svg", snapshot=None): if snapshot is None: snapshot = _snapshot() - with open(filename, 'w') as f: + with open(filename, "w") as f: f.write(_memory(snapshot)) + def _set_allocator_settings(env: str): return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env) + def get_allocator_backend() -> str: r"""Returns a string describing the active allocator backend as set by ``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are @@ -736,9 +806,10 @@ def get_allocator_backend() -> str: """ return torch._C._cuda_getAllocatorBackend() + class _CUDAAllocator: - r"""Wrapper over internal CUDA memory allocators. - """ + r"""Wrapper over internal CUDA memory allocators.""" + def __init__(self, allocator: torch._C._cuda_CUDAAllocator): self._allocator = allocator @@ -769,6 +840,7 @@ class CUDAPluggableAllocator(_CUDAAllocator): .. note:: See :ref:`cuda-memory-management` for details on creating and using a custom allocator """ + def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str): allocator = ctypes.CDLL(path_to_so_file) alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value diff --git a/torch/cuda/nccl.py b/torch/cuda/nccl.py index db8ba710c5c7..f04112289cd9 100644 --- a/torch/cuda/nccl.py +++ b/torch/cuda/nccl.py @@ -1,18 +1,18 @@ import collections import warnings - -import torch.cuda from typing import Optional, Sequence, Union +import torch.cuda -__all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter'] + +__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"] SUM = 0 # ncclRedOp_t def is_available(tensors): - if not hasattr(torch._C, '_nccl_all_reduce'): - warnings.warn('PyTorch is not compiled with NCCL support') + if not hasattr(torch._C, "_nccl_all_reduce"): + warnings.warn("PyTorch is not compiled with NCCL support") return False devices = set() @@ -48,7 +48,9 @@ def init_rank(num_ranks, uid, rank): def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None: - if not isinstance(inputs, collections.abc.Container) or isinstance(inputs, torch.Tensor): + if not isinstance(inputs, collections.abc.Container) or isinstance( + inputs, torch.Tensor + ): raise TypeError("Inputs should be a collection of tensors") @@ -62,13 +64,16 @@ def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None): # `output` used to be `outputs`, taking in a list of tensors. So we have two # arguments for BC reasons. -def reduce(inputs: Sequence[torch.Tensor], - output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None, - root: int = 0, - op: int = SUM, - streams: Optional[Sequence[torch.cuda.Stream]] = None, - comms=None, *, - outputs: Optional[Sequence[torch.Tensor]] = None) -> None: +def reduce( + inputs: Sequence[torch.Tensor], + output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None, + root: int = 0, + op: int = SUM, + streams: Optional[Sequence[torch.cuda.Stream]] = None, + comms=None, + *, + outputs: Optional[Sequence[torch.Tensor]] = None, +) -> None: _check_sequence_type(inputs) _output: torch.Tensor if outputs is not None: @@ -76,38 +81,53 @@ def reduce(inputs: Sequence[torch.Tensor], raise ValueError( "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in " "favor of 'output', taking in a single output tensor. The signature of reduce is: " - "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None).") + "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)." + ) else: warnings.warn( "nccl.reduce with an output tensor list is deprecated. " - "Please specify a single output tensor with argument 'output' instead instead.") + "Please specify a single output tensor with argument 'output' instead instead." + ) _output = outputs[root] - elif not isinstance(output, torch.Tensor) and isinstance(output, collections.abc.Sequence): + elif not isinstance(output, torch.Tensor) and isinstance( + output, collections.abc.Sequence + ): # User called old API with positional arguments of list of output tensors. warnings.warn( "nccl.reduce with an output tensor list is deprecated. " - "Please specify a single output tensor.") + "Please specify a single output tensor." + ) _output = output[root] else: _output = inputs[root] if output is None else output torch._C._nccl_reduce(inputs, _output, root, op, streams, comms) -def broadcast(inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None) -> None: +def broadcast( + inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None +) -> None: _check_sequence_type(inputs) torch._C._nccl_broadcast(inputs, root, streams, comms) -def all_gather(inputs: Sequence[torch.Tensor], outputs: Sequence[torch.Tensor], streams=None, comms=None) -> None: +def all_gather( + inputs: Sequence[torch.Tensor], + outputs: Sequence[torch.Tensor], + streams=None, + comms=None, +) -> None: _check_sequence_type(inputs) _check_sequence_type(outputs) torch._C._nccl_all_gather(inputs, outputs, streams, comms) -def reduce_scatter(inputs: Sequence[torch.Tensor], - outputs: Sequence[torch.Tensor], - op: int = SUM, - streams=None, comms=None) -> None: +def reduce_scatter( + inputs: Sequence[torch.Tensor], + outputs: Sequence[torch.Tensor], + op: int = SUM, + streams=None, + comms=None, +) -> None: _check_sequence_type(inputs) _check_sequence_type(outputs) torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms) diff --git a/torch/cuda/nvtx.py b/torch/cuda/nvtx.py index 1ec9e2610895..1bf707f9cdfa 100644 --- a/torch/cuda/nvtx.py +++ b/torch/cuda/nvtx.py @@ -3,10 +3,13 @@ from contextlib import contextmanager try: from torch._C import _nvtx except ImportError: + class _NVTXStub: @staticmethod def _fail(*args, **kwargs): - raise RuntimeError("NVTX functions not installed. Are you sure you have a CUDA build?") + raise RuntimeError( + "NVTX functions not installed. Are you sure you have a CUDA build?" + ) rangePushA = _fail rangePop = _fail diff --git a/torch/cuda/profiler.py b/torch/cuda/profiler.py index 6ea7c65d34cc..51c8aa46f714 100644 --- a/torch/cuda/profiler.py +++ b/torch/cuda/profiler.py @@ -1,7 +1,8 @@ -import tempfile -import torch import contextlib -from . import cudart, check_error +import tempfile + +import torch +from . import check_error, cudart __all__ = ["init", "start", "stop", "profile"] @@ -16,23 +17,29 @@ DEFAULT_FLAGS = [ ] -def init(output_file, flags=None, output_mode='key_value'): +def init(output_file, flags=None, output_mode="key_value"): rt = cudart() - if not hasattr(rt, 'cudaOutputMode'): + if not hasattr(rt, "cudaOutputMode"): raise AssertionError("HIP does not support profiler initialization!") - if hasattr(torch.version, "cuda") and torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12: + if ( + hasattr(torch.version, "cuda") + and torch.version.cuda is not None + and int(torch.version.cuda.split(".")[0]) >= 12 + ): # Check https://github.com/pytorch/pytorch/pull/91118 # cudaProfilerInitialize is no longer needed after CUDA 12 raise AssertionError("CUDA12+ does not need profiler initialization!") flags = DEFAULT_FLAGS if flags is None else flags - if output_mode == 'key_value': + if output_mode == "key_value": output_mode_enum = rt.cudaOutputMode.KeyValuePair - elif output_mode == 'csv': + elif output_mode == "csv": output_mode_enum = rt.cudaOutputMode.CSV else: - raise RuntimeError("supported CUDA profiler output modes are: key_value and csv") + raise RuntimeError( + "supported CUDA profiler output modes are: key_value and csv" + ) with tempfile.NamedTemporaryFile(delete=True) as f: - f.write(b'\n'.join(f.encode('ascii') for f in flags)) + f.write(b"\n".join(f.encode("ascii") for f in flags)) f.flush() check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum)) diff --git a/torch/cuda/random.py b/torch/cuda/random.py index d55f147b2440..5680b8b45bdb 100644 --- a/torch/cuda/random.py +++ b/torch/cuda/random.py @@ -1,15 +1,23 @@ -import torch from typing import Iterable, List, Union -from . import _lazy_init, _lazy_call, device_count, current_device + +import torch from .. import Tensor +from . import _lazy_call, _lazy_init, current_device, device_count -__all__ = ['get_rng_state', 'get_rng_state_all', - 'set_rng_state', 'set_rng_state_all', - 'manual_seed', 'manual_seed_all', - 'seed', 'seed_all', 'initial_seed'] +__all__ = [ + "get_rng_state", + "get_rng_state_all", + "set_rng_state", + "set_rng_state_all", + "manual_seed", + "manual_seed_all", + "seed", + "seed_all", + "initial_seed", +] -def get_rng_state(device: Union[int, str, torch.device] = 'cuda') -> Tensor: +def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor: r"""Returns the random number generator state of the specified GPU as a ByteTensor. Args: @@ -23,7 +31,7 @@ def get_rng_state(device: Union[int, str, torch.device] = 'cuda') -> Tensor: if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): - device = torch.device('cuda', device) + device = torch.device("cuda", device) idx = device.index if idx is None: idx = current_device() @@ -40,7 +48,9 @@ def get_rng_state_all() -> List[Tensor]: return results -def set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'cuda') -> None: +def set_rng_state( + new_state: Tensor, device: Union[int, str, torch.device] = "cuda" +) -> None: r"""Sets the random number generator state of the specified GPU. Args: @@ -53,7 +63,7 @@ def set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'cu if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): - device = torch.device('cuda', device) + device = torch.device("cuda", device) def cb(): idx = device.index @@ -123,6 +133,7 @@ def seed() -> None: If you are working with a multi-GPU model, this function will only initialize the seed on one GPU. To initialize all GPUs, use :func:`seed_all`. """ + def cb(): idx = current_device() default_generator = torch.cuda.default_generators[idx] @@ -136,6 +147,7 @@ def seed_all() -> None: It's safe to call this function if CUDA is not available; in that case, it is silently ignored. """ + def cb(): random_seed = 0 seeded = False diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index bcb3e1faf40b..87269189e0d5 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -1,13 +1,15 @@ import ctypes + import torch from ._utils import _dummy_type -if not hasattr(torch._C, '_CudaStreamBase'): +if not hasattr(torch._C, "_CudaStreamBase"): # Define dummy base classes - torch._C.__dict__['_CudaStreamBase'] = _dummy_type('_CudaStreamBase') - torch._C.__dict__['_CudaEventBase'] = _dummy_type('_CudaEventBase') + torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase") + torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase") + class Stream(torch._C._CudaStreamBase): r"""Wrapper around a CUDA stream. @@ -108,7 +110,7 @@ class Stream(torch._C._CudaStreamBase): return hash((self.cuda_stream, self.device)) def __repr__(self): - return (f'') + return f"" class ExternalStream(Stream): @@ -160,7 +162,10 @@ class Event(torch._C._CudaEventBase): def __new__(cls, enable_timing=False, blocking=False, interprocess=False): return super().__new__( cls, - enable_timing=enable_timing, blocking=blocking, interprocess=interprocess) + enable_timing=enable_timing, + blocking=blocking, + interprocess=interprocess, + ) @classmethod def from_ipc_handle(cls, device, handle): @@ -217,7 +222,7 @@ class Event(torch._C._CudaEventBase): def ipc_handle(self): r"""Returns an IPC handle of this event. If not recorded yet, the event - will use the current device. """ + will use the current device.""" return super().ipc_handle() @property @@ -226,6 +231,6 @@ class Event(torch._C._CudaEventBase): def __repr__(self): if self.cuda_event: - return f'' + return f"" else: - return '' + return "" diff --git a/torch/distributions/__init__.py b/torch/distributions/__init__.py index 71ad4b4fbfba..e4053573a2d7 100644 --- a/torch/distributions/__init__.py +++ b/torch/distributions/__init__.py @@ -90,7 +90,7 @@ from .gumbel import Gumbel from .half_cauchy import HalfCauchy from .half_normal import HalfNormal from .independent import Independent -from .kl import kl_divergence, register_kl, _add_kl_info +from .kl import _add_kl_info, kl_divergence, register_kl from .kumaraswamy import Kumaraswamy from .laplace import Laplace from .lkj_cholesky import LKJCholesky @@ -110,60 +110,60 @@ from .relaxed_categorical import RelaxedOneHotCategorical from .studentT import StudentT from .transformed_distribution import TransformedDistribution from .transforms import * # noqa: F403 +from . import transforms from .uniform import Uniform from .von_mises import VonMises from .weibull import Weibull from .wishart import Wishart -from . import transforms _add_kl_info() del _add_kl_info __all__ = [ - 'Bernoulli', - 'Beta', - 'Binomial', - 'Categorical', - 'Cauchy', - 'Chi2', - 'ContinuousBernoulli', - 'Dirichlet', - 'Distribution', - 'Exponential', - 'ExponentialFamily', - 'FisherSnedecor', - 'Gamma', - 'Geometric', - 'Gumbel', - 'HalfCauchy', - 'HalfNormal', - 'Independent', - 'Kumaraswamy', - 'LKJCholesky', - 'Laplace', - 'LogNormal', - 'LogisticNormal', - 'LowRankMultivariateNormal', - 'MixtureSameFamily', - 'Multinomial', - 'MultivariateNormal', - 'NegativeBinomial', - 'Normal', - 'OneHotCategorical', - 'OneHotCategoricalStraightThrough', - 'Pareto', - 'RelaxedBernoulli', - 'RelaxedOneHotCategorical', - 'StudentT', - 'Poisson', - 'Uniform', - 'VonMises', - 'Weibull', - 'Wishart', - 'TransformedDistribution', - 'biject_to', - 'kl_divergence', - 'register_kl', - 'transform_to', + "Bernoulli", + "Beta", + "Binomial", + "Categorical", + "Cauchy", + "Chi2", + "ContinuousBernoulli", + "Dirichlet", + "Distribution", + "Exponential", + "ExponentialFamily", + "FisherSnedecor", + "Gamma", + "Geometric", + "Gumbel", + "HalfCauchy", + "HalfNormal", + "Independent", + "Kumaraswamy", + "LKJCholesky", + "Laplace", + "LogNormal", + "LogisticNormal", + "LowRankMultivariateNormal", + "MixtureSameFamily", + "Multinomial", + "MultivariateNormal", + "NegativeBinomial", + "Normal", + "OneHotCategorical", + "OneHotCategoricalStraightThrough", + "Pareto", + "RelaxedBernoulli", + "RelaxedOneHotCategorical", + "StudentT", + "Poisson", + "Uniform", + "VonMises", + "Weibull", + "Wishart", + "TransformedDistribution", + "biject_to", + "kl_divergence", + "register_kl", + "transform_to", ] __all__.extend(transforms.__all__) diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index 9d9b0fd7b8c9..8f021ea6677d 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -4,10 +4,16 @@ import torch from torch import nan from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily -from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) from torch.nn.functional import binary_cross_entropy_with_logits -__all__ = ['Bernoulli'] +__all__ = ["Bernoulli"] + class Bernoulli(ExponentialFamily): r""" @@ -28,21 +34,22 @@ class Bernoulli(ExponentialFamily): probs (Number, Tensor): the probability of sampling `1` logits (Number, Tensor): the log-odds of sampling `1` """ - arg_constraints = {'probs': constraints.unit_interval, - 'logits': constraints.real} + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.boolean has_enumerate_support = True _mean_carrier_measure = 0 def __init__(self, probs=None, logits=None, validate_args=None): if (probs is None) == (logits is None): - raise ValueError("Either `probs` or `logits` must be specified, but not both.") + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) if probs is not None: is_scalar = isinstance(probs, Number) - self.probs, = broadcast_all(probs) + (self.probs,) = broadcast_all(probs) else: is_scalar = isinstance(logits, Number) - self.logits, = broadcast_all(logits) + (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits if is_scalar: batch_shape = torch.Size() @@ -53,10 +60,10 @@ class Bernoulli(ExponentialFamily): def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Bernoulli, _instance) batch_shape = torch.Size(batch_shape) - if 'probs' in self.__dict__: + if "probs" in self.__dict__: new.probs = self.probs.expand(batch_shape) new._param = new.probs - if 'logits' in self.__dict__: + if "logits" in self.__dict__: new.logits = self.logits.expand(batch_shape) new._param = new.logits super(Bernoulli, new).__init__(batch_shape, validate_args=False) @@ -101,10 +108,12 @@ class Bernoulli(ExponentialFamily): if self._validate_args: self._validate_sample(value) logits, value = broadcast_all(self.logits, value) - return -binary_cross_entropy_with_logits(logits, value, reduction='none') + return -binary_cross_entropy_with_logits(logits, value, reduction="none") def entropy(self): - return binary_cross_entropy_with_logits(self.logits, self.probs, reduction='none') + return binary_cross_entropy_with_logits( + self.logits, self.probs, reduction="none" + ) def enumerate_support(self, expand=True): values = torch.arange(2, dtype=self._param.dtype, device=self._param.device) @@ -115,7 +124,7 @@ class Bernoulli(ExponentialFamily): @property def _natural_params(self): - return (torch.logit(self.probs), ) + return (torch.logit(self.probs),) def _log_normalizer(self, x): return torch.log1p(torch.exp(x)) diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py index dd6ed437c1e5..982cfe838c82 100644 --- a/torch/distributions/beta.py +++ b/torch/distributions/beta.py @@ -1,4 +1,4 @@ -from numbers import Real, Number +from numbers import Number, Real import torch from torch.distributions import constraints @@ -6,7 +6,8 @@ from torch.distributions.dirichlet import Dirichlet from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all -__all__ = ['Beta'] +__all__ = ["Beta"] + class Beta(ExponentialFamily): r""" @@ -25,17 +26,28 @@ class Beta(ExponentialFamily): concentration0 (float or Tensor): 2nd concentration parameter of the distribution (often referred to as beta) """ - arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive} + arg_constraints = { + "concentration1": constraints.positive, + "concentration0": constraints.positive, + } support = constraints.unit_interval has_rsample = True def __init__(self, concentration1, concentration0, validate_args=None): if isinstance(concentration1, Real) and isinstance(concentration0, Real): - concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)]) + concentration1_concentration0 = torch.tensor( + [float(concentration1), float(concentration0)] + ) else: - concentration1, concentration0 = broadcast_all(concentration1, concentration0) - concentration1_concentration0 = torch.stack([concentration1, concentration0], -1) - self._dirichlet = Dirichlet(concentration1_concentration0, validate_args=validate_args) + concentration1, concentration0 = broadcast_all( + concentration1, concentration0 + ) + concentration1_concentration0 = torch.stack( + [concentration1, concentration0], -1 + ) + self._dirichlet = Dirichlet( + concentration1_concentration0, validate_args=validate_args + ) super().__init__(self._dirichlet._batch_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): @@ -57,8 +69,7 @@ class Beta(ExponentialFamily): @property def variance(self): total = self.concentration1 + self.concentration0 - return (self.concentration1 * self.concentration0 / - (total.pow(2) * (total + 1))) + return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1)) def rsample(self, sample_shape=()): return self._dirichlet.rsample(sample_shape).select(-1, 0) diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index c4d33ca8a4c4..ba8bb4ccc166 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -1,9 +1,15 @@ import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution -from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) + +__all__ = ["Binomial"] -__all__ = ['Binomial'] def _clamp_by_zero(x): # works like clamp(x, min=0) but has grad at 0 is 0.5 @@ -33,19 +39,29 @@ class Binomial(Distribution): probs (Tensor): Event probabilities logits (Tensor): Event log-odds """ - arg_constraints = {'total_count': constraints.nonnegative_integer, - 'probs': constraints.unit_interval, - 'logits': constraints.real} + arg_constraints = { + "total_count": constraints.nonnegative_integer, + "probs": constraints.unit_interval, + "logits": constraints.real, + } has_enumerate_support = True def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): if (probs is None) == (logits is None): - raise ValueError("Either `probs` or `logits` must be specified, but not both.") + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) if probs is not None: - self.total_count, self.probs, = broadcast_all(total_count, probs) + ( + self.total_count, + self.probs, + ) = broadcast_all(total_count, probs) self.total_count = self.total_count.type_as(self.probs) else: - self.total_count, self.logits, = broadcast_all(total_count, logits) + ( + self.total_count, + self.logits, + ) = broadcast_all(total_count, logits) self.total_count = self.total_count.type_as(self.logits) self._param = self.probs if probs is not None else self.logits @@ -56,10 +72,10 @@ class Binomial(Distribution): new = self._get_checked_instance(Binomial, _instance) batch_shape = torch.Size(batch_shape) new.total_count = self.total_count.expand(batch_shape) - if 'probs' in self.__dict__: + if "probs" in self.__dict__: new.probs = self.probs.expand(batch_shape) new._param = new.probs - if 'logits' in self.__dict__: + if "logits" in self.__dict__: new.logits = self.logits.expand(batch_shape) new._param = new.logits super(Binomial, new).__init__(batch_shape, validate_args=False) @@ -100,7 +116,9 @@ class Binomial(Distribution): def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) with torch.no_grad(): - return torch.binomial(self.total_count.expand(shape), self.probs.expand(shape)) + return torch.binomial( + self.total_count.expand(shape), self.probs.expand(shape) + ) def log_prob(self, value): if self._validate_args: @@ -113,15 +131,21 @@ class Binomial(Distribution): # (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p) # = k * logit - n * logit - n * log1p(e^-logit) # (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|) - normalize_term = (self.total_count * _clamp_by_zero(self.logits) - + self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits))) - - log_factorial_n) - return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term + normalize_term = ( + self.total_count * _clamp_by_zero(self.logits) + + self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits))) + - log_factorial_n + ) + return ( + value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term + ) def entropy(self): total_count = int(self.total_count.max()) if not self.total_count.min() == total_count: - raise NotImplementedError("Inhomogeneous total count not supported by `entropy`.") + raise NotImplementedError( + "Inhomogeneous total count not supported by `entropy`." + ) log_prob = self.log_prob(self.enumerate_support(False)) return -(torch.exp(log_prob) * log_prob).sum(0) @@ -129,8 +153,12 @@ class Binomial(Distribution): def enumerate_support(self, expand=True): total_count = int(self.total_count.max()) if not self.total_count.min() == total_count: - raise NotImplementedError("Inhomogeneous total count not supported by `enumerate_support`.") - values = torch.arange(1 + total_count, dtype=self._param.dtype, device=self._param.device) + raise NotImplementedError( + "Inhomogeneous total count not supported by `enumerate_support`." + ) + values = torch.arange( + 1 + total_count, dtype=self._param.dtype, device=self._param.device + ) values = values.view((-1,) + (1,) * len(self._batch_shape)) if expand: values = values.expand((-1,) + self._batch_shape) diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 7cff0e4ee35a..5f010dd49878 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -2,9 +2,10 @@ import torch from torch import nan from torch.distributions import constraints from torch.distributions.distribution import Distribution -from torch.distributions.utils import probs_to_logits, logits_to_probs, lazy_property +from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits + +__all__ = ["Categorical"] -__all__ = ['Categorical'] class Categorical(Distribution): r""" @@ -44,13 +45,14 @@ class Categorical(Distribution): probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ - arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real_vector} + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): if (probs is None) == (logits is None): - raise ValueError("Either `probs` or `logits` must be specified, but not both.") + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) if probs is not None: if probs.dim() < 1: raise ValueError("`probs` parameter must be at least one-dimensional.") @@ -62,17 +64,19 @@ class Categorical(Distribution): self.logits = logits - logits.logsumexp(dim=-1, keepdim=True) self._param = self.probs if probs is not None else self.logits self._num_events = self._param.size()[-1] - batch_shape = self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size() + batch_shape = ( + self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size() + ) super().__init__(batch_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Categorical, _instance) batch_shape = torch.Size(batch_shape) param_shape = batch_shape + torch.Size((self._num_events,)) - if 'probs' in self.__dict__: + if "probs" in self.__dict__: new.probs = self.probs.expand(param_shape) new._param = new.probs - if 'logits' in self.__dict__: + if "logits" in self.__dict__: new.logits = self.logits.expand(param_shape) new._param = new.logits new._num_events = self._num_events @@ -101,7 +105,12 @@ class Categorical(Distribution): @property def mean(self): - return torch.full(self._extended_shape(), nan, dtype=self.probs.dtype, device=self.probs.device) + return torch.full( + self._extended_shape(), + nan, + dtype=self.probs.dtype, + device=self.probs.device, + ) @property def mode(self): @@ -109,7 +118,12 @@ class Categorical(Distribution): @property def variance(self): - return torch.full(self._extended_shape(), nan, dtype=self.probs.dtype, device=self.probs.device) + return torch.full( + self._extended_shape(), + nan, + dtype=self.probs.dtype, + device=self.probs.device, + ) def sample(self, sample_shape=torch.Size()): if not isinstance(sample_shape, torch.Size): diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index 2ef0fb95aa82..84f749e7ab3d 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -1,13 +1,14 @@ import math -from torch import inf, nan from numbers import Number import torch +from torch import inf, nan from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all -__all__ = ['Cauchy'] +__all__ = ["Cauchy"] + class Cauchy(Distribution): r""" @@ -26,7 +27,7 @@ class Cauchy(Distribution): loc (float or Tensor): mode or median of the distribution. scale (float or Tensor): half width at half maximum. """ - arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real has_rsample = True @@ -49,7 +50,9 @@ class Cauchy(Distribution): @property def mean(self): - return torch.full(self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device) + return torch.full( + self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device + ) @property def mode(self): @@ -57,7 +60,9 @@ class Cauchy(Distribution): @property def variance(self): - return torch.full(self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device) + return torch.full( + self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device + ) def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) @@ -67,7 +72,11 @@ class Cauchy(Distribution): def log_prob(self, value): if self._validate_args: self._validate_sample(value) - return -math.log(math.pi) - self.scale.log() - (((value - self.loc) / self.scale)**2).log1p() + return ( + -math.log(math.pi) + - self.scale.log() + - (((value - self.loc) / self.scale) ** 2).log1p() + ) def cdf(self, value): if self._validate_args: diff --git a/torch/distributions/chi2.py b/torch/distributions/chi2.py index 4394a078832f..e078923b548e 100644 --- a/torch/distributions/chi2.py +++ b/torch/distributions/chi2.py @@ -1,7 +1,8 @@ from torch.distributions import constraints from torch.distributions.gamma import Gamma -__all__ = ['Chi2'] +__all__ = ["Chi2"] + class Chi2(Gamma): r""" @@ -18,7 +19,7 @@ class Chi2(Gamma): Args: df (float or Tensor): shape parameter of the distribution """ - arg_constraints = {'df': constraints.positive} + arg_constraints = {"df": constraints.positive} def __init__(self, df, validate_args=None): super().__init__(0.5 * df, 0.5, validate_args=validate_args) diff --git a/torch/distributions/constraint_registry.py b/torch/distributions/constraint_registry.py index 88497fcfbce6..83192f69547f 100644 --- a/torch/distributions/constraint_registry.py +++ b/torch/distributions/constraint_registry.py @@ -70,9 +70,9 @@ import numbers from torch.distributions import constraints, transforms __all__ = [ - 'ConstraintRegistry', - 'biject_to', - 'transform_to', + "ConstraintRegistry", + "biject_to", + "transform_to", ] @@ -80,6 +80,7 @@ class ConstraintRegistry: """ Registry to link constraints to transforms. """ + def __init__(self): self._registry = {} super().__init__() @@ -109,8 +110,12 @@ class ConstraintRegistry: if isinstance(constraint, constraints.Constraint): constraint = type(constraint) - if not isinstance(constraint, type) or not issubclass(constraint, constraints.Constraint): - raise TypeError(f'Expected constraint to be either a Constraint subclass or instance, but got {constraint}') + if not isinstance(constraint, type) or not issubclass( + constraint, constraints.Constraint + ): + raise TypeError( + f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}" + ) self._registry[constraint] = factory return factory @@ -139,7 +144,8 @@ class ConstraintRegistry: factory = self._registry[type(constraint)] except KeyError: raise NotImplementedError( - f'Cannot transform {type(constraint).__name__} constraints') from None + f"Cannot transform {type(constraint).__name__} constraints" + ) from None return factory(constraint) @@ -151,6 +157,7 @@ transform_to = ConstraintRegistry() # Registration Table ################################################################################ + @biject_to.register(constraints.real) @transform_to.register(constraints.real) def _transform_to_real(constraint): @@ -161,14 +168,16 @@ def _transform_to_real(constraint): def _biject_to_independent(constraint): base_transform = biject_to(constraint.base_constraint) return transforms.IndependentTransform( - base_transform, constraint.reinterpreted_batch_ndims) + base_transform, constraint.reinterpreted_batch_ndims + ) @transform_to.register(constraints.independent) def _transform_to_independent(constraint): base_transform = transform_to(constraint.base_constraint) return transforms.IndependentTransform( - base_transform, constraint.reinterpreted_batch_ndims) + base_transform, constraint.reinterpreted_batch_ndims + ) @biject_to.register(constraints.positive) @@ -184,15 +193,23 @@ def _transform_to_positive(constraint): @transform_to.register(constraints.greater_than) @transform_to.register(constraints.greater_than_eq) def _transform_to_greater_than(constraint): - return transforms.ComposeTransform([transforms.ExpTransform(), - transforms.AffineTransform(constraint.lower_bound, 1)]) + return transforms.ComposeTransform( + [ + transforms.ExpTransform(), + transforms.AffineTransform(constraint.lower_bound, 1), + ] + ) @biject_to.register(constraints.less_than) @transform_to.register(constraints.less_than) def _transform_to_less_than(constraint): - return transforms.ComposeTransform([transforms.ExpTransform(), - transforms.AffineTransform(constraint.upper_bound, -1)]) + return transforms.ComposeTransform( + [ + transforms.ExpTransform(), + transforms.AffineTransform(constraint.upper_bound, -1), + ] + ) @biject_to.register(constraints.interval) @@ -201,15 +218,22 @@ def _transform_to_less_than(constraint): @transform_to.register(constraints.half_open_interval) def _transform_to_interval(constraint): # Handle the special case of the unit interval. - lower_is_0 = isinstance(constraint.lower_bound, numbers.Number) and constraint.lower_bound == 0 - upper_is_1 = isinstance(constraint.upper_bound, numbers.Number) and constraint.upper_bound == 1 + lower_is_0 = ( + isinstance(constraint.lower_bound, numbers.Number) + and constraint.lower_bound == 0 + ) + upper_is_1 = ( + isinstance(constraint.upper_bound, numbers.Number) + and constraint.upper_bound == 1 + ) if lower_is_0 and upper_is_1: return transforms.SigmoidTransform() loc = constraint.lower_bound scale = constraint.upper_bound - constraint.lower_bound - return transforms.ComposeTransform([transforms.SigmoidTransform(), - transforms.AffineTransform(loc, scale)]) + return transforms.ComposeTransform( + [transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)] + ) @biject_to.register(constraints.simplex) @@ -242,29 +266,27 @@ def _transform_to_corr_cholesky(constraint): @biject_to.register(constraints.cat) def _biject_to_cat(constraint): - return transforms.CatTransform([biject_to(c) - for c in constraint.cseq], - constraint.dim, - constraint.lengths) + return transforms.CatTransform( + [biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths + ) @transform_to.register(constraints.cat) def _transform_to_cat(constraint): - return transforms.CatTransform([transform_to(c) - for c in constraint.cseq], - constraint.dim, - constraint.lengths) + return transforms.CatTransform( + [transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths + ) @biject_to.register(constraints.stack) def _biject_to_stack(constraint): return transforms.StackTransform( - [biject_to(c) - for c in constraint.cseq], constraint.dim) + [biject_to(c) for c in constraint.cseq], constraint.dim + ) @transform_to.register(constraints.stack) def _transform_to_stack(constraint): return transforms.StackTransform( - [transform_to(c) - for c in constraint.cseq], constraint.dim) + [transform_to(c) for c in constraint.cseq], constraint.dim + ) diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 28d61a5de8cc..4ae81c1033c4 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -33,35 +33,35 @@ The following constraints are implemented: import torch __all__ = [ - 'Constraint', - 'boolean', - 'cat', - 'corr_cholesky', - 'dependent', - 'dependent_property', - 'greater_than', - 'greater_than_eq', - 'independent', - 'integer_interval', - 'interval', - 'half_open_interval', - 'is_dependent', - 'less_than', - 'lower_cholesky', - 'lower_triangular', - 'multinomial', - 'nonnegative_integer', - 'positive', - 'positive_semidefinite', - 'positive_definite', - 'positive_integer', - 'real', - 'real_vector', - 'simplex', - 'square', - 'stack', - 'symmetric', - 'unit_interval', + "Constraint", + "boolean", + "cat", + "corr_cholesky", + "dependent", + "dependent_property", + "greater_than", + "greater_than_eq", + "independent", + "integer_interval", + "interval", + "half_open_interval", + "is_dependent", + "less_than", + "lower_cholesky", + "lower_triangular", + "multinomial", + "nonnegative_integer", + "positive", + "positive_semidefinite", + "positive_definite", + "positive_integer", + "real", + "real_vector", + "simplex", + "square", + "stack", + "symmetric", + "unit_interval", ] @@ -79,6 +79,7 @@ class Constraint: an event. The :meth:`check` method will remove this many dimensions when computing validity. """ + is_discrete = False # Default to continuous. event_dim = 0 # Default to univariate. @@ -90,7 +91,7 @@ class Constraint: raise NotImplementedError def __repr__(self): - return self.__class__.__name__[1:] + '()' + return self.__class__.__name__[1:] + "()" class _Dependent(Constraint): @@ -106,6 +107,7 @@ class _Dependent(Constraint): can be computed statically. If not provided, access to the ``.event_dim`` attribute will raise a NotImplementedError. """ + def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): self._is_discrete = is_discrete self._event_dim = event_dim @@ -136,7 +138,7 @@ class _Dependent(Constraint): return _Dependent(is_discrete=is_discrete, event_dim=event_dim) def check(self, x): - raise ValueError('Cannot determine validity of dependent constraint') + raise ValueError("Cannot determine validity of dependent constraint") def is_dependent(constraint): @@ -167,7 +169,10 @@ class _DependentProperty(property, _Dependent): can be computed statically. If not provided, access to the ``.event_dim`` attribute will raise a NotImplementedError. """ - def __init__(self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented): + + def __init__( + self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented + ): super().__init__(fn) self._is_discrete = is_discrete self._event_dim = event_dim @@ -180,7 +185,9 @@ class _DependentProperty(property, _Dependent): def support(self): ... """ - return _DependentProperty(fn, is_discrete=self._is_discrete, event_dim=self._event_dim) + return _DependentProperty( + fn, is_discrete=self._is_discrete, event_dim=self._event_dim + ) class _IndependentConstraint(Constraint): @@ -189,6 +196,7 @@ class _IndependentConstraint(Constraint): dims in :meth:`check`, so that an event is valid only if all its independent entries are valid. """ + def __init__(self, base_constraint, reinterpreted_batch_ndims): assert isinstance(base_constraint, Constraint) assert isinstance(reinterpreted_batch_ndims, int) @@ -209,8 +217,12 @@ class _IndependentConstraint(Constraint): result = self.base_constraint.check(value) if result.dim() < self.reinterpreted_batch_ndims: expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims - raise ValueError(f"Expected value.dim() >= {expected} but got {value.dim()}") - result = result.reshape(result.shape[:result.dim() - self.reinterpreted_batch_ndims] + (-1,)) + raise ValueError( + f"Expected value.dim() >= {expected} but got {value.dim()}" + ) + result = result.reshape( + result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,) + ) result = result.all(-1) return result @@ -222,6 +234,7 @@ class _Boolean(Constraint): """ Constrain to the two values `{0, 1}`. """ + is_discrete = True def check(self, value): @@ -232,6 +245,7 @@ class _OneHot(Constraint): """ Constrain to one-hot vectors. """ + is_discrete = True event_dim = 1 @@ -245,6 +259,7 @@ class _IntegerInterval(Constraint): """ Constrain to an integer interval `[lower_bound, upper_bound]`. """ + is_discrete = True def __init__(self, lower_bound, upper_bound): @@ -253,11 +268,15 @@ class _IntegerInterval(Constraint): super().__init__() def check(self, value): - return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) + return ( + (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) + ) def __repr__(self): fmt_string = self.__class__.__name__[1:] - fmt_string += f'(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})' + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) return fmt_string @@ -265,6 +284,7 @@ class _IntegerLessThan(Constraint): """ Constrain to an integer interval `(-inf, upper_bound]`. """ + is_discrete = True def __init__(self, upper_bound): @@ -276,7 +296,7 @@ class _IntegerLessThan(Constraint): def __repr__(self): fmt_string = self.__class__.__name__[1:] - fmt_string += f'(upper_bound={self.upper_bound})' + fmt_string += f"(upper_bound={self.upper_bound})" return fmt_string @@ -284,6 +304,7 @@ class _IntegerGreaterThan(Constraint): """ Constrain to an integer interval `[lower_bound, inf)`. """ + is_discrete = True def __init__(self, lower_bound): @@ -295,7 +316,7 @@ class _IntegerGreaterThan(Constraint): def __repr__(self): fmt_string = self.__class__.__name__[1:] - fmt_string += f'(lower_bound={self.lower_bound})' + fmt_string += f"(lower_bound={self.lower_bound})" return fmt_string @@ -303,6 +324,7 @@ class _Real(Constraint): """ Trivially constrain to the extended real line `[-inf, inf]`. """ + def check(self, value): return value == value # False for NANs. @@ -311,6 +333,7 @@ class _GreaterThan(Constraint): """ Constrain to a real half line `(lower_bound, inf]`. """ + def __init__(self, lower_bound): self.lower_bound = lower_bound super().__init__() @@ -320,7 +343,7 @@ class _GreaterThan(Constraint): def __repr__(self): fmt_string = self.__class__.__name__[1:] - fmt_string += f'(lower_bound={self.lower_bound})' + fmt_string += f"(lower_bound={self.lower_bound})" return fmt_string @@ -328,6 +351,7 @@ class _GreaterThanEq(Constraint): """ Constrain to a real half line `[lower_bound, inf)`. """ + def __init__(self, lower_bound): self.lower_bound = lower_bound super().__init__() @@ -337,7 +361,7 @@ class _GreaterThanEq(Constraint): def __repr__(self): fmt_string = self.__class__.__name__[1:] - fmt_string += f'(lower_bound={self.lower_bound})' + fmt_string += f"(lower_bound={self.lower_bound})" return fmt_string @@ -345,6 +369,7 @@ class _LessThan(Constraint): """ Constrain to a real half line `[-inf, upper_bound)`. """ + def __init__(self, upper_bound): self.upper_bound = upper_bound super().__init__() @@ -354,7 +379,7 @@ class _LessThan(Constraint): def __repr__(self): fmt_string = self.__class__.__name__[1:] - fmt_string += f'(upper_bound={self.upper_bound})' + fmt_string += f"(upper_bound={self.upper_bound})" return fmt_string @@ -362,6 +387,7 @@ class _Interval(Constraint): """ Constrain to a real interval `[lower_bound, upper_bound]`. """ + def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound @@ -372,7 +398,9 @@ class _Interval(Constraint): def __repr__(self): fmt_string = self.__class__.__name__[1:] - fmt_string += f'(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})' + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) return fmt_string @@ -380,6 +408,7 @@ class _HalfOpenInterval(Constraint): """ Constrain to a real interval `[lower_bound, upper_bound)`. """ + def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound @@ -390,7 +419,9 @@ class _HalfOpenInterval(Constraint): def __repr__(self): fmt_string = self.__class__.__name__[1:] - fmt_string += f'(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})' + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) return fmt_string @@ -399,6 +430,7 @@ class _Simplex(Constraint): Constrain to the unit simplex in the innermost (rightmost) dimension. Specifically: `x >= 0` and `x.sum(-1) == 1`. """ + event_dim = 1 def check(self, value): @@ -413,6 +445,7 @@ class _Multinomial(Constraint): checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future this may be strengthened to ``value.sum(-1) == upper_bound``. """ + is_discrete = True event_dim = 1 @@ -427,6 +460,7 @@ class _LowerTriangular(Constraint): """ Constrain to lower-triangular square matrices. """ + event_dim = 2 def check(self, value): @@ -438,11 +472,14 @@ class _LowerCholesky(Constraint): """ Constrain to lower-triangular square matrices with positive diagonals. """ + event_dim = 2 def check(self, value): value_tril = value.tril() - lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + lower_triangular = ( + (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + ) positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0] return lower_triangular & positive_diagonal @@ -453,12 +490,15 @@ class _CorrCholesky(Constraint): Constrain to lower-triangular square matrices with positive diagonals and each row vector being of unit length. """ + event_dim = 2 def check(self, value): - tol = torch.finfo(value.dtype).eps * value.size(-1) * 10 # 10 is an adjustable fudge factor + tol = ( + torch.finfo(value.dtype).eps * value.size(-1) * 10 + ) # 10 is an adjustable fudge factor row_norm = torch.linalg.norm(value.detach(), dim=-1) - unit_row_norm = (row_norm - 1.).abs().le(tol).all(dim=-1) + unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1) return _LowerCholesky().check(value) & unit_row_norm @@ -466,6 +506,7 @@ class _Square(Constraint): """ Constrain to square matrices. """ + event_dim = 2 def check(self, value): @@ -473,7 +514,7 @@ class _Square(Constraint): size=value.shape[:-2], fill_value=(value.shape[-2] == value.shape[-1]), dtype=torch.bool, - device=value.device + device=value.device, ) @@ -493,6 +534,7 @@ class _PositiveSemidefinite(_Symmetric): """ Constrain to positive-semidefinite matrices. """ + def check(self, value): sym_check = super().check(value) if not sym_check.all(): @@ -504,6 +546,7 @@ class _PositiveDefinite(_Symmetric): """ Constrain to positive-definite matrices. """ + def check(self, value): sym_check = super().check(value) if not sym_check.all(): @@ -517,6 +560,7 @@ class _Cat(Constraint): `cseq` at the submatrices at dimension `dim`, each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`. """ + def __init__(self, cseq, dim=0, lengths=None): assert all(isinstance(c, Constraint) for c in cseq) self.cseq = list(cseq) @@ -552,6 +596,7 @@ class _Stack(Constraint): `cseq` at the submatrices at dimension `dim`, in a way compatible with :func:`torch.stack`. """ + def __init__(self, cseq, dim=0): assert all(isinstance(c, Constraint) for c in cseq) self.cseq = list(cseq) @@ -572,8 +617,9 @@ class _Stack(Constraint): def check(self, value): assert -value.dim() <= self.dim < value.dim() vs = [value.select(self.dim, i) for i in range(value.size(self.dim))] - return torch.stack([constr.check(v) - for v, constr in zip(vs, self.cseq)], self.dim) + return torch.stack( + [constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim + ) # Public interface. @@ -587,13 +633,13 @@ positive_integer = _IntegerGreaterThan(1) integer_interval = _IntegerInterval real = _Real() real_vector = independent(real, 1) -positive = _GreaterThan(0.) -nonnegative = _GreaterThanEq(0.) +positive = _GreaterThan(0.0) +nonnegative = _GreaterThanEq(0.0) greater_than = _GreaterThan greater_than_eq = _GreaterThanEq less_than = _LessThan multinomial = _Multinomial -unit_interval = _Interval(0., 1.) +unit_interval = _Interval(0.0, 1.0) interval = _Interval half_open_interval = _HalfOpenInterval simplex = _Simplex() diff --git a/torch/distributions/continuous_bernoulli.py b/torch/distributions/continuous_bernoulli.py index d14048566935..bcd879542840 100644 --- a/torch/distributions/continuous_bernoulli.py +++ b/torch/distributions/continuous_bernoulli.py @@ -1,13 +1,20 @@ -from numbers import Number import math +from numbers import Number import torch from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily -from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property, clamp_probs +from torch.distributions.utils import ( + broadcast_all, + clamp_probs, + lazy_property, + logits_to_probs, + probs_to_logits, +) from torch.nn.functional import binary_cross_entropy_with_logits -__all__ = ['ContinuousBernoulli'] +__all__ = ["ContinuousBernoulli"] + class ContinuousBernoulli(ExponentialFamily): r""" @@ -35,27 +42,30 @@ class ContinuousBernoulli(ExponentialFamily): autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019. https://arxiv.org/abs/1907.06845 """ - arg_constraints = {'probs': constraints.unit_interval, - 'logits': constraints.real} + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.unit_interval _mean_carrier_measure = 0 has_rsample = True - def __init__(self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None): + def __init__( + self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None + ): if (probs is None) == (logits is None): - raise ValueError("Either `probs` or `logits` must be specified, but not both.") + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) if probs is not None: is_scalar = isinstance(probs, Number) - self.probs, = broadcast_all(probs) + (self.probs,) = broadcast_all(probs) # validate 'probs' here if necessary as it is later clamped for numerical stability # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass if validate_args is not None: - if not self.arg_constraints['probs'].check(self.probs).all(): + if not self.arg_constraints["probs"].check(self.probs).all(): raise ValueError("The parameter probs has invalid values") self.probs = clamp_probs(self.probs) else: is_scalar = isinstance(logits, Number) - self.logits, = broadcast_all(logits) + (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits if is_scalar: batch_shape = torch.Size() @@ -68,10 +78,10 @@ class ContinuousBernoulli(ExponentialFamily): new = self._get_checked_instance(ContinuousBernoulli, _instance) new._lims = self._lims batch_shape = torch.Size(batch_shape) - if 'probs' in self.__dict__: + if "probs" in self.__dict__: new.probs = self.probs.expand(batch_shape) new._param = new.probs - if 'logits' in self.__dict__: + if "logits" in self.__dict__: new.logits = self.logits.expand(batch_shape) new._param = new.logits super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False) @@ -82,27 +92,33 @@ class ContinuousBernoulli(ExponentialFamily): return self._param.new(*args, **kwargs) def _outside_unstable_region(self): - return torch.max(torch.le(self.probs, self._lims[0]), - torch.gt(self.probs, self._lims[1])) + return torch.max( + torch.le(self.probs, self._lims[0]), torch.gt(self.probs, self._lims[1]) + ) def _cut_probs(self): - return torch.where(self._outside_unstable_region(), - self.probs, - self._lims[0] * torch.ones_like(self.probs)) + return torch.where( + self._outside_unstable_region(), + self.probs, + self._lims[0] * torch.ones_like(self.probs), + ) def _cont_bern_log_norm(self): - '''computes the log normalizing constant as a function of the 'probs' parameter''' + """computes the log normalizing constant as a function of the 'probs' parameter""" cut_probs = self._cut_probs() - cut_probs_below_half = torch.where(torch.le(cut_probs, 0.5), - cut_probs, - torch.zeros_like(cut_probs)) - cut_probs_above_half = torch.where(torch.ge(cut_probs, 0.5), - cut_probs, - torch.ones_like(cut_probs)) - log_norm = torch.log(torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs))) - torch.where( + cut_probs_below_half = torch.where( + torch.le(cut_probs, 0.5), cut_probs, torch.zeros_like(cut_probs) + ) + cut_probs_above_half = torch.where( + torch.ge(cut_probs, 0.5), cut_probs, torch.ones_like(cut_probs) + ) + log_norm = torch.log( + torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs)) + ) - torch.where( torch.le(cut_probs, 0.5), torch.log1p(-2.0 * cut_probs_below_half), - torch.log(2.0 * cut_probs_above_half - 1.0)) + torch.log(2.0 * cut_probs_above_half - 1.0), + ) x = torch.pow(self.probs - 0.5, 2) taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x return torch.where(self._outside_unstable_region(), log_norm, taylor) @@ -110,7 +126,9 @@ class ContinuousBernoulli(ExponentialFamily): @property def mean(self): cut_probs = self._cut_probs() - mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (torch.log1p(-cut_probs) - torch.log(cut_probs)) + mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / ( + torch.log1p(-cut_probs) - torch.log(cut_probs) + ) x = self.probs - 0.5 taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x return torch.where(self._outside_unstable_region(), mus, taylor) @@ -122,10 +140,11 @@ class ContinuousBernoulli(ExponentialFamily): @property def variance(self): cut_probs = self._cut_probs() - vars = cut_probs * (cut_probs - 1.0) / torch.pow(1.0 - 2.0 * cut_probs, 2) + 1.0 / torch.pow( - torch.log1p(-cut_probs) - torch.log(cut_probs), 2) + vars = cut_probs * (cut_probs - 1.0) / torch.pow( + 1.0 - 2.0 * cut_probs, 2 + ) + 1.0 / torch.pow(torch.log1p(-cut_probs) - torch.log(cut_probs), 2) x = torch.pow(self.probs - 0.5, 2) - taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128. / 945.0 * x) * x + taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x return torch.where(self._outside_unstable_region(), vars, taylor) @lazy_property @@ -155,44 +174,62 @@ class ContinuousBernoulli(ExponentialFamily): if self._validate_args: self._validate_sample(value) logits, value = broadcast_all(self.logits, value) - return -binary_cross_entropy_with_logits(logits, value, reduction='none') + self._cont_bern_log_norm() + return ( + -binary_cross_entropy_with_logits(logits, value, reduction="none") + + self._cont_bern_log_norm() + ) def cdf(self, value): if self._validate_args: self._validate_sample(value) cut_probs = self._cut_probs() - cdfs = (torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value) - + cut_probs - 1.0) / (2.0 * cut_probs - 1.0) + cdfs = ( + torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value) + + cut_probs + - 1.0 + ) / (2.0 * cut_probs - 1.0) unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value) return torch.where( torch.le(value, 0.0), torch.zeros_like(value), - torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs)) + torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs), + ) def icdf(self, value): cut_probs = self._cut_probs() return torch.where( self._outside_unstable_region(), - (torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0)) - - torch.log1p(-cut_probs)) / (torch.log(cut_probs) - torch.log1p(-cut_probs)), - value) + ( + torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0)) + - torch.log1p(-cut_probs) + ) + / (torch.log(cut_probs) - torch.log1p(-cut_probs)), + value, + ) def entropy(self): log_probs0 = torch.log1p(-self.probs) log_probs1 = torch.log(self.probs) - return self.mean * (log_probs0 - log_probs1) - self._cont_bern_log_norm() - log_probs0 + return ( + self.mean * (log_probs0 - log_probs1) + - self._cont_bern_log_norm() + - log_probs0 + ) @property def _natural_params(self): - return (self.logits, ) + return (self.logits,) def _log_normalizer(self, x): """computes the log normalizing constant as a function of the natural parameter""" - out_unst_reg = torch.max(torch.le(x, self._lims[0] - 0.5), - torch.gt(x, self._lims[1] - 0.5)) - cut_nat_params = torch.where(out_unst_reg, - x, - (self._lims[0] - 0.5) * torch.ones_like(x)) - log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log(torch.abs(cut_nat_params)) + out_unst_reg = torch.max( + torch.le(x, self._lims[0] - 0.5), torch.gt(x, self._lims[1] - 0.5) + ) + cut_nat_params = torch.where( + out_unst_reg, x, (self._lims[0] - 0.5) * torch.ones_like(x) + ) + log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log( + torch.abs(cut_nat_params) + ) taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0 return torch.where(out_unst_reg, log_norm, taylor) diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index 0a38ff50c268..2aa03e53ab63 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -4,7 +4,8 @@ from torch.autograd.function import once_differentiable from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily -__all__ = ['Dirichlet'] +__all__ = ["Dirichlet"] + # This helper is exposed for testing. def _Dirichlet_backward(x, concentration, grad_output): @@ -42,13 +43,17 @@ class Dirichlet(ExponentialFamily): concentration (Tensor): concentration parameter of the distribution (often referred to as alpha) """ - arg_constraints = {'concentration': constraints.independent(constraints.positive, 1)} + arg_constraints = { + "concentration": constraints.independent(constraints.positive, 1) + } support = constraints.simplex has_rsample = True def __init__(self, concentration, validate_args=None): if concentration.dim() < 1: - raise ValueError("`concentration` parameter must be at least one-dimensional.") + raise ValueError( + "`concentration` parameter must be at least one-dimensional." + ) self.concentration = concentration batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:] super().__init__(batch_shape, event_shape, validate_args=validate_args) @@ -57,7 +62,9 @@ class Dirichlet(ExponentialFamily): new = self._get_checked_instance(Dirichlet, _instance) batch_shape = torch.Size(batch_shape) new.concentration = self.concentration.expand(batch_shape + self.event_shape) - super(Dirichlet, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(Dirichlet, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -69,9 +76,11 @@ class Dirichlet(ExponentialFamily): def log_prob(self, value): if self._validate_args: self._validate_sample(value) - return (torch.xlogy(self.concentration - 1.0, value).sum(-1) + - torch.lgamma(self.concentration.sum(-1)) - - torch.lgamma(self.concentration).sum(-1)) + return ( + torch.xlogy(self.concentration - 1.0, value).sum(-1) + + torch.lgamma(self.concentration.sum(-1)) + - torch.lgamma(self.concentration).sum(-1) + ) @property def mean(self): @@ -79,27 +88,36 @@ class Dirichlet(ExponentialFamily): @property def mode(self): - concentrationm1 = (self.concentration - 1).clamp(min=0.) + concentrationm1 = (self.concentration - 1).clamp(min=0.0) mode = concentrationm1 / concentrationm1.sum(-1, True) mask = (self.concentration < 1).all(axis=-1) - mode[mask] = torch.nn.functional.one_hot(mode[mask].argmax(axis=-1), concentrationm1.shape[-1]).to(mode) + mode[mask] = torch.nn.functional.one_hot( + mode[mask].argmax(axis=-1), concentrationm1.shape[-1] + ).to(mode) return mode @property def variance(self): con0 = self.concentration.sum(-1, True) - return self.concentration * (con0 - self.concentration) / (con0.pow(2) * (con0 + 1)) + return ( + self.concentration + * (con0 - self.concentration) + / (con0.pow(2) * (con0 + 1)) + ) def entropy(self): k = self.concentration.size(-1) a0 = self.concentration.sum(-1) - return (torch.lgamma(self.concentration).sum(-1) - torch.lgamma(a0) - - (k - a0) * torch.digamma(a0) - - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1)) + return ( + torch.lgamma(self.concentration).sum(-1) + - torch.lgamma(a0) + - (k - a0) * torch.digamma(a0) + - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1) + ) @property def _natural_params(self): - return (self.concentration, ) + return (self.concentration,) def _log_normalizer(self, x): return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1)) diff --git a/torch/distributions/exp_family.py b/torch/distributions/exp_family.py index 8db7075c7fff..e60f6489d5bf 100644 --- a/torch/distributions/exp_family.py +++ b/torch/distributions/exp_family.py @@ -1,7 +1,8 @@ import torch from torch.distributions.distribution import Distribution -__all__ = ['ExponentialFamily'] +__all__ = ["ExponentialFamily"] + class ExponentialFamily(Distribution): r""" diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index f333bfc18b75..5c40b4ad5778 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -5,7 +5,8 @@ from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all -__all__ = ['Exponential'] +__all__ = ["Exponential"] + class Exponential(ExponentialFamily): r""" @@ -21,7 +22,7 @@ class Exponential(ExponentialFamily): Args: rate (float or Tensor): rate = 1 / scale of the distribution """ - arg_constraints = {'rate': constraints.positive} + arg_constraints = {"rate": constraints.positive} support = constraints.nonnegative has_rsample = True _mean_carrier_measure = 0 @@ -43,7 +44,7 @@ class Exponential(ExponentialFamily): return self.rate.pow(-2) def __init__(self, rate, validate_args=None): - self.rate, = broadcast_all(rate) + (self.rate,) = broadcast_all(rate) batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size() super().__init__(batch_shape, validate_args=validate_args) @@ -77,7 +78,7 @@ class Exponential(ExponentialFamily): @property def _natural_params(self): - return (-self.rate, ) + return (-self.rate,) def _log_normalizer(self, x): return -torch.log(-x) diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py index 26511ab4b894..d0d80732b233 100644 --- a/torch/distributions/fishersnedecor.py +++ b/torch/distributions/fishersnedecor.py @@ -1,4 +1,5 @@ from numbers import Number + import torch from torch import nan from torch.distributions import constraints @@ -6,7 +7,8 @@ from torch.distributions.distribution import Distribution from torch.distributions.gamma import Gamma from torch.distributions.utils import broadcast_all -__all__ = ['FisherSnedecor'] +__all__ = ["FisherSnedecor"] + class FisherSnedecor(Distribution): r""" @@ -23,7 +25,7 @@ class FisherSnedecor(Distribution): df1 (float or Tensor): degrees of freedom parameter 1 df2 (float or Tensor): degrees of freedom parameter 2 """ - arg_constraints = {'df1': constraints.positive, 'df2': constraints.positive} + arg_constraints = {"df1": constraints.positive, "df2": constraints.positive} support = constraints.positive has_rsample = True @@ -65,7 +67,12 @@ class FisherSnedecor(Distribution): def variance(self): df2 = self.df2.clone(memory_format=torch.contiguous_format) df2[df2 <= 4] = nan - return 2 * df2.pow(2) * (self.df1 + df2 - 2) / (self.df1 * (df2 - 2).pow(2) * (df2 - 4)) + return ( + 2 + * df2.pow(2) + * (self.df1 + df2 - 2) + / (self.df1 * (df2 - 2).pow(2) * (df2 - 4)) + ) def rsample(self, sample_shape=torch.Size(())): shape = self._extended_shape(sample_shape) diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py index 2601109dcb4f..11e689f8eaa7 100644 --- a/torch/distributions/gamma.py +++ b/torch/distributions/gamma.py @@ -5,7 +5,8 @@ from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all -__all__ = ['Gamma'] +__all__ = ["Gamma"] + def _standard_gamma(concentration): return torch._standard_gamma(concentration) @@ -28,7 +29,10 @@ class Gamma(ExponentialFamily): rate (float or Tensor): rate = 1 / scale of the distribution (often referred to as beta) """ - arg_constraints = {'concentration': constraints.positive, 'rate': constraints.positive} + arg_constraints = { + "concentration": constraints.positive, + "rate": constraints.positive, + } support = constraints.nonnegative has_rsample = True _mean_carrier_measure = 0 @@ -64,21 +68,32 @@ class Gamma(ExponentialFamily): def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape) - value.detach().clamp_(min=torch.finfo(value.dtype).tiny) # do not record in autograd graph + value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand( + shape + ) + value.detach().clamp_( + min=torch.finfo(value.dtype).tiny + ) # do not record in autograd graph return value def log_prob(self, value): value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device) if self._validate_args: self._validate_sample(value) - return (torch.xlogy(self.concentration, self.rate) + - torch.xlogy(self.concentration - 1, value) - - self.rate * value - torch.lgamma(self.concentration)) + return ( + torch.xlogy(self.concentration, self.rate) + + torch.xlogy(self.concentration - 1, value) + - self.rate * value + - torch.lgamma(self.concentration) + ) def entropy(self): - return (self.concentration - torch.log(self.rate) + torch.lgamma(self.concentration) + - (1.0 - self.concentration) * torch.digamma(self.concentration)) + return ( + self.concentration + - torch.log(self.rate) + + torch.lgamma(self.concentration) + + (1.0 - self.concentration) * torch.digamma(self.concentration) + ) @property def _natural_params(self): diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index 0cac28f6e9ef..c97f31d65ced 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -3,10 +3,16 @@ from numbers import Number import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution -from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) from torch.nn.functional import binary_cross_entropy_with_logits -__all__ = ['Geometric'] +__all__ = ["Geometric"] + class Geometric(Distribution): r""" @@ -28,17 +34,18 @@ class Geometric(Distribution): probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1] logits (Number, Tensor): the log-odds of sampling `1`. """ - arg_constraints = {'probs': constraints.unit_interval, - 'logits': constraints.real} + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.nonnegative_integer def __init__(self, probs=None, logits=None, validate_args=None): if (probs is None) == (logits is None): - raise ValueError("Either `probs` or `logits` must be specified, but not both.") + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) if probs is not None: - self.probs, = broadcast_all(probs) + (self.probs,) = broadcast_all(probs) else: - self.logits, = broadcast_all(logits) + (self.logits,) = broadcast_all(logits) probs_or_logits = probs if probs is not None else logits if isinstance(probs_or_logits, Number): batch_shape = torch.Size() @@ -61,9 +68,9 @@ class Geometric(Distribution): def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Geometric, _instance) batch_shape = torch.Size(batch_shape) - if 'probs' in self.__dict__: + if "probs" in self.__dict__: new.probs = self.probs.expand(batch_shape) - if 'logits' in self.__dict__: + if "logits" in self.__dict__: new.logits = self.logits.expand(batch_shape) super(Geometric, new).__init__(batch_shape, validate_args=False) new._validate_args = self._validate_args @@ -71,7 +78,7 @@ class Geometric(Distribution): @property def mean(self): - return 1. / self.probs - 1. + return 1.0 / self.probs - 1.0 @property def mode(self): @@ -79,7 +86,7 @@ class Geometric(Distribution): @property def variance(self): - return (1. / self.probs - 1.) / self.probs + return (1.0 / self.probs - 1.0) / self.probs @lazy_property def logits(self): @@ -110,4 +117,7 @@ class Geometric(Distribution): return value * (-probs).log1p() + self.probs.log() def entropy(self): - return binary_cross_entropy_with_logits(self.logits, self.probs, reduction='none') / self.probs + return ( + binary_cross_entropy_with_logits(self.logits, self.probs, reduction="none") + / self.probs + ) diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index eeb89b15e6f2..fab393cd8530 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -1,13 +1,15 @@ -from numbers import Number import math +from numbers import Number + import torch from torch.distributions import constraints -from torch.distributions.uniform import Uniform from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import AffineTransform, ExpTransform +from torch.distributions.uniform import Uniform from torch.distributions.utils import broadcast_all, euler_constant -__all__ = ['Gumbel'] +__all__ = ["Gumbel"] + class Gumbel(TransformedDistribution): r""" @@ -24,7 +26,7 @@ class Gumbel(TransformedDistribution): loc (float or Tensor): Location parameter of the distribution scale (float or Tensor): Scale parameter of the distribution """ - arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real def __init__(self, loc, scale, validate_args=None): @@ -33,11 +35,17 @@ class Gumbel(TransformedDistribution): if isinstance(loc, Number) and isinstance(scale, Number): base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args) else: - base_dist = Uniform(torch.full_like(self.loc, finfo.tiny), - torch.full_like(self.loc, 1 - finfo.eps), - validate_args=validate_args) - transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)), - ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)] + base_dist = Uniform( + torch.full_like(self.loc, finfo.tiny), + torch.full_like(self.loc, 1 - finfo.eps), + validate_args=validate_args, + ) + transforms = [ + ExpTransform().inv, + AffineTransform(loc=0, scale=-torch.ones_like(self.scale)), + ExpTransform().inv, + AffineTransform(loc=loc, scale=-self.scale), + ] super().__init__(base_dist, transforms, validate_args=validate_args) def expand(self, batch_shape, _instance=None): diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py index c50107654342..6fc33cdc6736 100644 --- a/torch/distributions/half_cauchy.py +++ b/torch/distributions/half_cauchy.py @@ -3,11 +3,12 @@ import math import torch from torch import inf from torch.distributions import constraints -from torch.distributions.transforms import AbsTransform from torch.distributions.cauchy import Cauchy from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AbsTransform + +__all__ = ["HalfCauchy"] -__all__ = ['HalfCauchy'] class HalfCauchy(TransformedDistribution): r""" @@ -26,7 +27,7 @@ class HalfCauchy(TransformedDistribution): Args: scale (float or Tensor): scale of the full Cauchy distribution """ - arg_constraints = {'scale': constraints.positive} + arg_constraints = {"scale": constraints.positive} support = constraints.nonnegative has_rsample = True @@ -44,7 +45,12 @@ class HalfCauchy(TransformedDistribution): @property def mean(self): - return torch.full(self._extended_shape(), math.inf, dtype=self.scale.dtype, device=self.scale.device) + return torch.full( + self._extended_shape(), + math.inf, + dtype=self.scale.dtype, + device=self.scale.device, + ) @property def mode(self): @@ -57,8 +63,9 @@ class HalfCauchy(TransformedDistribution): def log_prob(self, value): if self._validate_args: self._validate_sample(value) - value = torch.as_tensor(value, dtype=self.base_dist.scale.dtype, - device=self.base_dist.scale.device) + value = torch.as_tensor( + value, dtype=self.base_dist.scale.dtype, device=self.base_dist.scale.device + ) log_prob = self.base_dist.log_prob(value) + math.log(2) log_prob = torch.where(value >= 0, log_prob, -inf) return log_prob diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py index 184d6f16c3c3..475bbf6112cc 100644 --- a/torch/distributions/half_normal.py +++ b/torch/distributions/half_normal.py @@ -3,11 +3,12 @@ import math import torch from torch import inf from torch.distributions import constraints -from torch.distributions.transforms import AbsTransform from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AbsTransform + +__all__ = ["HalfNormal"] -__all__ = ['HalfNormal'] class HalfNormal(TransformedDistribution): r""" @@ -26,7 +27,7 @@ class HalfNormal(TransformedDistribution): Args: scale (float or Tensor): scale of the full Normal distribution """ - arg_constraints = {'scale': constraints.positive} + arg_constraints = {"scale": constraints.positive} support = constraints.nonnegative has_rsample = True diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index 44a01fd62f91..a58e81b7562e 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -1,10 +1,12 @@ +from typing import Dict + import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import _sum_rightmost -from typing import Dict -__all__ = ['Independent'] +__all__ = ["Independent"] + class Independent(Distribution): r""" @@ -37,15 +39,20 @@ class Independent(Distribution): """ arg_constraints: Dict[str, constraints.Constraint] = {} - def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None): + def __init__( + self, base_distribution, reinterpreted_batch_ndims, validate_args=None + ): if reinterpreted_batch_ndims > len(base_distribution.batch_shape): - raise ValueError("Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), " - "actual {} vs {}".format(reinterpreted_batch_ndims, - len(base_distribution.batch_shape))) + raise ValueError( + "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), " + "actual {} vs {}".format( + reinterpreted_batch_ndims, len(base_distribution.batch_shape) + ) + ) shape = base_distribution.batch_shape + base_distribution.event_shape event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape) - batch_shape = shape[:len(shape) - event_dim] - event_shape = shape[len(shape) - event_dim:] + batch_shape = shape[: len(shape) - event_dim] + event_shape = shape[len(shape) - event_dim :] self.base_dist = base_distribution self.reinterpreted_batch_ndims = reinterpreted_batch_ndims super().__init__(batch_shape, event_shape, validate_args=validate_args) @@ -53,10 +60,13 @@ class Independent(Distribution): def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Independent, _instance) batch_shape = torch.Size(batch_shape) - new.base_dist = self.base_dist.expand(batch_shape + - self.event_shape[:self.reinterpreted_batch_ndims]) + new.base_dist = self.base_dist.expand( + batch_shape + self.event_shape[: self.reinterpreted_batch_ndims] + ) new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims - super(Independent, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(Independent, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -105,8 +115,13 @@ class Independent(Distribution): def enumerate_support(self, expand=True): if self.reinterpreted_batch_ndims > 0: - raise NotImplementedError("Enumeration over cartesian product is not implemented") + raise NotImplementedError( + "Enumeration over cartesian product is not implemented" + ) return self.base_dist.enumerate_support(expand=expand) def __repr__(self): - return self.__class__.__name__ + f'({self.base_dist}, {self.reinterpreted_batch_ndims})' + return ( + self.__class__.__name__ + + f"({self.base_dist}, {self.reinterpreted_batch_ndims})" + ) diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 7b2ce5b58ecd..2b9db6ef2558 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -1,7 +1,7 @@ import math import warnings from functools import total_ordering -from typing import Type, Dict, Callable, Tuple +from typing import Callable, Dict, Tuple, Type import torch from torch import inf @@ -14,17 +14,20 @@ from .cauchy import Cauchy from .continuous_bernoulli import ContinuousBernoulli from .dirichlet import Dirichlet from .distribution import Distribution -from .exponential import Exponential from .exp_family import ExponentialFamily +from .exponential import Exponential from .gamma import Gamma from .geometric import Geometric from .gumbel import Gumbel from .half_normal import HalfNormal from .independent import Independent from .laplace import Laplace -from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet, - _batch_lowrank_mahalanobis) -from .multivariate_normal import (MultivariateNormal, _batch_mahalanobis) +from .lowrank_multivariate_normal import ( + _batch_lowrank_logdet, + _batch_lowrank_mahalanobis, + LowRankMultivariateNormal, +) +from .multivariate_normal import _batch_mahalanobis, MultivariateNormal from .normal import Normal from .one_hot_categorical import OneHotCategorical from .pareto import Pareto @@ -33,11 +36,16 @@ from .transformed_distribution import TransformedDistribution from .uniform import Uniform from .utils import _sum_rightmost, euler_constant as _euler_gamma -_KL_REGISTRY = {} # Source of truth mapping a few general (type, type) pairs to functions. -_KL_MEMOIZE: Dict[Tuple[Type, Type], Callable] = {} # Memoized version mapping many specific (type, type) pairs to functions. +_KL_REGISTRY = ( + {} +) # Source of truth mapping a few general (type, type) pairs to functions. +_KL_MEMOIZE: Dict[ + Tuple[Type, Type], Callable +] = {} # Memoized version mapping many specific (type, type) pairs to functions. __all__ = ["register_kl", "kl_divergence"] + def register_kl(type_p, type_q): """ Decorator to register a pairwise function with :meth:`kl_divergence`. @@ -65,9 +73,13 @@ def register_kl(type_p, type_q): type_q (type): A subclass of :class:`~torch.distributions.Distribution`. """ if not isinstance(type_p, type) and issubclass(type_p, Distribution): - raise TypeError(f'Expected type_p to be a Distribution subclass but got {type_p}') + raise TypeError( + f"Expected type_p to be a Distribution subclass but got {type_p}" + ) if not isinstance(type_q, type) and issubclass(type_q, Distribution): - raise TypeError(f'Expected type_q to be a Distribution subclass but got {type_q}') + raise TypeError( + f"Expected type_q to be a Distribution subclass but got {type_q}" + ) def decorator(fun): _KL_REGISTRY[type_p, type_q] = fun @@ -79,7 +91,7 @@ def register_kl(type_p, type_q): @total_ordering class _Match: - __slots__ = ['types'] + __slots__ = ["types"] def __init__(self, *types): self.types = types @@ -100,8 +112,11 @@ def _dispatch_kl(type_p, type_q): """ Find the most specific approximate match, assuming single inheritance. """ - matches = [(super_p, super_q) for super_p, super_q in _KL_REGISTRY - if issubclass(type_p, super_p) and issubclass(type_q, super_q)] + matches = [ + (super_p, super_q) + for super_p, super_q in _KL_REGISTRY + if issubclass(type_p, super_p) and issubclass(type_q, super_q) + ] if not matches: return NotImplemented # Check that the left- and right- lexicographic orders agree. @@ -112,9 +127,12 @@ def _dispatch_kl(type_p, type_q): left_fun = _KL_REGISTRY[left_p, left_q] right_fun = _KL_REGISTRY[right_p, right_q] if left_fun is not right_fun: - warnings.warn('Ambiguous kl_divergence({}, {}). Please register_kl({}, {})'.format( - type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__), - RuntimeWarning) + warnings.warn( + "Ambiguous kl_divergence({}, {}). Please register_kl({}, {})".format( + type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__ + ), + RuntimeWarning, + ) return left_fun @@ -167,8 +185,9 @@ def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor: fun = _dispatch_kl(type(p), type(q)) _KL_MEMOIZE[type(p), type(q)] = fun if fun is NotImplemented: - raise NotImplementedError("No KL(p || q) is implemented for p type {} and q type {}" - .format(p.__class__.__name__, q.__class__.__name__)) + raise NotImplementedError( + f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}" + ) return fun(p, q) @@ -181,10 +200,15 @@ def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor: @register_kl(Bernoulli, Bernoulli) def _kl_bernoulli_bernoulli(p, q): - t1 = p.probs * (torch.nn.functional.softplus(-q.logits) - torch.nn.functional.softplus(-p.logits)) + t1 = p.probs * ( + torch.nn.functional.softplus(-q.logits) + - torch.nn.functional.softplus(-p.logits) + ) t1[q.probs == 0] = inf t1[p.probs == 0] = 0 - t2 = (1 - p.probs) * (torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits)) + t2 = (1 - p.probs) * ( + torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits) + ) t2[q.probs == 1] = inf t2[p.probs == 1] = 0 return t1 + t2 @@ -207,8 +231,12 @@ def _kl_binomial_binomial(p, q): # from https://math.stackexchange.com/questions/2214993/ # kullback-leibler-divergence-for-binomial-distributions-p-and-q if (p.total_count < q.total_count).any(): - raise NotImplementedError('KL between Binomials where q.total_count > p.total_count is not implemented') - kl = p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p()) + raise NotImplementedError( + "KL between Binomials where q.total_count > p.total_count is not implemented" + ) + kl = p.total_count * ( + p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p() + ) inf_idxs = p.total_count > q.total_count kl[inf_idxs] = _infinite_like(kl[inf_idxs]) return kl @@ -226,7 +254,7 @@ def _kl_categorical_categorical(p, q): def _kl_continuous_bernoulli_continuous_bernoulli(p, q): t1 = p.mean * (p.logits - q.logits) t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs) - t3 = - q._cont_bern_log_norm() - torch.log1p(-q.probs) + t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs) return t1 + t2 + t3 @@ -252,8 +280,10 @@ def _kl_exponential_exponential(p, q): @register_kl(ExponentialFamily, ExponentialFamily) def _kl_expfamily_expfamily(p, q): if not type(p) == type(q): - raise NotImplementedError("The cross KL-divergence between different exponential families cannot \ - be computed using Bregman divergences") + raise NotImplementedError( + "The cross KL-divergence between different exponential families cannot \ + be computed using Bregman divergences" + ) p_nparams = [np.detach().requires_grad_() for np in p._natural_params] q_nparams = q._natural_params lg_normal = p._log_normalizer(*p_nparams) @@ -309,25 +339,31 @@ def _kl_laplace_laplace(p, q): @register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal) def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q): if p.event_shape != q.event_shape: - raise ValueError("KL-divergence between two Low Rank Multivariate Normals with\ - different event shapes cannot be computed") + raise ValueError( + "KL-divergence between two Low Rank Multivariate Normals with\ + different event shapes cannot be computed" + ) - term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, - q._capacitance_tril) - - _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, - p._capacitance_tril)) - term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, - q.loc - p.loc, - q._capacitance_tril) + term1 = _batch_lowrank_logdet( + q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril + ) - _batch_lowrank_logdet( + p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril + ) + term3 = _batch_lowrank_mahalanobis( + q._unbroadcasted_cov_factor, + q._unbroadcasted_cov_diag, + q.loc - p.loc, + q._capacitance_tril, + ) # Expands term2 according to # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD) # = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T) - qWt_qDinv = (q._unbroadcasted_cov_factor.mT / - q._unbroadcasted_cov_diag.unsqueeze(-2)) + qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1) - term22 = _batch_trace_XXT(p._unbroadcasted_cov_factor * - q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)) + term22 = _batch_trace_XXT( + p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) + ) term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2)) term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor)) term2 = term21 + term22 - term23 - term24 @@ -337,23 +373,28 @@ def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q): @register_kl(MultivariateNormal, LowRankMultivariateNormal) def _kl_multivariatenormal_lowrankmultivariatenormal(p, q): if p.event_shape != q.event_shape: - raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\ - different event shapes cannot be computed") + raise ValueError( + "KL-divergence between two (Low Rank) Multivariate Normals with\ + different event shapes cannot be computed" + ) - term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, - q._capacitance_tril) - - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)) - term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, - q.loc - p.loc, - q._capacitance_tril) + term1 = _batch_lowrank_logdet( + q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril + ) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + term3 = _batch_lowrank_mahalanobis( + q._unbroadcasted_cov_factor, + q._unbroadcasted_cov_diag, + q.loc - p.loc, + q._capacitance_tril, + ) # Expands term2 according to # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T # = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T - qWt_qDinv = (q._unbroadcasted_cov_factor.mT / - q._unbroadcasted_cov_diag.unsqueeze(-2)) + qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) - term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril * - q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)) + term21 = _batch_trace_XXT( + p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) + ) term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril)) term2 = term21 - term22 return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) @@ -362,25 +403,36 @@ def _kl_multivariatenormal_lowrankmultivariatenormal(p, q): @register_kl(LowRankMultivariateNormal, MultivariateNormal) def _kl_lowrankmultivariatenormal_multivariatenormal(p, q): if p.event_shape != q.event_shape: - raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\ - different event shapes cannot be computed") + raise ValueError( + "KL-divergence between two (Low Rank) Multivariate Normals with\ + different event shapes cannot be computed" + ) - term1 = (2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) - - _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, - p._capacitance_tril)) + term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( + -1 + ) - _batch_lowrank_logdet( + p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril + ) term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) # Expands term2 according to # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD) - combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2], - p._unbroadcasted_cov_factor.shape[:-2]) + combined_batch_shape = torch._C._infer_size( + q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2] + ) n = p.event_shape[0] q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) - p_cov_factor = p._unbroadcasted_cov_factor.expand(combined_batch_shape + - (n, p.cov_factor.size(-1))) - p_cov_diag = (torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()) - .expand(combined_batch_shape + (n, n))) - term21 = _batch_trace_XXT(torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False)) - term22 = _batch_trace_XXT(torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False)) + p_cov_factor = p._unbroadcasted_cov_factor.expand( + combined_batch_shape + (n, p.cov_factor.size(-1)) + ) + p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand( + combined_batch_shape + (n, n) + ) + term21 = _batch_trace_XXT( + torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False) + ) + term22 = _batch_trace_XXT( + torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False) + ) term2 = term21 + term22 return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) @@ -389,17 +441,23 @@ def _kl_lowrankmultivariatenormal_multivariatenormal(p, q): def _kl_multivariatenormal_multivariatenormal(p, q): # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence if p.event_shape != q.event_shape: - raise ValueError("KL-divergence between two Multivariate Normals with\ - different event shapes cannot be computed") + raise ValueError( + "KL-divergence between two Multivariate Normals with\ + different event shapes cannot be computed" + ) - half_term1 = (q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) - - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)) - combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2], - p._unbroadcasted_scale_tril.shape[:-2]) + half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( + -1 + ) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + combined_batch_shape = torch._C._infer_size( + q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2] + ) n = p.event_shape[0] q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) - term2 = _batch_trace_XXT(torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False)) + term2 = _batch_trace_XXT( + torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False) + ) term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) return half_term1 + 0.5 * (term2 + term3 - n) @@ -457,7 +515,12 @@ def _kl_bernoulli_poisson(p, q): @register_kl(Beta, ContinuousBernoulli) def _kl_beta_continuous_bernoulli(p, q): - return -p.entropy() - p.mean * q.logits - torch.log1p(-q.probs) - q._cont_bern_log_norm() + return ( + -p.entropy() + - p.mean * q.logits + - torch.log1p(-q.probs) + - q._cont_bern_log_norm() + ) @register_kl(Beta, Pareto) @@ -467,17 +530,24 @@ def _kl_beta_infinity(p, q): @register_kl(Beta, Exponential) def _kl_beta_exponential(p, q): - return -p.entropy() - q.rate.log() + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0)) + return ( + -p.entropy() + - q.rate.log() + + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0)) + ) @register_kl(Beta, Gamma) def _kl_beta_gamma(p, q): t1 = -p.entropy() t2 = q.concentration.lgamma() - q.concentration * q.rate.log() - t3 = (q.concentration - 1) * (p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma()) + t3 = (q.concentration - 1) * ( + p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma() + ) t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0) return t1 + t2 - t3 + t4 + # TODO: Add Beta-Laplace KL Divergence @@ -487,7 +557,10 @@ def _kl_beta_normal(p, q): var_normal = q.scale.pow(2) t1 = -p.entropy() t2 = 0.5 * (var_normal * 2 * math.pi).log() - t3 = (E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1) + E_beta.pow(2)) * 0.5 + t3 = ( + E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1) + + E_beta.pow(2) + ) * 0.5 t4 = q.loc * E_beta t5 = q.loc.pow(2) * 0.5 return t1 + t2 + (t3 - t4 + t5) / var_normal @@ -499,6 +572,7 @@ def _kl_beta_uniform(p, q): result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf return result + # Note that the KL between a ContinuousBernoulli and Beta has no closed form @@ -511,6 +585,7 @@ def _kl_continuous_bernoulli_infinity(p, q): def _kl_continuous_bernoulli_exponential(p, q): return -p.entropy() - torch.log(q.rate) + q.rate * p.mean + # Note that the KL between a ContinuousBernoulli and Gamma has no closed form # TODO: Add ContinuousBernoulli-Laplace KL Divergence @@ -518,17 +593,26 @@ def _kl_continuous_bernoulli_exponential(p, q): @register_kl(ContinuousBernoulli, Normal) def _kl_continuous_bernoulli_normal(p, q): t1 = -p.entropy() - t2 = 0.5 * (math.log(2. * math.pi) + torch.square(q.loc / q.scale)) + torch.log(q.scale) - t3 = (p.variance + torch.square(p.mean) - 2. * q.loc * p.mean) / (2.0 * torch.square(q.scale)) + t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log( + q.scale + ) + t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / ( + 2.0 * torch.square(q.scale) + ) return t1 + t2 + t3 @register_kl(ContinuousBernoulli, Uniform) def _kl_continuous_bernoulli_uniform(p, q): result = -p.entropy() + (q.high - q.low).log() - return torch.where(torch.max(torch.ge(q.low, p.support.lower_bound), - torch.le(q.high, p.support.upper_bound)), - torch.ones_like(result) * inf, result) + return torch.where( + torch.max( + torch.ge(q.low, p.support.lower_bound), + torch.le(q.high, p.support.upper_bound), + ), + torch.ones_like(result) * inf, + result, + ) @register_kl(Exponential, Beta) @@ -543,7 +627,13 @@ def _kl_exponential_infinity(p, q): def _kl_exponential_gamma(p, q): ratio = q.rate / p.rate t1 = -q.concentration * torch.log(ratio) - return t1 + ratio + q.concentration.lgamma() + q.concentration * _euler_gamma - (1 + _euler_gamma) + return ( + t1 + + ratio + + q.concentration.lgamma() + + q.concentration * _euler_gamma + - (1 + _euler_gamma) + ) @register_kl(Exponential, Gumbel) @@ -555,6 +645,7 @@ def _kl_exponential_gumbel(p, q): t3 = scale_rate_prod.reciprocal() return t1 - loc_scale_ratio + t2 + t3 + # TODO: Add Exponential-Laplace KL Divergence @@ -586,11 +677,20 @@ def _kl_gamma_exponential(p, q): def _kl_gamma_gumbel(p, q): beta_scale_prod = p.rate * q.scale loc_scale_ratio = q.loc / q.scale - t1 = (p.concentration - 1) * p.concentration.digamma() - p.concentration.lgamma() - p.concentration + t1 = ( + (p.concentration - 1) * p.concentration.digamma() + - p.concentration.lgamma() + - p.concentration + ) t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod - t3 = torch.exp(loc_scale_ratio) * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration) - loc_scale_ratio + t3 = ( + torch.exp(loc_scale_ratio) + * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration) + - loc_scale_ratio + ) return t1 + t2 + t3 + # TODO: Add Gamma-Laplace KL Divergence @@ -598,11 +698,19 @@ def _kl_gamma_gumbel(p, q): def _kl_gamma_normal(p, q): var_normal = q.scale.pow(2) beta_sqr = p.rate.pow(2) - t1 = 0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi) - p.concentration - p.concentration.lgamma() + t1 = ( + 0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi) + - p.concentration + - p.concentration.lgamma() + ) t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr t3 = q.loc * p.concentration / p.rate t4 = 0.5 * q.loc.pow(2) - return t1 + (p.concentration - 1) * p.concentration.digamma() + (t2 - t3 + t4) / var_normal + return ( + t1 + + (p.concentration - 1) * p.concentration.digamma() + + (t2 - t3 + t4) / var_normal + ) @register_kl(Gumbel, Beta) @@ -614,6 +722,7 @@ def _kl_gamma_normal(p, q): def _kl_gumbel_infinity(p, q): return _infinite_like(p.loc) + # TODO: Add Gumbel-Laplace KL Divergence @@ -674,7 +783,9 @@ def _kl_normal_laplace(p, q): scale_ratio = p.scale / q.scale loc_diff_scale_ratio = loc_diff / p.scale t1 = torch.log(scale_ratio) - t2 = math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2)) + t2 = ( + math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2)) + ) t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio) return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi))) @@ -708,6 +819,7 @@ def _kl_pareto_gamma(p, q): result[p.alpha <= 1] = inf return result + # TODO: Add Pareto-Laplace KL Divergence @@ -734,9 +846,21 @@ def _kl_poisson_infinity(p, q): def _kl_uniform_beta(p, q): common_term = p.high - p.low t1 = torch.log(common_term) - t2 = (q.concentration1 - 1) * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) / common_term - t3 = (q.concentration0 - 1) * (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term) / common_term - t4 = q.concentration1.lgamma() + q.concentration0.lgamma() - (q.concentration1 + q.concentration0).lgamma() + t2 = ( + (q.concentration1 - 1) + * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) + / common_term + ) + t3 = ( + (q.concentration0 - 1) + * (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term) + / common_term + ) + t4 = ( + q.concentration1.lgamma() + + q.concentration0.lgamma() + - (q.concentration1 + q.concentration0).lgamma() + ) result = t3 + t4 - t1 - t2 result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf return result @@ -744,10 +868,20 @@ def _kl_uniform_beta(p, q): @register_kl(Uniform, ContinuousBernoulli) def _kl_uniform_continuous_bernoulli(p, q): - result = -p.entropy() - p.mean * q.logits - torch.log1p(-q.probs) - q._cont_bern_log_norm() - return torch.where(torch.max(torch.ge(p.high, q.support.upper_bound), - torch.le(p.low, q.support.lower_bound)), - torch.ones_like(result) * inf, result) + result = ( + -p.entropy() + - p.mean * q.logits + - torch.log1p(-q.probs) + - q._cont_bern_log_norm() + ) + return torch.where( + torch.max( + torch.ge(p.high, q.support.upper_bound), + torch.le(p.low, q.support.lower_bound), + ), + torch.ones_like(result) * inf, + result, + ) @register_kl(Uniform, Exponential) @@ -762,7 +896,11 @@ def _kl_uniform_gamma(p, q): common_term = p.high - p.low t1 = common_term.log() t2 = q.concentration.lgamma() - q.concentration * q.rate.log() - t3 = (1 - q.concentration) * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) / common_term + t3 = ( + (1 - q.concentration) + * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) + / common_term + ) t4 = q.rate * (p.high + p.low) / 2 result = -t1 + t2 + t3 + t4 result[p.low < q.support.lower_bound] = inf @@ -778,6 +916,7 @@ def _kl_uniform_gumbel(p, q): t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff)) return t1 - t2 + # TODO: Uniform-Laplace KL Divergence @@ -815,12 +954,18 @@ def _kl_cauchy_cauchy(p, q): t2 = (4 * p.scale * q.scale).log() return t1 - t2 + def _add_kl_info(): """Appends a list of implemented KL functions to the doc for kl_divergence.""" - rows = ["KL divergence is currently implemented for the following distribution pairs:"] - for p, q in sorted(_KL_REGISTRY, - key=lambda p_q: (p_q[0].__name__, p_q[1].__name__)): - rows.append(f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`") - kl_info = '\n\t'.join(rows) + rows = [ + "KL divergence is currently implemented for the following distribution pairs:" + ] + for p, q in sorted( + _KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__) + ): + rows.append( + f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`" + ) + kl_info = "\n\t".join(rows) if kl_divergence.__doc__: kl_divergence.__doc__ += kl_info # type: ignore[operator] diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py index 249cdf07b14c..d0b1c993dd10 100644 --- a/torch/distributions/kumaraswamy.py +++ b/torch/distributions/kumaraswamy.py @@ -1,12 +1,13 @@ import torch from torch import nan from torch.distributions import constraints -from torch.distributions.uniform import Uniform from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import AffineTransform, PowerTransform +from torch.distributions.uniform import Uniform from torch.distributions.utils import broadcast_all, euler_constant -__all__ = ['Kumaraswamy'] +__all__ = ["Kumaraswamy"] + def _moments(a, b, n): """ @@ -34,19 +35,28 @@ class Kumaraswamy(TransformedDistribution): concentration0 (float or Tensor): 2nd concentration parameter of the distribution (often referred to as beta) """ - arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive} + arg_constraints = { + "concentration1": constraints.positive, + "concentration0": constraints.positive, + } support = constraints.unit_interval has_rsample = True def __init__(self, concentration1, concentration0, validate_args=None): - self.concentration1, self.concentration0 = broadcast_all(concentration1, concentration0) + self.concentration1, self.concentration0 = broadcast_all( + concentration1, concentration0 + ) finfo = torch.finfo(self.concentration0.dtype) - base_dist = Uniform(torch.full_like(self.concentration0, 0), - torch.full_like(self.concentration0, 1), - validate_args=validate_args) - transforms = [PowerTransform(exponent=self.concentration0.reciprocal()), - AffineTransform(loc=1., scale=-1.), - PowerTransform(exponent=self.concentration1.reciprocal())] + base_dist = Uniform( + torch.full_like(self.concentration0, 0), + torch.full_like(self.concentration0, 1), + validate_args=validate_args, + ) + transforms = [ + PowerTransform(exponent=self.concentration0.reciprocal()), + AffineTransform(loc=1.0, scale=-1.0), + PowerTransform(exponent=self.concentration1.reciprocal()), + ] super().__init__(base_dist, transforms, validate_args=validate_args) def expand(self, batch_shape, _instance=None): @@ -62,17 +72,26 @@ class Kumaraswamy(TransformedDistribution): @property def mode(self): # Evaluate in log-space for numerical stability. - log_mode = self.concentration0.reciprocal() * \ - (-self.concentration0).log1p() - (-self.concentration0 * self.concentration1).log1p() + log_mode = ( + self.concentration0.reciprocal() * (-self.concentration0).log1p() + - (-self.concentration0 * self.concentration1).log1p() + ) log_mode[(self.concentration0 < 1) | (self.concentration1 < 1)] = nan return log_mode.exp() @property def variance(self): - return _moments(self.concentration1, self.concentration0, 2) - torch.pow(self.mean, 2) + return _moments(self.concentration1, self.concentration0, 2) - torch.pow( + self.mean, 2 + ) def entropy(self): - t1 = (1 - self.concentration1.reciprocal()) - t0 = (1 - self.concentration0.reciprocal()) + t1 = 1 - self.concentration1.reciprocal() + t0 = 1 - self.concentration0.reciprocal() H0 = torch.digamma(self.concentration0 + 1) + euler_constant - return t0 + t1 * H0 - torch.log(self.concentration1) - torch.log(self.concentration0) + return ( + t0 + + t1 * H0 + - torch.log(self.concentration1) + - torch.log(self.concentration0) + ) diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index 3dfe968eda35..bd21c6006705 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -1,10 +1,12 @@ from numbers import Number + import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all -__all__ = ['Laplace'] +__all__ = ["Laplace"] + class Laplace(Distribution): r""" @@ -21,7 +23,7 @@ class Laplace(Distribution): loc (float or Tensor): mean of the distribution scale (float or Tensor): scale of the distribution """ - arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real has_rsample = True @@ -39,7 +41,7 @@ class Laplace(Distribution): @property def stddev(self): - return (2 ** 0.5) * self.scale + return (2**0.5) * self.scale def __init__(self, loc, scale, validate_args=None): self.loc, self.scale = broadcast_all(loc, scale) @@ -64,7 +66,9 @@ class Laplace(Distribution): if torch._C._get_tracing_state(): # [JIT WORKAROUND] lack of support for .uniform_() u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1 - return self.loc - self.scale * u.sign() * torch.log1p(-u.abs().clamp(min=finfo.tiny)) + return self.loc - self.scale * u.sign() * torch.log1p( + -u.abs().clamp(min=finfo.tiny) + ) u = self.loc.new(shape).uniform_(finfo.eps - 1, 1) # TODO: If we ever implement tensor.nextafter, below is what we want ideally. # u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5) @@ -78,7 +82,9 @@ class Laplace(Distribution): def cdf(self, value): if self._validate_args: self._validate_sample(value) - return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1(-(value - self.loc).abs() / self.scale) + return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1( + -(value - self.loc).abs() / self.scale + ) def icdf(self, value): term = value - 0.5 diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py index dbc094adc2b8..86656120bc68 100644 --- a/torch/distributions/lkj_cholesky.py +++ b/torch/distributions/lkj_cholesky.py @@ -10,11 +10,12 @@ Original copyright notice: import math import torch -from torch.distributions import constraints, Beta +from torch.distributions import Beta, constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all -__all__ = ['LKJCholesky'] +__all__ = ["LKJCholesky"] + class LKJCholesky(Distribution): r""" @@ -54,19 +55,25 @@ class LKJCholesky(Distribution): Daniel Lewandowski, Dorota Kurowicka, Harry Joe. Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008 """ - arg_constraints = {'concentration': constraints.positive} + arg_constraints = {"concentration": constraints.positive} support = constraints.corr_cholesky - def __init__(self, dim, concentration=1., validate_args=None): + def __init__(self, dim, concentration=1.0, validate_args=None): if dim < 2: - raise ValueError(f'Expected dim to be an integer greater than or equal to 2. Found dim={dim}.') + raise ValueError( + f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}." + ) self.dim = dim - self.concentration, = broadcast_all(concentration) + (self.concentration,) = broadcast_all(concentration) batch_shape = self.concentration.size() event_shape = torch.Size((dim, dim)) # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1]. marginal_conc = self.concentration + 0.5 * (self.dim - 2) - offset = torch.arange(self.dim - 1, dtype=self.concentration.dtype, device=self.concentration.device) + offset = torch.arange( + self.dim - 1, + dtype=self.concentration.dtype, + device=self.concentration.device, + ) offset = torch.cat([offset.new_zeros((1,)), offset]) beta_conc1 = offset + 0.5 beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset @@ -79,7 +86,9 @@ class LKJCholesky(Distribution): new.dim = self.dim new.concentration = self.concentration.expand(batch_shape) new._beta = self._beta.expand(batch_shape + (self.dim,)) - super(LKJCholesky, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(LKJCholesky, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -91,12 +100,12 @@ class LKJCholesky(Distribution): # the correlation matrix instead of the correlation matrix itself. As such, # we only need to generate `w`. y = self._beta.sample(sample_shape).unsqueeze(-1) - u_normal = torch.randn(self._extended_shape(sample_shape), - dtype=y.dtype, - device=y.device).tril(-1) + u_normal = torch.randn( + self._extended_shape(sample_shape), dtype=y.dtype, device=y.device + ).tril(-1) u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True) # Replace NaNs in first row - u_hypersphere[..., 0, :].fill_(0.) + u_hypersphere[..., 0, :].fill_(0.0) w = torch.sqrt(y) * u_hypersphere # Fill diagonal elements; clamp for numerical stability eps = torch.finfo(w.dtype).tiny diff --git a/torch/distributions/log_normal.py b/torch/distributions/log_normal.py index 1621b5cc2bd5..23111637d30e 100644 --- a/torch/distributions/log_normal.py +++ b/torch/distributions/log_normal.py @@ -1,9 +1,10 @@ from torch.distributions import constraints -from torch.distributions.transforms import ExpTransform from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import ExpTransform + +__all__ = ["LogNormal"] -__all__ = ['LogNormal'] class LogNormal(TransformedDistribution): r""" @@ -24,7 +25,7 @@ class LogNormal(TransformedDistribution): loc (float or Tensor): mean of log of distribution scale (float or Tensor): standard deviation of log of the distribution """ - arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.positive has_rsample = True diff --git a/torch/distributions/logistic_normal.py b/torch/distributions/logistic_normal.py index d424f1b14004..ac072d0d4d6c 100644 --- a/torch/distributions/logistic_normal.py +++ b/torch/distributions/logistic_normal.py @@ -3,7 +3,8 @@ from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import StickBreakingTransform -__all__ = ['LogisticNormal'] +__all__ = ["LogisticNormal"] + class LogisticNormal(TransformedDistribution): r""" @@ -28,7 +29,7 @@ class LogisticNormal(TransformedDistribution): tensor([ 0.7653, 0.0341, 0.0579, 0.1427]) """ - arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.simplex has_rsample = True @@ -36,7 +37,9 @@ class LogisticNormal(TransformedDistribution): base_dist = Normal(loc, scale, validate_args=validate_args) if not base_dist.batch_shape: base_dist = base_dist.expand([1]) - super().__init__(base_dist, StickBreakingTransform(), validate_args=validate_args) + super().__init__( + base_dist, StickBreakingTransform(), validate_args=validate_args + ) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(LogisticNormal, _instance) diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index 7ba920e970bc..a3acaa990966 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -6,7 +6,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv from torch.distributions.utils import _standard_normal, lazy_property -__all__ = ['LowRankMultivariateNormal'] +__all__ = ["LowRankMultivariateNormal"] def _batch_capacitance_tril(W, D): @@ -17,7 +17,7 @@ def _batch_capacitance_tril(W, D): m = W.size(-1) Wt_Dinv = W.mT / D.unsqueeze(-2) K = torch.matmul(Wt_Dinv, W).contiguous() - K.view(-1, m * m)[:, ::m + 1] += 1 # add identity matrix to K + K.view(-1, m * m)[:, :: m + 1] += 1 # add identity matrix to K return torch.linalg.cholesky(K) @@ -28,7 +28,9 @@ def _batch_lowrank_logdet(W, D, capacitance_tril): where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the log determinant. """ - return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(-1) + return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum( + -1 + ) def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril): @@ -76,9 +78,11 @@ class LowRankMultivariateNormal(Distribution): capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor """ - arg_constraints = {"loc": constraints.real_vector, - "cov_factor": constraints.independent(constraints.real, 2), - "cov_diag": constraints.independent(constraints.positive, 1)} + arg_constraints = { + "loc": constraints.real_vector, + "cov_factor": constraints.independent(constraints.real, 2), + "cov_diag": constraints.independent(constraints.positive, 1), + } support = constraints.real_vector has_rsample = True @@ -87,20 +91,29 @@ class LowRankMultivariateNormal(Distribution): raise ValueError("loc must be at least one-dimensional.") event_shape = loc.shape[-1:] if cov_factor.dim() < 2: - raise ValueError("cov_factor must be at least two-dimensional, " - "with optional leading batch dimensions") + raise ValueError( + "cov_factor must be at least two-dimensional, " + "with optional leading batch dimensions" + ) if cov_factor.shape[-2:-1] != event_shape: - raise ValueError(f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m") + raise ValueError( + f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m" + ) if cov_diag.shape[-1:] != event_shape: - raise ValueError(f"cov_diag must be a batch of vectors with shape {event_shape}") + raise ValueError( + f"cov_diag must be a batch of vectors with shape {event_shape}" + ) loc_ = loc.unsqueeze(-1) cov_diag_ = cov_diag.unsqueeze(-1) try: - loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_) + loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors( + loc_, cov_factor, cov_diag_ + ) except RuntimeError as e: - raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}" - .format(loc.shape, cov_factor.shape, cov_diag.shape)) from e + raise ValueError( + f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}" + ) from e self.loc = loc_[..., 0] self.cov_diag = cov_diag_[..., 0] batch_shape = self.loc.shape[:-1] @@ -120,9 +133,9 @@ class LowRankMultivariateNormal(Distribution): new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag new._capacitance_tril = self._capacitance_tril - super(LowRankMultivariateNormal, new).__init__(batch_shape, - self.event_shape, - validate_args=False) + super(LowRankMultivariateNormal, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -136,8 +149,9 @@ class LowRankMultivariateNormal(Distribution): @lazy_property def variance(self): - return (self._unbroadcasted_cov_factor.pow(2).sum(-1) - + self._unbroadcasted_cov_diag).expand(self._batch_shape + self._event_shape) + return ( + self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag + ).expand(self._batch_shape + self._event_shape) @lazy_property def scale_tril(self): @@ -150,55 +164,72 @@ class LowRankMultivariateNormal(Distribution): cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1) Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous() - K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K + K.view(-1, n * n)[:, :: n + 1] += 1 # add identity matrix to K scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K) - return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape) + return scale_tril.expand( + self._batch_shape + self._event_shape + self._event_shape + ) @lazy_property def covariance_matrix(self): - covariance_matrix = (torch.matmul(self._unbroadcasted_cov_factor, - self._unbroadcasted_cov_factor.mT) - + torch.diag_embed(self._unbroadcasted_cov_diag)) - return covariance_matrix.expand(self._batch_shape + self._event_shape + - self._event_shape) + covariance_matrix = torch.matmul( + self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT + ) + torch.diag_embed(self._unbroadcasted_cov_diag) + return covariance_matrix.expand( + self._batch_shape + self._event_shape + self._event_shape + ) @lazy_property def precision_matrix(self): # We use "Woodbury matrix identity" to take advantage of low rank form:: # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D) # where :math:`C` is the capacitance matrix. - Wt_Dinv = (self._unbroadcasted_cov_factor.mT - / self._unbroadcasted_cov_diag.unsqueeze(-2)) + Wt_Dinv = ( + self._unbroadcasted_cov_factor.mT + / self._unbroadcasted_cov_diag.unsqueeze(-2) + ) A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False) - precision_matrix = torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A - return precision_matrix.expand(self._batch_shape + self._event_shape + - self._event_shape) + precision_matrix = ( + torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A + ) + return precision_matrix.expand( + self._batch_shape + self._event_shape + self._event_shape + ) def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) W_shape = shape[:-1] + self.cov_factor.shape[-1:] eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device) eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) - return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W) - + self._unbroadcasted_cov_diag.sqrt() * eps_D) + return ( + self.loc + + _batch_mv(self._unbroadcasted_cov_factor, eps_W) + + self._unbroadcasted_cov_diag.sqrt() * eps_D + ) def log_prob(self, value): if self._validate_args: self._validate_sample(value) diff = value - self.loc - M = _batch_lowrank_mahalanobis(self._unbroadcasted_cov_factor, - self._unbroadcasted_cov_diag, - diff, - self._capacitance_tril) - log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor, - self._unbroadcasted_cov_diag, - self._capacitance_tril) + M = _batch_lowrank_mahalanobis( + self._unbroadcasted_cov_factor, + self._unbroadcasted_cov_diag, + diff, + self._capacitance_tril, + ) + log_det = _batch_lowrank_logdet( + self._unbroadcasted_cov_factor, + self._unbroadcasted_cov_diag, + self._capacitance_tril, + ) return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M) def entropy(self): - log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor, - self._unbroadcasted_cov_diag, - self._capacitance_tril) + log_det = _batch_lowrank_logdet( + self._unbroadcasted_cov_factor, + self._unbroadcasted_cov_diag, + self._capacitance_tril, + ) H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det) if len(self._batch_shape) == 0: return H diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index f60ad4b5419c..b5a853797b4e 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -1,10 +1,11 @@ -import torch -from torch.distributions.distribution import Distribution -from torch.distributions import Categorical -from torch.distributions import constraints from typing import Dict -__all__ = ['MixtureSameFamily'] +import torch +from torch.distributions import Categorical, constraints +from torch.distributions.distribution import Distribution + +__all__ = ["MixtureSameFamily"] + class MixtureSameFamily(Distribution): r""" @@ -51,57 +52,66 @@ class MixtureSameFamily(Distribution): arg_constraints: Dict[str, constraints.Constraint] = {} has_rsample = False - def __init__(self, - mixture_distribution, - component_distribution, - validate_args=None): + def __init__( + self, mixture_distribution, component_distribution, validate_args=None + ): self._mixture_distribution = mixture_distribution self._component_distribution = component_distribution if not isinstance(self._mixture_distribution, Categorical): - raise ValueError(" The Mixture distribution needs to be an " - " instance of torch.distributions.Categorical") + raise ValueError( + " The Mixture distribution needs to be an " + " instance of torch.distributions.Categorical" + ) if not isinstance(self._component_distribution, Distribution): - raise ValueError("The Component distribution need to be an " - "instance of torch.distributions.Distribution") + raise ValueError( + "The Component distribution need to be an " + "instance of torch.distributions.Distribution" + ) # Check that batch size matches mdbs = self._mixture_distribution.batch_shape cdbs = self._component_distribution.batch_shape[:-1] for size1, size2 in zip(reversed(mdbs), reversed(cdbs)): if size1 != 1 and size2 != 1 and size1 != size2: - raise ValueError(f"`mixture_distribution.batch_shape` ({mdbs}) is not " - "compatible with `component_distribution." - f"batch_shape`({cdbs})") + raise ValueError( + f"`mixture_distribution.batch_shape` ({mdbs}) is not " + "compatible with `component_distribution." + f"batch_shape`({cdbs})" + ) # Check that the number of mixture component matches km = self._mixture_distribution.logits.shape[-1] kc = self._component_distribution.batch_shape[-1] if km is not None and kc is not None and km != kc: - raise ValueError(f"`mixture_distribution component` ({km}) does not" - " equal `component_distribution.batch_shape[-1]`" - f" ({kc})") + raise ValueError( + f"`mixture_distribution component` ({km}) does not" + " equal `component_distribution.batch_shape[-1]`" + f" ({kc})" + ) self._num_component = km event_shape = self._component_distribution.event_shape self._event_ndims = len(event_shape) - super().__init__(batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args) + super().__init__( + batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args + ) def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) batch_shape_comp = batch_shape + (self._num_component,) new = self._get_checked_instance(MixtureSameFamily, _instance) - new._component_distribution = \ - self._component_distribution.expand(batch_shape_comp) - new._mixture_distribution = \ - self._mixture_distribution.expand(batch_shape) + new._component_distribution = self._component_distribution.expand( + batch_shape_comp + ) + new._mixture_distribution = self._mixture_distribution.expand(batch_shape) new._num_component = self._num_component new._event_ndims = self._event_ndims event_shape = new._component_distribution.event_shape - super(MixtureSameFamily, new).__init__(batch_shape=batch_shape, - event_shape=event_shape, - validate_args=False) + super(MixtureSameFamily, new).__init__( + batch_shape=batch_shape, event_shape=event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -122,18 +132,21 @@ class MixtureSameFamily(Distribution): @property def mean(self): probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) - return torch.sum(probs * self.component_distribution.mean, - dim=-1 - self._event_ndims) # [B, E] + return torch.sum( + probs * self.component_distribution.mean, dim=-1 - self._event_ndims + ) # [B, E] @property def variance(self): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) - mean_cond_var = torch.sum(probs * self.component_distribution.variance, - dim=-1 - self._event_ndims) - var_cond_mean = torch.sum(probs * (self.component_distribution.mean - - self._pad(self.mean)).pow(2.0), - dim=-1 - self._event_ndims) + mean_cond_var = torch.sum( + probs * self.component_distribution.variance, dim=-1 - self._event_ndims + ) + var_cond_mean = torch.sum( + probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0), + dim=-1 - self._event_ndims, + ) return mean_cond_var + var_cond_mean def cdf(self, x): @@ -148,8 +161,9 @@ class MixtureSameFamily(Distribution): self._validate_sample(x) x = self._pad(x) log_prob_x = self.component_distribution.log_prob(x) # [S, B, k] - log_mix_prob = torch.log_softmax(self.mixture_distribution.logits, - dim=-1) # [B, k] + log_mix_prob = torch.log_softmax( + self.mixture_distribution.logits, dim=-1 + ) # [B, k] return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B] def sample(self, sample_shape=torch.Size()): @@ -168,9 +182,11 @@ class MixtureSameFamily(Distribution): # Gather along the k dimension mix_sample_r = mix_sample.reshape( - mix_shape + torch.Size([1] * (len(es) + 1))) + mix_shape + torch.Size([1] * (len(es) + 1)) + ) mix_sample_r = mix_sample_r.repeat( - torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es) + torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es + ) samples = torch.gather(comp_samples, gather_dim, mix_sample_r) return samples.squeeze(gather_dim) @@ -181,13 +197,18 @@ class MixtureSameFamily(Distribution): def _pad_mixture_dimensions(self, x): dist_batch_ndims = self.batch_shape.numel() cat_batch_ndims = self.mixture_distribution.batch_shape.numel() - pad_ndims = 0 if cat_batch_ndims == 1 else \ - dist_batch_ndims - cat_batch_ndims + pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims xs = x.shape - x = x.reshape(xs[:-1] + torch.Size(pad_ndims * [1]) + - xs[-1:] + torch.Size(self._event_ndims * [1])) + x = x.reshape( + xs[:-1] + + torch.Size(pad_ndims * [1]) + + xs[-1:] + + torch.Size(self._event_ndims * [1]) + ) return x def __repr__(self): - args_string = f'\n {self.mixture_distribution},\n {self.component_distribution}' - return 'MixtureSameFamily' + '(' + args_string + ')' + args_string = ( + f"\n {self.mixture_distribution},\n {self.component_distribution}" + ) + return "MixtureSameFamily" + "(" + args_string + ")" diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 579febb819a5..3f316e823a79 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -1,12 +1,12 @@ import torch from torch import inf +from torch.distributions import Categorical, constraints from torch.distributions.binomial import Binomial from torch.distributions.distribution import Distribution -from torch.distributions import Categorical -from torch.distributions import constraints from torch.distributions.utils import broadcast_all -__all__ = ['Multinomial'] +__all__ = ["Multinomial"] + class Multinomial(Distribution): r""" @@ -45,8 +45,7 @@ class Multinomial(Distribution): probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ - arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real_vector} + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} total_count: int @property @@ -59,7 +58,7 @@ class Multinomial(Distribution): def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): if not isinstance(total_count, int): - raise NotImplementedError('inhomogeneous total_count is not supported') + raise NotImplementedError("inhomogeneous total_count is not supported") self.total_count = total_count self._categorical = Categorical(probs=probs, logits=logits) self._binomial = Binomial(total_count=total_count, probs=self.probs) @@ -72,7 +71,9 @@ class Multinomial(Distribution): batch_shape = torch.Size(batch_shape) new.total_count = self.total_count new._categorical = self._categorical.expand(batch_shape) - super(Multinomial, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(Multinomial, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -97,7 +98,9 @@ class Multinomial(Distribution): def sample(self, sample_shape=torch.Size()): sample_shape = torch.Size(sample_shape) - samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape) + samples = self._categorical.sample( + torch.Size((self.total_count,)) + sample_shape + ) # samples.shape is (total_count, sample_shape, batch_shape), need to change it to # (sample_shape, batch_shape, total_count) shifted_idx = list(range(samples.dim())) diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 5354c848a957..2784eeb214d5 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -5,7 +5,7 @@ from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import _standard_normal, lazy_property -__all__ = ['MultivariateNormal'] +__all__ = ["MultivariateNormal"] def _batch_mv(bmat, bvec): @@ -42,21 +42,25 @@ def _batch_mahalanobis(bL, bx): new_batch_dims = outer_batch_dims + 2 * bL_batch_dims # Reshape bx with the shape (..., 1, i, j, 1, n) bx_new_shape = bx.shape[:outer_batch_dims] - for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]): + for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]): bx_new_shape += (sx // sL, sL) bx_new_shape += (n,) bx = bx.reshape(bx_new_shape) # Permute bx to make it have shape (..., 1, j, i, 1, n) - permute_dims = (list(range(outer_batch_dims)) + - list(range(outer_batch_dims, new_batch_dims, 2)) + - list(range(outer_batch_dims + 1, new_batch_dims, 2)) + - [new_batch_dims]) + permute_dims = ( + list(range(outer_batch_dims)) + + list(range(outer_batch_dims, new_batch_dims, 2)) + + list(range(outer_batch_dims + 1, new_batch_dims, 2)) + + [new_batch_dims] + ) bx = bx.permute(permute_dims) flat_L = bL.reshape(-1, n, n) # shape = b x n x n flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c - M_swap = torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) # shape = b x c + M_swap = ( + torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) + ) # shape = b x c M = M_swap.t() # shape = c x b # Now we revert the above reshape and permute operators. @@ -113,36 +117,59 @@ class MultivariateNormal(Distribution): :attr:`precision_matrix` is passed instead, it is only used to compute the corresponding lower triangular matrices using a Cholesky decomposition. """ - arg_constraints = {'loc': constraints.real_vector, - 'covariance_matrix': constraints.positive_definite, - 'precision_matrix': constraints.positive_definite, - 'scale_tril': constraints.lower_cholesky} + arg_constraints = { + "loc": constraints.real_vector, + "covariance_matrix": constraints.positive_definite, + "precision_matrix": constraints.positive_definite, + "scale_tril": constraints.lower_cholesky, + } support = constraints.real_vector has_rsample = True - def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None): + def __init__( + self, + loc, + covariance_matrix=None, + precision_matrix=None, + scale_tril=None, + validate_args=None, + ): if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") - if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1: - raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.") + if (covariance_matrix is not None) + (scale_tril is not None) + ( + precision_matrix is not None + ) != 1: + raise ValueError( + "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." + ) if scale_tril is not None: if scale_tril.dim() < 2: - raise ValueError("scale_tril matrix must be at least two-dimensional, " - "with optional leading batch dimensions") + raise ValueError( + "scale_tril matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) elif covariance_matrix is not None: if covariance_matrix.dim() < 2: - raise ValueError("covariance_matrix must be at least two-dimensional, " - "with optional leading batch dimensions") - batch_shape = torch.broadcast_shapes(covariance_matrix.shape[:-2], loc.shape[:-1]) + raise ValueError( + "covariance_matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + batch_shape = torch.broadcast_shapes( + covariance_matrix.shape[:-2], loc.shape[:-1] + ) self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) else: if precision_matrix.dim() < 2: - raise ValueError("precision_matrix must be at least two-dimensional, " - "with optional leading batch dimensions") - batch_shape = torch.broadcast_shapes(precision_matrix.shape[:-2], loc.shape[:-1]) + raise ValueError( + "precision_matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + batch_shape = torch.broadcast_shapes( + precision_matrix.shape[:-2], loc.shape[:-1] + ) self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) self.loc = loc.expand(batch_shape + (-1,)) @@ -163,33 +190,35 @@ class MultivariateNormal(Distribution): cov_shape = batch_shape + self.event_shape + self.event_shape new.loc = self.loc.expand(loc_shape) new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril - if 'covariance_matrix' in self.__dict__: + if "covariance_matrix" in self.__dict__: new.covariance_matrix = self.covariance_matrix.expand(cov_shape) - if 'scale_tril' in self.__dict__: + if "scale_tril" in self.__dict__: new.scale_tril = self.scale_tril.expand(cov_shape) - if 'precision_matrix' in self.__dict__: + if "precision_matrix" in self.__dict__: new.precision_matrix = self.precision_matrix.expand(cov_shape) - super(MultivariateNormal, new).__init__(batch_shape, - self.event_shape, - validate_args=False) + super(MultivariateNormal, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @lazy_property def scale_tril(self): return self._unbroadcasted_scale_tril.expand( - self._batch_shape + self._event_shape + self._event_shape) + self._batch_shape + self._event_shape + self._event_shape + ) @lazy_property def covariance_matrix(self): - return (torch.matmul(self._unbroadcasted_scale_tril, - self._unbroadcasted_scale_tril.mT) - .expand(self._batch_shape + self._event_shape + self._event_shape)) + return torch.matmul( + self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT + ).expand(self._batch_shape + self._event_shape + self._event_shape) @lazy_property def precision_matrix(self): return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand( - self._batch_shape + self._event_shape + self._event_shape) + self._batch_shape + self._event_shape + self._event_shape + ) @property def mean(self): @@ -201,8 +230,11 @@ class MultivariateNormal(Distribution): @property def variance(self): - return self._unbroadcasted_scale_tril.pow(2).sum(-1).expand( - self._batch_shape + self._event_shape) + return ( + self._unbroadcasted_scale_tril.pow(2) + .sum(-1) + .expand(self._batch_shape + self._event_shape) + ) def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) @@ -214,11 +246,15 @@ class MultivariateNormal(Distribution): self._validate_sample(value) diff = value - self.loc M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff) - half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + half_log_det = ( + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + ) return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det def entropy(self): - half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + half_log_det = ( + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + ) H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det if len(self._batch_shape) == 0: return H diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index 1fdbd85488c8..59edee589f9a 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -2,9 +2,15 @@ import torch import torch.nn.functional as F from torch.distributions import constraints from torch.distributions.distribution import Distribution -from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) + +__all__ = ["NegativeBinomial"] -__all__ = ['NegativeBinomial'] class NegativeBinomial(Distribution): r""" @@ -20,19 +26,29 @@ class NegativeBinomial(Distribution): probs (Tensor): Event probabilities of success in the half open interval [0, 1) logits (Tensor): Event log-odds for probabilities of success """ - arg_constraints = {'total_count': constraints.greater_than_eq(0), - 'probs': constraints.half_open_interval(0., 1.), - 'logits': constraints.real} + arg_constraints = { + "total_count": constraints.greater_than_eq(0), + "probs": constraints.half_open_interval(0.0, 1.0), + "logits": constraints.real, + } support = constraints.nonnegative_integer def __init__(self, total_count, probs=None, logits=None, validate_args=None): if (probs is None) == (logits is None): - raise ValueError("Either `probs` or `logits` must be specified, but not both.") + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) if probs is not None: - self.total_count, self.probs, = broadcast_all(total_count, probs) + ( + self.total_count, + self.probs, + ) = broadcast_all(total_count, probs) self.total_count = self.total_count.type_as(self.probs) else: - self.total_count, self.logits, = broadcast_all(total_count, logits) + ( + self.total_count, + self.logits, + ) = broadcast_all(total_count, logits) self.total_count = self.total_count.type_as(self.logits) self._param = self.probs if probs is not None else self.logits @@ -43,10 +59,10 @@ class NegativeBinomial(Distribution): new = self._get_checked_instance(NegativeBinomial, _instance) batch_shape = torch.Size(batch_shape) new.total_count = self.total_count.expand(batch_shape) - if 'probs' in self.__dict__: + if "probs" in self.__dict__: new.probs = self.probs.expand(batch_shape) new._param = new.probs - if 'logits' in self.__dict__: + if "logits" in self.__dict__: new.logits = self.logits.expand(batch_shape) new._param = new.logits super(NegativeBinomial, new).__init__(batch_shape, validate_args=False) @@ -62,7 +78,7 @@ class NegativeBinomial(Distribution): @property def mode(self): - return ((self.total_count - 1) * self.logits.exp()).floor().clamp(min=0.) + return ((self.total_count - 1) * self.logits.exp()).floor().clamp(min=0.0) @property def variance(self): @@ -83,9 +99,11 @@ class NegativeBinomial(Distribution): @lazy_property def _gamma(self): # Note we avoid validating because self.total_count can be zero. - return torch.distributions.Gamma(concentration=self.total_count, - rate=torch.exp(-self.logits), - validate_args=False) + return torch.distributions.Gamma( + concentration=self.total_count, + rate=torch.exp(-self.logits), + validate_args=False, + ) def sample(self, sample_shape=torch.Size()): with torch.no_grad(): @@ -96,14 +114,20 @@ class NegativeBinomial(Distribution): if self._validate_args: self._validate_sample(value) - log_unnormalized_prob = (self.total_count * F.logsigmoid(-self.logits) + - value * F.logsigmoid(self.logits)) + log_unnormalized_prob = self.total_count * F.logsigmoid( + -self.logits + ) + value * F.logsigmoid(self.logits) - log_normalization = (-torch.lgamma(self.total_count + value) + torch.lgamma(1. + value) + - torch.lgamma(self.total_count)) + log_normalization = ( + -torch.lgamma(self.total_count + value) + + torch.lgamma(1.0 + value) + + torch.lgamma(self.total_count) + ) # The case self.total_count == 0 and value == 0 has probability 1 but # lgamma(0) is infinite. Handle this case separately using a function # that does not modify tensors in place to allow Jit compilation. - log_normalization = log_normalization.masked_fill(self.total_count + value == 0., 0.) + log_normalization = log_normalization.masked_fill( + self.total_count + value == 0.0, 0.0 + ) return log_unnormalized_prob - log_normalization diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 39e41d729eeb..8234dd260759 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -1,13 +1,13 @@ import math -from numbers import Real -from numbers import Number +from numbers import Number, Real import torch from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import _standard_normal, broadcast_all -__all__ = ['Normal'] +__all__ = ["Normal"] + class Normal(ExponentialFamily): r""" @@ -26,7 +26,7 @@ class Normal(ExponentialFamily): scale (float or Tensor): standard deviation of the distribution (often referred to as sigma) """ - arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real has_rsample = True _mean_carrier_measure = 0 @@ -78,14 +78,22 @@ class Normal(ExponentialFamily): if self._validate_args: self._validate_sample(value) # compute the variance - var = (self.scale ** 2) - log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log() - return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi)) + var = self.scale**2 + log_scale = ( + math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log() + ) + return ( + -((value - self.loc) ** 2) / (2 * var) + - log_scale + - math.log(math.sqrt(2 * math.pi)) + ) def cdf(self, value): if self._validate_args: self._validate_sample(value) - return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2))) + return 0.5 * ( + 1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)) + ) def icdf(self, value): return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2) diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index 128010c4ce45..46583bc8b853 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -3,7 +3,8 @@ from torch.distributions import constraints from torch.distributions.categorical import Categorical from torch.distributions.distribution import Distribution -__all__ = ['OneHotCategorical', 'OneHotCategoricalStraightThrough'] +__all__ = ["OneHotCategorical", "OneHotCategoricalStraightThrough"] + class OneHotCategorical(Distribution): r""" @@ -34,8 +35,7 @@ class OneHotCategorical(Distribution): probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ - arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real_vector} + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.one_hot has_enumerate_support = True @@ -49,7 +49,9 @@ class OneHotCategorical(Distribution): new = self._get_checked_instance(OneHotCategorical, _instance) batch_shape = torch.Size(batch_shape) new._categorical = self._categorical.expand(batch_shape) - super(OneHotCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(OneHotCategorical, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -110,6 +112,7 @@ class OneHotCategorical(Distribution): values = values.expand((n,) + self.batch_shape + (n,)) return values + class OneHotCategoricalStraightThrough(OneHotCategorical): r""" Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight- diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index f57ccd559c63..91672ebfc221 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -4,7 +4,8 @@ from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import AffineTransform, ExpTransform from torch.distributions.utils import broadcast_all -__all__ = ['Pareto'] +__all__ = ["Pareto"] + class Pareto(TransformedDistribution): r""" @@ -21,7 +22,7 @@ class Pareto(TransformedDistribution): scale (float or Tensor): Scale parameter of the distribution alpha (float or Tensor): Shape parameter of the distribution """ - arg_constraints = {'alpha': constraints.positive, 'scale': constraints.positive} + arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive} def __init__(self, scale, alpha, validate_args=None): self.scale, self.alpha = broadcast_all(scale, alpha) @@ -56,4 +57,4 @@ class Pareto(TransformedDistribution): return constraints.greater_than_eq(self.scale) def entropy(self): - return ((self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())) + return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal()) diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index bad1d0548705..81c0898a577b 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -5,7 +5,8 @@ from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all -__all__ = ['Poisson'] +__all__ = ["Poisson"] + class Poisson(ExponentialFamily): r""" @@ -26,7 +27,7 @@ class Poisson(ExponentialFamily): Args: rate (Number, Tensor): the rate parameter """ - arg_constraints = {'rate': constraints.nonnegative} + arg_constraints = {"rate": constraints.nonnegative} support = constraints.nonnegative_integer @property @@ -42,7 +43,7 @@ class Poisson(ExponentialFamily): return self.rate def __init__(self, rate, validate_args=None): - self.rate, = broadcast_all(rate) + (self.rate,) = broadcast_all(rate) if isinstance(rate, Number): batch_shape = torch.Size() else: @@ -70,7 +71,7 @@ class Poisson(ExponentialFamily): @property def _natural_params(self): - return (torch.log(self.rate), ) + return (torch.log(self.rate),) def _log_normalizer(self, x): return torch.exp(x) diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py index 634c0131ca04..2d86c5c3a636 100644 --- a/torch/distributions/relaxed_bernoulli.py +++ b/torch/distributions/relaxed_bernoulli.py @@ -1,12 +1,20 @@ -import torch from numbers import Number + +import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import SigmoidTransform -from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property, clamp_probs +from torch.distributions.utils import ( + broadcast_all, + clamp_probs, + lazy_property, + logits_to_probs, + probs_to_logits, +) + +__all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"] -__all__ = ['LogitRelaxedBernoulli', 'RelaxedBernoulli'] class LogitRelaxedBernoulli(Distribution): r""" @@ -27,20 +35,21 @@ class LogitRelaxedBernoulli(Distribution): [2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017) """ - arg_constraints = {'probs': constraints.unit_interval, - 'logits': constraints.real} + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.real def __init__(self, temperature, probs=None, logits=None, validate_args=None): self.temperature = temperature if (probs is None) == (logits is None): - raise ValueError("Either `probs` or `logits` must be specified, but not both.") + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) if probs is not None: is_scalar = isinstance(probs, Number) - self.probs, = broadcast_all(probs) + (self.probs,) = broadcast_all(probs) else: is_scalar = isinstance(logits, Number) - self.logits, = broadcast_all(logits) + (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits if is_scalar: batch_shape = torch.Size() @@ -52,10 +61,10 @@ class LogitRelaxedBernoulli(Distribution): new = self._get_checked_instance(LogitRelaxedBernoulli, _instance) batch_shape = torch.Size(batch_shape) new.temperature = self.temperature - if 'probs' in self.__dict__: + if "probs" in self.__dict__: new.probs = self.probs.expand(batch_shape) new._param = new.probs - if 'logits' in self.__dict__: + if "logits" in self.__dict__: new.logits = self.logits.expand(batch_shape) new._param = new.logits super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False) @@ -80,8 +89,12 @@ class LogitRelaxedBernoulli(Distribution): def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) probs = clamp_probs(self.probs.expand(shape)) - uniforms = clamp_probs(torch.rand(shape, dtype=probs.dtype, device=probs.device)) - return (uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p()) / self.temperature + uniforms = clamp_probs( + torch.rand(shape, dtype=probs.dtype, device=probs.device) + ) + return ( + uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p() + ) / self.temperature def log_prob(self, value): if self._validate_args: @@ -111,8 +124,7 @@ class RelaxedBernoulli(TransformedDistribution): probs (Number, Tensor): the probability of sampling `1` logits (Number, Tensor): the log-odds of sampling `1` """ - arg_constraints = {'probs': constraints.unit_interval, - 'logits': constraints.real} + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.unit_interval has_rsample = True diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index 859078284b33..6f0f87afada6 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -1,12 +1,13 @@ import torch from torch.distributions import constraints from torch.distributions.categorical import Categorical -from torch.distributions.utils import clamp_probs, broadcast_all from torch.distributions.distribution import Distribution from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import ExpTransform +from torch.distributions.utils import broadcast_all, clamp_probs + +__all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"] -__all__ = ['ExpRelaxedCategorical', 'RelaxedOneHotCategorical'] class ExpRelaxedCategorical(Distribution): r""" @@ -30,9 +31,10 @@ class ExpRelaxedCategorical(Distribution): [2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017) """ - arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real_vector} - support = constraints.real_vector # The true support is actually a submanifold of this. + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + support = ( + constraints.real_vector + ) # The true support is actually a submanifold of this. has_rsample = True def __init__(self, temperature, probs=None, logits=None, validate_args=None): @@ -47,7 +49,9 @@ class ExpRelaxedCategorical(Distribution): batch_shape = torch.Size(batch_shape) new.temperature = self.temperature new._categorical = self._categorical.expand(batch_shape) - super(ExpRelaxedCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(ExpRelaxedCategorical, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -68,7 +72,9 @@ class ExpRelaxedCategorical(Distribution): def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - uniforms = clamp_probs(torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)) + uniforms = clamp_probs( + torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) + ) gumbels = -((-(uniforms.log())).log()) scores = (self.logits + gumbels) / self.temperature return scores - scores.logsumexp(dim=-1, keepdim=True) @@ -78,8 +84,9 @@ class ExpRelaxedCategorical(Distribution): if self._validate_args: self._validate_sample(value) logits, value = broadcast_all(self.logits, value) - log_scale = (torch.full_like(self.temperature, float(K)).lgamma() - - self.temperature.log().mul(-(K - 1))) + log_scale = torch.full_like( + self.temperature, float(K) + ).lgamma() - self.temperature.log().mul(-(K - 1)) score = logits - value.mul(self.temperature) score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1) return score + log_scale @@ -105,13 +112,14 @@ class RelaxedOneHotCategorical(TransformedDistribution): probs (Tensor): event probabilities logits (Tensor): unnormalized log probability for each event """ - arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real_vector} + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.simplex has_rsample = True def __init__(self, temperature, probs=None, logits=None, validate_args=None): - base_dist = ExpRelaxedCategorical(temperature, probs, logits, validate_args=validate_args) + base_dist = ExpRelaxedCategorical( + temperature, probs, logits, validate_args=validate_args + ) super().__init__(base_dist, ExpTransform(), validate_args=validate_args) def expand(self, batch_shape, _instance=None): diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py index 83b06c668a2f..24a146d79f2c 100644 --- a/torch/distributions/studentT.py +++ b/torch/distributions/studentT.py @@ -6,7 +6,8 @@ from torch.distributions import Chi2, constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import _standard_normal, broadcast_all -__all__ = ['StudentT'] +__all__ = ["StudentT"] + class StudentT(Distribution): r""" @@ -25,7 +26,11 @@ class StudentT(Distribution): loc (float or Tensor): mean of the distribution scale (float or Tensor): scale of the distribution """ - arg_constraints = {'df': constraints.positive, 'loc': constraints.real, 'scale': constraints.positive} + arg_constraints = { + "df": constraints.positive, + "loc": constraints.real, + "scale": constraints.positive, + } support = constraints.real has_rsample = True @@ -42,12 +47,16 @@ class StudentT(Distribution): @property def variance(self): m = self.df.clone(memory_format=torch.contiguous_format) - m[self.df > 2] = self.scale[self.df > 2].pow(2) * self.df[self.df > 2] / (self.df[self.df > 2] - 2) + m[self.df > 2] = ( + self.scale[self.df > 2].pow(2) + * self.df[self.df > 2] + / (self.df[self.df > 2] - 2) + ) m[(self.df <= 2) & (self.df > 1)] = inf m[self.df <= 1] = nan return m - def __init__(self, df, loc=0., scale=1., validate_args=None): + def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): self.df, self.loc, self.scale = broadcast_all(df, loc, scale) self._chi2 = Chi2(self.df) batch_shape = self.df.size() @@ -82,16 +91,26 @@ class StudentT(Distribution): if self._validate_args: self._validate_sample(value) y = (value - self.loc) / self.scale - Z = (self.scale.log() + - 0.5 * self.df.log() + - 0.5 * math.log(math.pi) + - torch.lgamma(0.5 * self.df) - - torch.lgamma(0.5 * (self.df + 1.))) - return -0.5 * (self.df + 1.) * torch.log1p(y**2. / self.df) - Z + Z = ( + self.scale.log() + + 0.5 * self.df.log() + + 0.5 * math.log(math.pi) + + torch.lgamma(0.5 * self.df) + - torch.lgamma(0.5 * (self.df + 1.0)) + ) + return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z def entropy(self): - lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(0.5 * (self.df + 1)) - return (self.scale.log() + - 0.5 * (self.df + 1) * - (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) + - 0.5 * self.df.log() + lbeta) + lbeta = ( + torch.lgamma(0.5 * self.df) + + math.lgamma(0.5) + - torch.lgamma(0.5 * (self.df + 1)) + ) + return ( + self.scale.log() + + 0.5 + * (self.df + 1) + * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) + + 0.5 * self.df.log() + + lbeta + ) diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index cd7b5f088a99..060909f38ad0 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -1,12 +1,14 @@ +from typing import Dict + import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.independent import Independent from torch.distributions.transforms import ComposeTransform, Transform from torch.distributions.utils import _sum_rightmost -from typing import Dict -__all__ = ['TransformedDistribution'] +__all__ = ["TransformedDistribution"] + class TransformedDistribution(Distribution): r""" @@ -45,36 +47,51 @@ class TransformedDistribution(Distribution): def __init__(self, base_distribution, transforms, validate_args=None): if isinstance(transforms, Transform): - self.transforms = [transforms, ] + self.transforms = [ + transforms, + ] elif isinstance(transforms, list): if not all(isinstance(t, Transform) for t in transforms): - raise ValueError("transforms must be a Transform or a list of Transforms") + raise ValueError( + "transforms must be a Transform or a list of Transforms" + ) self.transforms = transforms else: - raise ValueError(f"transforms must be a Transform or list, but was {transforms}") + raise ValueError( + f"transforms must be a Transform or list, but was {transforms}" + ) # Reshape base_distribution according to transforms. base_shape = base_distribution.batch_shape + base_distribution.event_shape base_event_dim = len(base_distribution.event_shape) transform = ComposeTransform(self.transforms) if len(base_shape) < transform.domain.event_dim: - raise ValueError("base_distribution needs to have shape with size at least {}, but got {}." - .format(transform.domain.event_dim, base_shape)) + raise ValueError( + "base_distribution needs to have shape with size at least {}, but got {}.".format( + transform.domain.event_dim, base_shape + ) + ) forward_shape = transform.forward_shape(base_shape) expanded_base_shape = transform.inverse_shape(forward_shape) if base_shape != expanded_base_shape: - base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim] + base_batch_shape = expanded_base_shape[ + : len(expanded_base_shape) - base_event_dim + ] base_distribution = base_distribution.expand(base_batch_shape) reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim if reinterpreted_batch_ndims > 0: - base_distribution = Independent(base_distribution, reinterpreted_batch_ndims) + base_distribution = Independent( + base_distribution, reinterpreted_batch_ndims + ) self.base_dist = base_distribution # Compute shapes. - transform_change_in_event_dim = transform.codomain.event_dim - transform.domain.event_dim + transform_change_in_event_dim = ( + transform.codomain.event_dim - transform.domain.event_dim + ) event_dim = max( transform.codomain.event_dim, # the transform is coupled - base_event_dim + transform_change_in_event_dim # the base dist is coupled + base_event_dim + transform_change_in_event_dim, # the base dist is coupled ) assert len(forward_shape) >= event_dim cut = len(forward_shape) - event_dim @@ -88,10 +105,12 @@ class TransformedDistribution(Distribution): shape = batch_shape + self.event_shape for t in reversed(self.transforms): shape = t.inverse_shape(shape) - base_batch_shape = shape[:len(shape) - len(self.base_dist.event_shape)] + base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)] new.base_dist = self.base_dist.expand(base_batch_shape) new.transforms = self.transforms - super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False) + super(TransformedDistribution, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) new._validate_args = self._validate_args return new @@ -101,7 +120,9 @@ class TransformedDistribution(Distribution): return self.base_dist.support support = self.transforms[-1].codomain if len(self.event_shape) > support.event_dim: - support = constraints.independent(support, len(self.event_shape) - support.event_dim) + support = constraints.independent( + support, len(self.event_shape) - support.event_dim + ) return support @property @@ -146,12 +167,15 @@ class TransformedDistribution(Distribution): for transform in reversed(self.transforms): x = transform.inv(y) event_dim += transform.domain.event_dim - transform.codomain.event_dim - log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y), - event_dim - transform.domain.event_dim) + log_prob = log_prob - _sum_rightmost( + transform.log_abs_det_jacobian(x, y), + event_dim - transform.domain.event_dim, + ) y = x - log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y), - event_dim - len(self.base_dist.event_shape)) + log_prob = log_prob + _sum_rightmost( + self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape) + ) return log_prob def _monotonize_cdf(self, value): diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 6745d1f6fbd5..8463609856c3 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -8,33 +8,36 @@ from typing import List import torch import torch.nn.functional as F from torch.distributions import constraints -from torch.distributions.utils import (_sum_rightmost, broadcast_all, - lazy_property, tril_matrix_to_vec, - vec_to_tril_matrix) -from torch.nn.functional import pad -from torch.nn.functional import softplus +from torch.distributions.utils import ( + _sum_rightmost, + broadcast_all, + lazy_property, + tril_matrix_to_vec, + vec_to_tril_matrix, +) +from torch.nn.functional import pad, softplus __all__ = [ - 'AbsTransform', - 'AffineTransform', - 'CatTransform', - 'ComposeTransform', - 'CorrCholeskyTransform', - 'CumulativeDistributionTransform', - 'ExpTransform', - 'IndependentTransform', - 'LowerCholeskyTransform', - 'PositiveDefiniteTransform', - 'PowerTransform', - 'ReshapeTransform', - 'SigmoidTransform', - 'SoftplusTransform', - 'TanhTransform', - 'SoftmaxTransform', - 'StackTransform', - 'StickBreakingTransform', - 'Transform', - 'identity_transform', + "AbsTransform", + "AffineTransform", + "CatTransform", + "ComposeTransform", + "CorrCholeskyTransform", + "CumulativeDistributionTransform", + "ExpTransform", + "IndependentTransform", + "LowerCholeskyTransform", + "PositiveDefiniteTransform", + "PowerTransform", + "ReshapeTransform", + "SigmoidTransform", + "SoftplusTransform", + "TanhTransform", + "SoftmaxTransform", + "StackTransform", + "StickBreakingTransform", + "Transform", + "identity_transform", ] @@ -82,6 +85,7 @@ class Transform: should be +1 or -1 depending on whether transform is monotone increasing or decreasing. """ + bijective = False domain: constraints.Constraint codomain: constraints.Constraint @@ -94,7 +98,7 @@ class Transform: elif cache_size == 1: self._cached_x_y = None, None else: - raise ValueError('cache_size must be 0 or 1') + raise ValueError("cache_size must be 0 or 1") super().__init__() def __getstate__(self): @@ -189,7 +193,7 @@ class Transform: raise NotImplementedError def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" def forward_shape(self, shape): """ @@ -211,6 +215,7 @@ class _InverseTransform(Transform): Inverts a single :class:`Transform`. This class is private; please instead use the ``Transform.inv`` property. """ + def __init__(self, transform: Transform): super().__init__(cache_size=transform._cache_size) self._inv: Transform = transform @@ -277,6 +282,7 @@ class ComposeTransform(Transform): cache_size (int): Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported. """ + def __init__(self, parts: List[Transform], cache_size=0): if cache_size: parts = [part.with_cache(cache_size) for part in parts] @@ -363,8 +369,11 @@ class ComposeTransform(Transform): terms = [] event_dim = self.domain.event_dim for part, x, y in zip(self.parts, xs[:-1], xs[1:]): - terms.append(_sum_rightmost(part.log_abs_det_jacobian(x, y), - event_dim - part.domain.event_dim)) + terms.append( + _sum_rightmost( + part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim + ) + ) event_dim += part.codomain.event_dim - part.domain.event_dim return functools.reduce(operator.add, terms) @@ -379,9 +388,9 @@ class ComposeTransform(Transform): return shape def __repr__(self): - fmt_string = self.__class__.__name__ + '(\n ' - fmt_string += ',\n '.join([p.__repr__() for p in self.parts]) - fmt_string += '\n)' + fmt_string = self.__class__.__name__ + "(\n " + fmt_string += ",\n ".join([p.__repr__() for p in self.parts]) + fmt_string += "\n)" return fmt_string @@ -401,6 +410,7 @@ class IndependentTransform(Transform): reinterpreted_batch_ndims (int): The number of extra rightmost dimensions to treat as dependent. """ + def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0): super().__init__(cache_size=cache_size) self.base_transform = base_transform.with_cache(cache_size) @@ -409,19 +419,21 @@ class IndependentTransform(Transform): def with_cache(self, cache_size=1): if self._cache_size == cache_size: return self - return IndependentTransform(self.base_transform, - self.reinterpreted_batch_ndims, - cache_size=cache_size) + return IndependentTransform( + self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size + ) @constraints.dependent_property(is_discrete=False) def domain(self): - return constraints.independent(self.base_transform.domain, - self.reinterpreted_batch_ndims) + return constraints.independent( + self.base_transform.domain, self.reinterpreted_batch_ndims + ) @constraints.dependent_property(is_discrete=False) def codomain(self): - return constraints.independent(self.base_transform.codomain, - self.reinterpreted_batch_ndims) + return constraints.independent( + self.base_transform.codomain, self.reinterpreted_batch_ndims + ) @property def bijective(self): @@ -467,6 +479,7 @@ class ReshapeTransform(Transform): in_shape (torch.Size): The input event shape. out_shape (torch.Size): The output event shape. """ + bijective = True def __init__(self, in_shape, out_shape, cache_size=0): @@ -490,15 +503,15 @@ class ReshapeTransform(Transform): return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size) def _call(self, x): - batch_shape = x.shape[:x.dim() - len(self.in_shape)] + batch_shape = x.shape[: x.dim() - len(self.in_shape)] return x.reshape(batch_shape + self.out_shape) def _inverse(self, y): - batch_shape = y.shape[:y.dim() - len(self.out_shape)] + batch_shape = y.shape[: y.dim() - len(self.out_shape)] return y.reshape(batch_shape + self.in_shape) def log_abs_det_jacobian(self, x, y): - batch_shape = x.shape[:x.dim() - len(self.in_shape)] + batch_shape = x.shape[: x.dim() - len(self.in_shape)] return x.new_zeros(batch_shape) def forward_shape(self, shape): @@ -506,7 +519,9 @@ class ReshapeTransform(Transform): raise ValueError("Too few dimensions on input") cut = len(shape) - len(self.in_shape) if shape[cut:] != self.in_shape: - raise ValueError(f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}") + raise ValueError( + f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}" + ) return shape[:cut] + self.out_shape def inverse_shape(self, shape): @@ -514,7 +529,9 @@ class ReshapeTransform(Transform): raise ValueError("Too few dimensions on input") cut = len(shape) - len(self.out_shape) if shape[cut:] != self.out_shape: - raise ValueError(f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}") + raise ValueError( + f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}" + ) return shape[:cut] + self.in_shape @@ -551,7 +568,7 @@ class PowerTransform(Transform): def __init__(self, exponent, cache_size=0): super().__init__(cache_size=cache_size) - self.exponent, = broadcast_all(exponent) + (self.exponent,) = broadcast_all(exponent) def with_cache(self, cache_size=1): if self._cache_size == cache_size: @@ -581,7 +598,7 @@ class PowerTransform(Transform): def _clipped_sigmoid(x): finfo = torch.finfo(x.dtype) - return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1. - finfo.eps) + return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps) class SigmoidTransform(Transform): @@ -601,7 +618,7 @@ class SigmoidTransform(Transform): def _inverse(self, y): finfo = torch.finfo(y.dtype) - y = y.clamp(min=finfo.tiny, max=1. - finfo.eps) + y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps) return y.log() - (-y).log1p() def log_abs_det_jacobian(self, x, y): @@ -664,7 +681,7 @@ class TanhTransform(Transform): def log_abs_det_jacobian(self, x, y): # We use a formula that is more numerically stable, see details in the following link # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80 - return 2. * (math.log(2.) - x - softplus(-2. * x)) + return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x)) class AbsTransform(Transform): @@ -722,20 +739,26 @@ class AffineTransform(Transform): def with_cache(self, cache_size=1): if self._cache_size == cache_size: return self - return AffineTransform(self.loc, self.scale, self.event_dim, cache_size=cache_size) + return AffineTransform( + self.loc, self.scale, self.event_dim, cache_size=cache_size + ) def __eq__(self, other): if not isinstance(other, AffineTransform): return False - if isinstance(self.loc, numbers.Number) and isinstance(other.loc, numbers.Number): + if isinstance(self.loc, numbers.Number) and isinstance( + other.loc, numbers.Number + ): if self.loc != other.loc: return False else: if not (self.loc == other.loc).all().item(): return False - if isinstance(self.scale, numbers.Number) and isinstance(other.scale, numbers.Number): + if isinstance(self.scale, numbers.Number) and isinstance( + other.scale, numbers.Number + ): if self.scale != other.scale: return False else: @@ -764,20 +787,20 @@ class AffineTransform(Transform): else: result = torch.abs(scale).log() if self.event_dim: - result_size = result.size()[:-self.event_dim] + (-1,) + result_size = result.size()[: -self.event_dim] + (-1,) result = result.view(result_size).sum(-1) - shape = shape[:-self.event_dim] + shape = shape[: -self.event_dim] return result.expand(shape) def forward_shape(self, shape): - return torch.broadcast_shapes(shape, - getattr(self.loc, "shape", ()), - getattr(self.scale, "shape", ())) + return torch.broadcast_shapes( + shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) + ) def inverse_shape(self, shape): - return torch.broadcast_shapes(shape, - getattr(self.loc, "shape", ()), - getattr(self.scale, "shape", ())) + return torch.broadcast_shapes( + shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) + ) class CorrCholeskyTransform(Transform): @@ -808,7 +831,7 @@ class CorrCholeskyTransform(Transform): # apply stick-breaking on the squared values # Note that y = sign(r) * sqrt(z * z1m_cumprod) # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) - z = r ** 2 + z = r**2 z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) # Diagonal elements must be 1. r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device) @@ -838,7 +861,7 @@ class CorrCholeskyTransform(Transform): # also works for 2 x 2 matrix y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2) stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1) - tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.)).sum(dim=-1) + tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1) return stick_breaking_logdet + tanh_logdet def forward_shape(self, shape): @@ -910,6 +933,7 @@ class StickBreakingTransform(Transform): This is bijective and appropriate for use in HMC; however it mixes coordinates together and is less appropriate for optimization. """ + domain = constraints.real_vector codomain = constraints.simplex bijective = True @@ -960,6 +984,7 @@ class LowerCholeskyTransform(Transform): This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization. """ + domain = constraints.independent(constraints.real, 2) codomain = constraints.lower_cholesky @@ -977,6 +1002,7 @@ class PositiveDefiniteTransform(Transform): """ Transform from unconstrained matrices to positive-definite matrices. """ + domain = constraints.independent(constraints.real, 2) codomain = constraints.positive_definite # type: ignore[assignment] @@ -1006,6 +1032,7 @@ class CatTransform(Transform): t = CatTransform([t0, t0], dim=0, lengths=[20, 20]) y = t(x) """ + transforms: List[Transform] def __init__(self, tseq, dim=0, lengths=None, cache_size=0): @@ -1086,13 +1113,15 @@ class CatTransform(Transform): @constraints.dependent_property def domain(self): - return constraints.cat([t.domain for t in self.transforms], - self.dim, self.lengths) + return constraints.cat( + [t.domain for t in self.transforms], self.dim, self.lengths + ) @constraints.dependent_property def codomain(self): - return constraints.cat([t.codomain for t in self.transforms], - self.dim, self.lengths) + return constraints.cat( + [t.codomain for t in self.transforms], self.dim, self.lengths + ) class StackTransform(Transform): @@ -1107,6 +1136,7 @@ class StackTransform(Transform): t = StackTransform([ExpTransform(), identity_transform], dim=1) y = t(x) """ + transforms: List[Transform] def __init__(self, tseq, dim=0, cache_size=0): @@ -1184,6 +1214,7 @@ class CumulativeDistributionTransform(Transform): transform = CumulativeDistributionTransform(Normal(0, 1)) copula = TransformedDistribution(base_dist, [transform]) """ + bijective = True codomain = constraints.unit_interval sign = +1 diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index cbbd8d1ed28d..e939bb4aae39 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -6,7 +6,8 @@ from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all -__all__ = ['Uniform'] +__all__ = ["Uniform"] + class Uniform(Distribution): r""" @@ -25,8 +26,10 @@ class Uniform(Distribution): high (float or Tensor): upper range (exclusive). """ # TODO allow (loc,scale) parameterization to allow independent constraints. - arg_constraints = {'low': constraints.dependent(is_discrete=False, event_dim=0), - 'high': constraints.dependent(is_discrete=False, event_dim=0)} + arg_constraints = { + "low": constraints.dependent(is_discrete=False, event_dim=0), + "high": constraints.dependent(is_discrete=False, event_dim=0), + } has_rsample = True @property diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index a73d41ef0ae5..7a6d31a05722 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -1,14 +1,23 @@ from functools import update_wrapper from numbers import Number +from typing import Any, Dict + import torch import torch.nn.functional as F -from typing import Dict, Any from torch.overrides import is_tensor_like euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant -__all__ = ["broadcast_all", "logits_to_probs", "clamp_probs", "probs_to_logits", "lazy_property", - "tril_matrix_to_vec", "vec_to_tril_matrix"] +__all__ = [ + "broadcast_all", + "logits_to_probs", + "clamp_probs", + "probs_to_logits", + "lazy_property", + "tril_matrix_to_vec", + "vec_to_tril_matrix", +] + def broadcast_all(*values): r""" @@ -26,18 +35,20 @@ def broadcast_all(*values): ValueError: if any of the values is not a `numbers.Number` instance, a `torch.*Tensor` instance, or an instance implementing __torch_function__ """ - if not all(is_tensor_like(v) or isinstance(v, Number) - for v in values): - raise ValueError('Input arguments must all be instances of numbers.Number, ' - 'torch.Tensor or objects implementing __torch_function__.') + if not all(is_tensor_like(v) or isinstance(v, Number) for v in values): + raise ValueError( + "Input arguments must all be instances of numbers.Number, " + "torch.Tensor or objects implementing __torch_function__." + ) if not all(is_tensor_like(v) for v in values): options: Dict[str, Any] = dict(dtype=torch.get_default_dtype()) for value in values: if isinstance(value, torch.Tensor): options = dict(dtype=value.dtype, device=value.device) break - new_values = [v if is_tensor_like(v) else torch.tensor(v, **options) - for v in values] + new_values = [ + v if is_tensor_like(v) else torch.tensor(v, **options) for v in values + ] return torch.broadcast_tensors(*new_values) return torch.broadcast_tensors(*values) @@ -45,8 +56,10 @@ def broadcast_all(*values): def _standard_normal(shape, dtype, device): if torch._C._get_tracing_state(): # [JIT WORKAROUND] lack of support for .normal_() - return torch.normal(torch.zeros(shape, dtype=dtype, device=device), - torch.ones(shape, dtype=dtype, device=device)) + return torch.normal( + torch.zeros(shape, dtype=dtype, device=device), + torch.ones(shape, dtype=dtype, device=device), + ) return torch.empty(shape, dtype=dtype, device=device).normal_() @@ -101,6 +114,7 @@ class lazy_property: first call; thereafter replacing the wrapped method into an instance attribute. """ + def __init__(self, wrapped): self.wrapped = wrapped update_wrapper(self, wrapped) @@ -120,6 +134,7 @@ class _lazy_property_and_property(lazy_property, property): * property when Sphinx autodoc looks * lazy_property when Distribution validate_args looks """ + def __init__(self, wrapped): property.__init__(self, wrapped) @@ -131,7 +146,7 @@ def tril_matrix_to_vec(mat: torch.Tensor, diag: int = 0) -> torch.Tensor: """ n = mat.shape[-1] if not torch._C._get_tracing_state() and (diag < -n or diag >= n): - raise ValueError(f'diag ({diag}) provided is outside [{-n}, {n-1}].') + raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n-1}].") arange = torch.arange(n, device=mat.device) tril_mask = arange < arange.view(-1, 1) + (diag + 1) vec = mat[..., tril_mask] @@ -144,11 +159,16 @@ def vec_to_tril_matrix(vec: torch.Tensor, diag: int = 0) -> torch.Tensor: lower triangular matrix containing elements from the vector in row order. """ # +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0 - n = (-(1 + 2 * diag) + ((1 + 2 * diag)**2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1))**0.5) / 2 + n = ( + -(1 + 2 * diag) + + ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5 + ) / 2 eps = torch.finfo(vec.dtype).eps if not torch._C._get_tracing_state() and (round(n) - n > eps): - raise ValueError(f'The size of last dimension is {vec.shape[-1]} which cannot be expressed as ' + - 'the lower triangular part of a square D x D matrix.') + raise ValueError( + f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as " + + "the lower triangular part of a square D x D matrix." + ) n = round(n.item()) if isinstance(n, torch.Tensor) else round(n) mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n))) arange = torch.arange(n, device=vec.device) diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index 30457d7de715..f42f400869a5 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -6,7 +6,7 @@ from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all, lazy_property -__all__ = ['VonMises'] +__all__ = ["VonMises"] def _eval_poly(y, coef): @@ -17,12 +17,46 @@ def _eval_poly(y, coef): return result -_I0_COEF_SMALL = [1.0, 3.5156229, 3.0899424, 1.2067492, 0.2659732, 0.360768e-1, 0.45813e-2] -_I0_COEF_LARGE = [0.39894228, 0.1328592e-1, 0.225319e-2, -0.157565e-2, 0.916281e-2, - -0.2057706e-1, 0.2635537e-1, -0.1647633e-1, 0.392377e-2] -_I1_COEF_SMALL = [0.5, 0.87890594, 0.51498869, 0.15084934, 0.2658733e-1, 0.301532e-2, 0.32411e-3] -_I1_COEF_LARGE = [0.39894228, -0.3988024e-1, -0.362018e-2, 0.163801e-2, -0.1031555e-1, - 0.2282967e-1, -0.2895312e-1, 0.1787654e-1, -0.420059e-2] +_I0_COEF_SMALL = [ + 1.0, + 3.5156229, + 3.0899424, + 1.2067492, + 0.2659732, + 0.360768e-1, + 0.45813e-2, +] +_I0_COEF_LARGE = [ + 0.39894228, + 0.1328592e-1, + 0.225319e-2, + -0.157565e-2, + 0.916281e-2, + -0.2057706e-1, + 0.2635537e-1, + -0.1647633e-1, + 0.392377e-2, +] +_I1_COEF_SMALL = [ + 0.5, + 0.87890594, + 0.51498869, + 0.15084934, + 0.2658733e-1, + 0.301532e-2, + 0.32411e-3, +] +_I1_COEF_LARGE = [ + 0.39894228, + -0.3988024e-1, + -0.362018e-2, + 0.163801e-2, + -0.1031555e-1, + 0.2282967e-1, + -0.2895312e-1, + 0.1787654e-1, + -0.420059e-2, +] _COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL] _COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE] @@ -36,7 +70,7 @@ def _log_modified_bessel_fn(x, order=0): assert order == 0 or order == 1 # compute small solution - y = (x / 3.75) + y = x / 3.75 y = y * y small = _eval_poly(y, _COEF_SMALL[order]) if order == 1: @@ -84,7 +118,8 @@ class VonMises(Distribution): :param torch.Tensor loc: an angle in radians. :param torch.Tensor concentration: concentration parameter """ - arg_constraints = {'loc': constraints.real, 'concentration': constraints.positive} + + arg_constraints = {"loc": constraints.real, "concentration": constraints.positive} support = constraints.real has_rsample = False @@ -94,9 +129,9 @@ class VonMises(Distribution): event_shape = torch.Size() # Parameters for sampling - tau = 1 + (1 + 4 * self.concentration ** 2).sqrt() + tau = 1 + (1 + 4 * self.concentration**2).sqrt() rho = (tau - (2 * tau).sqrt()) / (2 * self.concentration) - self._proposal_r = (1 + rho ** 2) / (2 * rho) + self._proposal_r = (1 + rho**2) / (2 * rho) super().__init__(batch_shape, event_shape, validate_args) @@ -104,7 +139,11 @@ class VonMises(Distribution): if self._validate_args: self._validate_sample(value) log_prob = self.concentration * torch.cos(value - self.loc) - log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0) + log_prob = ( + log_prob + - math.log(2 * math.pi) + - _log_modified_bessel_fn(self.concentration, order=0) + ) return log_prob @torch.no_grad() @@ -122,7 +161,7 @@ class VonMises(Distribution): try: return super().expand(batch_shape) except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') + validate_args = self.__dict__.get("_validate_args") loc = self.loc.expand(batch_shape) concentration = self.concentration.expand(batch_shape) return type(self)(loc, concentration, validate_args=validate_args) @@ -143,5 +182,10 @@ class VonMises(Distribution): """ The provided variance is the circular one. """ - return 1 - (_log_modified_bessel_fn(self.concentration, order=1) - - _log_modified_bessel_fn(self.concentration, order=0)).exp() + return ( + 1 + - ( + _log_modified_bessel_fn(self.concentration, order=1) + - _log_modified_bessel_fn(self.concentration, order=0) + ).exp() + ) diff --git a/torch/distributions/weibull.py b/torch/distributions/weibull.py index 6d8b16c448f7..2cef7f8550bc 100644 --- a/torch/distributions/weibull.py +++ b/torch/distributions/weibull.py @@ -1,12 +1,13 @@ import torch from torch.distributions import constraints from torch.distributions.exponential import Exponential +from torch.distributions.gumbel import euler_constant from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import AffineTransform, PowerTransform from torch.distributions.utils import broadcast_all -from torch.distributions.gumbel import euler_constant -__all__ = ['Weibull'] +__all__ = ["Weibull"] + class Weibull(TransformedDistribution): r""" @@ -23,15 +24,22 @@ class Weibull(TransformedDistribution): scale (float or Tensor): Scale parameter of distribution (lambda). concentration (float or Tensor): Concentration parameter of distribution (k/shape). """ - arg_constraints = {'scale': constraints.positive, 'concentration': constraints.positive} + arg_constraints = { + "scale": constraints.positive, + "concentration": constraints.positive, + } support = constraints.positive def __init__(self, scale, concentration, validate_args=None): self.scale, self.concentration = broadcast_all(scale, concentration) self.concentration_reciprocal = self.concentration.reciprocal() - base_dist = Exponential(torch.ones_like(self.scale), validate_args=validate_args) - transforms = [PowerTransform(exponent=self.concentration_reciprocal), - AffineTransform(loc=0, scale=self.scale)] + base_dist = Exponential( + torch.ones_like(self.scale), validate_args=validate_args + ) + transforms = [ + PowerTransform(exponent=self.concentration_reciprocal), + AffineTransform(loc=0, scale=self.scale), + ] super().__init__(base_dist, transforms, validate_args=validate_args) def expand(self, batch_shape, _instance=None): @@ -40,11 +48,11 @@ class Weibull(TransformedDistribution): new.concentration = self.concentration.expand(batch_shape) new.concentration_reciprocal = new.concentration.reciprocal() base_dist = self.base_dist.expand(batch_shape) - transforms = [PowerTransform(exponent=new.concentration_reciprocal), - AffineTransform(loc=0, scale=new.scale)] - super(Weibull, new).__init__(base_dist, - transforms, - validate_args=False) + transforms = [ + PowerTransform(exponent=new.concentration_reciprocal), + AffineTransform(loc=0, scale=new.scale), + ] + super(Weibull, new).__init__(base_dist, transforms, validate_args=False) new._validate_args = self._validate_args return new @@ -54,13 +62,22 @@ class Weibull(TransformedDistribution): @property def mode(self): - return self.scale * ((self.concentration - 1) / self.concentration) ** self.concentration.reciprocal() + return ( + self.scale + * ((self.concentration - 1) / self.concentration) + ** self.concentration.reciprocal() + ) @property def variance(self): - return self.scale.pow(2) * (torch.exp(torch.lgamma(1 + 2 * self.concentration_reciprocal)) - - torch.exp(2 * torch.lgamma(1 + self.concentration_reciprocal))) + return self.scale.pow(2) * ( + torch.exp(torch.lgamma(1 + 2 * self.concentration_reciprocal)) + - torch.exp(2 * torch.lgamma(1 + self.concentration_reciprocal)) + ) def entropy(self): - return euler_constant * (1 - self.concentration_reciprocal) + \ - torch.log(self.scale * self.concentration_reciprocal) + 1 + return ( + euler_constant * (1 - self.concentration_reciprocal) + + torch.log(self.scale * self.concentration_reciprocal) + + 1 + ) diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index 0277d2299719..733efbbeb95f 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -7,11 +7,11 @@ import torch from torch import nan from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily -from torch.distributions.utils import lazy_property from torch.distributions.multivariate_normal import _precision_to_scale_tril +from torch.distributions.utils import lazy_property -__all__ = ['Wishart'] +__all__ = ["Wishart"] _log_2 = math.log(2) @@ -23,10 +23,12 @@ def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor: - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,)) ).sum(-1) + def _clamp_above_eps(x: torch.Tensor) -> torch.Tensor: # We assume positive input for this function return x.clamp(min=torch.finfo(x.dtype).eps) + class Wishart(ExponentialFamily): r""" Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`, @@ -61,28 +63,37 @@ class Wishart(ExponentialFamily): [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`. """ arg_constraints = { - 'covariance_matrix': constraints.positive_definite, - 'precision_matrix': constraints.positive_definite, - 'scale_tril': constraints.lower_cholesky, - 'df': constraints.greater_than(0), + "covariance_matrix": constraints.positive_definite, + "precision_matrix": constraints.positive_definite, + "scale_tril": constraints.lower_cholesky, + "df": constraints.greater_than(0), } support = constraints.positive_definite has_rsample = True _mean_carrier_measure = 0 - def __init__(self, - df: Union[torch.Tensor, Number], - covariance_matrix: Optional[torch.Tensor] = None, - precision_matrix: Optional[torch.Tensor] = None, - scale_tril: Optional[torch.Tensor] = None, - validate_args=None): - assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \ - "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." + def __init__( + self, + df: Union[torch.Tensor, Number], + covariance_matrix: Optional[torch.Tensor] = None, + precision_matrix: Optional[torch.Tensor] = None, + scale_tril: Optional[torch.Tensor] = None, + validate_args=None, + ): + assert (covariance_matrix is not None) + (scale_tril is not None) + ( + precision_matrix is not None + ) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." - param = next(p for p in (covariance_matrix, precision_matrix, scale_tril) if p is not None) + param = next( + p + for p in (covariance_matrix, precision_matrix, scale_tril) + if p is not None + ) if param.dim() < 2: - raise ValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions") + raise ValueError( + "scale_tril must be at least two-dimensional, with optional leading batch dimensions" + ) if isinstance(df, Number): batch_shape = torch.Size(param.shape[:-2]) @@ -93,7 +104,9 @@ class Wishart(ExponentialFamily): event_shape = param.shape[-2:] if self.df.le(event_shape[-1] - 1).any(): - raise ValueError(f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}.") + raise ValueError( + f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}." + ) if scale_tril is not None: self.scale_tril = param.expand(batch_shape + (-1, -1)) @@ -102,9 +115,11 @@ class Wishart(ExponentialFamily): elif precision_matrix is not None: self.precision_matrix = param.expand(batch_shape + (-1, -1)) - self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1) + self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1) if self.df.lt(event_shape[-1]).any(): - warnings.warn("Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim.") + warnings.warn( + "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim." + ) super().__init__(batch_shape, event_shape, validate_args=validate_args) self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))] @@ -137,11 +152,11 @@ class Wishart(ExponentialFamily): new._batch_dims = [-(x + 1) for x in range(len(batch_shape))] - if 'covariance_matrix' in self.__dict__: + if "covariance_matrix" in self.__dict__: new.covariance_matrix = self.covariance_matrix.expand(cov_shape) - if 'scale_tril' in self.__dict__: + if "scale_tril" in self.__dict__: new.scale_tril = self.scale_tril.expand(cov_shape) - if 'precision_matrix' in self.__dict__: + if "precision_matrix" in self.__dict__: new.precision_matrix = self.precision_matrix.expand(cov_shape) # Chi2 distribution is needed for Bartlett decomposition sampling @@ -163,12 +178,14 @@ class Wishart(ExponentialFamily): @lazy_property def scale_tril(self): return self._unbroadcasted_scale_tril.expand( - self._batch_shape + self._event_shape) + self._batch_shape + self._event_shape + ) @lazy_property def covariance_matrix(self): return ( - self._unbroadcasted_scale_tril @ self._unbroadcasted_scale_tril.transpose(-2, -1) + self._unbroadcasted_scale_tril + @ self._unbroadcasted_scale_tril.transpose(-2, -1) ).expand(self._batch_shape + self._event_shape) @lazy_property @@ -178,9 +195,9 @@ class Wishart(ExponentialFamily): device=self._unbroadcasted_scale_tril.device, dtype=self._unbroadcasted_scale_tril.dtype, ) - return torch.cholesky_solve( - identity, self._unbroadcasted_scale_tril - ).expand(self._batch_shape + self._event_shape) + return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand( + self._batch_shape + self._event_shape + ) @property def mean(self): @@ -192,12 +209,13 @@ class Wishart(ExponentialFamily): factor[factor <= 0] = nan return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix - @property def variance(self): V = self.covariance_matrix # has shape (batch_shape x event_shape) diag_V = V.diagonal(dim1=-2, dim2=-1) - return self.df.view(self._batch_shape + (1, 1)) * (V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V)) + return self.df.view(self._batch_shape + (1, 1)) * ( + V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V) + ) def _bartlett_sampling(self, sample_shape=torch.Size()): p = self._event_shape[-1] # has singleton shape @@ -272,10 +290,19 @@ class Wishart(ExponentialFamily): nu = self.df # has shape (batch_shape) p = self._event_shape[-1] # has singleton shape return ( - - nu * (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)) + -nu + * ( + p * _log_2 / 2 + + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) + .log() + .sum(-1) + ) - torch.mvlgamma(nu / 2, p=p) + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet - - torch.cholesky_solve(value, self._unbroadcasted_scale_tril).diagonal(dim1=-2, dim2=-1).sum(dim=-1) / 2 + - torch.cholesky_solve(value, self._unbroadcasted_scale_tril) + .diagonal(dim1=-2, dim2=-1) + .sum(dim=-1) + / 2 ) def entropy(self): @@ -283,7 +310,13 @@ class Wishart(ExponentialFamily): p = self._event_shape[-1] # has singleton shape V = self.covariance_matrix # has shape (batch_shape x event_shape) return ( - (p + 1) * (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)) + (p + 1) + * ( + p * _log_2 / 2 + + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) + .log() + .sum(-1) + ) + torch.mvlgamma(nu / 2, p=p) - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p) + nu * p / 2 @@ -293,11 +326,10 @@ class Wishart(ExponentialFamily): def _natural_params(self): nu = self.df # has shape (batch_shape) p = self._event_shape[-1] # has singleton shape - return - self.precision_matrix / 2, (nu - p - 1) / 2 + return -self.precision_matrix / 2, (nu - p - 1) / 2 def _log_normalizer(self, x, y): p = self._event_shape[-1] - return ( - (y + (p + 1) / 2) * (- torch.linalg.slogdet(- 2 * x).logabsdet + _log_2 * p) - + torch.mvlgamma(y + (p + 1) / 2, p=p) - ) + return (y + (p + 1) / 2) * ( + -torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p + ) + torch.mvlgamma(y + (p + 1) / 2, p=p) diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 3cdad2e2c9cb..67c545dcc028 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1,93 +1,99 @@ -import torch._C - -from contextlib import contextmanager -from typing import Iterator, Any import warnings -from torch.utils import set_module +from contextlib import contextmanager +from typing import Any, Iterator + +import torch._C # These are imported so users can access them from the `torch.jit` module from torch._jit_internal import ( - Final, - Future, _Await, _drop, _IgnoreContextManager, + _isinstance, _overload, _overload_method, - ignore, - _isinstance, - is_scripting, export, + Final, + Future, + ignore, + is_scripting, unused, ) +from torch.jit._async import fork, wait +from torch.jit._await import _awaitable, _awaitable_nowait, _awaitable_wait +from torch.jit._decomposition_utils import _register_decomposition +from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations +from torch.jit._fuser import ( + fuser, + last_executed_optimized_graph, + optimized_execution, + set_fusion_strategy, +) +from torch.jit._ir_utils import _InsertPoint from torch.jit._script import ( - script, - Attribute, - ScriptModule, - script_method, - RecursiveScriptClass, - RecursiveScriptModule, - ScriptWarning, - interface, - CompilationUnit, - ScriptFunction, _ScriptProfile, _unwrap_optional, + Attribute, + CompilationUnit, + interface, + RecursiveScriptClass, + RecursiveScriptModule, + script, + script_method, + ScriptFunction, + ScriptModule, + ScriptWarning, +) +from torch.jit._serialization import ( + jit_module_from_flatbuffer, + load, + save, + save_jit_module_to_flatbuffer, ) from torch.jit._trace import ( + _flatten, + _get_trace_graph, + _script_if_tracing, + _unique_state_dict, + is_tracing, + ONNXTracedModule, + TopLevelTracedModule, trace, trace_module, TracedModule, TracerWarning, TracingCheckError, - is_tracing, - ONNXTracedModule, - TopLevelTracedModule, - _unique_state_dict, - _flatten, - _script_if_tracing, - _get_trace_graph, ) -from torch.jit._async import fork, wait -from torch.jit._await import _awaitable, _awaitable_wait, _awaitable_nowait -from torch.jit._decomposition_utils import _register_decomposition -from torch.jit._serialization import ( - save, - load, - jit_module_from_flatbuffer, - save_jit_module_to_flatbuffer, -) -from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph, set_fusion_strategy -from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations -from torch.jit._ir_utils import _InsertPoint + +from torch.utils import set_module __all__ = [ - 'Attribute', - 'CompilationUnit', - 'Error', - 'Future', - 'ScriptFunction', - 'ScriptModule', - 'annotate', - 'enable_onednn_fusion', - 'export_opnames', - 'fork', - 'freeze', - 'ignore', - 'isinstance', - 'load', - 'onednn_fusion_enabled', - 'optimize_for_inference', - 'save', - 'script', - 'script_if_tracing', - 'set_fusion_strategy', - 'strict_fusion', - 'trace', - 'trace_module', - 'unused', - 'wait' + "Attribute", + "CompilationUnit", + "Error", + "Future", + "ScriptFunction", + "ScriptModule", + "annotate", + "enable_onednn_fusion", + "export_opnames", + "fork", + "freeze", + "ignore", + "isinstance", + "load", + "onednn_fusion_enabled", + "optimize_for_inference", + "save", + "script", + "script_if_tracing", + "set_fusion_strategy", + "strict_fusion", + "trace", + "trace_module", + "unused", + "wait", ] # For backwards compatibility @@ -98,10 +104,10 @@ _set_fusion_strategy = set_fusion_strategy def export_opnames(m): r""" - Generates new bytecode for a Script module and returns what the op list - would be for a Script Module based off the current code base. If you - have a LiteScriptModule and want to get the currently present - list of ops call _export_operator_list instead. + Generates new bytecode for a Script module and returns what the op list + would be for a Script Module based off the current code base. If you + have a LiteScriptModule and want to get the currently present + list of ops call _export_operator_list instead. """ return torch._C._export_opnames(m._c) @@ -113,6 +119,7 @@ set_module(Error, "torch.jit") Error.__name__ = "Error" Error.__qualname__ = "Error" + # for use in python if using annotate def annotate(the_type, the_value): """ @@ -224,6 +231,7 @@ def isinstance(obj, target_type): """ return _isinstance(obj, target_type) + class strict_fusion: """ This class errors if not all nodes have been fused in @@ -253,6 +261,7 @@ class strict_fusion: def __exit__(self, type: Any, value: Any, tb: Any) -> None: pass + # Context manager for globally hiding source ranges when printing graphs. # Note that these functions are exposed to Python as static members of the # Graph class, so mypy checks need to be skipped. @@ -265,6 +274,7 @@ def _hide_source_ranges() -> Iterator[None]: finally: torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined] + def enable_onednn_fusion(enabled: bool): """ Enables or disables onednn JIT fusion based on the parameter `enabled`. @@ -272,12 +282,14 @@ def enable_onednn_fusion(enabled: bool): torch._C._jit_set_llga_enabled(enabled) + def onednn_fusion_enabled(): """ Returns whether onednn JIT fusion is enabled """ return torch._C._jit_llga_enabled() + del Any if not torch._C._jit_init(): diff --git a/torch/jit/_async.py b/torch/jit/_async.py index e4dbce6dca4a..9fdadead5382 100644 --- a/torch/jit/_async.py +++ b/torch/jit/_async.py @@ -8,10 +8,10 @@ functionalities in `torch.jit`. """ import torch +from torch._jit_internal import Future +from torch.jit._builtins import _register_builtin from torch.utils import set_module -from torch.jit._builtins import _register_builtin -from torch._jit_internal import Future set_module(Future, "torch.jit") diff --git a/torch/jit/_await.py b/torch/jit/_await.py index d0df60d72405..8814726ff8ca 100644 --- a/torch/jit/_await.py +++ b/torch/jit/_await.py @@ -1,11 +1,12 @@ import torch +from torch._jit_internal import _Await +from torch.jit._builtins import _register_builtin from torch.utils import set_module -from torch.jit._builtins import _register_builtin -from torch._jit_internal import _Await set_module(_Await, "torch.jit") + def _awaitable(func, *args, **kwargs): r""" Creates Await object that will call specified functioni with specified args, @@ -13,6 +14,7 @@ def _awaitable(func, *args, **kwargs): """ return torch._C._awaitable(func, *args, **kwargs) + def _awaitable_wait(aw): r""" Requests await the result of execution, if Await is not completed yet, @@ -20,6 +22,7 @@ def _awaitable_wait(aw): """ return torch._C._awaitable_wait(aw) + def _awaitable_nowait(o): r""" Creates completed Await with specified result. diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index c070b7a566a9..f50e1bbfedb5 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -1,14 +1,14 @@ -import math import cmath +import math import warnings +from collections import OrderedDict +from typing import Dict, Optional + import torch import torch.backends.cudnn as cudnn -from ..nn.modules.utils import _single, _pair, _triple, _quadruple, _list_with_default - -from collections import OrderedDict -from typing import Dict, Optional +from ..nn.modules.utils import _list_with_default, _pair, _quadruple, _single, _triple _builtin_table: Optional[Dict[int, str]] = None @@ -112,19 +112,32 @@ _builtin_ops = [ # in these cases, we want to resolve the function to their python implementation # instead looking up a builtin "aten::" schema + def _gen_torch_functional_registered_ops(): # eventually ops should encompass all of torch/functional.py, (torch.functional.__all__) # but we are currently only able to compile some of the functions. additionally, # some functions directly map to their aten:: implementations. # TODO: add support for more ops - ops = ["stft", "istft", "lu", "cdist", "norm", "unique", "unique_consecutive", "tensordot"] + ops = [ + "stft", + "istft", + "lu", + "cdist", + "norm", + "unique", + "unique_consecutive", + "tensordot", + ] return {getattr(torch.functional, name) for name in ops} + _functional_registered_ops = _gen_torch_functional_registered_ops() + def _is_special_functional_bound_op(fn): return fn in _functional_registered_ops + # lazily built to ensure the correct initialization order def _get_builtin_table(): global _builtin_table @@ -135,11 +148,17 @@ def _get_builtin_table(): def register_all(mod): for name in dir(mod): v = getattr(mod, name) - if callable(v) and not _is_special_functional_bound_op(v) and v is not torch.no_grad and v is not torch.autocast: + if ( + callable(v) + and not _is_special_functional_bound_op(v) + and v is not torch.no_grad + and v is not torch.autocast + ): # Fixup inconsistency in segment_reduce if name == "_segment_reduce": name = name[1:] _builtin_ops.append((v, "aten::" + name)) + for mod in _modules_containing_builtins: register_all(mod) @@ -148,6 +167,7 @@ def _get_builtin_table(): _builtin_ops.append((math.remainder, "aten::mathremainder")) # type: ignore[attr-defined] import torch.distributed.autograd as dist_autograd + if dist_autograd.is_available(): _builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients")) _builtin_ops.append((dist_autograd.backward, "aten::dist_backward")) diff --git a/torch/jit/_check.py b/torch/jit/_check.py index 9d8557d9d2c5..6b8ed52033d6 100644 --- a/torch/jit/_check.py +++ b/torch/jit/_check.py @@ -1,10 +1,11 @@ - import ast import inspect import textwrap -import torch import warnings +import torch + + class AttributeTypeIsSupportedChecker(ast.NodeVisitor): """ Checks the ``__init__`` method of a given ``nn.Module`` to ensure @@ -64,7 +65,10 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor): def is_useless_comment(line): line = line.strip() return line.startswith("#") and not line.startswith("# type:") - source_lines = "\n".join([l for l in source_lines.split("\n") if not is_useless_comment(l)]) + + source_lines = "\n".join( + [l for l in source_lines.split("\n") if not is_useless_comment(l)] + ) # This AST only contains the `__init__` method of the nn.Module init_ast = ast.parse(textwrap.dedent(source_lines)) @@ -114,8 +118,10 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor): target built in.) """ try: - if (isinstance(node.value, ast.Call) - and node.targets[0].attr in self.class_level_annotations): + if ( + isinstance(node.value, ast.Call) + and node.targets[0].attr in self.class_level_annotations + ): self.visiting_class_level_ann = True except AttributeError: return @@ -169,11 +175,13 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor): if not self._is_empty_container(node.value, ann_type): return - warnings.warn("The TorchScript type system doesn't support " - "instance-level annotations on empty non-base " - "types in `__init__`. Instead, either 1) use a " - "type annotation in the class body, or 2) wrap " - "the type in `torch.jit.Attribute`.") + warnings.warn( + "The TorchScript type system doesn't support " + "instance-level annotations on empty non-base " + "types in `__init__`. Instead, either 1) use a " + "type annotation in the class body, or 2) wrap " + "the type in `torch.jit.Attribute`." + ) def visit_Call(self, node): """ @@ -188,12 +196,15 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor): # If this isn't a call to `torch.jit.annotate` try: - if (node.func.value.value.id != "torch" - or node.func.value.attr != "jit" - or node.func.attr != "annotate"): + if ( + node.func.value.value.id != "torch" + or node.func.value.attr != "jit" + or node.func.attr != "annotate" + ): self.generic_visit(node) - elif (node.func.value.value.id != "jit" - or node.func.value.attr != "annotate"): + elif ( + node.func.value.value.id != "jit" or node.func.value.attr != "annotate" + ): self.generic_visit(node) except AttributeError: # Looks like we didn't even have the right node structure @@ -217,7 +228,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor): containers = {"List", "Dict", "Optional"} try: - ann_type = node.args[0].value.id # type: ignore[attr-defined] + ann_type = node.args[0].value.id # type: ignore[attr-defined] except AttributeError: return @@ -228,8 +239,10 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor): if not self._is_empty_container(node.args[1], ann_type): return - warnings.warn("The TorchScript type system doesn't support " - "instance-level annotations on empty non-base " - "types in `__init__`. Instead, either 1) use a " - "type annotation in the class body, or 2) wrap " - "the type in `torch.jit.Attribute`.") + warnings.warn( + "The TorchScript type system doesn't support " + "instance-level annotations on empty non-base " + "types in `__init__`. Instead, either 1) use a " + "type annotation in the class body, or 2) wrap " + "the type in `torch.jit.Attribute`." + ) diff --git a/torch/jit/_dataclass_impls.py b/torch/jit/_dataclass_impls.py index 6adfa4f70100..52056ce46bea 100644 --- a/torch/jit/_dataclass_impls.py +++ b/torch/jit/_dataclass_impls.py @@ -1,20 +1,22 @@ # Functions for synthesizing magic methods for JIT-compiled dataclasses -import os -from functools import partial -from torch._jit_internal import is_optional, FAKE_FILENAME_PREFIX -from torch._sources import ParsedDef, SourceContext -from typing import Callable, Dict, List import ast import dataclasses import inspect +import os +from functools import partial +from typing import Callable, Dict, List + +from torch._jit_internal import FAKE_FILENAME_PREFIX, is_optional +from torch._sources import ParsedDef, SourceContext + def _get_fake_filename(cls, method_name): return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name) def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef: - body = '\n'.join(f' {b}' for b in body_lines) - decl = f'def {name}{signature}:\n{body}' + body = "\n".join(f" {b}" for b in body_lines) + decl = f"def {name}{signature}:\n{body}" # Parse the function declaration try: @@ -31,22 +33,24 @@ def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedD return ParsedDef( py_ast, ctx=SourceContext( - source=decl, - filename=fake_filename, - file_lineno=0, - leading_whitespace_len=0 + source=decl, filename=fake_filename, file_lineno=0, leading_whitespace_len=0 ), source=decl, filename=fake_filename, - file_lineno=0 + file_lineno=0, ) def synthesize__init__(cls) -> ParsedDef: # Supporting default factories in the way that people expect would sort of require us to # allow compiling lambda functions, which is not currently supported. - if any(field.default_factory is not dataclasses.MISSING for field in dataclasses.fields(cls)): - raise NotImplementedError("Default factory initializers are not supported in TorchScript dataclasses") + if any( + field.default_factory is not dataclasses.MISSING + for field in dataclasses.fields(cls) + ): + raise NotImplementedError( + "Default factory initializers are not supported in TorchScript dataclasses" + ) # Simply read off the generated __init__ signature from CPython's implementation. It'll be # almost correct except for InitVar annotations, which we need to handle specially. @@ -62,7 +66,7 @@ def synthesize__init__(cls) -> ParsedDef: if isinstance(ann, dataclasses.InitVar): # The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here init_vars.append(name) - params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined] + params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined] else: params.append(param) @@ -70,72 +74,107 @@ def synthesize__init__(cls) -> ParsedDef: body = [ # Assign all attributes to self - f'self.{field.name} = {field.name}' + f"self.{field.name} = {field.name}" for field in dataclasses.fields(cls) if field.init and field.name not in init_vars ] # Call user's impl of __post_init__ if it exists - if hasattr(cls, '__post_init__'): - body.append('self.__post_init__(' + ', '.join(init_vars) + ')') + if hasattr(cls, "__post_init__"): + body.append("self.__post_init__(" + ", ".join(init_vars) + ")") + + return compose_fn(cls, "__init__", body or ["pass"], signature=str(signature)) - return compose_fn(cls, '__init__', body or ['pass'], signature=str(signature)) # This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__ def synthesize__repr__(cls) -> ParsedDef: return compose_fn( - cls, '__repr__', - [f"return '{cls.__name__}(" + ", ".join([ - f"{field.name}=self.{field.name}" - for field in dataclasses.fields(cls) if field.repr - ]) + ")'"], - signature='(self) -> str' + cls, + "__repr__", + [ + f"return '{cls.__name__}(" + + ", ".join( + [ + f"{field.name}=self.{field.name}" + for field in dataclasses.fields(cls) + if field.repr + ] + ) + + ")'" + ], + signature="(self) -> str", ) + def synthesize__hash__(cls) -> ParsedDef: return compose_fn( - cls, '__hash__', + cls, + "__hash__", [ # This is just a placeholder to prevent compilation from failing; this won't even get called at # all right now because the TorchScript interpreter doesn't call custom __hash__ implementations "raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')" ], - signature='(self) -> int' + signature="(self) -> int", ) + # Implementation for __eq__ and __ne__ def synthesize_equality(cls, name: str, converse: str) -> ParsedDef: - return synthesize_comparison(cls, name, allow_eq=True, raise_on_none=False, inner=[ - f"if val1 {converse} val2: return False" - ]) + return synthesize_comparison( + cls, + name, + allow_eq=True, + raise_on_none=False, + inner=[f"if val1 {converse} val2: return False"], + ) + def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef: - return synthesize_comparison(cls, name, allow_eq, raise_on_none=True, inner=[ - f"if val1 {op} val2: return True", - f"elif val2 {op} val1: return False", - ]) + return synthesize_comparison( + cls, + name, + allow_eq, + raise_on_none=True, + inner=[ + f"if val1 {op} val2: return True", + f"elif val2 {op} val1: return False", + ], + ) -def synthesize_comparison(cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]) -> ParsedDef: + +def synthesize_comparison( + cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str] +) -> ParsedDef: body = [] for field in dataclasses.fields(cls): if not field.compare: continue - body.extend([ - f"val1 = self.{field.name}", - f"val2 = other.{field.name}", - ]) body.extend( - inner if not is_optional(field.type) else [ + [ + f"val1 = self.{field.name}", + f"val2 = other.{field.name}", + ] + ) + body.extend( + inner + if not is_optional(field.type) + else [ # Type refinement for optional fields; we need this to avoid type errors from the interpreter "if val1 is not None and val2 is not None:", - *[' ' + line for line in inner], + *[" " + line for line in inner], "elif (val1 is None) != (val2 is None):", - f" raise TypeError('Cannot compare {cls.__name__} with None')" if raise_on_none else " return False" + f" raise TypeError('Cannot compare {cls.__name__} with None')" + if raise_on_none + else " return False", ] ) body.append(f"return {allow_eq}") - return compose_fn(cls, name, body, signature=f'(self, other: {cls.__name__}) -> bool') + return compose_fn( + cls, name, body, signature=f"(self, other: {cls.__name__}) -> bool" + ) + DATACLASS_MAGIC_METHODS: Dict[str, Callable] = { "__init__": synthesize__init__, diff --git a/torch/jit/_decomposition_utils.py b/torch/jit/_decomposition_utils.py index 3aa9b670ed4a..fb4448e2b900 100644 --- a/torch/jit/_decomposition_utils.py +++ b/torch/jit/_decomposition_utils.py @@ -1,8 +1,11 @@ import torch from torch._ops import OpOverload, OpOverloadPacket + def _register_decomposition(op: OpOverload, graph: torch._C.Graph): - assert not isinstance(op, OpOverloadPacket), f"Must pass specific op overload, not overload packet, found {op}" + assert not isinstance( + op, OpOverloadPacket + ), f"Must pass specific op overload, not overload packet, found {op}" assert isinstance(op, OpOverload) torch._C._jit_register_decomposition_for_schema(op._schema, graph) diff --git a/torch/jit/_decompositions.py b/torch/jit/_decompositions.py index beeca5df4e95..babb70eaf7cb 100644 --- a/torch/jit/_decompositions.py +++ b/torch/jit/_decompositions.py @@ -1,25 +1,29 @@ - - import torch from torch import Tensor + aten = torch.ops.aten -from typing import Optional, List, Dict, Set import inspect import warnings +from typing import Dict, List, Optional, Set + from torch.types import Number decomposition_table: Dict[str, torch.jit.ScriptFunction] = {} function_name_set: Set[str] = set() -def check_decomposition_has_type_annotations(f): +def check_decomposition_has_type_annotations(f): inspect_empty = inspect._empty # type: ignore[attr-defined] sig = inspect.signature(f) for param in sig.parameters.values(): - assert param.annotation != inspect_empty, \ - f"No signature on param {param.name} for function {f.name}" + assert ( + param.annotation != inspect_empty + ), f"No signature on param {param.name} for function {f.name}" + + assert ( + sig.return_annotation != inspect_empty + ), f"No return annotation for function {f.name}" - assert sig.return_annotation != inspect_empty, f"No return annotation for function {f.name}" def signatures_match(decomposition_sig, torch_op_sig): decomp_params = decomposition_sig.parameters @@ -28,15 +32,14 @@ def signatures_match(decomposition_sig, torch_op_sig): if len(decomp_params) != len(op_params): return False - for decomp_param, op_param in zip(decomp_params.values(), op_params.values()): # can't check full equality yet because not all fields are correcly deduced # in the torch_op_sig - like default value # can't check 'kind' bc # kwarg-only values with defaults not yet supported in TS inspect_empty = inspect._empty # type: ignore[attr-defined] - for field in ['name', 'annotation']: - if field == 'name' and decomp_param.name == "self": + for field in ["name", "annotation"]: + if field == "name" and decomp_param.name == "self": warnings.warn("PyTorch uses 'input' instead of 'self' on public api") if getattr(decomp_param, field) != getattr(op_param, field): @@ -52,6 +55,7 @@ def signatures_match(decomposition_sig, torch_op_sig): return decomposition_sig.return_annotation == torch_op_sig.return_annotation + def register_decomposition(aten_op, registry=None): def decomposition_decorator(f): nonlocal registry @@ -61,7 +65,9 @@ def register_decomposition(aten_op, registry=None): assert isinstance(aten_op, torch._ops.OpOverload) # Need unique name for jit function serialization - assert f.__name__ not in function_name_set, f"Duplicated function name {f.__name__}" + assert ( + f.__name__ not in function_name_set + ), f"Duplicated function name {f.__name__}" function_name_set.add(f.__name__) scripted_func = torch.jit.script(f) @@ -76,12 +82,17 @@ def register_decomposition(aten_op, registry=None): return decomposition_decorator + # TODO: replace torch.sigmoid -> aten.sigmoid + @register_decomposition(aten.var.correction) -def var_decomposition(input: Tensor, dim: Optional[List[int]] = None, - correction: Optional[Number] = None, - keepdim: bool = False) -> Tensor: +def var_decomposition( + input: Tensor, + dim: Optional[List[int]] = None, + correction: Optional[Number] = None, + keepdim: bool = False, +) -> Tensor: if dim is None: dim_i: List[int] = [] dim = dim_i @@ -110,6 +121,7 @@ def var_decomposition(input: Tensor, dim: Optional[List[int]] = None, return sum / max(0, denom) + @register_decomposition(aten.var.default) def var(input: Tensor, unbiased: bool = True) -> Tensor: return var_decomposition(input, correction=(1 if unbiased else 0)) diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py index 0db888f6411d..5725c36ff985 100644 --- a/torch/jit/_freeze.py +++ b/torch/jit/_freeze.py @@ -4,13 +4,15 @@ This is not intended to be imported directly; please use the exposed functionalities in `torch.jit`. """ -from typing import Optional, List +from typing import List, Optional import torch from torch.jit._script import RecursiveScriptModule, ScriptModule -def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True): +def freeze( + mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True +): r""" Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned module's submodules, parameters, and attributes as constants in the TorchScript IR Graph. @@ -174,7 +176,9 @@ def run_frozen_optimizations( ) -def optimize_for_inference(mod: ScriptModule, other_methods: Optional[List[str]] = None) -> ScriptModule: +def optimize_for_inference( + mod: ScriptModule, other_methods: Optional[List[str]] = None +) -> ScriptModule: """ Performs a set of optimization passes to optimize a model for the purposes of inference. If the model is not already frozen, optimize_for_inference @@ -206,7 +210,8 @@ def optimize_for_inference(mod: ScriptModule, other_methods: Optional[List[str]] if not isinstance(mod, ScriptModule): raise RuntimeError( "optimize_for_inference expects a ScriptModule as input. " - "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.") + "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'." + ) if other_methods is None: other_methods = [] diff --git a/torch/jit/_fuser.py b/torch/jit/_fuser.py index ddab9d99cf2b..4ba275e52aab 100644 --- a/torch/jit/_fuser.py +++ b/torch/jit/_fuser.py @@ -1,7 +1,8 @@ import contextlib +from typing import List, Tuple import torch -from typing import List, Tuple + @contextlib.contextmanager def optimized_execution(should_optimize): @@ -16,6 +17,7 @@ def optimized_execution(should_optimize): finally: torch._C._set_graph_executor_optimize(stored_flag) + @contextlib.contextmanager def fuser(name): """ @@ -33,13 +35,13 @@ def fuser(name): old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() old_nvfuser_state = torch._C._jit_nvfuser_enabled() old_llga_state = torch._C._jit_llga_enabled() - if name == 'fuser0': # legacy fuser + if name == "fuser0": # legacy fuser torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(False) torch._C._jit_set_llga_enabled(False) - elif name == 'fuser1': # NNC + elif name == "fuser1": # NNC old_profiling_executor = torch._C._jit_set_profiling_executor(True) old_profiling_mode = torch._C._get_graph_executor_optimize(True) torch._C._jit_override_can_fuse_on_cpu(True) @@ -47,13 +49,13 @@ def fuser(name): torch._C._jit_set_texpr_fuser_enabled(True) torch._C._jit_set_nvfuser_enabled(False) torch._C._jit_set_llga_enabled(False) - elif name == 'fuser2': # nvFuser + elif name == "fuser2": # nvFuser torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(True) torch._C._jit_set_llga_enabled(False) - elif name == 'fuser3': # oneDNN Graph + elif name == "fuser3": # oneDNN Graph old_profiling_executor = torch._C._jit_set_profiling_executor(True) old_profiling_mode = torch._C._get_graph_executor_optimize(True) torch._C._jit_override_can_fuse_on_cpu(True) @@ -61,7 +63,7 @@ def fuser(name): torch._C._jit_set_texpr_fuser_enabled(True) torch._C._jit_set_nvfuser_enabled(False) torch._C._jit_set_llga_enabled(True) - elif name == 'none': # Turn Pytorch fuser off + elif name == "none": # Turn Pytorch fuser off torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) @@ -72,7 +74,7 @@ def fuser(name): try: yield finally: - if name in ['fuser1', 'fuser3']: # NNC or oneDNN Graph + if name in ["fuser1", "fuser3"]: # NNC or oneDNN Graph torch._C._jit_set_profiling_executor(old_profiling_executor) torch._C._get_graph_executor_optimize(old_profiling_mode) # recover the previous values @@ -85,22 +87,25 @@ def fuser(name): last_executed_optimized_graph = torch._C._last_executed_optimized_graph + def _get_differentiable_graph_node(node, diff_node): - if node.kind() == 'prim::DifferentiableGraph': + if node.kind() == "prim::DifferentiableGraph": diff_node.append(node) else: for block in node.blocks(): for n in block.nodes(): _get_differentiable_graph_node(n, diff_node) + def _graph_for(self, *args, **kwargs): return _script_method_graph_for(self, self, *args, **kwargs) + def _script_method_graph_for(self, parent, *args, **kwargs): try: dbs = parent.get_debug_state() eps = list(dbs.execution_plans.values()) - assert(len(eps) == 1) + assert len(eps) == 1 graph = eps[0].graph.copy() # graph_executor_states for differentiable node @@ -109,7 +114,7 @@ def _script_method_graph_for(self, parent, *args, **kwargs): for n in graph.nodes(): _get_differentiable_graph_node(n, diff_nodes) - assert(len(fw_states) == len(diff_nodes)) + assert len(fw_states) == len(diff_nodes) # swap each differentiable graph with optimized graph in their execution plan for n, state in zip(diff_nodes, fw_states): fw_execution_plans = list(state.execution_plans.values()) @@ -117,7 +122,7 @@ def _script_method_graph_for(self, parent, *args, **kwargs): # plan. Avoid assert here so we would skip the ones that can't be # updated while try the best effort to update other nodes. if len(fw_execution_plans) == 1: - n.g_('Subgraph', fw_execution_plans[0].graph) + n.g_("Subgraph", fw_execution_plans[0].graph) return graph except Exception: @@ -126,6 +131,7 @@ def _script_method_graph_for(self, parent, *args, **kwargs): self(*args, **kwargs) return last_executed_optimized_graph() + def set_fusion_strategy(strategy: List[Tuple[str, int]]): """ Sets the type and number of specializations that can occur during fusion. diff --git a/torch/jit/_ir_utils.py b/torch/jit/_ir_utils.py index 9e4596de7758..028247f54011 100644 --- a/torch/jit/_ir_utils.py +++ b/torch/jit/_ir_utils.py @@ -1,8 +1,14 @@ -import torch from typing import Union +import torch + + class _InsertPoint: - def __init__(self, insert_point_graph: torch._C.Graph, insert_point: Union[torch._C.Node, torch._C.Block]): + def __init__( + self, + insert_point_graph: torch._C.Graph, + insert_point: Union[torch._C.Node, torch._C.Block], + ): self.insert_point = insert_point self.g = insert_point_graph self.guard = None @@ -14,5 +20,6 @@ class _InsertPoint: def __exit__(self, *args): self.g.setInsertPoint(self.prev_insert_point) + def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]): return _InsertPoint(self, insert_point) diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index 9957541ff25d..dbc2769341c5 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -1,36 +1,41 @@ - -import torch - import inspect -import typing import pathlib import sys -from typing import Optional, Iterable, List, Dict +import typing from collections import defaultdict from types import CodeType +from typing import Dict, Iterable, List, Optional + +import torch _IS_MONKEYTYPE_INSTALLED = True try: import monkeytype # type: ignore[import] from monkeytype import trace as monkeytype_trace - from monkeytype.db.base import CallTraceThunk, CallTraceStore, CallTraceStoreLogger # type: ignore[import] from monkeytype.config import _startswith, LIB_PATHS # type: ignore[import] + from monkeytype.db.base import ( # type: ignore[import] + CallTraceStore, + CallTraceStoreLogger, + CallTraceThunk, + ) from monkeytype.tracing import CallTrace, CodeFilter # type: ignore[import] except ImportError: _IS_MONKEYTYPE_INSTALLED = False + # Checks whether a class is defind in `torch.*` modules def is_torch_native_class(cls): - if not hasattr(cls, '__module__'): + if not hasattr(cls, "__module__"): return False - parent_modules = cls.__module__.split('.') + parent_modules = cls.__module__.split(".") if not parent_modules: return False root_module = sys.modules.get(parent_modules[0]) return root_module is torch + def get_type(type): """ Helper function which converts the given type to a torchScript acceptable format. @@ -43,15 +48,16 @@ def get_type(type): # with a null string. This needs to be done since # typing.List is not accepted by TorchScript. type_to_string = str(type) - return type_to_string.replace(type.__module__ + '.', '') + return type_to_string.replace(type.__module__ + ".", "") elif is_torch_native_class(type): # If the type is a subtype of torch module, then TorchScript expects a fully qualified name # for the type which is obtained by combining the module name and type name. - return type.__module__ + '.' + type.__name__ + return type.__module__ + "." + type.__name__ else: # For all other types use the name for the type. return type.__name__ + def get_optional_of_element_type(types): """ Helper function to extracts the type of the element to be annotated to Optional @@ -63,15 +69,18 @@ def get_optional_of_element_type(types): # Optional type is internally converted to Union[type, NoneType], which # is not supported yet in TorchScript. Hence, representing the optional type as string. - return 'Optional[' + elem_type + ']' + return "Optional[" + elem_type + "]" + def get_qualified_name(func): return func.__qualname__ + if _IS_MONKEYTYPE_INSTALLED: class JitTypeTraceStoreLogger(CallTraceStoreLogger): """A JitTypeCallTraceLogger that stores logged traces in a CallTraceStore.""" + def __init__(self, store: CallTraceStore): super().__init__(store) @@ -95,7 +104,7 @@ if _IS_MONKEYTYPE_INSTALLED: self, qualified_name: str, qualname_prefix: Optional[str] = None, - limit: int = 2000 + limit: int = 2000, ) -> List[CallTraceThunk]: return self.trace_records[qualified_name] @@ -122,7 +131,7 @@ if _IS_MONKEYTYPE_INSTALLED: # TODO: To remove this check once Union suppport in TorchScript lands. all_args[arg] = get_optional_of_element_type(types) elif type_length > 1: - all_args[arg] = 'Any' + all_args[arg] = "Any" elif type_length == 1: all_args[arg] = get_type(types[0]) return all_args @@ -147,6 +156,7 @@ if _IS_MONKEYTYPE_INSTALLED: def code_filter(self) -> Optional[CodeFilter]: return jit_code_filter + else: # When MonkeyType is not installed, we provide dummy class definitions # for the below classes. @@ -164,6 +174,7 @@ else: monkeytype_trace = None # noqa: F811 + def jit_code_filter(code: CodeType) -> bool: """ Custom CodeFilter for Torchscript to trace forward calls. @@ -176,7 +187,9 @@ def jit_code_filter(code: CodeType) -> bool: excludes tracing of stdlib and site-packages. """ # Filter code without a source file and exclude this check for 'forward' calls. - if code.co_name != 'forward' and (not code.co_filename or code.co_filename[0] == '<'): + if code.co_name != "forward" and ( + not code.co_filename or code.co_filename[0] == "<" + ): return False filename = pathlib.Path(code.co_filename).resolve() diff --git a/torch/jit/_passes/_property_propagation.py b/torch/jit/_passes/_property_propagation.py index b0a307cfc0d6..8ebd21e4bc10 100644 --- a/torch/jit/_passes/_property_propagation.py +++ b/torch/jit/_passes/_property_propagation.py @@ -28,14 +28,19 @@ def apply_input_props_using_example(graph: Graph, example_input: List[Any]): if not len(graph_inputs) == len(example_input): raise RuntimeError( - "Number of inputs in graph does not match number of inputs in the example") + "Number of inputs in graph does not match number of inputs in the example" + ) for i, (graph_i, example_i) in enumerate(zip(graph_inputs, example_input)): if example_i is None: continue # Skip the type check - if isinstance(example_i, torch.Tensor) != isinstance(graph_i.type(), TensorType): - raise RuntimeError(f"Input {i} does not match type of example", graph_i, example_i) + if isinstance(example_i, torch.Tensor) != isinstance( + graph_i.type(), TensorType + ): + raise RuntimeError( + f"Input {i} does not match type of example", graph_i, example_i + ) if isinstance(example_i, torch.Tensor): graph_i.setType(TensorType.create_from_tensor(example_i)) # type: ignore[arg-type] diff --git a/torch/jit/_pickle.py b/torch/jit/_pickle.py index db2982e822d4..1cb4a0a93efd 100644 --- a/torch/jit/_pickle.py +++ b/torch/jit/_pickle.py @@ -7,6 +7,7 @@ # a type attached and restored via `restore_type_tag` below. The legacy # functions should stick around for backwards-compatibility. + def build_intlist(data): return data diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index f1f4e66cc20f..aa4e370e3b25 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -1,24 +1,32 @@ -import inspect -import torch -import types import collections -import textwrap import functools -import warnings +import inspect import sys +import textwrap +import types +import warnings from typing import Dict, List, Set, Type +import torch + import torch._jit_internal as _jit_internal from torch._sources import fake_range -from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def, get_class_properties from torch.jit._builtins import _find_builtin from torch.jit._check import AttributeTypeIsSupportedChecker -from torch.jit._state import _python_cu, _add_script_class, _get_script_class +from torch.jit._state import _add_script_class, _get_script_class, _python_cu +from torch.jit.frontend import ( + get_class_properties, + get_default_args, + get_jit_class_def, + get_jit_def, +) from torch.nn import Module -ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method')) -PropertyStub = collections.namedtuple('PropertyStub', ('resolution_callback', 'def_')) +ScriptMethodStub = collections.namedtuple( + "ScriptMethodStub", ("resolution_callback", "def_", "original_method") +) +PropertyStub = collections.namedtuple("PropertyStub", ("resolution_callback", "def_")) # TODO: there should be a more principled way of doing this. @@ -43,22 +51,27 @@ ignored_attributes = [ "dump_patches", ] + def _compile_and_register_class(obj, rcb, qualified_name): script_class = _get_script_class(obj) if not script_class: ast = get_jit_class_def(obj, obj.__name__) defaults = torch.jit.frontend.get_default_args_for_class(obj) - script_class = torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb) + script_class = torch._C._jit_script_class_compile( + qualified_name, ast, defaults, rcb + ) _add_script_class(obj, script_class) return script_class + def make_stub(func, name): rcb = _jit_internal.createResolutionCallbackFromClosure(func) ast = get_jit_def(func, name, self_name="RecursiveScriptModule") return ScriptMethodStub(rcb, ast, func) + def make_stub_from_method(nn_module, method_name): func = getattr(nn_module, method_name) if isinstance(func, ScriptMethodStub): @@ -86,8 +99,11 @@ def make_stubs_from_exported_methods(mod): return stubs + def jit_ignored_properties(module): - user_annotated_ignored_attributes = getattr(module, "__jit_ignored_attributes__", list()) + user_annotated_ignored_attributes = getattr( + module, "__jit_ignored_attributes__", list() + ) def get_properties_names(module): return {k for k, v in vars(module).items() if isinstance(v, property)} @@ -100,11 +116,22 @@ def jit_ignored_properties(module): user_annoted_ignored_properties.add(ignored_attr) return user_annoted_ignored_properties + # base types that can be constants # in addition, tuples and lists of these base types are also considered constants # If you edit this list, then you also need to edit the handlers in # ConstantValue in jit/script/init.cpp -_constant_types = (bool, float, int, str, type(None), torch.device, torch.layout, torch.dtype) +_constant_types = ( + bool, + float, + int, + str, + type(None), + torch.device, + torch.layout, + torch.dtype, +) + def _get_valid_constant(attr, v, owner_type): if isinstance(v, _constant_types): @@ -112,13 +139,17 @@ def _get_valid_constant(attr, v, owner_type): elif isinstance(v, (tuple, list)): return tuple(_get_valid_constant(attr, x, owner_type) for x in v) constants = ", ".join(torch.typename(typ) for typ in _constant_types) - raise TypeError(textwrap.dedent(f""" + raise TypeError( + textwrap.dedent( + f""" '{torch.typename(type(v))}' object in attribute '{owner_type}.{attr}' is not a valid constant. Valid constants are: 1. a nn.ModuleList 2. a value of type {{{constants}}} 3. a list or tuple of (2) - """)) + """ + ) + ) class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): @@ -128,7 +159,7 @@ class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): def get_annotations(obj): if sys.version_info < (3, 10): - return getattr(obj, '__annotations__', {}) + return getattr(obj, "__annotations__", {}) # In Python-3.10+ it is recommended to use inspect.get_annotations # See https://docs.python.org/3.10/howto/annotations.html # But also, in 3.10 annotations from base class are not inherited @@ -172,7 +203,9 @@ def infer_concrete_type_builder(nn_module, share_types=True): class_annotations = {} # Get user-annotated ignored attributes. - user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list()) + user_annotated_ignored_attributes = getattr( + nn_module, "__jit_ignored_attributes__", list() + ) concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes) ignored_properties = jit_ignored_properties(nn_module) @@ -185,8 +218,14 @@ def infer_concrete_type_builder(nn_module, share_types=True): # is also true! inferred = False try: - if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]: - ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], fake_range()) + if ( + name in class_annotations + and class_annotations[name] + != torch.nn.Module.__annotations__["forward"] + ): + ann_to_type = torch.jit.annotations.ann_to_type( + class_annotations[name], fake_range() + ) attr_type = torch._C.InferredType(ann_to_type) elif isinstance(item, torch.jit.Attribute): ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range()) @@ -195,9 +234,7 @@ def infer_concrete_type_builder(nn_module, share_types=True): attr_type = torch._C._jit_try_infer_type(item) inferred = True except RuntimeError as re: - raise RuntimeError( - f"Error inferring type for {name}: {item}: {re}" - ) from re + raise RuntimeError(f"Error inferring type for {name}: {item}: {re}") from re return attr_type, inferred @@ -239,7 +276,9 @@ def infer_concrete_type_builder(nn_module, share_types=True): if attr_type.success(): assert attr_type.type().is_interface_type() # if the type can be inferred, it should be a module interface type - sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(attr_type.type()) + sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type( + attr_type.type() + ) else: # otherwise we get the concrete module type for item and add it to concrete_type sub_concrete_type = get_module_concrete_type(item, share_types) @@ -266,26 +305,38 @@ def infer_concrete_type_builder(nn_module, share_types=True): elif name in nn_module._parameters: hint = "parameter" else: - raise AssertionError("added_names must be submodule, parameter, or buffer") + raise AssertionError( + "added_names must be submodule, parameter, or buffer" + ) - warnings.warn("'{}' was found in ScriptModule constants, " - " but it is a non-constant {}. Consider removing it.".format(name, hint)) + warnings.warn( + "'{}' was found in ScriptModule constants, " + " but it is a non-constant {}. Consider removing it.".format(name, hint) + ) continue if not hasattr(nn_module, name): # TODO: We should really error in this case, but its bc-breaking so # we need to warn for at least one release - warnings.warn("'{}' was found in ScriptModule constants, " - "but was not actually set in __init__. " - "Consider removing it.".format(name)) + warnings.warn( + "'{}' was found in ScriptModule constants, " + "but was not actually set in __init__. " + "Consider removing it.".format(name) + ) continue value = getattr(nn_module, name) - concrete_type_builder.add_constant(name, _get_valid_constant(name, value, type(nn_module).__name__)) + concrete_type_builder.add_constant( + name, _get_valid_constant(name, value, type(nn_module).__name__) + ) added_names.add(name) # populate overloads overloads = getattr(nn_module, "__overloads__", {}) # update with any annotated overloads - overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module, ignored_properties))) + overloads.update( + get_overload_name_mapping( + get_overload_annotations(nn_module, ignored_properties) + ) + ) for name, overloaded_names in overloads.items(): concrete_type_builder.add_overload(name, overloaded_names) @@ -310,16 +361,17 @@ def infer_concrete_type_builder(nn_module, share_types=True): try: scripted_fn = torch.jit.script(value) concrete_type_builder.add_function_attribute( - name, - torch._C._jit_try_infer_type(scripted_fn).type(), - value) + name, torch._C._jit_try_infer_type(scripted_fn).type(), value + ) except Exception as e: # If we fail to script the function, it isn't a hard error. # Instead, we will add it to the list of attributes we failed # to convert, with the compilation error. - hint = ("(This function exists as an attribute on the Python module, " - "but we failed to compile it to a TorchScript function. " - "\nThe error stack is reproduced here:\n{}").format(e) + hint = ( + "(This function exists as an attribute on the Python module, " + "but we failed to compile it to a TorchScript function. " + "\nThe error stack is reproduced here:\n{}" + ).format(e) concrete_type_builder.add_failed_attribute(name, hint) pass @@ -335,9 +387,8 @@ def infer_concrete_type_builder(nn_module, share_types=True): # Handle Script function attributes if isinstance(value, torch.jit.ScriptFunction): concrete_type_builder.add_function_attribute( - name, - torch._C._jit_try_infer_type(value).type(), - value) + name, torch._C._jit_try_infer_type(value).type(), value + ) continue # If we got here, this is a regular "data" attribute, add it to the concrete type @@ -347,11 +398,17 @@ def infer_concrete_type_builder(nn_module, share_types=True): else: # TODO: could add more detail here. For example, what the user should do # when the pytype is `list` or `NoneType` - inferred_msg = "Its type was inferred; try adding a type annotation for the attribute." if inferred else "" + inferred_msg = ( + "Its type was inferred; try adding a type annotation for the attribute." + if inferred + else "" + ) additional_info = f"{attr_type.reason()}. {inferred_msg}" - hint = "(This attribute exists on the Python module, " \ - f"but we failed to convert Python type: '{torch.typename(type(value))}' " \ + hint = ( + "(This attribute exists on the Python module, " + f"but we failed to convert Python type: '{torch.typename(type(value))}' " f"to a TorchScript type. {additional_info})" + ) concrete_type_builder.add_failed_attribute(name, hint) # add hooks to concrete type @@ -362,6 +419,7 @@ def infer_concrete_type_builder(nn_module, share_types=True): return concrete_type_builder + class ConcreteTypeStore: type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]] methods_compiled: Set[torch._C.ConcreteModuleType] @@ -394,10 +452,13 @@ class ConcreteTypeStore: self.type_store[nn_module_type].append(concrete_type) return concrete_type + concrete_type_store = ConcreteTypeStore() -def create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs): +def create_methods_and_properties_from_stubs( + concrete_type, method_stubs, property_stubs +): method_defs = [m.def_ for m in method_stubs] method_rcbs = [m.resolution_callback for m in method_stubs] method_defaults = [get_default_args(m.original_method) for m in method_stubs] @@ -405,7 +466,10 @@ def create_methods_and_properties_from_stubs(concrete_type, method_stubs, proper property_defs = [p.def_ for p in property_stubs] property_rcbs = [p.resolution_callback for p in property_stubs] - concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults) + concrete_type._create_methods_and_properties( + property_defs, property_rcbs, method_defs, method_rcbs, method_defaults + ) + def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs): hook_defs = [h.def_ for h in hook_stubs] @@ -416,6 +480,7 @@ def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs): concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs) + def get_module_concrete_type(nn_module, share_types=True): """ Gets a concrete type for nn_modules. If share_types is True, the concrete @@ -430,8 +495,9 @@ def get_module_concrete_type(nn_module, share_types=True): A concrete type for nn_module. """ assert isinstance(nn_module, Module) - if isinstance(nn_module, torch.jit.ScriptModule) and \ - hasattr(nn_module, "_concrete_type"): + if isinstance(nn_module, torch.jit.ScriptModule) and hasattr( + nn_module, "_concrete_type" + ): return nn_module._concrete_type if share_types: @@ -446,6 +512,7 @@ def get_module_concrete_type(nn_module, share_types=True): return concrete_type + def create_script_class(obj): """ Create and return a RecursiveScriptClass instance from a Python object. @@ -467,6 +534,7 @@ def create_script_class(obj): # Wrap the torch._C.ScriptObject in a RecursiveScriptClass instance. return wrap_cpp_class(cpp_object) + def create_script_module(nn_module, stubs_fn, share_types=True, is_tracing=False): """ Creates a new ScriptModule from an nn.Module @@ -490,6 +558,7 @@ def create_script_module(nn_module, stubs_fn, share_types=True, is_tracing=False AttributeTypeIsSupportedChecker().check(nn_module) return create_script_module_impl(nn_module, concrete_type, stubs_fn) + def create_script_module_impl(nn_module, concrete_type, stubs_fn): """ Convert an nn.Module to a RecursiveScriptModule. @@ -504,7 +573,9 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): property_stubs = get_property_stubs(nn_module) hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module) - user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list()) + user_annotated_ignored_attributes = getattr( + nn_module, "__jit_ignored_attributes__", list() + ) ignored_properties = jit_ignored_properties(nn_module) def init_fn(script_module): @@ -512,14 +583,20 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. for name in concrete_type.get_attributes().keys(): orig_value = getattr(nn_module, name) - orig_value = orig_value.value if isinstance(orig_value, torch.jit.Attribute) else orig_value + orig_value = ( + orig_value.value + if isinstance(orig_value, torch.jit.Attribute) + else orig_value + ) cpp_module.setattr(name, orig_value) # 2. Copy the submodules from the original `nn_module` to the new ScriptModule, # recursively scripting them. for name, sub_concrete_type in concrete_type.get_modules(): orig_value = getattr(nn_module, name) - assert isinstance(orig_value, Module), f"Expected Module but got {type(orig_value)}" + assert isinstance( + orig_value, Module + ), f"Expected Module but got {type(orig_value)}" module_type = sub_concrete_type.jit_type if isinstance(module_type, torch._C.InterfaceType): # use the interface inference rule to compile the module @@ -528,7 +605,9 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): scripted = orig_value else: # always reuse the provided stubs_fn to infer the methods to compile - scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn) + scripted = create_script_module_impl( + orig_value, sub_concrete_type, stubs_fn + ) cpp_module.setattr(name, scripted) script_module._modules[name] = scripted @@ -554,7 +633,9 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): # Compile methods if necessary if concrete_type not in concrete_type_store.methods_compiled: - create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs) + create_methods_and_properties_from_stubs( + concrete_type, method_stubs, property_stubs + ) # Create hooks after methods to ensure no name collisions between hooks and methods. # If done before, hooks can overshadow methods that aren't exported. create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs) @@ -568,20 +649,26 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): for idx, fn in enumerate(script_module._c._get_forward_hooks()): script_module._forward_hooks[idx] = fn - # Special handling so methods like __len__ work in script methods on classes derived from containers - if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)) and \ - '__len__' not in cpp_module._method_names(): + if ( + isinstance( + nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) + ) + and "__len__" not in cpp_module._method_names() + ): script_module.define(f"def __len__(self):\n return {len(nn_module)}\n") - if isinstance(nn_module, torch.nn.ModuleDict) and \ - '__contains__' not in cpp_module._method_names(): + if ( + isinstance(nn_module, torch.nn.ModuleDict) + and "__contains__" not in cpp_module._method_names() + ): if len(nn_module.keys()): keys = repr(list(nn_module.keys())) - script_module.define(f"def __contains__(self, key: str):\n return key in {keys}\n") + script_module.define( + f"def __contains__(self, key: str):\n return key in {keys}\n" + ) else: script_module.define("def __contains__(self, key: str):\n return False\n") - # Make the compiled methods available to the Python ScriptModule class. for method_stub in method_stubs: if method_stub.original_method is None: @@ -599,14 +686,15 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): # Wrap the original to propagate docstrings and such. # TODO: we don't currently do this functions that are recursively # compiled, we should. - wrapped_script_method = functools.wraps(method_stub.original_method)(script_method) + wrapped_script_method = functools.wraps(method_stub.original_method)( + script_method + ) # Add the methods to the script_module directly. This ensures they will # be found first when `name` is looked up (as opposed to the stubs or # nn.Module.forward) script_module.__dict__[name] = wrapped_script_method - # Make module properties available on the Python ScriptModule class. for property_stub in property_stubs: property_name = property_stub.def_.name().name @@ -622,7 +710,10 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): if name in ignored_properties: continue item = getattr(nn_module, name, None) - if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER: + if ( + _jit_internal.get_torchscript_modifier(item) + is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER + ): add_python_attr_to_scripted_model(script_module, nn_module, name) return script_module @@ -640,10 +731,12 @@ def script_model_defines_attr(script_model, attr): return False return script_attr != default_attr + def add_python_attr_to_scripted_model(script_model, orig, attr): if hasattr(orig, attr) and script_model_defines_attr(script_model, attr): setattr(script_model, attr, getattr(orig, attr)) + def get_overload_annotations(mod, jit_ignored_properties): # original function => [(mangled overload name, overload function)] overloads = {} @@ -657,19 +750,25 @@ def get_overload_annotations(mod, jit_ignored_properties): # builtin functions like repr() in python 2 do not have __module__ defined if hasattr(item, "__module__") and item.__module__ is not None: - method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__) + method_overloads = _jit_internal._get_overloaded_methods( + item, mod.__class__ + ) if method_overloads is None: continue if item.__func__ in method_overloads: - raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message( - 'method', item.__func__)) + raise RuntimeError( + _jit_internal.get_overload_no_implementation_error_message( + "method", item.__func__ + ) + ) names = [name + "__" + str(i) for i in range(len(method_overloads))] overloads[item] = list(zip(names, method_overloads)) return overloads + def get_overload_name_mapping(overload_info): # Same format as __overloads__ # original function => [overload names] @@ -683,39 +782,61 @@ def get_overload_name_mapping(overload_info): overload_name_mappings[original_name].append(overload_name) return overload_name_mappings + def _check_no_signature(func): - signature = torch.jit.annotations.get_signature(func, None, fake_range(), inspect.ismethod(func)) + signature = torch.jit.annotations.get_signature( + func, None, fake_range(), inspect.ismethod(func) + ) if signature is None: qual_name = _jit_internal._qualified_name(func) - raise RuntimeError(f"Must explicitly add type annotations to overloaded functions: {qual_name}") + raise RuntimeError( + f"Must explicitly add type annotations to overloaded functions: {qual_name}" + ) + def make_stubs_for_overloads(overload_info): overload_stubs = [] for orig_fn, overloads in overload_info.items(): - orig_ast = get_jit_def(orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule") + orig_ast = get_jit_def( + orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule" + ) for overload_name, overload_fn in overloads: _check_no_signature(overload_fn) - over_ast = get_jit_def(overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule") - new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, overload_name) + over_ast = get_jit_def( + overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule" + ) + new_ast = torch._C._replace_overloaded_method_decl( + over_ast.decl(), orig_ast, overload_name + ) _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn) overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn)) return overload_stubs + def check_module_initialized(mod): assert isinstance(mod, torch.nn.Module) - if not hasattr(mod, '_parameters'): - raise RuntimeError(f"'{torch.typename(type(mod))}' has not been initialized, did you forget to call 'super()'?") + if not hasattr(mod, "_parameters"): + raise RuntimeError( + f"'{torch.typename(type(mod))}' has not been initialized, did you forget to call 'super()'?" + ) # This is to avoid importing torch.distributed.nn - if not hasattr(mod, 'remote_parameters'): + if not hasattr(mod, "remote_parameters"): for name, param in mod._parameters.items(): if param is not None and torch.nn.parameter.is_lazy(param): - raise RuntimeError("'{}' has uninitialized parameters {}. Did you forget to run a forward pass?" - .format(torch.typename(type(mod)), name)) + raise RuntimeError( + "'{}' has uninitialized parameters {}. Did you forget to run a forward pass?".format( + torch.typename(type(mod)), name + ) + ) for name, buf in mod._buffers.items(): if buf is not None and torch.nn.parameter.is_lazy(buf): - raise RuntimeError("'{}' has uninitialized buffers {}. Did you forget to run a forward pass?" - .format(torch.typename(type(mod)), name)) + raise RuntimeError( + "'{}' has uninitialized buffers {}. Did you forget to run a forward pass?".format( + torch.typename(type(mod)), name + ) + ) + def infer_methods_to_compile(nn_module): """ @@ -723,22 +844,29 @@ def infer_methods_to_compile(nn_module): points for compilation (TODO add a link when the rules are published). """ check_module_initialized(nn_module) - user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list()) + user_annotated_ignored_attributes = getattr( + nn_module, "__jit_ignored_attributes__", list() + ) ignored_properties = jit_ignored_properties(nn_module) methods: List[str] = [] - if hasattr(nn_module, 'forward') and not _jit_internal.is_ignored_fn(nn_module.forward): + if hasattr(nn_module, "forward") and not _jit_internal.is_ignored_fn( + nn_module.forward + ): forward_func = getattr(nn_module.forward, "__func__", None) module_forward = getattr(torch.nn.Module, "forward", None) if forward_func != module_forward: - methods = ['forward'] + methods = ["forward"] exported = [] for name in dir(nn_module): if name in ignored_properties: continue item = getattr(nn_module, name, None) - if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT: + if ( + _jit_internal.get_torchscript_modifier(item) + is _jit_internal.FunctionModifiers.EXPORT + ): exported.append(name) methods = methods + exported @@ -821,7 +949,9 @@ def get_property_stubs(nn_module): item = getattr(module_ty, name, None) if isinstance(item, property): if not item.fget: - raise RuntimeError(f'Property {name} of {nn_module.__name__} must have a getter') + raise RuntimeError( + f"Property {name} of {nn_module.__name__} must have a getter" + ) rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget) @@ -855,6 +985,7 @@ def interface_script(mod_interface, nn_module): return create_script_module(nn_module, infer_interface_methods_to_compile) + def try_compile_fn(fn, loc): if _jit_internal.is_ignored_fn(fn): # Don't do anything for @ignore'd functions @@ -866,9 +997,11 @@ def try_compile_fn(fn, loc): return None if not inspect.isfunction(fn) and not inspect.ismethod(fn): - raise RuntimeError("`{}` is not a function. Recursive scripting only supports " - "Python functions or methods currently.\n" - "Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn)) + raise RuntimeError( + "`{}` is not a function. Recursive scripting only supports " + "Python functions or methods currently.\n" + "Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn) + ) # We don't have the actual scope where the function was defined, but we can # extract the necessary info from the closed over variables on the function @@ -876,20 +1009,25 @@ def try_compile_fn(fn, loc): rcb = _jit_internal.createResolutionCallbackFromClosure(fn) return torch.jit.script(fn, _rcb=rcb) + def wrap_cpp_class(cpp_class): """ Wrap this torch._C.Object in a Python RecursiveScriptClass. """ return torch.jit.RecursiveScriptClass(cpp_class) + def wrap_cpp_module(cpp_module): """ Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules """ + def init_fn(script_module): for name, cpp_module in torch._C.ModuleDict(script_module._c).items(): setattr(script_module, name, wrap_cpp_module(cpp_module)) - script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(script_module._c._type()) + script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type( + script_module._c._type() + ) for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()): script_module._forward_pre_hooks[idx] = fn @@ -898,6 +1036,7 @@ def wrap_cpp_module(cpp_module): return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) + def compile_unbound_method(concrete_type, fn): if _jit_internal.is_ignored_fn(fn): return None @@ -908,6 +1047,7 @@ def compile_unbound_method(concrete_type, fn): create_methods_and_properties_from_stubs(concrete_type, (stub,), ()) return stub + def lazy_bind(concrete_type, unbound_method): """ Returns a function that lazily binds `unbound_method` to a provided @@ -915,6 +1055,7 @@ def lazy_bind(concrete_type, unbound_method): shenanigans that will poison type sharing are impossible at compile time. """ + def lazy_binding_method(cpp_module, *args): def init_fn(script_module): orig_class = concrete_type.py_class diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index fb38b535ddb3..51515039866d 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -1,5 +1,6 @@ -from typing import List, Any, Optional, Union, Dict, Callable, Tuple import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + number = Union[int, float] # flake8: noqa @@ -44,12 +45,15 @@ def broadcast(a: List[int], b: List[int]): return expandedSizes + def broadcast_three(a: List[int], b: List[int], c: List[int]): return broadcast(broadcast(a, b), c) + def broadcast_one_three(a: List[int], b: Any, c: List[int]): return broadcast(a, c) + def adaptive_avg_pool2d(self: List[int], out: List[int]): assert len(out) == 2 assert len(self) == 3 or len(self) == 4 @@ -159,7 +163,9 @@ def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool = False return view(self, sizes) -def sum_mean_dim(self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, dt: Any): +def sum_mean_dim( + self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, dt: Any +): out: List[int] = [] if opt_dims is None or len(opt_dims) == 0: dims: List[int] = list(range(len(self))) @@ -178,10 +184,12 @@ def sum_mean_dim(self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, out.append(self[idx]) return out + def max_dim(self: List[int], dim: int, keep_dim: bool): out = sum_mean_dim(self, [dim], keep_dim, None) return out, out + # note: python already rounds down towards negative infinity on integer division, special arithmetic not needed def div_rtn(x: int, y: int): return x // y @@ -354,7 +362,7 @@ def upsample_nearest2d( out.append(input[0]) out.append(input[1]) - if (scale_factors is None and output_size is None): + if scale_factors is None and output_size is None: assert 0, "Either output_size or scale_factors must be presented" if output_size is not None: @@ -424,6 +432,7 @@ def squeeze(li: List[int], dim: int): out.append(li[i]) return out + def squeeze_dims(li: List[int], dims: List[int]): if len(dims) == 0: return li @@ -439,6 +448,7 @@ def squeeze_dims(li: List[int], dims: List[int]): result.append(li[i]) return result + def index_select(self: List[int], dim: int, index: List[int]): dim = maybe_wrap_dim(dim, len(self)) numel = multiply_integers(index) @@ -505,6 +515,7 @@ def check_cat_no_zero_dim(tensors: List[List[int]]): for tensor in tensors: assert len(tensor) > 0 + def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]): out_dim: Optional[int] = None for size in tensor_sizes: @@ -760,11 +771,27 @@ def conv2d( assert len(input) == 4 return conv_output_size(input, weight, bias, stride, padding, dilation, groups) -def conv_backwards(grad_output: List[int], input:List[int], weight:List[int], biases:Optional[List[int]]): + +def conv_backwards( + grad_output: List[int], + input: List[int], + weight: List[int], + biases: Optional[List[int]], +): # Bias gradient is always generated regardess of if biases is supplied return _copy(input), _copy(weight), [grad_output[1]] -def conv_transpose2d_input(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: Optional[List[int]] = None, padding: Optional[List[int]] = None, output_padding: Optional[List[int]] = None, groups: int = 1, dilation: Optional[List[int]] = None) -> List[int]: + +def conv_transpose2d_input( + input: List[int], + weight: List[int], + bias: Optional[List[int]] = None, + stride: Optional[List[int]] = None, + padding: Optional[List[int]] = None, + output_padding: Optional[List[int]] = None, + groups: int = 1, + dilation: Optional[List[int]] = None, +) -> List[int]: if stride is None: stride = [1, 1] if padding is None: @@ -784,10 +811,27 @@ def conv_transpose2d_input(input: List[int], weight: List[int], bias: Optional[L for d in range(2, dim): dilation_ = dilation[d - 2] if has_dilation else 1 kernel = dilation_ * (weight[d] - 1) - output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + output_padding[d - 2] + 1) + output_size.append( + (input[d] - 1) * stride[d - 2] + - 2 * padding[d - 2] + + kernel + + output_padding[d - 2] + + 1 + ) return output_size -def conv_forwards(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]: + +def conv_forwards( + input: List[int], + weight: List[int], + bias: Optional[List[int]], + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, +) -> List[int]: has_dilation = len(dilation) > 0 has_output_padding = len(output_padding) > 0 dim = len(input) @@ -805,14 +849,48 @@ def conv_forwards(input: List[int], weight: List[int], bias: Optional[List[int]] output_padding_ = output_padding[d - 2] if has_output_padding else 0 if transposed: kernel = dilation_ * (weight[d] - 1) - output_size.append((input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + output_padding_ + 1) + output_size.append( + (input[d] - 1) * stride[d - 2] + - 2 * padding[d - 2] + + kernel + + output_padding_ + + 1 + ) else: kernel = dilation_ * (weight[d] - 1) + 1 - output_size.append((input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1) + output_size.append( + (input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1 + ) return output_size -def _conv_forwards(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]: - return conv_forwards(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) + +def _conv_forwards( + input: List[int], + weight: List[int], + bias: Optional[List[int]], + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + benchmark: bool, + deterministic: bool, + cudnn_enabled: bool, + allow_tf32: bool, +) -> List[int]: + return conv_forwards( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + def batch_norm( input: List[int], @@ -907,12 +985,13 @@ def permute(input: List[int], dims: List[int]): assert seen_dims[i] != seen_dims[j] return newSizes + def movedim(self: List[int], source: List[int], destination: List[int]) -> List[int]: self_dim = len(self) if self_dim <= 1: return self - normalized_src : List[int] = [] - normalized_dst : List[int] = [] + normalized_src: List[int] = [] + normalized_dst: List[int] = [] for i in range(len(source)): normalized_src.append(maybe_wrap_dim(source[i], self_dim)) normalized_dst.append(maybe_wrap_dim(destination[i], self_dim)) @@ -925,8 +1004,8 @@ def movedim(self: List[int], source: List[int], destination: List[int]) -> List[ src_dims[normalized_src[i]] = -1 dst_dims[normalized_dst[i]] = -1 - source_dims : List[int] = [] - destination_dims : List[int] = [] + source_dims: List[int] = [] + destination_dims: List[int] = [] for ele in src_dims: if ele != -1: source_dims.append(ele) @@ -939,6 +1018,7 @@ def movedim(self: List[int], source: List[int], destination: List[int]) -> List[ order[destination_dims[i]] = source_dims[i] return permute(self, order) + def flatten(input: List[int], start_dim: int, end_dim: int): start_dim = maybe_wrap_dim(start_dim, len(input)) end_dim = maybe_wrap_dim(end_dim, len(input)) @@ -964,12 +1044,15 @@ def flatten(input: List[int], start_dim: int, end_dim: int): shape.append(input[i]) return shape + def nonzero_lower_bound(input: List[int]): return [0, len(input)] + def nonzero_upper_bound(input: List[int]): return [numel(input), len(input)] + def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): dim = maybe_wrap_dim(dim, len(self)) out: List[int] = [] @@ -981,11 +1064,15 @@ def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): out.append(self_dim) return out -def argmax(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: + +def argmax( + self: List[int], dim: Optional[int] = None, keepdim: bool = False +) -> List[int]: if dim is None: return [] return _reduce_along_dim(self, dim, keepdim) + def bmm(self: List[int], mat2: List[int]) -> List[int]: assert len(self) == 3, "bmm only supports 3D tensors" assert len(mat2) == 3, "bmm only supports 3D tensors" @@ -993,19 +1080,26 @@ def bmm(self: List[int], mat2: List[int]) -> List[int]: assert self[2] == mat2[1], "mismatching contracting dimension" return [self[0], self[1], mat2[2]] + def _shape_as_tensor(self: List[int]) -> List[int]: return [len(self)] + def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]: if len(self) == 0: result: List[int] = [] else: - assert k <= self[dim], f"k ({k}) is too big for dimension {dim} of size {self[dim]}" + assert ( + k <= self[dim] + ), f"k ({k}) is too big for dimension {dim} of size {self[dim]}" result = _copy(self) result[dim] = k return result, result -def nll_loss_forward(self: List[int], target: List[int], weight: Optional[List[int]], reduction: int) -> Tuple[List[int], List[int]]: + +def nll_loss_forward( + self: List[int], target: List[int], weight: Optional[List[int]], reduction: int +) -> Tuple[List[int], List[int]]: # This is taken shamelessly from the meta function in LossNLL.cpp self_dim = len(self) target_dim = len(target) @@ -1022,7 +1116,10 @@ def nll_loss_forward(self: List[int], target: List[int], weight: Optional[List[i reduction_shape = scalar_shape return reduction_shape, scalar_shape -def native_layer_norm(input: List[int], normalized_shape: List[int]) -> Tuple[List[int], List[int], List[int]]: + +def native_layer_norm( + input: List[int], normalized_shape: List[int] +) -> Tuple[List[int], List[int], List[int]]: reduction_shape: List[int] = [] num_unreduced_dimensions = len(input) - len(normalized_shape) assert num_unreduced_dimensions >= 0 @@ -1032,17 +1129,34 @@ def native_layer_norm(input: List[int], normalized_shape: List[int]) -> Tuple[Li reduction_shape.append(1) return _copy(input), reduction_shape, reduction_shape -def native_batch_norm(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool) -> Tuple[List[int], List[int], List[int]]: + +def native_batch_norm( + input: List[int], + weight: Optional[List[int]], + bias: Optional[List[int]], + running_mean: Optional[List[int]], + running_var: Optional[List[int]], + training: bool, +) -> Tuple[List[int], List[int], List[int]]: if training: _size = [input[1]] else: _size = [0] return _copy(input), _size, _size -def cross_entropy_loss(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]: + +def cross_entropy_loss( + self: List[int], + target: List[int], + weight: Optional[List[int]] = None, + reduction: int = 1, + ignore_index: int = -100, + label_smoothing: float = 0.0, +) -> List[int]: result_shape = nll_loss_forward(self, target, weight, reduction)[0] return result_shape + """ Currently deferring the enabling of this, as part of the propoasal to suspend adding ops. @@ -1063,10 +1177,11 @@ def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[in """ ScriptFn = torch._C.ScriptFunction -shape_compute_graph_mapping : Dict[str, ScriptFn ] = {} -bounded_compute_graph_mapping : Dict[str, Tuple[ScriptFn, ScriptFn]] = {} +shape_compute_graph_mapping: Dict[str, ScriptFn] = {} +bounded_compute_graph_mapping: Dict[str, Tuple[ScriptFn, ScriptFn]] = {} script_func_map: Dict[Callable, ScriptFn] = {} + def process_func(func: Callable): if func not in script_func_map: scripted_func = torch.jit.script(func) @@ -1086,90 +1201,259 @@ def add_shape_compute_mapping(operator_schema: str, func: Callable): shape_compute_graph_mapping[operator_schema] = process_func(func) -def add_bounded_compute_mapping(operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable): + +def add_bounded_compute_mapping( + operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable +): # Adds a shape compute function for both upper and lower bounds fns = (process_func(lower_bound_func), process_func(upper_bound_func)) bounded_compute_graph_mapping[operator_schema] = fns -add_shape_compute_mapping("aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)", unary) -add_shape_compute_mapping("aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary) -add_shape_compute_mapping("aten::dropout(Tensor input, float p, bool train) -> Tensor", unary) -add_shape_compute_mapping("aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", adaptive_avg_pool2d) -add_shape_compute_mapping("prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor) + +add_shape_compute_mapping( + "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)", + unary, +) +add_shape_compute_mapping( + "aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary +) +add_shape_compute_mapping( + "aten::dropout(Tensor input, float p, bool train) -> Tensor", unary +) +add_shape_compute_mapping( + "aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", + adaptive_avg_pool2d, +) +add_shape_compute_mapping( + "prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor +) add_shape_compute_mapping("prim::NumToTensor.bool(bool a) -> Tensor", zero_dim_tensor) -add_shape_compute_mapping("aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", unary) -add_shape_compute_mapping("aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", unary) -add_shape_compute_mapping("aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", arange_end) -add_shape_compute_mapping("aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", arange_start) -add_shape_compute_mapping("aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", arange_start_step) +add_shape_compute_mapping( + "aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", + unary, +) +add_shape_compute_mapping( + "aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", + unary, +) +add_shape_compute_mapping( + "aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", + arange_end, +) +add_shape_compute_mapping( + "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", + arange_start, +) +add_shape_compute_mapping( + "aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", + arange_start_step, +) add_shape_compute_mapping("aten::squeeze(Tensor(a) self) -> Tensor(a)", squeeze_nodim) -add_shape_compute_mapping("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze) -add_shape_compute_mapping("aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", squeeze_dims) -add_shape_compute_mapping("aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze) -add_shape_compute_mapping("aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)", slice) -add_shape_compute_mapping("aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select) -add_shape_compute_mapping("aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select) -add_shape_compute_mapping("aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, " - "float eps=1e-05, bool cudnn_enable=True) -> Tensor", unary) -add_shape_compute_mapping("aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary) -add_shape_compute_mapping("aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", unary) -add_shape_compute_mapping("aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", unary) -add_shape_compute_mapping("aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", embedding) +add_shape_compute_mapping( + "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze +) +add_shape_compute_mapping( + "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", squeeze_dims +) +add_shape_compute_mapping( + "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze +) +add_shape_compute_mapping( + "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)", + slice, +) +add_shape_compute_mapping( + "aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select +) +add_shape_compute_mapping( + "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select +) +add_shape_compute_mapping( + "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, " + "float eps=1e-05, bool cudnn_enable=True) -> Tensor", + unary, +) +add_shape_compute_mapping( + "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary +) +add_shape_compute_mapping( + "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", + unary, +) +add_shape_compute_mapping( + "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", + unary, +) +add_shape_compute_mapping( + "aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", + embedding, +) add_shape_compute_mapping("aten::mm(Tensor self, Tensor mat2) -> Tensor", mm) add_shape_compute_mapping("aten::dot(Tensor self, Tensor tensor) -> Tensor", dot) add_shape_compute_mapping("aten::mv(Tensor self, Tensor vec) -> Tensor", mv) add_shape_compute_mapping("aten::matmul(Tensor self, Tensor other) -> Tensor", matmul) -add_shape_compute_mapping("aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear) -add_shape_compute_mapping("aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", max_pool2d) -add_shape_compute_mapping("aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", max_pool2d_with_indices) +add_shape_compute_mapping( + "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear +) +add_shape_compute_mapping( + "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + max_pool2d, +) +add_shape_compute_mapping( + "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", + max_pool2d_with_indices, +) add_shape_compute_mapping("aten::t(Tensor(a) self) -> Tensor(a)", t) -add_shape_compute_mapping("aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose) -add_shape_compute_mapping("aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor", conv1d) -add_shape_compute_mapping("aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", conv2d) -add_shape_compute_mapping("aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", batch_norm) -add_shape_compute_mapping("aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", conv3d) -add_shape_compute_mapping("aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", conv_backwards) -add_shape_compute_mapping("aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", conv_forwards) -add_shape_compute_mapping("aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", _conv_forwards) -add_shape_compute_mapping("aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", conv_transpose2d_input) -add_shape_compute_mapping("aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", flatten) +add_shape_compute_mapping( + "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose +) +add_shape_compute_mapping( + "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor", + conv1d, +) +add_shape_compute_mapping( + "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", + conv2d, +) +add_shape_compute_mapping( + "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", + batch_norm, +) +add_shape_compute_mapping( + "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", + conv3d, +) +add_shape_compute_mapping( + "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", + conv_backwards, +) +add_shape_compute_mapping( + "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", + conv_forwards, +) +add_shape_compute_mapping( + "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", + _conv_forwards, +) +add_shape_compute_mapping( + "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", + conv_transpose2d_input, +) +add_shape_compute_mapping( + "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", + flatten, +) add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat) add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack) -add_shape_compute_mapping("aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute) -add_shape_compute_mapping("aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", movedim) +add_shape_compute_mapping( + "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute +) +add_shape_compute_mapping( + "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", + movedim, +) add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view) -add_shape_compute_mapping("aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand) -add_shape_compute_mapping("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", expand_one_unused) -add_shape_compute_mapping("aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", sum_mean_dim) -add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", sum_mean_dim) -add_shape_compute_mapping("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", max_dim) -add_shape_compute_mapping("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor) -add_shape_compute_mapping("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor) -add_shape_compute_mapping("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", addmm) -add_shape_compute_mapping("aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)", upsample_nearest2d) -add_shape_compute_mapping("aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor", unary) -add_shape_compute_mapping("aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor", unary) +add_shape_compute_mapping( + "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand +) +add_shape_compute_mapping( + "aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", + expand_one_unused, +) +add_shape_compute_mapping( + "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", + sum_mean_dim, +) +add_shape_compute_mapping( + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", + sum_mean_dim, +) +add_shape_compute_mapping( + "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", + max_dim, +) +add_shape_compute_mapping( + "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor +) +add_shape_compute_mapping( + "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor +) +add_shape_compute_mapping( + "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", + addmm, +) +add_shape_compute_mapping( + "aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)", + upsample_nearest2d, +) +add_shape_compute_mapping( + "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor", + unary, +) +add_shape_compute_mapping( + "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor", + unary, +) add_shape_compute_mapping("aten::dequantize(Tensor self) -> Tensor", unary) -add_shape_compute_mapping("quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc", broadcast) -add_shape_compute_mapping("aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax) +add_shape_compute_mapping( + "quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc", + broadcast, +) +add_shape_compute_mapping( + "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax +) add_shape_compute_mapping("aten::bmm(Tensor self, Tensor mat2) -> Tensor", bmm) -add_shape_compute_mapping("aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor) -add_shape_compute_mapping("aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)", topk) -add_shape_compute_mapping("aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)", nll_loss_forward) -add_shape_compute_mapping("aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", native_layer_norm) -add_shape_compute_mapping("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm) -add_shape_compute_mapping("aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm) -add_shape_compute_mapping("aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm) -add_shape_compute_mapping("aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", cross_entropy_loss) +add_shape_compute_mapping( + "aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor +) +add_shape_compute_mapping( + "aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)", + topk, +) +add_shape_compute_mapping( + "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)", + nll_loss_forward, +) +add_shape_compute_mapping( + "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", + native_layer_norm, +) +add_shape_compute_mapping( + "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", + native_batch_norm, +) +add_shape_compute_mapping( + "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", + native_batch_norm, +) +add_shape_compute_mapping( + "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", + native_batch_norm, +) +add_shape_compute_mapping( + "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", + cross_entropy_loss, +) # add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor) # TODO: migrate over all of symbolic_shape_registry_util.cpp # These are duplicated here so that the functions will be serialiazed -add_shape_compute_mapping("aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", broadcast_three) -add_shape_compute_mapping("aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", broadcast_one_three) -add_shape_compute_mapping("aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", broadcast_inplace) +add_shape_compute_mapping( + "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", + broadcast_three, +) +add_shape_compute_mapping( + "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", + broadcast_one_three, +) +add_shape_compute_mapping( + "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", + broadcast_inplace, +) # quantized_conv_prepack TODO # Shape Compute Fn with upper and lower bounds -add_bounded_compute_mapping("aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound) +add_bounded_compute_mapping( + "aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound +) diff --git a/torch/jit/_state.py b/torch/jit/_state.py index 3980a1e74405..5bde05421fdd 100644 --- a/torch/jit/_state.py +++ b/torch/jit/_state.py @@ -5,10 +5,12 @@ This module stores various pieces of Python-global state relating to the JIT. This is not intended to be imported directly; please the exposed functionalities in `torch.jit`. """ -import torch import os import weakref +import torch + + class EnabledProxy: """Stores whether the JIT is enabled or not. @@ -94,6 +96,7 @@ def _clear_class_state(): _jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() _jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + def _try_get_jit_cached_overloads(key): qual_names = _jit_function_overload_caching.get(key, None) if qual_names: @@ -101,9 +104,11 @@ def _try_get_jit_cached_overloads(key): else: return None + def _set_jit_overload_cache(key, compiled_fns): _jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns] + def _try_get_jit_cached_function(key): if getattr(key, "__disable_jit_function_caching__", False) is True: return None @@ -113,6 +118,7 @@ def _try_get_jit_cached_function(key): else: return None + def _set_jit_function_cache(key, value): # only free functions currently supported assert isinstance(value, torch.jit.ScriptFunction) diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index f5e3635d8932..11e8de464aeb 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -7,21 +7,26 @@ This module contains functionality to support the JIT's tracing frontend, notabl This is not intended to be imported directly; please use the exposed functionalities in `torch.jit`. """ -import torch +import contextlib import copy -import os -import contextlib import functools -import warnings import inspect +import os import re +import warnings from typing import Any, Callable, Dict, List, Optional, Set -from torch.jit._state import _python_cu, _enabled -from torch.jit._script import ScriptModule, _CachedForward, script -from torch._jit_internal import _qualified_name, is_scripting, get_callable_argument_names +import torch +from torch._jit_internal import ( + _qualified_name, + get_callable_argument_names, + is_scripting, +) from torch.autograd import function +from torch.jit._script import _CachedForward, script, ScriptModule + +from torch.jit._state import _enabled, _python_cu from torch.nn import Module from torch.testing._comparison import default_tolerances @@ -105,7 +110,7 @@ class ONNXTracedModule(torch.nn.Module): in_args: List[torch.Tensor] = [] for i in range(len(in_vars)): if not isinstance(args[i], torch.Tensor): - raise RuntimeError('Expected Tensor argument') + raise RuntimeError("Expected Tensor argument") in_args.append(args[i]) trace_inputs = _unflatten(in_args, in_desc) @@ -321,7 +326,6 @@ def _check_trace( ): # Note: tracing is independent of optimizations, which consume the trace for inputs in check_inputs: - if isinstance(inputs, torch.Tensor): inputs = (inputs,) @@ -338,11 +342,15 @@ def _check_trace( _module_class=_module_class, _compilation_unit=torch._C.CompilationUnit(), example_inputs_is_kwarg=example_inputs_is_kwarg, - _store_inputs=False + _store_inputs=False, ) check_mod_func = check_mod._c._get_method(traced_func.name) inputs = inputs[traced_func.name] - if isinstance(inputs, (torch.Tensor)) or isinstance(inputs, dict) and not example_inputs_is_kwarg: + if ( + isinstance(inputs, (torch.Tensor)) + or isinstance(inputs, dict) + and not example_inputs_is_kwarg + ): inputs = (inputs,) else: if example_inputs_is_kwarg: @@ -353,7 +361,7 @@ def _check_trace( _force_outplace=force_outplace, _module_class=_module_class, example_kwarg_inputs=_clone_inputs(inputs), - _store_inputs=False + _store_inputs=False, ) else: check_mod = torch.jit.trace( @@ -363,7 +371,7 @@ def _check_trace( strict=strict, _force_outplace=force_outplace, _module_class=_module_class, - _store_inputs=False + _store_inputs=False, ) check_mod_func = check_mod @@ -433,7 +441,9 @@ def _check_trace( check_tensor_val = n_check.t("value") try: - torch.testing.assert_close(mod_tensor_val, check_tensor_val, equal_nan=True) + torch.testing.assert_close( + mod_tensor_val, check_tensor_val, equal_nan=True + ) except (RuntimeError, AssertionError) as e: if tensor_compare_errors is None: tensor_compare_errors = "" @@ -591,13 +601,9 @@ def make_module(mod, _module_class, _compilation_unit): if isinstance(mod, ScriptModule): return mod elif torch._jit_internal.module_has_exports(mod): - infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods return torch.jit._recursive.create_script_module( - mod, - infer_methods_stubs_fn, - share_types=False, - is_tracing=True + mod, infer_methods_stubs_fn, share_types=False, is_tracing=True ) else: if _module_class is None: @@ -624,7 +630,7 @@ def trace( _module_class=None, _compilation_unit=_python_cu, example_kwarg_inputs=None, - _store_inputs=True + _store_inputs=True, ): """ Trace a function and return an executable or :class:`ScriptFunction` @@ -785,7 +791,6 @@ def trace( ) return func - if isinstance(func, torch.nn.Module): if example_inputs is None: if isinstance(example_kwarg_inputs, dict): @@ -803,7 +808,7 @@ def trace( _force_outplace, _module_class, example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), - _store_inputs=_store_inputs + _store_inputs=_store_inputs, ) if ( hasattr(func, "__self__") @@ -826,11 +831,14 @@ def trace( _force_outplace, _module_class, example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), - _store_inputs=_store_inputs + _store_inputs=_store_inputs, ) # Special case for common case of passing a single Tensor - if isinstance(example_inputs, (torch.Tensor, dict)) and example_kwarg_inputs is None: + if ( + isinstance(example_inputs, (torch.Tensor, dict)) + and example_kwarg_inputs is None + ): example_inputs = (example_inputs,) # done primarily so that weird iterables fail here and not pybind11 code elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple): @@ -854,7 +862,7 @@ def trace( var_lookup_fn, strict, _force_outplace, - get_callable_argument_names(func) + get_callable_argument_names(func), ) else: traced = torch._C._create_function_from_trace( @@ -864,7 +872,7 @@ def trace( var_lookup_fn, strict, _force_outplace, - get_callable_argument_names(func) + get_callable_argument_names(func), ) # Check the trace against new traces created from user-specified inputs @@ -1039,9 +1047,11 @@ def trace_module( # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/ for key in example_inputs: if key not in argument_names: - valid_arguments = "[" + ','.join(argument_names) + "]" - raise NameError(f"""'{key}' is not in forward() method's arguments, - valid arguments name are {valid_arguments}""") + valid_arguments = "[" + ",".join(argument_names) + "]" + raise NameError( + f"""'{key}' is not in forward() method's arguments, + valid arguments name are {valid_arguments}""" + ) module._c._create_method_from_trace_with_dict( method_name, func, @@ -1050,7 +1060,7 @@ def trace_module( strict, _force_outplace, argument_names, - _store_inputs + _store_inputs, ) else: example_inputs = make_tuple(example_inputs) @@ -1062,7 +1072,7 @@ def trace_module( strict, _force_outplace, argument_names, - _store_inputs + _store_inputs, ) check_trace_method = module._c._get_method(method_name) @@ -1140,6 +1150,7 @@ class TracedModule(ScriptModule): "TracedModules don't support parameter sharing between modules" ) id_set.add(param) + tmp_module.training = orig.training for name, param in orig._parameters.items(): @@ -1229,8 +1240,15 @@ def _script_if_tracing(fn): return wrapper -def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False, - return_inputs=False, _return_inputs_states=False): +def _get_trace_graph( + f, + args=(), + kwargs=None, + strict=True, + _force_outplace=False, + return_inputs=False, + _return_inputs_states=False, +): """ .. warning:: This function is internal-only and should only be used by the ONNX @@ -1266,5 +1284,7 @@ def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False kwargs = {} if not isinstance(args, tuple): args = (args,) - outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) + outs = ONNXTracedModule( + f, strict, _force_outplace, return_inputs, _return_inputs_states + )(*args, **kwargs) return outs diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index e0734ba85ba9..ddb23be4e00d 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -1,31 +1,70 @@ import ast +import builtins import dis import enum import inspect import re -import builtins -import torch import warnings -from .._jit_internal import List, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \ - is_optional, _qualified_name, Any, Future, is_future, _Await, is_await, is_ignored_fn, Union, is_union -from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingList3 # type: ignore[attr-defined] -from ._state import _get_script_class - -from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \ - ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, \ - NoneType, DeviceObjType, StreamObjType, FutureType, AwaitType, EnumType, UnionType, NumberType - from textwrap import dedent -from torch._sources import get_source_lines_and_file from typing import Type +import torch + +from torch._C import ( + AnyType, + AwaitType, + BoolType, + ComplexType, + DeviceObjType, + DictType, + EnumType, + FloatType, + FutureType, + InterfaceType, + IntType, + ListType, + NoneType, + NumberType, + OptionalType, + StreamObjType, + StringType, + TensorType, + TupleType, + UnionType, +) +from torch._sources import get_source_lines_and_file +from .._jit_internal import ( # type: ignore[attr-defined] + _Await, + _qualified_name, + Any, + BroadcastingList1, + BroadcastingList2, + BroadcastingList3, + Dict, + Future, + is_await, + is_dict, + is_future, + is_ignored_fn, + is_list, + is_optional, + is_tuple, + is_union, + List, + Optional, + Tuple, + Union, +) +from ._state import _get_script_class + if torch.distributed.rpc.is_available(): - from .._jit_internal import RRef, is_rref from torch._C import RRefType + from .._jit_internal import is_rref, RRef from torch._ops import OpOverloadPacket + class Module: def __init__(self, name, members): self.name = name @@ -35,27 +74,29 @@ class Module: try: return self.members[name] except KeyError: - raise RuntimeError(f"Module {self.name} has no member called {name}") from None + raise RuntimeError( + f"Module {self.name} has no member called {name}" + ) from None class EvalEnv: env = { - 'torch': Module('torch', {'Tensor': torch.Tensor}), - 'Tensor': torch.Tensor, - 'typing': Module('typing', {'Tuple': Tuple}), - 'Tuple': Tuple, - 'List': List, - 'Dict': Dict, - 'Optional': Optional, - 'Union': Union, - 'Future': Future, - 'Await': _Await + "torch": Module("torch", {"Tensor": torch.Tensor}), + "Tensor": torch.Tensor, + "typing": Module("typing", {"Tuple": Tuple}), + "Tuple": Tuple, + "List": List, + "Dict": Dict, + "Optional": Optional, + "Union": Union, + "Future": Future, + "Await": _Await, } def __init__(self, rcb): self.rcb = rcb if torch.distributed.rpc.is_available(): - self.env['RRef'] = RRef + self.env["RRef"] = RRef def __getitem__(self, name): if name in self.env: @@ -64,6 +105,7 @@ class EvalEnv: return self.rcb(name) return getattr(builtins, name, None) + def get_signature(fn, rcb, loc, is_method): if isinstance(fn, OpOverloadPacket): signature = try_real_annotations(fn.op, loc) @@ -81,7 +123,7 @@ def get_signature(fn, rcb, loc, is_method): if signature is None: type_line, source = None, None try: - source = dedent(''.join(get_source_lines_and_file(fn)[0])) + source = dedent("".join(get_source_lines_and_file(fn)[0])) type_line = get_type_line(source) except TypeError: pass @@ -100,7 +142,7 @@ def is_function_or_method(the_callable): def is_vararg(the_callable): - if not is_function_or_method(the_callable) and hasattr(the_callable, '__call__'): # noqa: B004 + if not is_function_or_method(the_callable) and callable(the_callable): # noqa: B004 # If `the_callable` is a class, de-sugar the call so we can still get # the signature the_callable = the_callable.__call__ @@ -115,7 +157,11 @@ def get_param_names(fn, n_args): if isinstance(fn, OpOverloadPacket): fn = fn.op - if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__): # noqa: B004 + if ( + not is_function_or_method(fn) + and callable(fn) + and is_function_or_method(fn.__call__) + ): # noqa: B004 # De-sugar calls to classes fn = fn.__call__ @@ -132,7 +178,7 @@ def get_param_names(fn, n_args): def check_fn(fn, loc): # Make sure the function definition is not a class instantiation try: - source = dedent(''.join(get_source_lines_and_file(fn)[0])) + source = dedent("".join(get_source_lines_and_file(fn)[0])) except (OSError, TypeError): return if source is None: @@ -141,9 +187,13 @@ def check_fn(fn, loc): py_ast = ast.parse(source) if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef): raise torch.jit.frontend.FrontendError( - loc, f"Cannot instantiate class '{py_ast.body[0].name}' in a script function") + loc, + f"Cannot instantiate class '{py_ast.body[0].name}' in a script function", + ) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): - raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function") + raise torch.jit.frontend.FrontendError( + loc, "Expected a single top-level function" + ) def _eval_no_call(stmt, glob, loc): @@ -151,7 +201,9 @@ def _eval_no_call(stmt, glob, loc): bytecode = compile(stmt, "", mode="eval") for insn in dis.get_instructions(bytecode): if "CALL" in insn.opname: - raise RuntimeError(f"Type annotation should not contain calls, but '{stmt}' does") + raise RuntimeError( + f"Type annotation should not contain calls, but '{stmt}' does" + ) return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204 @@ -167,7 +219,9 @@ def parse_type_line(type_line, rcb, loc): try: arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb)) except (NameError, SyntaxError) as e: - raise RuntimeError("Failed to parse the argument list of a type annotation") from e + raise RuntimeError( + "Failed to parse the argument list of a type annotation" + ) from e if not isinstance(arg_ann, tuple): arg_ann = (arg_ann,) @@ -175,7 +229,9 @@ def parse_type_line(type_line, rcb, loc): try: ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb)) except (NameError, SyntaxError) as e: - raise RuntimeError("Failed to parse the return type of a type annotation") from e + raise RuntimeError( + "Failed to parse the return type of a type annotation" + ) from e arg_types = [ann_to_type(ann, loc) for ann in arg_ann] return arg_types, ann_to_type(ret_ann, loc) @@ -183,9 +239,9 @@ def parse_type_line(type_line, rcb, loc): def get_type_line(source): """Tries to find the line containing a comment with the type annotation.""" - type_comment = '# type:' + type_comment = "# type:" - lines = source.split('\n') + lines = source.split("\n") lines = [(line_num, line) for line_num, line in enumerate(lines)] type_lines = list(filter(lambda line: type_comment in line[1], lines)) # `type: ignore` comments may be needed in JIT'ed functions for mypy, due @@ -199,18 +255,22 @@ def get_type_line(source): # adding an extra backslash before the space, to avoid triggering # one of the checks in .github/workflows/lint.yml type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$") - type_lines = list(filter(lambda line: not type_pattern.search(line[1]), - type_lines)) + type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines)) if len(type_lines) == 0: # Catch common typo patterns like extra spaces, typo in 'ignore', etc. wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):") - wrong_type_lines = list(filter(lambda line: wrong_type_pattern.search(line[1]), lines)) + wrong_type_lines = list( + filter(lambda line: wrong_type_pattern.search(line[1]), lines) + ) if len(wrong_type_lines) > 0: - raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0]) - + " is probably invalid.\nIt must be '# type:'" - + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950 - + "\nfor examples") + raise RuntimeError( + "The annotation prefix in line " + + str(wrong_type_lines[0][0]) + + " is probably invalid.\nIt must be '# type:'" + + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950 + + "\nfor examples" + ) return None elif len(type_lines) == 1: # Only 1 type line, quit now @@ -221,7 +281,7 @@ def get_type_line(source): return_line = None parameter_type_lines = [] for line_num, line in type_lines: - if '# type: (...) -> ' in line: + if "# type: (...) -> " in line: return_line = (line_num, line) break elif type_comment in line: @@ -229,12 +289,13 @@ def get_type_line(source): if return_line is None: raise RuntimeError( "Return type line '# type: (...) -> ...' not found on multiline " - "type annotation\nfor type lines:\n" + - '\n'.join([line[1] for line in type_lines]) + - "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)") + "type annotation\nfor type lines:\n" + + "\n".join([line[1] for line in type_lines]) + + "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" + ) def get_parameter_type(line): - item_type = line[line.find(type_comment) + len(type_comment):] + item_type = line[line.find(type_comment) + len(type_comment) :] return item_type.strip() types = map(get_parameter_type, parameter_type_lines) @@ -253,12 +314,14 @@ def split_type_line(type_line): ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]") """ - start_offset = len('# type:') + start_offset = len("# type:") try: - arrow_pos = type_line.index('->') + arrow_pos = type_line.index("->") except ValueError: - raise RuntimeError("Syntax error in type annotation (cound't find `->`)") from None - return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip() + raise RuntimeError( + "Syntax error in type annotation (cound't find `->`)" + ) from None + return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip() def try_real_annotations(fn, loc): @@ -271,12 +334,13 @@ def try_real_annotations(fn, loc): except ValueError: return None - all_annots = [sig.return_annotation] + [p.annotation for p in sig.parameters.values()] + all_annots = [sig.return_annotation] + [ + p.annotation for p in sig.parameters.values() + ] if all(ann is sig.empty for ann in all_annots): return None - arg_types = [ann_to_type(p.annotation, loc) - for p in sig.parameters.values()] + arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()] return_type = ann_to_type(sig.return_annotation, loc) return arg_types, return_type @@ -300,16 +364,30 @@ def get_enum_value_type(e: Type[enum.Enum], loc): return AnyType.get() return res + def is_tensor(ann): if issubclass(ann, torch.Tensor): return True - if issubclass(ann, (torch.LongTensor, torch.DoubleTensor, torch.FloatTensor, - torch.IntTensor, torch.ShortTensor, torch.HalfTensor, - torch.CharTensor, torch.ByteTensor, torch.BoolTensor)): - warnings.warn("TorchScript will treat type annotations of Tensor " - "dtype-specific subtypes as if they are normal Tensors. " - "dtype constraints are not enforced in compilation either.") + if issubclass( + ann, + ( + torch.LongTensor, + torch.DoubleTensor, + torch.FloatTensor, + torch.IntTensor, + torch.ShortTensor, + torch.HalfTensor, + torch.CharTensor, + torch.ByteTensor, + torch.BoolTensor, + ), + ): + warnings.warn( + "TorchScript will treat type annotations of Tensor " + "dtype-specific subtypes as if they are normal Tensors. " + "dtype constraints are not enforced in compilation either." + ) return True return False @@ -340,9 +418,13 @@ def try_ann_to_type(ann, loc, rcb=None): value = try_ann_to_type(ann.__args__[1], loc) # Raise error if key or value is None if key is None: - raise ValueError(f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}") + raise ValueError( + f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}" + ) if value is None: - raise ValueError(f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}") + raise ValueError( + f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}" + ) return DictType(key, value) if is_optional(ann): if issubclass(ann.__args__[1], type(None)): @@ -368,13 +450,17 @@ def try_ann_to_type(ann, loc, rcb=None): msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc)) inner.append(maybe_type) - return UnionType(inner) # type: ignore[arg-type] + return UnionType(inner) # type: ignore[arg-type] if torch.distributed.rpc.is_available() and is_rref(ann): return RRefType(try_ann_to_type(ann.__args__[0], loc)) if is_future(ann): return FutureType(try_ann_to_type(ann.__args__[0], loc)) if is_await(ann): - elementType = try_ann_to_type(ann.__args__[0], loc) if hasattr(ann, "__args__") else AnyType.get() + elementType = ( + try_ann_to_type(ann.__args__[0], loc) + if hasattr(ann, "__args__") + else AnyType.get() + ) return AwaitType(elementType) if ann is float: return FloatType.get() @@ -426,37 +512,37 @@ def ann_to_type(ann, loc, rcb=None): __all__ = [ - 'Any', - 'List', - 'BroadcastingList1', - 'BroadcastingList2', - 'BroadcastingList3', - 'Tuple', - 'is_tuple', - 'is_list', - 'Dict', - 'is_dict', - 'is_optional', - 'is_union', - 'TensorType', - 'TupleType', - 'FloatType', - 'ComplexType', - 'IntType', - 'ListType', - 'StringType', - 'DictType', - 'AnyType', - 'Module', + "Any", + "List", + "BroadcastingList1", + "BroadcastingList2", + "BroadcastingList3", + "Tuple", + "is_tuple", + "is_list", + "Dict", + "is_dict", + "is_optional", + "is_union", + "TensorType", + "TupleType", + "FloatType", + "ComplexType", + "IntType", + "ListType", + "StringType", + "DictType", + "AnyType", + "Module", # TODO: Consider not exporting these during wildcard import (reserve # that for the types; for idiomatic typing code.) - 'get_signature', - 'check_fn', - 'get_param_names', - 'parse_type_line', - 'get_type_line', - 'split_type_line', - 'try_real_annotations', - 'try_ann_to_type', - 'ann_to_type', + "get_signature", + "check_fn", + "get_param_names", + "parse_type_line", + "get_type_line", + "split_type_line", + "try_real_annotations", + "try_ann_to_type", + "ann_to_type", ] diff --git a/torch/jit/generate_bytecode.py b/torch/jit/generate_bytecode.py index b838f3ccd5a3..8e56c7665d1c 100644 --- a/torch/jit/generate_bytecode.py +++ b/torch/jit/generate_bytecode.py @@ -1,6 +1,8 @@ -from torch._C import _compile_graph_to_code_table, _generate_upgraders_graph from typing import List +from torch._C import _compile_graph_to_code_table, _generate_upgraders_graph + + def format_bytecode(table): # given a nested tuple, convert it to nested list def listify(content): @@ -16,6 +18,7 @@ def format_bytecode(table): formatted_table[identifier] = content return formatted_table + def generate_upgraders_bytecode() -> List: yaml_content = [] upgraders_graph_map = _generate_upgraders_graph() @@ -25,5 +28,6 @@ def generate_upgraders_bytecode() -> List: yaml_content.append(entry) return yaml_content + if __name__ == "__main__": raise RuntimeError("This file is not meant to be run directly") diff --git a/torch/jit/mobile/__init__.py b/torch/jit/mobile/__init__.py index f58dfd04d59f..b6824183aa8a 100644 --- a/torch/jit/mobile/__init__.py +++ b/torch/jit/mobile/__init__.py @@ -1,9 +1,11 @@ +import os + +import pathlib + import torch from torch.jit._serialization import validate_map_location -import pathlib -import os def _load_for_lite_interpreter(f, map_location=None): r""" @@ -47,10 +49,13 @@ def _load_for_lite_interpreter(f, map_location=None): if isinstance(f, (str, pathlib.Path)): cpp_module = torch._C._load_for_lite_interpreter(f, map_location) else: - cpp_module = torch._C._load_for_lite_interpreter_from_buffer(f.read(), map_location) + cpp_module = torch._C._load_for_lite_interpreter_from_buffer( + f.read(), map_location + ) return LiteScriptModule(cpp_module) + class LiteScriptModule: def __init__(self, cpp_module): self._c = cpp_module @@ -68,13 +73,15 @@ class LiteScriptModule: def run_method(self, method_name, *input): return self._c.run_method(method_name, input) + def _export_operator_list(module: LiteScriptModule): r""" - return a set of root operator names (with overload name) that are used by any method - in this mobile module. + return a set of root operator names (with overload name) that are used by any method + in this mobile module. """ return torch._C._export_operator_list(module._c) + def _get_model_bytecode_version(f_input) -> int: r""" Args: @@ -101,11 +108,12 @@ def _get_model_bytecode_version(f_input) -> int: if os.path.isdir(f_input): raise ValueError(f"The provided filename {f_input} is a directory") - if (isinstance(f_input, (str, pathlib.Path))): + if isinstance(f_input, (str, pathlib.Path)): return torch._C._get_model_bytecode_version(str(f_input)) else: return torch._C._get_model_bytecode_version_from_buffer(f_input.read()) + def _get_mobile_model_contained_types(f_input) -> int: r""" Args: @@ -131,11 +139,12 @@ def _get_mobile_model_contained_types(f_input) -> int: if os.path.isdir(f_input): raise ValueError(f"The provided filename {f_input} is a directory") - if (isinstance(f_input, (str, pathlib.Path))): + if isinstance(f_input, (str, pathlib.Path)): return torch._C._get_mobile_model_contained_types(str(f_input)) else: return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read()) + def _backport_for_mobile(f_input, f_output, to_version): r""" Args: @@ -152,11 +161,15 @@ def _backport_for_mobile(f_input, f_output, to_version): if os.path.isdir(f_input): raise ValueError(f"The provided filename {f_input} is a directory") - if ((isinstance(f_input, (str, pathlib.Path))) and ( - isinstance(f_output, (str, pathlib.Path)))): + if (isinstance(f_input, (str, pathlib.Path))) and ( + isinstance(f_output, (str, pathlib.Path)) + ): return torch._C._backport_for_mobile(str(f_input), str(f_output), to_version) else: - return torch._C._backport_for_mobile_from_buffer(f_input.read(), str(f_output), to_version) + return torch._C._backport_for_mobile_from_buffer( + f_input.read(), str(f_output), to_version + ) + def _backport_for_mobile_to_buffer(f_input, to_version): r""" @@ -171,10 +184,13 @@ def _backport_for_mobile_to_buffer(f_input, to_version): if os.path.isdir(f_input): raise ValueError(f"The provided filename {f_input} is a directory") - if (isinstance(f_input, (str, pathlib.Path))): + if isinstance(f_input, (str, pathlib.Path)): return torch._C._backport_for_mobile_to_buffer(str(f_input), to_version) else: - return torch._C._backport_for_mobile_from_buffer_to_buffer(f_input.read(), to_version) + return torch._C._backport_for_mobile_from_buffer_to_buffer( + f_input.read(), to_version + ) + def _get_model_ops_and_info(f_input): r""" @@ -211,7 +227,7 @@ def _get_model_ops_and_info(f_input): if os.path.isdir(f_input): raise ValueError(f"The provided filename {f_input} is a directory") - if (isinstance(f_input, (str, pathlib.Path))): + if isinstance(f_input, (str, pathlib.Path)): return torch._C._get_model_ops_and_info(str(f_input)) else: return torch._C._get_model_ops_and_info(f_input.read()) diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py index 5a74401a9c6f..63de5c5bb463 100644 --- a/torch/jit/quantized.py +++ b/torch/jit/quantized.py @@ -1,105 +1,144 @@ -from torch import Tensor, _VF # noqa: F401 -from torch.nn.utils.rnn import PackedSequence -import torch - import warnings from typing import List, Optional, Tuple +import torch +from torch import _VF, Tensor # noqa: F401 +from torch.nn.utils.rnn import PackedSequence + class QuantizedLinear(torch.jit.ScriptModule): - __constants__ = ['scale', 'zero_point'] + __constants__ = ["scale", "zero_point"] def __init__(self, other): super().__init__() warnings.warn( "torch.jit.QuantizedLinear is deprecated and will be removed in an upcoming " - "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead.") + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead." + ) self.in_features = other.in_features self.out_features = other.out_features # Quantize weight and discard the original - self.weight, self.col_offsets, self.scale, self.zero_point = torch.fbgemm_linear_quantize_weight( - other.weight.clone(memory_format=torch.contiguous_format).float()) + ( + self.weight, + self.col_offsets, + self.scale, + self.zero_point, + ) = torch.fbgemm_linear_quantize_weight( + other.weight.clone(memory_format=torch.contiguous_format).float() + ) self.weight = torch.nn.Parameter(self.weight, requires_grad=False) self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False) - assert other.bias is not None, 'QuantizedLinear requires a bias' - self.bias = torch.nn.Parameter(other.bias.clone(memory_format=torch.contiguous_format).float(), requires_grad=False) + assert other.bias is not None, "QuantizedLinear requires a bias" + self.bias = torch.nn.Parameter( + other.bias.clone(memory_format=torch.contiguous_format).float(), + requires_grad=False, + ) self.register_buffer( - 'packed_tensor_ptr', - torch.fbgemm_pack_quantized_matrix(self.weight.clone(memory_format=torch.contiguous_format))) + "packed_tensor_ptr", + torch.fbgemm_pack_quantized_matrix( + self.weight.clone(memory_format=torch.contiguous_format) + ), + ) @torch.jit.script_method def _unpack(self): - self.packed_tensor_ptr.set_( - torch.fbgemm_pack_quantized_matrix(self.weight)) + self.packed_tensor_ptr.set_(torch.fbgemm_pack_quantized_matrix(self.weight)) @torch.jit.script_method def _pack(self): self.packed_tensor_ptr.set_( - torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()) + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() + ) @torch.jit.script_method def forward(self, input): out = torch.fbgemm_linear_int8_weight_fp32_activation( - input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets, - self.scale, self.zero_point, self.bias) + input.float(), + self.weight, + self.packed_tensor_ptr, + self.col_offsets, + self.scale, + self.zero_point, + self.bias, + ) return out.to(input.dtype) def extra_repr(self): - repr = 'in_features={in_features}, out_features={out_features}, ' \ - 'scale={scale}, zero_point={zero_point}'.format(**self.__dict__) + repr = ( + "in_features={in_features}, out_features={out_features}, " + "scale={scale}, zero_point={zero_point}".format(**self.__dict__) + ) return repr + # FP16 weights class QuantizedLinearFP16(torch.jit.ScriptModule): - def __init__(self, other): super().__init__() warnings.warn( "torch.jit.QuantizedLinearFP16 is deprecated and will be removed in an upcoming " - "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead.") + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead." + ) self.in_features = other.in_features self.out_features = other.out_features self.original_weight = other.weight self.weight = torch.fbgemm_pack_gemm_matrix_fp16( - other.weight.clone(memory_format=torch.contiguous_format).float()) - assert other.bias is not None, 'QuantizedLinearFP16 requires a bias' - self.bias = torch.nn.Parameter(other.bias.clone(memory_format=torch.contiguous_format).float(), requires_grad=False) - self.register_buffer('packed_weight', self.weight) + other.weight.clone(memory_format=torch.contiguous_format).float() + ) + assert other.bias is not None, "QuantizedLinearFP16 requires a bias" + self.bias = torch.nn.Parameter( + other.bias.clone(memory_format=torch.contiguous_format).float(), + requires_grad=False, + ) + self.register_buffer("packed_weight", self.weight) @torch.jit.script_method def _unpack(self): self.packed_weight.set_( - torch.fbgemm_pack_gemm_matrix_fp16( - self.original_weight)) + torch.fbgemm_pack_gemm_matrix_fp16(self.original_weight) + ) @torch.jit.script_method def _pack(self): self.packed_weight.set_( - torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()) + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() + ) @torch.jit.script_method def forward(self, input): out = torch.fbgemm_linear_fp16_weight_fp32_activation( - input.float(), self.packed_weight, self.bias) + input.float(), self.packed_weight, self.bias + ) return out def extra_repr(self): - repr = 'in_features={in_features}, out_features={out_features}, '.format(**self.__dict__) + repr = "in_features={in_features}, out_features={out_features}, ".format( + **self.__dict__ + ) return repr + # Quantized RNN cell implementations class QuantizedRNNCellBase(torch.jit.ScriptModule): - __constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih', - 'zero_point_ih', 'zero_point_hh'] + __constants__ = [ + "input_size", + "hidden_size", + "bias", + "scale_hh", + "scale_ih", + "zero_point_ih", + "zero_point_hh", + ] def __init__(self, other): super().__init__() warnings.warn( "torch.jit.QuantizedRNNCellBase is deprecated and will be removed in an upcoming " - "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead.") + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead." + ) self.input_size = other.input_size self.hidden_size = other.hidden_size @@ -107,46 +146,69 @@ class QuantizedRNNCellBase(torch.jit.ScriptModule): if not self.bias: raise ValueError("Quantized RNN cells require bias terms") - weight_ih, col_offsets_ih, self.scale_ih, self.zero_point_ih = \ - torch.fbgemm_linear_quantize_weight(other.weight_ih.clone(memory_format=torch.contiguous_format).float()) - self.register_buffer('weight_ih', weight_ih) - self.register_buffer('col_offsets_ih', col_offsets_ih) - weight_hh, col_offsets_hh, self.scale_hh, self.zero_point_hh = \ - torch.fbgemm_linear_quantize_weight(other.weight_hh.clone(memory_format=torch.contiguous_format).float()) - self.register_buffer('weight_hh', weight_hh) - self.register_buffer('col_offsets_hh', col_offsets_hh) + ( + weight_ih, + col_offsets_ih, + self.scale_ih, + self.zero_point_ih, + ) = torch.fbgemm_linear_quantize_weight( + other.weight_ih.clone(memory_format=torch.contiguous_format).float() + ) + self.register_buffer("weight_ih", weight_ih) + self.register_buffer("col_offsets_ih", col_offsets_ih) + ( + weight_hh, + col_offsets_hh, + self.scale_hh, + self.zero_point_hh, + ) = torch.fbgemm_linear_quantize_weight( + other.weight_hh.clone(memory_format=torch.contiguous_format).float() + ) + self.register_buffer("weight_hh", weight_hh) + self.register_buffer("col_offsets_hh", col_offsets_hh) packed_ih = torch.fbgemm_pack_quantized_matrix(self.weight_ih) - self.register_buffer('packed_ih', packed_ih) + self.register_buffer("packed_ih", packed_ih) packed_hh = torch.fbgemm_pack_quantized_matrix(self.weight_hh) - self.register_buffer('packed_hh', packed_hh) + self.register_buffer("packed_hh", packed_hh) - self.bias_ih = torch.nn.Parameter(other.bias_ih.clone(memory_format=torch.contiguous_format).float(), requires_grad=False) - self.bias_hh = torch.nn.Parameter(other.bias_hh.clone(memory_format=torch.contiguous_format).float(), requires_grad=False) + self.bias_ih = torch.nn.Parameter( + other.bias_ih.clone(memory_format=torch.contiguous_format).float(), + requires_grad=False, + ) + self.bias_hh = torch.nn.Parameter( + other.bias_hh.clone(memory_format=torch.contiguous_format).float(), + requires_grad=False, + ) def extra_repr(self): - s = '{input_size}, {hidden_size}' - if 'bias' in self.__dict__ and self.bias is not True: - s += ', bias={bias}' - if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh": - s += ', nonlinearity={nonlinearity}' + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" return s.format(**self.__dict__) @torch.jit.script_method def check_forward_input(self, input): if input.size(1) != self.input_size: raise RuntimeError( - f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}") + f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}" + ) @torch.jit.script_method - def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None: + def check_forward_hidden( + self, input: Tensor, hx: Tensor, hidden_label: str = "" + ) -> None: if input.size(0) != hx.size(0): raise RuntimeError( - f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}") + f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}" + ) if hx.size(1) != self.hidden_size: raise RuntimeError( - f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}") + f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}" + ) # TODO: for some reason weak_script_method causes a destruction of the # module to occur, which in turn frees the packed_ih object via its DataPtr @@ -161,46 +223,78 @@ class QuantizedRNNCellBase(torch.jit.ScriptModule): @torch.jit.script_method def _pack(self): self.packed_ih.set_( - torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()) + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() + ) self.packed_hh.set_( - torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()) + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() + ) class QuantizedRNNCell(QuantizedRNNCellBase): - __constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih', - 'zero_point_ih', 'zero_point_hh', 'nonlinearity'] + __constants__ = [ + "input_size", + "hidden_size", + "bias", + "scale_hh", + "scale_ih", + "zero_point_ih", + "zero_point_hh", + "nonlinearity", + ] def __init__(self, other): super().__init__(other) warnings.warn( "torch.jit.QuantizedRNNCell is deprecated and will be removed in an upcoming " - "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead.") + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead." + ) self.nonlinearity = other.nonlinearity @torch.jit.script_method def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: self.check_forward_input(input) if hx is None: - hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) - self.check_forward_hidden(input, hx, '') + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + self.check_forward_hidden(input, hx, "") if self.nonlinearity == "tanh": ret = _VF.quantized_rnn_tanh_cell( - input, hx, self.weight_ih, self.weight_hh, self.bias_ih, - self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih, - self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih, - self.zero_point_hh + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + self.packed_ih, + self.packed_hh, + self.col_offsets_ih, + self.col_offsets_hh, + self.scale_ih, + self.scale_hh, + self.zero_point_ih, + self.zero_point_hh, ) elif self.nonlinearity == "relu": ret = _VF.quantized_rnn_relu_cell( - input, hx, self.weight_ih, self.weight_hh, self.bias_ih, - self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih, - self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih, - self.zero_point_hh + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + self.packed_ih, + self.packed_hh, + self.col_offsets_ih, + self.col_offsets_hh, + self.scale_ih, + self.scale_hh, + self.zero_point_ih, + self.zero_point_hh, ) else: ret = input # TODO: remove when jit supports exception flow - raise RuntimeError( - f"Unknown nonlinearity: {self.nonlinearity}") + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") return ret @@ -209,21 +303,36 @@ class QuantizedLSTMCell(QuantizedRNNCellBase): super().__init__(other) warnings.warn( "torch.jit.QuantizedLSTMCell is deprecated and will be removed in an upcoming " - "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead.") + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead." + ) @torch.jit.script_method - def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[Tensor, Tensor]: self.check_forward_input(input) if hx is None: - zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) hx = (zeros, zeros) - self.check_forward_hidden(input, hx[0], '[0]') - self.check_forward_hidden(input, hx[1], '[1]') + self.check_forward_hidden(input, hx[0], "[0]") + self.check_forward_hidden(input, hx[1], "[1]") return _VF.quantized_lstm_cell( - input, hx, self.weight_ih, self.weight_hh, self.bias_ih, - self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih, - self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih, - self.zero_point_hh + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + self.packed_ih, + self.packed_hh, + self.col_offsets_ih, + self.col_offsets_hh, + self.scale_ih, + self.scale_hh, + self.zero_point_ih, + self.zero_point_hh, ) @@ -232,19 +341,32 @@ class QuantizedGRUCell(QuantizedRNNCellBase): super().__init__(other) warnings.warn( "torch.jit.QuantizedGRUCell is deprecated and will be removed in an upcoming " - "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRUCell instead.") + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRUCell instead." + ) @torch.jit.script_method def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: self.check_forward_input(input) if hx is None: - hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) - self.check_forward_hidden(input, hx, '') + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + self.check_forward_hidden(input, hx, "") return _VF.quantized_gru_cell( - input, hx, self.weight_ih, self.weight_hh, self.bias_ih, - self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih, - self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih, - self.zero_point_hh + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + self.packed_ih, + self.packed_hh, + self.col_offsets_ih, + self.col_offsets_hh, + self.scale_ih, + self.scale_hh, + self.zero_point_ih, + self.zero_point_hh, ) @@ -253,21 +375,31 @@ def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tens class QuantizedRNNBase(torch.jit.ScriptModule): - __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias', - 'batch_first', 'dropout', 'bidirectional', 'dtype'] + __constants__ = [ + "mode", + "input_size", + "hidden_size", + "num_layers", + "bias", + "batch_first", + "dropout", + "bidirectional", + "dtype", + ] def __init__(self, other, dtype=torch.int8): super().__init__() warnings.warn( "torch.jit.QuantizedRNNBase is deprecated and will be removed in an upcoming " - "PyTorch release. Please use the torch.ao.nn.quantized.dynamic instead.") + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic instead." + ) self.mode = other.mode self.input_size = other.input_size self.hidden_size = other.hidden_size self.num_layers = other.num_layers self.bias = other.bias self.batch_first = other.batch_first - if self.mode != 'GRU': + if self.mode != "GRU": assert not self.batch_first self.dropout = other.dropout self.bidirectional = other.bidirectional @@ -277,43 +409,49 @@ class QuantizedRNNBase(torch.jit.ScriptModule): assert self.bias # TODO: support more than just LSTM - if self.mode != 'LSTM' and self.mode != 'GRU': - raise RuntimeError('Only LSTM or GRU is supported for QuantizedRNN') + if self.mode != "LSTM" and self.mode != "GRU": + raise RuntimeError("Only LSTM or GRU is supported for QuantizedRNN") if dtype != torch.int8 and dtype != torch.float16: - raise RuntimeError(f'Unsupported dtype: {dtype}') + raise RuntimeError(f"Unsupported dtype: {dtype}") self.all_weights = [] for layer in range(self.num_layers): for direction in range(num_directions): - layer_input_size = self.input_size if layer == 0 else self.hidden_size * num_directions + layer_input_size = ( + self.input_size if layer == 0 else self.hidden_size * num_directions + ) - suffix = '_reverse' if direction == 1 else '' + suffix = "_reverse" if direction == 1 else "" def get_weight_bias(ihhh): - weight_name = f'weight_{ihhh}_l{layer}{suffix}' - bias_name = f'bias_{ihhh}_l{layer}{suffix}' + weight_name = f"weight_{ihhh}_l{layer}{suffix}" + bias_name = f"bias_{ihhh}_l{layer}{suffix}" weight = getattr(other, weight_name) bias = getattr(other, bias_name) return weight, bias - weight_ih, bias_ih = get_weight_bias('ih') - weight_hh, bias_hh = get_weight_bias('hh') + weight_ih, bias_ih = get_weight_bias("ih") + weight_hh, bias_hh = get_weight_bias("hh") if dtype == torch.int8: cell_params = torch.ops.quantized.make_quantized_cell_params( - weight_ih, weight_hh, bias_ih, bias_hh) + weight_ih, weight_hh, bias_ih, bias_hh + ) else: packed_ih = torch.ops.quantized.linear_prepack_fp16( - weight_ih.float(), bias_ih) + weight_ih.float(), bias_ih + ) packed_hh = torch.ops.quantized.linear_prepack_fp16( - weight_hh.float(), bias_hh) + weight_hh.float(), bias_hh + ) cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( - packed_ih, packed_hh) + packed_ih, packed_hh + ) - setattr(self, f'cell_params_{layer}_{suffix}', cell_params) + setattr(self, f"cell_params_{layer}_{suffix}", cell_params) self.all_weights.append(cell_params) @torch.jit.script_method @@ -321,33 +459,48 @@ class QuantizedRNNBase(torch.jit.ScriptModule): expected_input_dim = 2 if batch_sizes is not None else 3 if input.dim() != expected_input_dim: raise RuntimeError( - f'input must have {expected_input_dim} dimensions, got {input.dim()}') + f"input must have {expected_input_dim} dimensions, got {input.dim()}" + ) if self.input_size != input.size(-1): raise RuntimeError( - f'input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}') + f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" + ) @torch.jit.script_method - def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: + def get_expected_hidden_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> Tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) else: mini_batch = input.size(0) if self.batch_first else input.size(1) num_directions = 2 if self.bidirectional else 1 - expected_hidden_size = (self.num_layers * num_directions, - mini_batch, self.hidden_size) + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) return expected_hidden_size @torch.jit.script_method - def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int], - msg: str = 'Expected hidden size {}, got {}') -> None: + def check_hidden_size( + self, + hx: Tensor, + expected_hidden_size: Tuple[int, int, int], + msg: str = "Expected hidden size {}, got {}", + ) -> None: if hx.size() != expected_hidden_size: raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) @torch.jit.script_method - def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None: + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + ) -> None: self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) - self.check_hidden_size(hidden, expected_hidden_size, msg='Expected hidden size {}, got {}') + self.check_hidden_size( + hidden, expected_hidden_size, msg="Expected hidden size {}, got {}" + ) @torch.jit.script_method def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: @@ -357,22 +510,33 @@ class QuantizedRNNBase(torch.jit.ScriptModule): class QuantizedLSTM(QuantizedRNNBase): - __overloads__ = {'forward': ['forward_packed', 'forward_tensor']} + __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} def __init__(self, other, dtype): super().__init__(other, dtype) warnings.warn( "torch.jit.QuantizedLSTM is deprecated and will be removed in an upcoming " - "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTM instead.") + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTM instead." + ) @torch.jit.script_method - def forward_impl(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]], batch_sizes: Optional[Tensor], - max_batch_size: int, sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + def forward_impl( + self, + input: Tensor, + hx: Optional[Tuple[Tensor, Tensor]], + batch_sizes: Optional[Tensor], + max_batch_size: int, + sorted_indices: Optional[Tensor], + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: if hx is None: num_directions = 2 if self.bidirectional else 1 - zeros = torch.zeros(self.num_layers * num_directions, - max_batch_size, self.hidden_size, - dtype=input.dtype, device=input.device) + zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) hx = (zeros, zeros) else: # Each batch of the hidden state should match the input sequence that @@ -381,22 +545,36 @@ class QuantizedLSTM(QuantizedRNNBase): self.check_forward_args(input, hx, batch_sizes) assert batch_sizes is None - result = torch.quantized_lstm(input, hx, self.all_weights, self.bias, self.num_layers, - float(self.dropout), self.training, self.bidirectional, - self.batch_first, dtype=self.dtype, use_dynamic=False) + result = torch.quantized_lstm( + input, + hx, + self.all_weights, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + self.batch_first, + dtype=self.dtype, + use_dynamic=False, + ) output = result[0] hidden = result[1:] return output, hidden @torch.jit.script_method - def forward_tensor(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + def forward_tensor( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: batch_sizes = None max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None unsorted_indices = None - output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) + output, hidden = self.forward_impl( + input, hx, batch_sizes, max_batch_size, sorted_indices + ) return output, self.permute_hidden(hidden, unsorted_indices) @@ -414,22 +592,32 @@ class QuantizedLSTM(QuantizedRNNBase): output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) return output, self.permute_hidden(hidden, unsorted_indices) - @torch.jit.script_method - def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]: + def permute_hidden( + self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor] + ) -> Tuple[Tensor, Tensor]: if permutation is None: return hx - return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation) + return apply_permutation(hx[0], permutation), apply_permutation( + hx[1], permutation + ) @torch.jit.script_method - def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]) -> None: + def check_forward_args( + self, + input: Tensor, + hidden: Tuple[Tensor, Tensor], + batch_sizes: Optional[Tensor], + ) -> None: self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) - self.check_hidden_size(hidden[0], expected_hidden_size, - 'Expected hidden[0] size {}, got {}') - self.check_hidden_size(hidden[1], expected_hidden_size, - 'Expected hidden[1] size {}, got {}') + self.check_hidden_size( + hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}" + ) + self.check_hidden_size( + hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}" + ) def forward(self, input, hx=None): if isinstance(input, PackedSequence): @@ -439,23 +627,33 @@ class QuantizedLSTM(QuantizedRNNBase): class QuantizedGRU(QuantizedRNNBase): - __overloads__ = {'forward': ['forward_packed', 'forward_tensor']} + __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torch.jit.QuantizedGRU is deprecated and will be removed in an upcoming " - "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRU instead.") - + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRU instead." + ) @torch.jit.script_method - def forward_impl(self, input: Tensor, hx: Optional[Tensor], batch_sizes: Optional[Tensor], max_batch_size: int, - sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tensor]: + def forward_impl( + self, + input: Tensor, + hx: Optional[Tensor], + batch_sizes: Optional[Tensor], + max_batch_size: int, + sorted_indices: Optional[Tensor], + ) -> Tuple[Tensor, Tensor]: if hx is None: num_directions = 2 if self.bidirectional else 1 - hx = torch.zeros(self.num_layers * num_directions, - max_batch_size, self.hidden_size, - dtype=input.dtype, device=input.device) + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. @@ -463,12 +661,29 @@ class QuantizedGRU(QuantizedRNNBase): self.check_forward_args(input, hx, batch_sizes) if batch_sizes is None: - result = torch.quantized_gru(input, hx, self.all_weights, self.bias, self.num_layers, - float(self.dropout), self.training, self.bidirectional, - self.batch_first) + result = torch.quantized_gru( + input, + hx, + self.all_weights, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + self.batch_first, + ) else: - result = torch.quantized_gru(input, batch_sizes, hx, self.all_weights, self.bias, self.num_layers, - float(self.dropout), self.training, self.bidirectional) + result = torch.quantized_gru( + input, + batch_sizes, + hx, + self.all_weights, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + ) output = result[0] hidden = result[1] @@ -476,17 +691,23 @@ class QuantizedGRU(QuantizedRNNBase): return output, hidden @torch.jit.script_method - def forward_tensor(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: + def forward_tensor( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: batch_sizes = None max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None unsorted_indices = None - output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) + output, hidden = self.forward_impl( + input, hx, batch_sizes, max_batch_size, sorted_indices + ) return output, self.permute_hidden(hidden, unsorted_indices) @torch.jit.script_method - def forward_packed(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]: + def forward_packed( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> Tuple[PackedSequence, Tensor]: input_, batch_sizes, sorted_indices, unsorted_indices = input max_batch_size = int(batch_sizes[0]) @@ -505,8 +726,10 @@ class QuantizedGRU(QuantizedRNNBase): def quantize_rnn_cell_modules(module): - warnings.warn("quantize_rnn_cell_modules function has been deprecated. " - "Please use torch.ao.quantization.quantize_dynamic API instead.") + warnings.warn( + "quantize_rnn_cell_modules function has been deprecated. " + "Please use torch.ao.quantization.quantize_dynamic API instead." + ) reassign = {} for name, mod in module.named_modules(): if mod is module: @@ -526,8 +749,10 @@ def quantize_rnn_cell_modules(module): def quantize_linear_modules(module, dtype=torch.int8): - warnings.warn("quantize_linear_modules function has been deprecated. " - "Please use torch.ao.quantization.quantize_dynamic API instead.") + warnings.warn( + "quantize_linear_modules function has been deprecated. " + "Please use torch.ao.quantization.quantize_dynamic API instead." + ) reassign = {} for name, mod in module.named_modules(): @@ -545,14 +770,15 @@ def quantize_linear_modules(module, dtype=torch.int8): elif dtype == torch.float16: return QuantizedLinearFP16(module) else: - raise RuntimeError( - f"Unsupported dtype: {dtype}") + raise RuntimeError(f"Unsupported dtype: {dtype}") return module def quantize_rnn_modules(module, dtype=torch.int8): - warnings.warn("quantize_rnn_modules function has been deprecated. " - "Please use torch.ao.quantization.quantize_dynamic API instead.") + warnings.warn( + "quantize_rnn_modules function has been deprecated. " + "Please use torch.ao.quantization.quantize_dynamic API instead." + ) reassign = {} for name, mod in module.named_modules(): if mod is module: diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index e3664674fbd8..053d1d598233 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -1,17 +1,22 @@ -import torch.jit -from torch.jit._builtins import _find_builtin import inspect import textwrap + +import torch.jit +from torch.jit._builtins import _find_builtin + # this file is for generating documentation using sphinx autodoc # > help(torch.jit.supported_ops) will also give a nice listed of the # supported ops programmatically + def _hidden(name): - return name.startswith('_') and not name.startswith('__') + return name.startswith("_") and not name.startswith("__") + def _emit_type(type): return str(type) + def _emit_arg(indent, i, arg): v = f"{arg.name} : {_emit_type(arg.type)}" default = arg.default_value @@ -21,33 +26,40 @@ def _emit_arg(indent, i, arg): v = f"\n{' ' * indent}{v}" return v + def _emit_args(indent, arguments): return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments)) + def _emit_ret(ret): return _emit_type(ret.type) + def _emit_rets(returns): if len(returns) == 1: return _emit_ret(returns[0]) return f"Tuple[{', '.join(_emit_ret(r) for r in returns)}]" + def _emit_schema(mod, name, schema, arg_start=0, padding=4): if mod is None: qualified_name = name else: qualified_name = f"{mod}.{name}" - schema_str = "{}({}) -> {}".format(qualified_name, - _emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:]), - _emit_rets(schema.returns)) + schema_str = "{}({}) -> {}".format( + qualified_name, + _emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:]), + _emit_rets(schema.returns), + ) return schema_str + def _get_tensor_ops(): def is_tensor_method(schema): if len(schema.arguments) == 0: return False self = schema.arguments[0] - if self.name != 'self': + if self.name != "self": return False if not self.type.isSubtypeOf(torch._C.TensorType.get()): return False @@ -60,10 +72,11 @@ def _get_tensor_ops(): schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem) for schema in schemas: if is_tensor_method(schema): - methods.append(_emit_schema('Tensor', elem, schema, arg_start=1)) + methods.append(_emit_schema("Tensor", elem, schema, arg_start=1)) return "Supported Tensor Methods", methods + def _get_nn_functional_ops(): functions = [] @@ -78,9 +91,9 @@ def _get_nn_functional_ops(): attr_module = inspect.getmodule(attr) if not attr_module: - raise RuntimeError(f'Module for {attr} not found') + raise RuntimeError(f"Module for {attr} not found") - if 'torch.nn.functional' not in attr_module.__name__: + if "torch.nn.functional" not in attr_module.__name__: # Ignore functions from outside torch.nn.functional continue @@ -106,12 +119,13 @@ def _get_nn_functional_ops(): functions.append(_emit_schema(name, elem, schema)) return "Supported PyTorch Functions", functions + def _get_builtins_helper(): builtins = [] for fn, _builtin_name in torch.jit._builtins._builtin_ops: mod = inspect.getmodule(fn) - if not hasattr(fn, '__name__'): + if not hasattr(fn, "__name__"): # typing classes continue if not mod: @@ -120,19 +134,20 @@ def _get_builtins_helper(): # skip internal-only methods continue - if 'torch._C' in mod.__name__: + if "torch._C" in mod.__name__: continue builtins.append((fn, _builtin_name)) return builtins + def _is_math_fn(fn): mod = inspect.getmodule(fn) if not mod: - raise RuntimeError(f'Module for {fn} not found') + raise RuntimeError(f"Module for {fn} not found") - return mod.__name__ == 'math' + return mod.__name__ == "math" def _get_torchscript_builtins(): @@ -143,7 +158,7 @@ def _get_torchscript_builtins(): for fn, _builtin_name in builtins_list: mod = inspect.getmodule(fn) if not mod: - raise RuntimeError(f'Module for {fn} not found') + raise RuntimeError(f"Module for {fn} not found") builtin = _find_builtin(fn) if builtin is not None: schemas = torch._C._jit_get_schemas_for_operator(builtin) @@ -162,13 +177,13 @@ def _get_math_builtins(): for fn, _builtin_name in builtins_list: mod = inspect.getmodule(fn) if not mod: - raise RuntimeError(f'Module for {fn} not found') + raise RuntimeError(f"Module for {fn} not found") builtin = _find_builtin(fn) if builtin is not None: schemas = torch._C._jit_get_schemas_for_operator(builtin) for schema in schemas: schema_str = _emit_schema(mod.__name__, fn.__name__, schema) - if 'Tensor' in schema_str: + if "Tensor" in schema_str: # Skip Tensor ops that have the same name as math functions # (they will show up in the tensor methods section) continue @@ -181,67 +196,67 @@ def _get_math_builtins(): def _get_global_builtins(): # Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp supported_builtins = [ - 'print', - 'tuple', - 'float', - 'complex', - 'int', - 'bool', - 'str', - 'getattr', - 'hasattr', - 'isinstance', - 'len', - 'hex', - 'oct', - 'round', - 'hash', - 'min', - 'max', - 'abs', - 'all', - 'divmod', - 'list', - 'ord', - 'chr', - 'bin', - 'range', - 'zip', - 'enumerate', - 'sorted', + "print", + "tuple", + "float", + "complex", + "int", + "bool", + "str", + "getattr", + "hasattr", + "isinstance", + "len", + "hex", + "oct", + "round", + "hash", + "min", + "max", + "abs", + "all", + "divmod", + "list", + "ord", + "chr", + "bin", + "range", + "zip", + "enumerate", + "sorted", ] op_renames = { - 'bool': 'aten::Bool', - 'int': 'aten::Int', - 'float': 'aten::Float', - 'complex': 'aten::Complex', - 'abs': 'prim::abs', - 'max': 'prim::max', - 'min': 'prim::min', - 'range': 'fake::does_not_exist', + "bool": "aten::Bool", + "int": "aten::Int", + "float": "aten::Float", + "complex": "aten::Complex", + "abs": "prim::abs", + "max": "prim::max", + "min": "prim::min", + "range": "fake::does_not_exist", } schemaless_op_explanations = { - 'print': 'Print any value', - 'tuple': 'Lists cannot be converted to tuples with this method since their size is not statically known', - 'getattr': 'Attribute name must be a literal string', - 'hasattr': 'Attribute name must be a literal string', - 'isinstance': 'Result is static', - 'zip': 'Arguments must be iterable. See :ref:`Iterables ` for details.', - 'enumerate': 'Arguments must be iterable. See :ref:`Iterables ` for details.', - 'range': 'Can only be used as an iterator in a for loop', + "print": "Print any value", + "tuple": "Lists cannot be converted to tuples with this method since their size is not statically known", + "getattr": "Attribute name must be a literal string", + "hasattr": "Attribute name must be a literal string", + "isinstance": "Result is static", + "zip": "Arguments must be iterable. See :ref:`Iterables ` for details.", + "enumerate": "Arguments must be iterable. See :ref:`Iterables ` for details.", + "range": "Can only be used as an iterator in a for loop", } magic_methods = [ - ('complex', '__complex__'), - ('float', '__float__'), - ('int', '__int__'), - ('bool', '__bool__'), - ('str', '__str__'), - ('len', '__len__'), - ('hex', '__hex__'), - ('oct', '__oct__'), + ("complex", "__complex__"), + ("float", "__float__"), + ("int", "__int__"), + ("bool", "__bool__"), + ("str", "__str__"), + ("len", "__len__"), + ("hex", "__hex__"), + ("oct", "__oct__"), ] magic_methods_rows = [] @@ -252,24 +267,24 @@ def _get_global_builtins(): schemaless_ops = [] for fn in supported_builtins: - op_name = f'aten::{fn}' + op_name = f"aten::{fn}" if fn in op_renames: op_name = op_renames[fn] schemas = torch._C._jit_get_schemas_for_operator(op_name) for s in schemas: schematized_ops.append(_emit_schema(None, fn, s, padding=0)) if len(schemas) > 0: - schematized_ops.append('') + schematized_ops.append("") else: table_row = f'":any:`{fn}`", "{schemaless_op_explanations[fn]}"' schemaless_ops.append(table_row) - schematized_ops_str = '\n'.join(schematized_ops) - schemaless_ops_str = '\n'.join(schemaless_ops) - magic_methods_rows_str = '\n'.join(magic_methods_rows) - schematized_ops_str = textwrap.indent(schematized_ops_str, '\t') - schemaless_ops_str = textwrap.indent(schemaless_ops_str, '\t') - magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, '\t') + schematized_ops_str = "\n".join(schematized_ops) + schemaless_ops_str = "\n".join(schemaless_ops) + magic_methods_rows_str = "\n".join(magic_methods_rows) + schematized_ops_str = textwrap.indent(schematized_ops_str, "\t") + schemaless_ops_str = textwrap.indent(schemaless_ops_str, "\t") + magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, "\t") section = f""" The functions in the following table are supported but do not have a static schema @@ -299,9 +314,11 @@ These built-in functions use the schema def _list_supported_ops(): def emit_block(decls): - return '\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n'.format(''.join(f' {d}\n\n' for d in decls)) + return "\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n".format( + "".join(f" {d}\n\n" for d in decls) + ) - body = '' + body = "" op_gathering_fns = ( _get_tensor_ops, _get_nn_functional_ops, @@ -311,14 +328,15 @@ def _list_supported_ops(): ) for fn in op_gathering_fns: header, items = fn() - link_target = header.replace('`', '').replace('-', '').lower().replace(' ', '-') + link_target = header.replace("`", "").replace("-", "").lower().replace(" ", "-") if isinstance(items, str): section = f"{header}\n{'~' * len(header)}\n{items}\n" else: section = f"{header}\n{'~' * len(header)}\n{emit_block(items)}" - section = f'.. _{link_target}:' + '\n\n' + section + section = f".. _{link_target}:" + "\n\n" + section body += section return body + __doc__ = _list_supported_ops() diff --git a/torch/jit/unsupported_tensor_ops.py b/torch/jit/unsupported_tensor_ops.py index 29d910051cfd..4e553757eab4 100644 --- a/torch/jit/unsupported_tensor_ops.py +++ b/torch/jit/unsupported_tensor_ops.py @@ -1,20 +1,35 @@ -import torch.jit from textwrap import dedent -from typing import Dict, Any +from typing import Any, Dict + +import torch.jit + def execWrapper(code, glob, loc): exec(code, glob, loc) + def _gen_unsupported_methods_properties(): tensor_attrs = set(filter(lambda x: x[0] != "_", dir(torch.Tensor))) tensor = torch.tensor([2]) - funcs_template = dedent(''' + funcs_template = dedent( + """ def func(x): return x.{op}() - ''') + """ + ) - deprecated_apis = {"volatile", "resize", "reinforce", "new", "name", "map2_", "has_names", "grad_fn", "resize_as"} + deprecated_apis = { + "volatile", + "resize", + "reinforce", + "new", + "name", + "map2_", + "has_names", + "grad_fn", + "resize_as", + } tensor_attrs = tensor_attrs - deprecated_apis properties = [] @@ -46,10 +61,18 @@ Unsupported Tensor Methods ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ methods, properties = _gen_unsupported_methods_properties() - return header + "\n" + methods + """ + return ( + header + + "\n" + + methods + + """ Unsupported Tensor Properties ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - """ + "\n" + properties + """ + + "\n" + + properties + ) + __doc__ = _list_unsupported_tensor_ops() diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py index e455e7de0557..13ba91d10de7 100644 --- a/torch/mps/__init__.py +++ b/torch/mps/__init__.py @@ -10,6 +10,7 @@ from .. import Tensor _is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False) _default_mps_generator: torch._C.Generator = None # type: ignore[assignment] + # local helper function (not public or exported) def _get_default_mps_generator() -> torch._C.Generator: global _default_mps_generator @@ -17,14 +18,17 @@ def _get_default_mps_generator() -> torch._C.Generator: _default_mps_generator = torch._C._mps_get_default_generator() return _default_mps_generator + def synchronize() -> None: r"""Waits for all kernels in all streams on a MPS device to complete.""" return torch._C._mps_deviceSynchronize() + def get_rng_state() -> Tensor: r"""Returns the random number generator state as a ByteTensor.""" return _get_default_mps_generator().get_state() + def set_rng_state(new_state: Tensor) -> None: r"""Sets the random number generator state. @@ -34,6 +38,7 @@ def set_rng_state(new_state: Tensor) -> None: new_state_copy = new_state.clone(memory_format=torch.contiguous_format) _get_default_mps_generator().set_state(new_state_copy) + def manual_seed(seed: int) -> None: r"""Sets the seed for generating random numbers. @@ -49,16 +54,19 @@ def manual_seed(seed: int) -> None: seed = int(seed) _get_default_mps_generator().manual_seed(seed) + def seed() -> None: r"""Sets the seed for generating random numbers to a random number.""" _get_default_mps_generator().seed() + def empty_cache() -> None: r"""Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU applications. """ torch._C._mps_emptyCache() + def set_per_process_memory_fraction(fraction) -> None: r"""Set memory fraction for limiting process's memory allocation on MPS device. The allowed value equals the fraction multiplied by recommended maximum device memory @@ -77,33 +85,44 @@ def set_per_process_memory_fraction(fraction) -> None: """ if not isinstance(fraction, float): - raise TypeError('Invalid type for fraction argument, must be `float`') + raise TypeError("Invalid type for fraction argument, must be `float`") if fraction < 0 or fraction > 2: - raise ValueError(f'Invalid fraction value: {fraction}. Allowed range: 0~2') + raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~2") torch._C._mps_setMemoryFraction(fraction) + def current_allocated_memory() -> int: r"""Returns the current GPU memory occupied by tensors in bytes. - .. note:: - The returned size does not include cached allocations in - memory pools of MPSAllocator. + .. note:: + The returned size does not include cached allocations in + memory pools of MPSAllocator. """ return torch._C._mps_currentAllocatedMemory() + def driver_allocated_memory() -> int: r"""Returns total GPU memory allocated by Metal driver for the process in bytes. - .. note:: - The returned size includes cached allocations in MPSAllocator pools - as well as allocations from MPS/MPSGraph frameworks. + .. note:: + The returned size includes cached allocations in MPSAllocator pools + as well as allocations from MPS/MPSGraph frameworks. """ return torch._C._mps_driverAllocatedMemory() + from . import profiler __all__ = [ - 'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize', - 'empty_cache', 'set_per_process_memory_fraction', 'current_allocated_memory', - 'driver_allocated_memory', 'profiler'] + "get_rng_state", + "manual_seed", + "seed", + "set_rng_state", + "synchronize", + "empty_cache", + "set_per_process_memory_fraction", + "current_allocated_memory", + "driver_allocated_memory", + "profiler", +] diff --git a/torch/mps/profiler.py b/torch/mps/profiler.py index 5ad94d01312c..9094a275136c 100644 --- a/torch/mps/profiler.py +++ b/torch/mps/profiler.py @@ -1,8 +1,10 @@ -import torch import contextlib +import torch + __all__ = ["start", "stop", "profile"] + def start(mode: str = "interval", wait_until_completed: bool = False) -> None: r"""Start OS Signpost tracing from MPS backend. @@ -26,10 +28,12 @@ def start(mode: str = "interval", wait_until_completed: bool = False) -> None: mode_normalized = mode.lower().replace(" ", "") torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed) + def stop(): r"""Stops generating OS Signpost tracing from MPS backend.""" torch._C._mps_profilerStopTrace() + @contextlib.contextmanager def profile(mode: str = "interval", wait_until_completed: bool = False): r"""Context Manager to enabling generating OS Signpost tracing from MPS backend. diff --git a/torch/multiprocessing/__init__.py b/torch/multiprocessing/__init__.py index 69a1590fe983..42860f024a56 100644 --- a/torch/multiprocessing/__init__.py +++ b/torch/multiprocessing/__init__.py @@ -13,13 +13,13 @@ memory. Because of the similarity of APIs we do not document most of this package contents, and we recommend referring to very good docs of the original module. """ -import torch -import sys -from .reductions import init_reductions import multiprocessing +import sys -__all__ = ['set_sharing_strategy', 'get_sharing_strategy', - 'get_all_sharing_strategies'] +import torch +from .reductions import init_reductions + +__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"] from multiprocessing import * # noqa: F403 @@ -35,16 +35,22 @@ torch._C._multiprocessing_init() """Add helper function to spawn N processes and wait for completion of any of them. This depends `mp.get_context` which was added in Python 3.4.""" -from .spawn import spawn, SpawnContext, start_processes, ProcessContext, \ - ProcessRaisedException, ProcessExitedException +from .spawn import ( + ProcessContext, + ProcessExitedException, + ProcessRaisedException, + spawn, + SpawnContext, + start_processes, +) -if sys.platform == 'darwin' or sys.platform == 'win32': - _sharing_strategy = 'file_system' - _all_sharing_strategies = {'file_system'} +if sys.platform == "darwin" or sys.platform == "win32": + _sharing_strategy = "file_system" + _all_sharing_strategies = {"file_system"} else: - _sharing_strategy = 'file_descriptor' - _all_sharing_strategies = {'file_descriptor', 'file_system'} + _sharing_strategy = "file_descriptor" + _all_sharing_strategies = {"file_descriptor", "file_system"} def set_sharing_strategy(new_strategy): diff --git a/torch/multiprocessing/_atfork.py b/torch/multiprocessing/_atfork.py index 74b4ec9fff16..92a3280fee78 100644 --- a/torch/multiprocessing/_atfork.py +++ b/torch/multiprocessing/_atfork.py @@ -1,20 +1,23 @@ import sys -__all__ = ['register_after_fork'] +__all__ = ["register_after_fork"] -if sys.platform == 'win32': +if sys.platform == "win32": import multiprocessing.util as _util def _register(func): def wrapper(arg): func() + _util.register_after_fork(_register, wrapper) + else: import os def _register(func): os.register_at_fork(after_in_child=func) + def register_after_fork(func): """Register a callable to be executed in the child process after a fork. diff --git a/torch/multiprocessing/pool.py b/torch/multiprocessing/pool.py index 85281e7e729f..e19c38e0c497 100644 --- a/torch/multiprocessing/pool.py +++ b/torch/multiprocessing/pool.py @@ -6,6 +6,7 @@ from .queue import SimpleQueue def clean_worker(*args, **kwargs): import gc + multiprocessing.pool.worker(*args, **kwargs) # Regular multiprocessing workers don't fully clean up after themselves, # so we have to explicitly trigger garbage collection to make sure that all @@ -30,14 +31,18 @@ class Pool(multiprocessing.pool.Pool): """ for i in range(self._processes - len(self._pool)): # changed worker -> clean_worker - args = (self._inqueue, self._outqueue, - self._initializer, - self._initargs, self._maxtasksperchild) - if hasattr(self, '_wrap_exception'): + args = ( + self._inqueue, + self._outqueue, + self._initializer, + self._initargs, + self._maxtasksperchild, + ) + if hasattr(self, "_wrap_exception"): args += (self._wrap_exception,) w = self.Process(target=clean_worker, args=args) self._pool.append(w) - w.name = w.name.replace('Process', 'PoolWorker') + w.name = w.name.replace("Process", "PoolWorker") w.daemon = True w.start() - util.debug('added worker') + util.debug("added worker") diff --git a/torch/multiprocessing/queue.py b/torch/multiprocessing/queue.py index 673c0a05c6bd..648520a47312 100644 --- a/torch/multiprocessing/queue.py +++ b/torch/multiprocessing/queue.py @@ -1,7 +1,7 @@ import io import multiprocessing.queues -from multiprocessing.reduction import ForkingPickler import pickle +from multiprocessing.reduction import ForkingPickler class ConnectionWrapper: @@ -21,13 +21,12 @@ class ConnectionWrapper: return pickle.loads(buf) def __getattr__(self, name): - if 'conn' in self.__dict__: + if "conn" in self.__dict__: return getattr(self.conn, name) raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'") class Queue(multiprocessing.queues.Queue): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) @@ -37,7 +36,6 @@ class Queue(multiprocessing.queues.Queue): class SimpleQueue(multiprocessing.queues.SimpleQueue): - def _make_methods(self): if not isinstance(self._reader, ConnectionWrapper): self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index 7fbde655d5eb..f2a5a7fb978f 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -1,12 +1,13 @@ +import multiprocessing +import os +import threading +from multiprocessing.reduction import ForkingPickler +from multiprocessing.util import register_after_fork +from typing import Union + import torch import torch.utils.hooks from torch._namedtensor_internals import check_serializing_named_tensor -import os -import threading -import multiprocessing -from multiprocessing.util import register_after_fork -from multiprocessing.reduction import ForkingPickler -from typing import Union try: # Early load resource_sharer to prevent a partially initialized instance @@ -117,14 +118,30 @@ def rebuild_tensor(cls, storage, metadata): return t -def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset, - storage_cls, dtype, storage_device, storage_handle, storage_size_bytes, storage_offset_bytes, - requires_grad, ref_counter_handle, ref_counter_offset, event_handle, event_sync_required): +def rebuild_cuda_tensor( + tensor_cls, + tensor_size, + tensor_stride, + tensor_offset, + storage_cls, + dtype, + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, +): # If storage_handle is None, storage points to nullptr. if storage_handle is None or storage_size_bytes == 0: storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) else: - storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes)) + storage = storage_from_cache( + storage_cls, (storage_handle, storage_offset_bytes) + ) if storage is None: torch.cuda._lazy_init() storage = storage_cls._new_shared_cuda( @@ -135,17 +152,29 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset, ref_counter_handle, ref_counter_offset, event_handle, - event_sync_required) - shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage) + event_sync_required, + ) + shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef( + storage + ) else: # We already ref counting this Storage, but producer needs new ref-counters to be released. - storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device) + storage_cls._release_ipc_counter( + ref_counter_handle, ref_counter_offset, device=storage_device + ) - _storage = storage if isinstance(storage, torch.UntypedStorage) else storage._untyped_storage + _storage = ( + storage + if isinstance(storage, torch.UntypedStorage) + else storage._untyped_storage + ) t = torch._utils._rebuild_tensor( torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), - tensor_offset, tensor_size, tensor_stride) + tensor_offset, + tensor_size, + tensor_stride, + ) if tensor_cls == torch.nn.parameter.Parameter: # It is crucial for integer tensors to receive @@ -161,10 +190,12 @@ def reduce_tensor(tensor): storage = tensor._typed_storage() if tensor.requires_grad and not tensor.is_leaf: - raise RuntimeError("Cowardly refusing to serialize non-leaf tensor which requires_grad, " - "since autograd does not support crossing process boundaries. " - "If you just want to transfer the data, call detach() on the tensor " - "before serializing (e.g., putting it on the queue).") + raise RuntimeError( + "Cowardly refusing to serialize non-leaf tensor which requires_grad, " + "since autograd does not support crossing process boundaries. " + "If you just want to transfer the data, call detach() on the tensor " + "before serializing (e.g., putting it on the queue)." + ) check_serializing_named_tensor(tensor) torch.utils.hooks.warn_if_has_hooks(tensor) @@ -259,42 +290,50 @@ def reduce_tensor(tensor): # eliminated it so that we could just use tensor views to implement the same # thing. # - if storage._untyped_storage.device.type == 'cuda': - (device, - handle, - storage_size_bytes, - storage_offset_bytes, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required) = storage._share_cuda_() + if storage._untyped_storage.device.type == "cuda": + ( + device, + handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) = storage._share_cuda_() tensor_offset = tensor.storage_offset() shared_cache[handle] = StorageWeakRef(storage) # _backward_hooks purposely omitted here, see # Note [Don't serialize hooks] - return (rebuild_cuda_tensor, - (type(tensor), - tensor.size(), - tensor.stride(), - tensor_offset, # tensor offset in its storage - type(storage), - tensor.dtype, - device, - handle, # identifier which CUDA allocation is the storage in. - storage_size_bytes, # size(in bytes) of the storage - storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation - tensor.requires_grad, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required)) + return ( + rebuild_cuda_tensor, + ( + type(tensor), + tensor.size(), + tensor.stride(), + tensor_offset, # tensor offset in its storage + type(storage), + tensor.dtype, + device, + handle, # identifier which CUDA allocation is the storage in. + storage_size_bytes, # size(in bytes) of the storage + storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation + tensor.requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ), + ) # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] - metadata = (tensor.storage_offset(), tensor.size(), tensor.stride(), tensor.requires_grad) - return (rebuild_tensor, ( - type(tensor), - storage, - metadata)) + metadata = ( + tensor.storage_offset(), + tensor.size(), + tensor.stride(), + tensor.requires_grad, + ) + return (rebuild_tensor, (type(tensor), storage, metadata)) def fd_id(fd): @@ -326,18 +365,21 @@ def rebuild_storage_fd(cls, df, size): def rebuild_storage_filename(cls, manager, handle, size, dtype=None): - storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(cls, handle) + storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache( + cls, handle + ) if storage is not None: return storage._shared_decref() if dtype is None: storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size) else: byte_size = size * torch._utils._element_size(dtype) - untyped_storage: torch.UntypedStorage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) + untyped_storage: torch.UntypedStorage = ( + torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) + ) storage = torch.TypedStorage( - wrap_storage=untyped_storage, - dtype=dtype, - _internal=True) + wrap_storage=untyped_storage, dtype=dtype, _internal=True + ) shared_cache[handle] = StorageWeakRef(storage) return storage._shared_decref() @@ -345,25 +387,33 @@ def rebuild_storage_filename(cls, manager, handle, size, dtype=None): def rebuild_storage_empty(cls): return cls() + def rebuild_typed_storage(storage, dtype): return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True) + # Use for torch.storage.TypedStorage def reduce_typed_storage(storage): return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype)) + def rebuild_typed_storage_child(storage, storage_type): return storage_type(wrap_storage=storage, _internal=True) + # Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage def reduce_typed_storage_child(storage): return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage))) + def reduce_storage(storage): from . import get_sharing_strategy + if storage.is_cuda: - raise RuntimeError("Cannot pickle CUDA storage; try pickling a CUDA tensor instead") - elif get_sharing_strategy() == 'file_system': + raise RuntimeError( + "Cannot pickle CUDA storage; try pickling a CUDA tensor instead" + ) + elif get_sharing_strategy() == "file_system": metadata = storage._share_filename_cpu_() cache_key = metadata[1] rebuild = rebuild_storage_filename @@ -389,7 +439,7 @@ def init_reductions(): ForkingPickler.register(torch.cuda.Event, reduce_event) for t in torch._storage_classes: - if t.__name__ == 'UntypedStorage': + if t.__name__ == "UntypedStorage": ForkingPickler.register(t, reduce_storage) else: ForkingPickler.register(t, reduce_typed_storage_child) diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index e802c3d14a44..5c683865f30e 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -1,10 +1,9 @@ - -from typing import Optional import multiprocessing import multiprocessing.connection import signal import sys import warnings +from typing import Optional from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] @@ -27,6 +26,7 @@ class ProcessRaisedException(ProcessException): Exception is thrown when the process failed due to exception raised by the code. """ + def __init__( self, msg: str, @@ -41,11 +41,16 @@ class ProcessExitedException(ProcessException): Exception is thrown when the process failed due to signal or exited with a specific code. """ + __slots__ = ["exit_code"] def __init__( - self, msg: str, error_index: int, error_pid: int, - exit_code: int, signal_name: Optional[str] = None + self, + msg: str, + error_index: int, + error_pid: int, + exit_code: int, + signal_name: Optional[str] = None, ): super().__init__(msg, error_index, error_pid) self.exit_code = exit_code @@ -72,6 +77,7 @@ def _wrap(fn, i, args, error_queue): except Exception: # Propagate exception to parent process, keeping original traceback import traceback + error_queue.put(traceback.format_exc()) sys.exit(1) @@ -81,8 +87,7 @@ class ProcessContext: self.error_queues = error_queues self.processes = processes self.sentinels = { - process.sentinel: index - for index, process in enumerate(processes) + process.sentinel: index for index, process in enumerate(processes) } def pids(self): @@ -138,20 +143,18 @@ class ProcessContext: if exitcode < 0: name = signal.Signals(-exitcode).name raise ProcessExitedException( - "process %d terminated with signal %s" % - (error_index, name), + "process %d terminated with signal %s" % (error_index, name), error_index=error_index, error_pid=failed_process.pid, exit_code=exitcode, - signal_name=name + signal_name=name, ) else: raise ProcessExitedException( - "process %d terminated with exit code %d" % - (error_index, exitcode), + "process %d terminated with exit code %d" % (error_index, exitcode), error_index=error_index, error_pid=failed_process.pid, - exit_code=exitcode + exit_code=exitcode, ) original_trace = self.error_queues[error_index].get() @@ -162,7 +165,7 @@ class ProcessContext: class SpawnContext(ProcessContext): def __init__(self, processes, error_queues): - warnings.warn('SpawnContext is renamed to ProcessContext since 1.4 release.') + warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.") super().__init__(processes, error_queues) @@ -174,7 +177,9 @@ class SpawnContext(ProcessContext): # general enough, and backends like XLA can reuse them in Colab notebooks as well. # Currently we only add this API first, we can consider adding it to documentation as # needed in the future. -def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'): +def start_processes( + fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn" +): mp = multiprocessing.get_context(start_method) error_queues = [] processes = [] @@ -198,7 +203,7 @@ def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method pass -def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'): +def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"): r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``. If one of the processes exits with a non-zero exit status, the @@ -231,9 +236,11 @@ def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'): :class:`~ProcessContext` if ``join`` is ``False`` """ - if start_method != 'spawn': - msg = ('This method only supports start_method=spawn (got: %s).\n' - 'To use a different start_method use:\n\t\t' - ' torch.multiprocessing.start_processes(...)' % start_method) + if start_method != "spawn": + msg = ( + "This method only supports start_method=spawn (got: %s).\n" + "To use a different start_method use:\n\t\t" + " torch.multiprocessing.start_processes(...)" % start_method + ) warnings.warn(msg) - return start_processes(fn, args, nprocs, join, daemon, start_method='spawn') + return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py index 35d71411aab1..e3c4145fd91f 100644 --- a/torch/profiler/__init__.py +++ b/torch/profiler/__init__.py @@ -11,7 +11,7 @@ import os from torch._C._autograd import _supported_activities, DeviceType, kineto_available from torch._C._profiler import _ExperimentalConfig, ProfilerActivity, RecordScope -from torch.autograd.profiler import record_function, KinetoStepTracker +from torch.autograd.profiler import KinetoStepTracker, record_function from torch.optim.optimizer import register_optimizer_step_post_hook from .profiler import ( @@ -39,8 +39,10 @@ __all__ = [ from . import itt + def _optimizer_post_hook(optimizer, args, kwargs): KinetoStepTracker.increment_step("Optimizer") + if os.environ.get("KINETO_USE_DAEMON", None): _ = register_optimizer_step_post_hook(_optimizer_post_hook) diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 46f9f22dd09d..2974fde90f4a 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -16,6 +16,8 @@ from typing import ( Union, ) +from typing_extensions import Literal + import torch from torch._C import FunctionSchema from torch._C._autograd import _ProfilerResult @@ -30,13 +32,12 @@ from torch._C._profiler import ( from torch._utils import _element_size from torch.profiler import _utils -from typing_extensions import Literal - KeyAndID = Tuple["Key", int] TensorAndID = Tuple["TensorKey", int] log = logging.getLogger(__name__) + class Category(enum.Enum): INPUT = enum.auto() TEMPORARY = enum.auto() @@ -46,6 +47,7 @@ class Category(enum.Enum): PARAMETER = enum.auto() OPTIMIZER_STATE = enum.auto() + _CATEGORY_TO_COLORS = { Category.PARAMETER: "darkgreen", Category.OPTIMIZER_STATE: "goldenrod", @@ -59,14 +61,17 @@ _CATEGORY_TO_COLORS = { _CATEGORY_TO_INDEX = {c: i for i, c in enumerate(_CATEGORY_TO_COLORS)} + class Action(enum.Enum): PREEXISTING = enum.auto() CREATE = enum.auto() INCREMENT_VERSION = enum.auto() DESTROY = enum.auto() + _ACTION_TO_INDEX = {i: i.value for i in Action} + @dataclasses.dataclass(eq=True, unsafe_hash=False, frozen=True) class Key: device: torch.device @@ -690,14 +695,18 @@ class MemoryProfile: ptr_and_device = (alloc_fields.ptr, key.device) if is_allocation: if ptr_and_device in live_unknown: - output.append((t, Action.INCREMENT_VERSION, (key, 0), alloc_size)) + output.append( + (t, Action.INCREMENT_VERSION, (key, 0), alloc_size) + ) else: live_unknown[ptr_and_device] = True output.append((t, Action.CREATE, (key, 0), alloc_size)) else: output.append((t, Action.DESTROY, (key, 0), -alloc_size)) if not live_unknown.pop(ptr_and_device, False): - output.append((-1, Action.PREEXISTING, (key, 0), -alloc_size)) + output.append( + (-1, Action.PREEXISTING, (key, 0), -alloc_size) + ) snapshot = self._category_snapshot() last_version = dict(sorted(snapshot.keys())) @@ -971,6 +980,7 @@ class MemoryProfile: key, version, Category.AUTOGRAD_DETAIL ) + class MemoryProfileTimeline: def __init__(self, memory_profile): """The minimum representation of the memory profile timeline @@ -1046,7 +1056,8 @@ class MemoryProfileTimeline: times, sizes = self._coalesce_timeline(device) # TODO: Write a faster serialize (orjson not available in CI) import json - with open(path, 'w') as f: + + with open(path, "w") as f: json.dump([times, sizes], f) def export_memory_timeline_raw(self, path, device_str) -> None: @@ -1070,37 +1081,72 @@ class MemoryProfileTimeline: continue if action in (Action.PREEXISTING, Action.CREATE): - raw_events.append((t, _ACTION_TO_INDEX[action], numbytes, get_category_index(key, version))) + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + numbytes, + get_category_index(key, version), + ) + ) elif action == Action.INCREMENT_VERSION: - raw_events.append((t, _ACTION_TO_INDEX[action], -numbytes, get_category_index(key, version))) - raw_events.append((t, _ACTION_TO_INDEX[action], numbytes, get_category_index(key, version + 1))) + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + -numbytes, + get_category_index(key, version), + ) + ) + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + numbytes, + get_category_index(key, version + 1), + ) + ) elif action == Action.DESTROY: - raw_events.append((t, _ACTION_TO_INDEX[action], -numbytes, get_category_index(key, version))) + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + -numbytes, + get_category_index(key, version), + ) + ) else: raise ValueError(f"Unknown action: {action}") import json - with open(path, 'w') as f: + + with open(path, "w") as f: json.dump(raw_events, f) - def export_memory_timeline_html(self, path, device, figsize=(20, 12), title=None) -> None: + def export_memory_timeline_html( + self, path, device, figsize=(20, 12), title=None + ) -> None: """Exports the memory timeline as an HTML file which contains the memory timeline plot embedded as a PNG file.""" # Check if user has matplotlib installed, return gracefully if not. import importlib.util + matplotlib_spec = importlib.util.find_spec("matplotlib") if matplotlib_spec is None: - print("export_memory_timeline_html failed because matplotlib was not found.") + print( + "export_memory_timeline_html failed because matplotlib was not found." + ) return + from base64 import b64encode + from os import remove + from tempfile import NamedTemporaryFile + import matplotlib.pyplot as plt import numpy as np - from base64 import b64encode - from tempfile import NamedTemporaryFile - from os import remove mt = self._coalesce_timeline(device) times, sizes = np.array(mt[0]), np.array(mt[1]) @@ -1123,12 +1169,12 @@ class MemoryProfileTimeline: axes.set_title(title) # Embed the memory timeline image into the HTML file - tmpfile = NamedTemporaryFile('wb', suffix='.png', delete=False) + tmpfile = NamedTemporaryFile("wb", suffix=".png", delete=False) tmpfile.close() - fig.savefig(tmpfile.name, format='png') + fig.savefig(tmpfile.name, format="png") - with open(tmpfile.name, 'rb') as tmp: - encoded = b64encode(tmp.read()).decode('utf-8') + with open(tmpfile.name, "rb") as tmp: + encoded = b64encode(tmp.read()).decode("utf-8") html = f""" GPU Memory Timeline HTML @@ -1136,6 +1182,6 @@ class MemoryProfileTimeline: """ - with open(path, 'w') as f: + with open(path, "w") as f: f.write(html) remove(tmpfile.name) diff --git a/torch/profiler/_pattern_matcher.py b/torch/profiler/_pattern_matcher.py index 1d85d193ecf8..02e9b014d308 100644 --- a/torch/profiler/_pattern_matcher.py +++ b/torch/profiler/_pattern_matcher.py @@ -5,21 +5,25 @@ import re from typing import Dict, List, Optional, Set import torch -from torch.profiler import profile import torch.utils.benchmark as benchmark +from torch._C._profiler import ( + _EventType, + _ExtraFields_PyCall, + _ExtraFields_PyCCall, + _ExtraFields_TorchOp, + _ProfilerEvent, +) +from torch.profiler import profile from torch.profiler._utils import index_of_first_match, traverse_bfs, traverse_dfs -from torch._C._profiler import (_ProfilerEvent, _ExtraFields_TorchOp, - _ExtraFields_PyCCall, _ExtraFields_PyCall, - _EventType) class Pattern: - ''' + """ Base class for all patterns, subclass this class and implement match() to define custom patterns. In subclass, define description and skip property. - ''' + """ def __init__(self, prof: profile, should_benchmark: bool = False): self.prof = prof @@ -28,8 +32,7 @@ class Pattern: self.description = "Please specify a description for pattern" self.url = "" assert prof.profiler is not None and prof.profiler.kineto_results is not None - self.event_tree = prof.profiler.kineto_results.experimental_event_tree( - ) + self.event_tree = prof.profiler.kineto_results.experimental_event_tree() self.tid_root: Dict[int, List[_ProfilerEvent]] = {} for event in self.event_tree: self.tid_root.setdefault(event.start_tid, []).append(event) @@ -39,27 +42,30 @@ class Pattern: return False def report(self, event: _ProfilerEvent): - msg = f"{self.description}\n[Source Code Location] {source_code_location(event)}" + msg = ( + f"{self.description}\n[Source Code Location] {source_code_location(event)}" + ) return msg def eventTreeTraversal(self): - ''' + """ Traverse the event tree and yield all events. Override this method in subclass to customize the traversal. - ''' + """ yield from traverse_dfs(self.event_tree) def summary(self, events: List[_ProfilerEvent]): default_summary = f"{self.name}: {len(events)} events matched." if self.should_benchmark: # If benchmark summary is not empty, use it. - return self.benchmark_summary( - events) if hasattr( # type: ignore[attr-defined] - self, 'benchmark') else default_summary + return ( + self.benchmark_summary(events) + if hasattr(self, "benchmark") # type: ignore[attr-defined] + else default_summary + ) return default_summary def benchmark_summary(self, events: List[_ProfilerEvent]): - def format_time(time_ns: int): unit_lst = ["ns", "us", "ms"] for unit in unit_lst: @@ -68,22 +74,23 @@ class Pattern: time_ns //= 1000 return f"{time_ns:.2f} s" - assert hasattr(self, 'benchmark'), 'Please implement benchmark()' - shapes_factor_map = self.benchmark( # type: ignore[attr-defined] - events) + assert hasattr(self, "benchmark"), "Please implement benchmark()" + shapes_factor_map = self.benchmark(events) # type: ignore[attr-defined] original_time = sum(event.duration_time_ns for event in events) - new_time = sum(shapes_factor_map[input_shapes(event)] * - event.duration_time_ns for event in events) + new_time = sum( + shapes_factor_map[input_shapes(event)] * event.duration_time_ns + for event in events + ) return ( f"{self.name}: {len(events)} events matched. " f"Total Estimated Speedup: {format_time(original_time - new_time)} ({round(original_time/new_time, 2)}X)" ) def match(self, event: _ProfilerEvent): - ''' + """ Return True if the event matches the pattern. This method should be overriden in subclass. - ''' + """ raise NotImplementedError def matched_events(self): @@ -106,7 +113,7 @@ class Pattern: else: children = self.tid_root[event.start_tid] index = children.index(event) - return children[:index], children[index + 1:] + return children[:index], children[index + 1 :] def next_of(self, event: _ProfilerEvent): _, next_events = self.siblings_of(event) @@ -128,11 +135,7 @@ class Pattern: class NamePattern(Pattern): - - def __init__(self, - prof: profile, - name: str, - should_benchmark: bool = False): + def __init__(self, prof: profile, name: str, should_benchmark: bool = False): super().__init__(prof, should_benchmark) self.description = f"Matched Name Event: {name}" self.name = name @@ -142,7 +145,7 @@ class NamePattern(Pattern): class ExtraCUDACopyPattern(Pattern): - ''' + """ This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU. example: torch.zeros((100, 100)).to("cuda") @@ -156,7 +159,7 @@ class ExtraCUDACopyPattern(Pattern): and check if we have a aten::fill_/aten::zero_ as we keep going down the tree. We always select the last child in the children list when we go down the tree. If at any step we failed, it is not a match. - ''' + """ def __init__(self, prof: profile, should_benchmark: bool = False): super().__init__(prof, should_benchmark) @@ -164,7 +167,10 @@ class ExtraCUDACopyPattern(Pattern): self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU." self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device" self.init_ops = { - "aten::fill_", "aten::zero_", "aten::normal_", "aten::uniform_" + "aten::fill_", + "aten::zero_", + "aten::normal_", + "aten::uniform_", } @property @@ -213,10 +219,12 @@ class ExtraCUDACopyPattern(Pattern): shapes_factor_map = {input_shapes(event): 0.0 for event in events} for shape in shapes_factor_map: size = shape[0] - to_timer = benchmark.Timer(stmt='torch.ones(size).to("cuda")', - globals={'size': size}) - de_timer = benchmark.Timer(stmt='torch.ones(size, device="cuda")', - globals={'size': size}) + to_timer = benchmark.Timer( + stmt='torch.ones(size).to("cuda")', globals={"size": size} + ) + de_timer = benchmark.Timer( + stmt='torch.ones(size, device="cuda")', globals={"size": size} + ) to_time = to_timer.timeit(10).mean de_time = de_timer.timeit(10).mean shapes_factor_map[shape] = de_time / to_time @@ -224,7 +232,7 @@ class ExtraCUDACopyPattern(Pattern): class ForLoopIndexingPattern(Pattern): - ''' + """ This pattern identifies if we use a for loop to index a tensor that can be vectorized. example: @@ -238,7 +246,7 @@ class ForLoopIndexingPattern(Pattern): Algorithm: We start at node aten::select, and we check if we can find this alternating patterns. We also keep a dictionary to avoid duplicate match in the for loop. - ''' + """ def __init__(self, prof: profile, should_benchmark: bool = False): super().__init__(prof, should_benchmark) @@ -247,9 +255,9 @@ class ForLoopIndexingPattern(Pattern): self.visited: Set[int] = set() def eventTreeTraversal(self): - ''' + """ We need to use BFS traversal order to avoid duplicate match. - ''' + """ yield from traverse_bfs(self.event_tree) def match(self, event: _ProfilerEvent): @@ -272,14 +280,13 @@ class ForLoopIndexingPattern(Pattern): return True # Record the ops between two aten::select - next_select_idx = index_of_first_match( - next, lambda e: e.name == "aten::select") + next_select_idx = index_of_first_match(next, lambda e: e.name == "aten::select") if next_select_idx is None: return False indexing_ops = [event] + next[:next_select_idx] - next = next[len(indexing_ops) - 1:] + next = next[len(indexing_ops) - 1 :] for i in range(0, len(next), len(indexing_ops)): - if same_ops(indexing_ops, next[i:i + len(indexing_ops)]): + if same_ops(indexing_ops, next[i : i + len(indexing_ops)]): repeat_count += 1 self.visited.add(next[i].id) else: @@ -288,7 +295,6 @@ class ForLoopIndexingPattern(Pattern): class FP32MatMulPattern(Pattern): - def __init__(self, prof: profile, should_benchmark: bool = False): super().__init__(prof, should_benchmark) self.name = "FP32 MatMul Pattern" @@ -304,8 +310,7 @@ class FP32MatMulPattern(Pattern): has_tf32 = False else: # Anything less than sm_80 is not Ampere which doesn't support TF32 - has_tf32 = all( - int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list()) + has_tf32 = all(int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list()) return has_tf32 is False or super().skip or not self.prof.record_shapes def match(self, event: _ProfilerEvent): @@ -326,18 +331,15 @@ class FP32MatMulPattern(Pattern): for shape in shapes_factor_map: matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32) matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32) - fp32_timer = benchmark.Timer(stmt='torch.mm(matrixA, matrixB)', - globals={ - "matrixA": matrixA, - "matrixB": matrixB - }) + fp32_timer = benchmark.Timer( + stmt="torch.mm(matrixA, matrixB)", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) tf32_timer = benchmark.Timer( - stmt='torch.mm(matrixA, matrixB)', - setup='torch.backends.cuda.matmul.allow_tf32 = True', - globals={ - "matrixA": matrixA, - "matrixB": matrixB - }) + stmt="torch.mm(matrixA, matrixB)", + setup="torch.backends.cuda.matmul.allow_tf32 = True", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) torch.backends.cuda.matmul.allow_tf32 = False fp32_time = fp32_timer.timeit(10).mean tf32_time = tf32_timer.timeit(10).mean @@ -346,7 +348,7 @@ class FP32MatMulPattern(Pattern): class OptimizerSingleTensorPattern(Pattern): - ''' + """ This pattern identifies if we are using the single-tensor version of an optimizer. example: optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -358,7 +360,7 @@ class OptimizerSingleTensorPattern(Pattern): Algorithm: String match - ''' + """ def __init__(self, prof: profile, should_benchmark: bool = False): super().__init__(prof, should_benchmark) @@ -378,7 +380,7 @@ class OptimizerSingleTensorPattern(Pattern): class SynchronizedDataLoaderPattern(Pattern): - ''' + """ This pattern identifies if we are using num_workers=0 in DataLoader. example: torch.utils.data.DataLoader(dataset, batch_size=batch_size) @@ -393,7 +395,7 @@ class SynchronizedDataLoaderPattern(Pattern): If we don't see check_worker_number_rationality call in the dataloader __iter__, It is not an asynchronous dataloader. - ''' + """ def __init__(self, prof: profile, should_benchmark: bool = False): super().__init__(prof, should_benchmark) @@ -404,14 +406,14 @@ class SynchronizedDataLoaderPattern(Pattern): ) self.url = ( "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" - "#enable-async-data-loading-and-augmentation") + "#enable-async-data-loading-and-augmentation" + ) def match(self, event: _ProfilerEvent): - def is_dataloader_function(name: str, function_name: str): return name.startswith( - os.path.join("torch", "utils", "data", - "dataloader.py")) and name.endswith(function_name) + os.path.join("torch", "utils", "data", "dataloader.py") + ) and name.endswith(function_name) # TODO: fixme! Due to lifetime issues of the function name, this field might # actually point to an already freed string when the even is a PyCall. @@ -431,13 +433,12 @@ class SynchronizedDataLoaderPattern(Pattern): if not event.children: return False event = event.children[0] - return not is_dataloader_function(event.name, - "check_worker_number_rationality") + return not is_dataloader_function(event.name, "check_worker_number_rationality") # TODO: We should also check if the loader is bottleneck. class GradNotSetToNonePattern(Pattern): - ''' + """ This pattern identifies if we are not setting grad to None in zero_grad. example: optimizer.zero_grad() @@ -453,17 +454,19 @@ class GradNotSetToNonePattern(Pattern): Algorithm: String match - ''' + """ def __init__(self, prof: profile, should_benchmark: bool = False): super().__init__(prof, should_benchmark) self.name = "Gradient Set To Zero Instead of None Pattern" self.description = ( "Detected gradient set to zero instead of None. " - "Please add 'set_to_none=True' when calling zero_grad().") + "Please add 'set_to_none=True' when calling zero_grad()." + ) self.url = ( "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" - "#disable-gradient-calculation-for-validation-or-inference") + "#disable-gradient-calculation-for-validation-or-inference" + ) def match(self, event: _ProfilerEvent): if not event.name.endswith(": zero_grad"): @@ -472,14 +475,17 @@ class GradNotSetToNonePattern(Pattern): return False for sub_event in traverse_dfs(event.children): - if sub_event.name == "aten::zero_" and sub_event.parent.name != "aten::zeros": + if ( + sub_event.name == "aten::zero_" + and sub_event.parent.name != "aten::zeros" + ): return True # TODO: We should also check if the optimizer's numerical behavior will change. return False class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern): - ''' + """ This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d. Bias doesn't do anything when followed by batchnorm. Pattern: @@ -489,7 +495,7 @@ class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern): The third argument is the bias Algorithm: String match - ''' + """ def __init__(self, prof: profile, should_benchmark: bool = False): super().__init__(prof, should_benchmark) @@ -497,7 +503,8 @@ class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern): self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d." self.url = ( "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" - "#disable-bias-for-convolutions-directly-followed-by-a-batch-norm") + "#disable-bias-for-convolutions-directly-followed-by-a-batch-norm" + ) @property def skip(self): @@ -510,7 +517,8 @@ class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern): return False # This means bias=True event = self.go_up_until( - event, lambda e: e.name.startswith("nn.Module: Conv2d")) + event, lambda e: e.name.startswith("nn.Module: Conv2d") + ) if not event: return False event = self.next_of(event) @@ -520,7 +528,6 @@ class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern): class MatMulDimInFP16Pattern(Pattern): - def __init__(self, prof: profile, should_benchmark: bool = False): super().__init__(prof, should_benchmark) self.name = "Matrix Multiplication Dimension Not Aligned Pattern" @@ -532,22 +539,21 @@ class MatMulDimInFP16Pattern(Pattern): return not self.prof.with_stack or not self.prof.record_shapes def match(self, event: _ProfilerEvent): - def mutiple_of(shapes, multiple): - return all(dim % multiple == 0 for shape in shapes - for dim in shape[-2:]) + return all(dim % multiple == 0 for shape in shapes for dim in shape[-2:]) if event.name not in ("aten::mm", "aten::bmm", "aten::addmm"): return False if not input_dtypes(event): return False arg_dtype = input_dtypes(event)[0] - if arg_dtype in (torch.bfloat16, torch.half) and not mutiple_of(input_shapes(event), 8): + if arg_dtype in (torch.bfloat16, torch.half) and not mutiple_of( + input_shapes(event), 8 + ): return True return False def benchmark(self, events: List[_ProfilerEvent]): - def closest_multiple(shapes, multiple): return [multiple * math.ceil(shape / multiple) for shape in shapes] @@ -556,23 +562,19 @@ class MatMulDimInFP16Pattern(Pattern): matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float16) matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float16) not_aligned_dim_timer = benchmark.Timer( - stmt='torch.mm(matrixA, matrixB)', - globals={ - "matrixA": matrixA, - "matrixB": matrixB - }) - matrixA = torch.randn(closest_multiple(shape[0], 8), - device="cuda", - dtype=torch.float16) - matrixB = torch.randn(closest_multiple(shape[1], 8), - device="cuda", - dtype=torch.float16) + stmt="torch.mm(matrixA, matrixB)", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) + matrixA = torch.randn( + closest_multiple(shape[0], 8), device="cuda", dtype=torch.float16 + ) + matrixB = torch.randn( + closest_multiple(shape[1], 8), device="cuda", dtype=torch.float16 + ) aligned_dim_timer = benchmark.Timer( - stmt='torch.mm(matrixA, matrixB)', - globals={ - "matrixA": matrixA, - "matrixB": matrixB - }) + stmt="torch.mm(matrixA, matrixB)", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) not_aligned_dim_time = not_aligned_dim_timer.timeit(10).mean aligned_dim_time = aligned_dim_timer.timeit(10).mean shapes_factor_map[shape] = aligned_dim_time / not_aligned_dim_time @@ -582,9 +584,10 @@ class MatMulDimInFP16Pattern(Pattern): def source_code_location(event: Optional[_ProfilerEvent]): while event: if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall: - assert isinstance(event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall)) - if not event.extra_fields.caller.file_name.startswith("torch" + - os.sep): + assert isinstance( + event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall) + ) + if not event.extra_fields.caller.file_name.startswith("torch" + os.sep): return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}" event = event.parent return "No source code location found" @@ -600,10 +603,12 @@ def input_dtypes(event: _ProfilerEvent): return tuple(getattr(i, "dtype", None) for i in event.extra_fields.inputs) -def report_all_anti_patterns(prof, - should_benchmark: bool = False, - print_enable: bool = True, - json_report_dir: Optional[str] = None): +def report_all_anti_patterns( + prof, + should_benchmark: bool = False, + print_enable: bool = True, + json_report_dir: Optional[str] = None, +): report_dict: Dict = {} anti_patterns = [ ExtraCUDACopyPattern(prof, should_benchmark), @@ -613,7 +618,7 @@ def report_all_anti_patterns(prof, SynchronizedDataLoaderPattern(prof, should_benchmark), GradNotSetToNonePattern(prof, should_benchmark), Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark), - MatMulDimInFP16Pattern(prof, should_benchmark) + MatMulDimInFP16Pattern(prof, should_benchmark), ] reported = set() summaries = [] @@ -631,16 +636,17 @@ def report_all_anti_patterns(prof, message_list.append(report_msg) reported.add(report_msg) src_location, line_no = source_code_location(event).split(":") - report_dict.setdefault(src_location, []).append({ - "line_number": int(line_no), - "name": anti_pattern.name, - "url": anti_pattern.url, - "message": anti_pattern.description, - }) + report_dict.setdefault(src_location, []).append( + { + "line_number": int(line_no), + "name": anti_pattern.name, + "url": anti_pattern.url, + "message": anti_pattern.description, + } + ) if json_report_dir is not None: - json_report_path = os.path.join(json_report_dir, - "torchtidy_report.json") + json_report_path = os.path.join(json_report_dir, "torchtidy_report.json") if os.path.exists(json_report_path): with open(json_report_path) as f: exisiting_report = json.load(f) diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index f8a1f53bf7db..cb9469e4c983 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -1,12 +1,13 @@ -from collections import deque -from dataclasses import dataclass import functools import re +from collections import deque +from dataclasses import dataclass from typing import Dict, List -from torch.profiler import DeviceType -from torch.autograd.profiler import profile from torch.autograd import _KinetoEvent +from torch.autograd.profiler import profile + +from torch.profiler import DeviceType def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = False): @@ -18,8 +19,11 @@ def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = F for child_event in order(children_fn(curr_event)): remaining.append(child_event) + traverse_dfs = functools.partial(_traverse, next_fn=lambda x: x.pop(), reverse=True) -traverse_bfs = functools.partial(_traverse, next_fn=lambda x: x.popleft(), reverse=False) +traverse_bfs = functools.partial( + _traverse, next_fn=lambda x: x.popleft(), reverse=False +) @dataclass @@ -44,7 +48,6 @@ class Interval: class EventKey: - def __init__(self, event): self.event = event @@ -69,7 +72,7 @@ class EventKey: overlap_time += overlap_end - overlap_start i, j = 0, 1 - while (j < len(intervals)): + while j < len(intervals): prev_interval = intervals[i] curr_interval = intervals[j] j += 1 @@ -91,23 +94,23 @@ class EventKey: class BasicEvaluation: - def __init__(self, prof: profile): self.profile = prof self.metrics: Dict[EventKey, EventMetrics] = {} self.compute_self_time() - self.event_keys = sorted((e for e in self.metrics.keys()), - key=lambda x: x.event.start_time_ns) + self.event_keys = sorted( + (e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns + ) self.events = [e.event for e in self.event_keys] self.cuda_events: List[_KinetoEvent] = [] self.queue_depth_list = self.compute_queue_depth() self.compute_idle_time() def compute_self_time(self): - ''' + """ Computes event's self time(total time - time in child ops). - ''' - assert (self.profile.kineto_results is not None) + """ + assert self.profile.kineto_results is not None stack = deque(self.profile.kineto_results.experimental_event_tree()) # standard iterating dfs @@ -117,21 +120,21 @@ class BasicEvaluation: for child_event in curr_event.children: self_time -= child_event.duration_time_ns stack.append(child_event) - assert EventKey( - curr_event - ) not in self.metrics, f"Duplicate id: {curr_event.id}, {curr_event.name}" - self.metrics[EventKey(curr_event)] = EventMetrics( - self_time_ns=self_time) - self.metrics[EventKey( - curr_event)].duration_time_ns = curr_event.duration_time_ns + assert ( + EventKey(curr_event) not in self.metrics + ), f"Duplicate id: {curr_event.id}, {curr_event.name}" + self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time) + self.metrics[ + EventKey(curr_event) + ].duration_time_ns = curr_event.duration_time_ns def compute_queue_depth(self): - ''' + """ Computes queue_depth at each event. This will calculate the queue depth data for All the events in the tree. This will return a list of Interval of queue depth data of cuda launch and kernels. - ''' - assert (self.profile.kineto_results is not None) + """ + assert self.profile.kineto_results is not None cuda_event_list = self.profile.kineto_results.events() def is_cuda_launch_kernel(e): @@ -144,22 +147,26 @@ class BasicEvaluation: cuda_launch_events = sorted( (e for e in cuda_event_list if is_cuda_launch_kernel(e)), - key=lambda x: x.start_us()) + key=lambda x: x.start_us(), + ) cuda_kernel_events = sorted( (e for e in cuda_event_list if is_cuda_kernel(e)), - key=lambda x: x.start_us()) + key=lambda x: x.start_us(), + ) - self.cuda_events = sorted(cuda_launch_events + cuda_kernel_events, - key=lambda x: x.start_us()) + self.cuda_events = sorted( + cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_us() + ) kernel_mapping: Dict[_KinetoEvent, int] = {} last_mapped_kernel = 0 for cuda_launch_event in cuda_launch_events: index = index_of_first_match( cuda_kernel_events, - lambda x: x.linked_correlation_id( - ) == cuda_launch_event.linked_correlation_id(), - start=last_mapped_kernel) + lambda x: x.linked_correlation_id() + == cuda_launch_event.linked_correlation_id(), + start=last_mapped_kernel, + ) kernel_mapping[cuda_launch_event] = index last_mapped_kernel = index if index is not None else last_mapped_kernel @@ -183,42 +190,42 @@ class BasicEvaluation: start_time = event.start_us() * 1000 end_time = (event.start_us() + event.duration_us()) * 1000 # Find current spawned cuda kernel event - if event in kernel_mapping and kernel_mapping[ - event] is not None: + if event in kernel_mapping and kernel_mapping[event] is not None: spawned_kernel_index = kernel_mapping[event] elif hasattr(event, "start_time_ns"): start_time = event.start_time_ns # type: ignore[attr-defined] end_time = event.end_time_ns # type: ignore[attr-defined] - while (current_kernel_index < len(cuda_kernel_events) and - (cuda_kernel_events[current_kernel_index].start_us()) * 1000 - <= start_time): + while ( + current_kernel_index < len(cuda_kernel_events) + and (cuda_kernel_events[current_kernel_index].start_us()) * 1000 + <= start_time + ): current_kernel_index += 1 current_queue_depth = spawned_kernel_index - current_kernel_index + 1 current_queue_depth = max(current_queue_depth, 0) if hasattr(event, "start_us"): queue_depth_list.append( - Interval(start_time, end_time, current_queue_depth)) + Interval(start_time, end_time, current_queue_depth) + ) elif hasattr(event, "start_time_ns"): self.metrics[EventKey(event)].queue_depth = current_queue_depth return queue_depth_list def compute_idle_time(self): - ''' + """ Computes idle time of the profile. - ''' + """ # Based on queue_depth_list, we can calculate idle time for all the events idle = False idle_start = 0 idle_intervals: List[Interval] = [] if self.queue_depth_list and self.events: idle_intervals += [ - Interval(self.events[0].start_time_ns, - self.queue_depth_list[0].start), - Interval(self.queue_depth_list[-1].end, - self.events[-1].end_time_ns) + Interval(self.events[0].start_time_ns, self.queue_depth_list[0].start), + Interval(self.queue_depth_list[-1].end, self.events[-1].end_time_ns), ] for data_point in self.queue_depth_list: @@ -232,20 +239,22 @@ class BasicEvaluation: event_list = [e.event for e in self.metrics.keys()] for event in event_list: self.metrics[EventKey(event)].idle_time_ns = EventKey( - event).intervals_overlap(idle_intervals) + event + ).intervals_overlap(idle_intervals) def rank_events(self, length): - ''' + """ Filter and Rank the events based on some heuristics: 1) Events that are in the falling phase of the queue depth. 2) Events that have a high idle_time, self_time difference. Parameters: length: The number of events to return. - ''' + """ # Find the interval when qd is falling to 0 import torch + queue_depth_list = list(reversed(self.queue_depth_list)) qd_values = [e.queue_depth for e in queue_depth_list] @@ -253,7 +262,7 @@ class BasicEvaluation: top_threashold = 4 decrease_interval = [] i = 0 - while (i < len(qd_values)): + while i < len(qd_values): if qd_values[i] > bottom_threashold: i += 1 continue @@ -261,61 +270,67 @@ class BasicEvaluation: # Find next zero and if the max value between them exceeds # the threshold, then we have a falling interval next_minimum_idx = index_of_first_match( - qd_values, lambda x: x <= bottom_threashold, start=j) + qd_values, lambda x: x <= bottom_threashold, start=j + ) peak_idx = argmax(qd_values, start=j, end=next_minimum_idx) # if is a valid peak, we add to list and continue - if peak_idx is not None and qd_values[ - peak_idx] >= top_threashold: + if peak_idx is not None and qd_values[peak_idx] >= top_threashold: decrease_interval.append( - Interval(queue_depth_list[peak_idx].start, - queue_depth_list[i].start)) + Interval( + queue_depth_list[peak_idx].start, queue_depth_list[i].start + ) + ) i = next_minimum_idx if next_minimum_idx is not None else i break i += 1 # Filter out events that are not in the decrease interval event_list = [ - event for event in self.metrics.keys() + event + for event in self.metrics.keys() if event.intervals_overlap(decrease_interval) ] if event_list: self_time = torch.tensor( [self.metrics[event].self_time_ns for event in event_list], - dtype=torch.float32) - idle_time = torch.tensor([ - self.metrics[event].fraction_idle_time for event in event_list - ], dtype=torch.float32) - normalized_gain = (idle_time - - torch.mean(idle_time)) / torch.std(idle_time) - normalized_self = (self_time - - torch.mean(self_time)) / torch.std(self_time) + dtype=torch.float32, + ) + idle_time = torch.tensor( + [self.metrics[event].fraction_idle_time for event in event_list], + dtype=torch.float32, + ) + normalized_gain = (idle_time - torch.mean(idle_time)) / torch.std(idle_time) + normalized_self = (self_time - torch.mean(self_time)) / torch.std(self_time) heuristic_score_list = normalized_gain + 0.6 * normalized_self # Sort events by heuristic event_list = [ event - for _, event in sorted(zip(heuristic_score_list, event_list), - key=lambda x: x[0], - reverse=True) + for _, event in sorted( + zip(heuristic_score_list, event_list), + key=lambda x: x[0], + reverse=True, + ) ] event_list = event_list[:length] return event_list - def get_optimizable_events(self, - length: int = 1, - print_enable: bool = True): + def get_optimizable_events(self, length: int = 1, print_enable: bool = True): event_list = self.rank_events(length) if not print_enable: return event_list output = "Optimizable events:\n" if event_list else "No events to optimize\n" - output += "\n".join([ - f"""{'-'*80} + output += "\n".join( + [ + f"""{'-'*80} Event: {event} Source code location: {source_code_location(event.event)} Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}% -{'-'*80}""" for event in event_list - ]) +{'-'*80}""" + for event in event_list + ] + ) if print_enable: print(output) return event_list @@ -338,9 +353,9 @@ def argmax(seq, key=lambda x: x, start=0, end=None): def source_code_location(event): - while (event is not None): + while event is not None: match = re.search(r"\.py\(.*\)", event.name) - if (match is None): + if match is None: event = event.parent continue return event.name @@ -353,5 +368,6 @@ def source_code_location(event): # we stop supporting older CUDA versions. def _init_for_cuda_graphs(): from torch.autograd.profiler import profile + with profile(): pass diff --git a/torch/profiler/itt.py b/torch/profiler/itt.py index 7f4de54597fd..4d072957d6fe 100644 --- a/torch/profiler/itt.py +++ b/torch/profiler/itt.py @@ -3,10 +3,13 @@ from contextlib import contextmanager try: from torch._C import _itt except ImportError: + class _ITTStub: @staticmethod def _fail(*args, **kwargs): - raise RuntimeError("ITT functions not installed. Are you sure you have a ITT build?") + raise RuntimeError( + "ITT functions not installed. Are you sure you have a ITT build?" + ) @staticmethod def is_available(): @@ -19,7 +22,7 @@ except ImportError: _itt = _ITTStub() # type: ignore[assignment] -__all__ = ['is_available', 'range_push', 'range_pop', 'mark', 'range'] +__all__ = ["is_available", "range_push", "range_pop", "mark", "range"] def is_available(): diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 9f21e0afc28e..9c70db0ec2fc 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -30,6 +30,7 @@ __all__ = [ ] PROFILER_STEP_NAME = "ProfilerStep" + def supported_activities(): """ Returns a set of supported profiler tracing activities. @@ -74,16 +75,18 @@ class _KinetoProfile: that may further prevent certain optimizations that depend on the reference count and introduce extra tensor copies. """ + def __init__( - self, - *, - activities: Optional[Iterable[ProfilerActivity]] = None, - record_shapes: bool = False, - profile_memory: bool = False, - with_stack: bool = False, - with_flops: bool = False, - with_modules: bool = False, - experimental_config: Optional[_ExperimentalConfig] = None): + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + experimental_config: Optional[_ExperimentalConfig] = None, + ): self.activities = set(activities) if activities else supported_activities() self.record_shapes = record_shapes self.with_flops = with_flops @@ -137,15 +140,12 @@ class _KinetoProfile: self.add_metadata_json("distributedInfo", json.dumps(dist_info)) # FIXME: CUPTI Lazy Re-init and CUDA Graph crashes with CUDA 11. - is_cuda11_or_lower = ( - (torch.version.cuda is not None) - and ([int(x) for x in torch.version.cuda.split(".")] < [12, 0]) + is_cuda11_or_lower = (torch.version.cuda is not None) and ( + [int(x) for x in torch.version.cuda.split(".")] < [12, 0] ) - if ( - is_cuda11_or_lower - and hasattr(torch, '_inductor') - ): + if is_cuda11_or_lower and hasattr(torch, "_inductor"): import torch._inductor.config as inductor_config + if inductor_config.triton.cudagraphs: os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1") @@ -159,12 +159,12 @@ class _KinetoProfile: Exports the collected trace in Chrome JSON format. """ assert self.profiler - if path.endswith('.gz'): - fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False) + if path.endswith(".gz"): + fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) fp.close() retvalue = self.profiler.export_chrome_trace(fp.name) with open(fp.name) as fin: - with gzip.open(path, 'wt') as fout: + with gzip.open(path, "wt") as fout: fout.writelines(fin) os.remove(fp.name) return retvalue @@ -188,7 +188,9 @@ class _KinetoProfile: assert self.profiler return self.profiler.export_stacks(path, metric) - def key_averages(self, group_by_input_shape: bool = False, group_by_stack_n: int = 0): + def key_averages( + self, group_by_input_shape: bool = False, group_by_stack_n: int = 0 + ): """Averages events, grouping them by operator name and (optionally) input shapes and stack. @@ -212,7 +214,7 @@ class _KinetoProfile: Adds a user defined metadata with a string key and a string value into the trace file """ - wrapped_value = "\"" + value.replace('"', '\\"') + "\"" + wrapped_value = '"' + value.replace('"', '\\"') + '"' torch.autograd._add_metadata_json(key, wrapped_value) def add_metadata_json(self, key: str, value: str): @@ -224,13 +226,14 @@ class _KinetoProfile: def _get_distributed_info(self): import torch.distributed as dist + if not dist.is_available() or not dist.is_initialized(): return None return { "backend": dist.get_backend(), "rank": dist.get_rank(), - "world_size": dist.get_world_size() + "world_size": dist.get_world_size(), } def _memory_profile(self) -> MemoryProfile: @@ -261,17 +264,17 @@ class _KinetoProfile: # Depending on the file suffix, save the data as json.gz or json. # For html, we can embed the image into an HTML file. - if path.endswith('.html'): + if path.endswith(".html"): self.mem_tl.export_memory_timeline_html(path, device) - elif path.endswith('.gz'): - fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False) + elif path.endswith(".gz"): + fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) fp.close() - if path.endswith('raw.json.gz'): + if path.endswith("raw.json.gz"): self.mem_tl.export_memory_timeline_raw(fp.name, device) else: self.mem_tl.export_memory_timeline(fp.name, device) with open(fp.name) as fin: - with gzip.open(path, 'wt') as fout: + with gzip.open(path, "wt") as fout: fout.writelines(fin) os.remove(fp.name) else: @@ -282,13 +285,16 @@ class ProfilerAction(Enum): """ Profiler actions that can be taken at the specified intervals """ + NONE = 0 WARMUP = 1 RECORD = 2 RECORD_AND_SAVE = 3 -def schedule(*, wait: int, warmup: int, active: int, repeat: int = 0, skip_first: int = 0) -> Callable: +def schedule( + *, wait: int, warmup: int, active: int, repeat: int = 0, skip_first: int = 0 +) -> Callable: """ Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip the first ``skip_first`` steps, then wait for ``wait`` steps, then do the warmup for the next ``warmup`` steps, @@ -296,6 +302,7 @@ def schedule(*, wait: int, warmup: int, active: int, repeat: int = 0, skip_first The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that the cycles will continue until the profiling is finished. """ + def schedule_fn(step: int) -> ProfilerAction: assert step >= 0 if step < skip_first: @@ -311,10 +318,15 @@ def schedule(*, wait: int, warmup: int, active: int, repeat: int = 0, skip_first elif mod_step < wait + warmup: return ProfilerAction.WARMUP else: - return ProfilerAction.RECORD if mod_step < num_steps - 1 \ + return ( + ProfilerAction.RECORD + if mod_step < num_steps - 1 else ProfilerAction.RECORD_AND_SAVE - assert wait >= 0 and warmup >= 0 and active > 0 and \ - repeat >= 0 and skip_first >= 0, "Invalid profiler schedule arguments" + ) + + assert ( + wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0 + ), "Invalid profiler schedule arguments" if warmup == 0: warn("Profiler won't be using warmup, this can skew profiler results") return schedule_fn @@ -328,7 +340,9 @@ def _default_schedule_fn(_: int) -> ProfilerAction: return ProfilerAction.RECORD -def tensorboard_trace_handler(dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False): +def tensorboard_trace_handler( + dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False +): """ Outputs tracing files to directory of ``dir_name``, then that directory can be directly delivered to tensorboard as logdir. @@ -351,8 +365,9 @@ def tensorboard_trace_handler(dir_name: str, worker_name: Optional[str] = None, # Use nanosecond here to avoid naming clash when exporting the trace file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json" if use_gzip: - file_name = file_name + '.gz' + file_name = file_name + ".gz" prof.export_chrome_trace(os.path.join(dir_name, file_name)) + return handler_fn @@ -466,21 +481,22 @@ class profile(_KinetoProfile): # send a signal to the profiler that the next iteration has started p.step() """ - def __init__( - self, - *, - activities: Optional[Iterable[ProfilerActivity]] = None, - schedule: Optional[Callable[[int], ProfilerAction]] = None, - on_trace_ready: Optional[Callable[..., Any]] = None, - record_shapes: bool = False, - profile_memory: bool = False, - with_stack: bool = False, - with_flops: bool = False, - with_modules: bool = False, - experimental_config: Optional[_ExperimentalConfig] = None, - # deprecated: - use_cuda: Optional[bool] = None): + def __init__( + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Callable[[int], ProfilerAction]] = None, + on_trace_ready: Optional[Callable[..., Any]] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + experimental_config: Optional[_ExperimentalConfig] = None, + # deprecated: + use_cuda: Optional[bool] = None, + ): activities_set = set(activities) if activities else supported_activities() if use_cuda is not None: warn("use_cuda is deprecated, use activities argument instead") @@ -512,43 +528,66 @@ class profile(_KinetoProfile): self.current_action = self.schedule(self.step_num) self.step_rec_fn: Optional[prof.record_function] = None - self.action_map: Dict[Tuple[ProfilerAction, Optional[ProfilerAction]], List[Any]] = { + self.action_map: Dict[ + Tuple[ProfilerAction, Optional[ProfilerAction]], List[Any] + ] = { # key is (prev_action, current_action), value is action list corresponding to the state pair. (ProfilerAction.NONE, ProfilerAction.NONE): [], (ProfilerAction.NONE, ProfilerAction.WARMUP): [self.prepare_trace], - (ProfilerAction.NONE, ProfilerAction.RECORD): [self.prepare_trace, self.start_trace], - (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [self.prepare_trace, self.start_trace], + (ProfilerAction.NONE, ProfilerAction.RECORD): [ + self.prepare_trace, + self.start_trace, + ], + (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [ + self.prepare_trace, + self.start_trace, + ], (ProfilerAction.WARMUP, ProfilerAction.NONE): [ partial(warn, "Incorrect schedule: WARMUP followed by NONE"), self.start_trace, - self.stop_trace], + self.stop_trace, + ], (ProfilerAction.WARMUP, ProfilerAction.WARMUP): [], (ProfilerAction.WARMUP, ProfilerAction.RECORD): [self.start_trace], (ProfilerAction.WARMUP, ProfilerAction.RECORD_AND_SAVE): [self.start_trace], (ProfilerAction.RECORD, ProfilerAction.NONE): [ partial(warn, "Incorrect schedule: RECORD followed by NONE"), - self.stop_trace], + self.stop_trace, + ], (ProfilerAction.RECORD, ProfilerAction.WARMUP): [ partial(warn, "Incorrect schedule: RECORD followed by WARMUP"), - self.stop_trace], + self.stop_trace, + ], (ProfilerAction.RECORD, ProfilerAction.RECORD): [], (ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE): [], - (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [self.stop_trace, self._trace_ready], - (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [self.stop_trace, self._trace_ready, self.prepare_trace], + (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [ + self.stop_trace, + self._trace_ready, + ], + (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [ + self.stop_trace, + self._trace_ready, + self.prepare_trace, + ], (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD): [ self.stop_trace, self._trace_ready, self.prepare_trace, - self.start_trace], + self.start_trace, + ], (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD_AND_SAVE): [ self.stop_trace, self._trace_ready, self.prepare_trace, - self.start_trace], + self.start_trace, + ], # used for exit action (ProfilerAction.WARMUP, None): [self.start_trace, self.stop_trace], (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready], - (ProfilerAction.RECORD_AND_SAVE, None): [self.stop_trace, self._trace_ready] + (ProfilerAction.RECORD_AND_SAVE, None): [ + self.stop_trace, + self._trace_ready, + ], } # Start tracking increments to profiler step, this will be used # by Kineto @@ -567,7 +606,9 @@ class profile(_KinetoProfile): def start(self): self._transit_action(ProfilerAction.NONE, self.current_action) if self.record_steps: - self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num)) + self.step_rec_fn = prof.record_function( + "ProfilerStep#" + str(self.step_num) + ) self.step_rec_fn.__enter__() def stop(self): @@ -604,7 +645,6 @@ class profile(_KinetoProfile): action() - class ExecutionTraceObserver: """Execution Trace Observer @@ -619,6 +659,7 @@ class ExecutionTraceObserver: record function callbacks, finalize the output file, and will stop incurring any overheads. """ + def __init__(self): """ Initializes the default states. diff --git a/torch/profiler/python_tracer.py b/torch/profiler/python_tracer.py index f803b64c95fb..b3e624911f95 100644 --- a/torch/profiler/python_tracer.py +++ b/torch/profiler/python_tracer.py @@ -8,11 +8,11 @@ import torch def _prefix_regex() -> typing.List[str]: raw_paths = ( - site.getsitepackages() + - sys.path + - [site.getuserbase()] + - [site.getusersitepackages()] + - [os.path.dirname(os.path.dirname(torch.__file__))] + site.getsitepackages() + + sys.path + + [site.getuserbase()] + + [site.getusersitepackages()] + + [os.path.dirname(os.path.dirname(torch.__file__))] ) path_prefixes = sorted({os.path.abspath(i) for i in raw_paths}, reverse=True) diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index 806086a78a18..fd83d88a3e3e 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -6,10 +6,12 @@ from .fuse_modules import fuse_modules from .stubs import * # noqa: F403 from .quant_type import * # noqa: F403 from .quantize_jit import * # noqa: F403 + # from .quantize_fx import * from .quantization_mappings import * # noqa: F403 from .fuser_method_mappings import * # noqa: F403 + def default_eval_fn(model, calib_data): r""" Default evaluation function takes a torch.utils.data.Dataset or a list of @@ -18,45 +20,68 @@ def default_eval_fn(model, calib_data): for data, target in calib_data: model(data) + __all__ = [ - 'QuantWrapper', 'QuantStub', 'DeQuantStub', + "QuantWrapper", + "QuantStub", + "DeQuantStub", # Top level API for eager mode quantization - 'quantize', 'quantize_dynamic', 'quantize_qat', - 'prepare', 'convert', 'prepare_qat', + "quantize", + "quantize_dynamic", + "quantize_qat", + "prepare", + "convert", + "prepare_qat", # Top level API for graph mode quantization on TorchScript - 'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit', - '_convert_ondevice_dynamic_jit', '_quantize_ondevice_dynamic_jit', + "quantize_jit", + "quantize_dynamic_jit", + "_prepare_ondevice_dynamic_jit", + "_convert_ondevice_dynamic_jit", + "_quantize_ondevice_dynamic_jit", # Top level API for graph mode quantization on GraphModule(torch.fx) # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', - 'QuantType', # quantization type + "QuantType", # quantization type # custom module APIs - 'get_default_static_quant_module_mappings', 'get_static_quant_module_class', - 'get_default_dynamic_quant_module_mappings', - 'get_default_qat_module_mappings', - 'get_default_qconfig_propagation_list', - 'get_default_compare_output_module_list', - 'get_quantized_operator', - 'get_fuser_method', + "get_default_static_quant_module_mappings", + "get_static_quant_module_class", + "get_default_dynamic_quant_module_mappings", + "get_default_qat_module_mappings", + "get_default_qconfig_propagation_list", + "get_default_compare_output_module_list", + "get_quantized_operator", + "get_fuser_method", # Sub functions for `prepare` and `swap_module` - 'propagate_qconfig_', 'add_quant_dequant', 'swap_module', - 'default_eval_fn', + "propagate_qconfig_", + "add_quant_dequant", + "swap_module", + "default_eval_fn", # Observers - 'ObserverBase', 'WeightObserver', 'HistogramObserver', - 'observer', 'default_observer', - 'default_weight_observer', 'default_placeholder_observer', - 'default_per_channel_weight_observer', + "ObserverBase", + "WeightObserver", + "HistogramObserver", + "observer", + "default_observer", + "default_weight_observer", + "default_placeholder_observer", + "default_per_channel_weight_observer", # FakeQuantize (for qat) - 'default_fake_quant', 'default_weight_fake_quant', - 'default_fixed_qparams_range_neg1to1_fake_quant', - 'default_fixed_qparams_range_0to1_fake_quant', - 'default_per_channel_weight_fake_quant', - 'default_histogram_fake_quant', + "default_fake_quant", + "default_weight_fake_quant", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_per_channel_weight_fake_quant", + "default_histogram_fake_quant", # QConfig - 'QConfig', 'default_qconfig', 'default_dynamic_qconfig', 'float16_dynamic_qconfig', - 'float_qparams_weight_only_qconfig', + "QConfig", + "default_qconfig", + "default_dynamic_qconfig", + "float16_dynamic_qconfig", + "float_qparams_weight_only_qconfig", # QAT utilities - 'default_qat_qconfig', 'prepare_qat', 'quantize_qat', + "default_qat_qconfig", + "prepare_qat", + "quantize_qat", # module transformations - 'fuse_modules', + "fuse_modules", ] diff --git a/torch/quantization/_numeric_suite.py b/torch/quantization/_numeric_suite.py index c5a7848f7b0f..49ccc8e69523 100644 --- a/torch/quantization/_numeric_suite.py +++ b/torch/quantization/_numeric_suite.py @@ -8,21 +8,21 @@ here. """ from torch.ao.ns._numeric_suite import ( - NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST, - _find_match, - compare_weights, - _get_logger_dict_helper, - get_logger_dict, - Logger, - ShadowLogger, - OutputLogger, _convert_tuple_to_list, _dequantize_tensor_list, - Shadow, - prepare_model_with_stubs, + _find_match, + _get_logger_dict_helper, _is_identical_module_type, - compare_model_stub, - get_matching_activations, - prepare_model_outputs, compare_model_outputs, + compare_model_stub, + compare_weights, + get_logger_dict, + get_matching_activations, + Logger, + NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST, + OutputLogger, + prepare_model_outputs, + prepare_model_with_stubs, + Shadow, + ShadowLogger, ) diff --git a/torch/quantization/_numeric_suite_fx.py b/torch/quantization/_numeric_suite_fx.py index 991b847a34a7..55cd7085740d 100644 --- a/torch/quantization/_numeric_suite_fx.py +++ b/torch/quantization/_numeric_suite_fx.py @@ -8,19 +8,19 @@ here. """ from torch.ao.ns._numeric_suite_fx import ( - RNNReturnType, - OutputLogger, - NSTracer, - _extract_weights_one_model, - _extract_weights_impl, - extract_weights, - _add_loggers_one_model, _add_loggers_impl, - add_loggers, - _extract_logger_info_one_model, - extract_logger_info, + _add_loggers_one_model, _add_shadow_loggers_impl, + _extract_logger_info_one_model, + _extract_weights_impl, + _extract_weights_one_model, + add_loggers, add_shadow_loggers, - extract_shadow_logger_info, extend_logger_results_with_comparison, + extract_logger_info, + extract_shadow_logger_info, + extract_weights, + NSTracer, + OutputLogger, + RNNReturnType, ) diff --git a/torch/quantization/fake_quantize.py b/torch/quantization/fake_quantize.py index e7da7a485ebb..69a5d730bfb6 100644 --- a/torch/quantization/fake_quantize.py +++ b/torch/quantization/fake_quantize.py @@ -8,25 +8,25 @@ here. """ from torch.ao.quantization.fake_quantize import ( + _is_fake_quant_script_module, _is_per_channel, _is_per_tensor, _is_symmetric_quant, - FakeQuantizeBase, + default_fake_quant, + default_fixed_qparams_range_0to1_fake_quant, + default_fixed_qparams_range_neg1to1_fake_quant, + default_fused_act_fake_quant, + default_fused_per_channel_wt_fake_quant, + default_fused_wt_fake_quant, + default_histogram_fake_quant, + default_per_channel_weight_fake_quant, + default_weight_fake_quant, + disable_fake_quant, + disable_observer, + enable_fake_quant, + enable_observer, FakeQuantize, + FakeQuantizeBase, FixedQParamsFakeQuantize, FusedMovingAvgObsFakeQuantize, - default_fake_quant, - default_weight_fake_quant, - default_fixed_qparams_range_neg1to1_fake_quant, - default_fixed_qparams_range_0to1_fake_quant, - default_per_channel_weight_fake_quant, - default_histogram_fake_quant, - default_fused_act_fake_quant, - default_fused_wt_fake_quant, - default_fused_per_channel_wt_fake_quant, - _is_fake_quant_script_module, - disable_fake_quant, - enable_fake_quant, - disable_observer, - enable_observer, ) diff --git a/torch/quantization/fuse_modules.py b/torch/quantization/fuse_modules.py index 55bd8363524b..6b704fa8094e 100644 --- a/torch/quantization/fuse_modules.py +++ b/torch/quantization/fuse_modules.py @@ -7,18 +7,16 @@ If you are adding a new entry/functionality, please, add it to the here. """ -from torch.ao.quantization.fuse_modules import fuse_modules -from torch.ao.quantization.fuse_modules import fuse_known_modules -from torch.ao.quantization.fuse_modules import get_fuser_method - -# for backward compatiblity -from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn -from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn_relu - # TODO: These functions are not used outside the `fuse_modules.py` # Keeping here for now, need to remove them later. from torch.ao.quantization.fuse_modules import ( _fuse_modules, _get_module, _set_module, + fuse_known_modules, + fuse_modules, + get_fuser_method, ) + +# for backward compatiblity +from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn, fuse_conv_bn_relu diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index 22f4e638ea69..cfb13ac96271 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -7,9 +7,9 @@ If you are adding a new entry/functionality, please, add it to the here. """ from torch.ao.quantization.fuser_method_mappings import ( + _DEFAULT_OP_LIST_TO_FUSER_METHOD, fuse_conv_bn, fuse_conv_bn_relu, fuse_linear_bn, - _DEFAULT_OP_LIST_TO_FUSER_METHOD, get_fuser_method, ) diff --git a/torch/quantization/fx/__init__.py b/torch/quantization/fx/__init__.py index c1c1effbb281..c01cbd457374 100644 --- a/torch/quantization/fx/__init__.py +++ b/torch/quantization/fx/__init__.py @@ -7,8 +7,9 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import stat here. """ +from torch.ao.quantization.fx.convert import convert +from torch.ao.quantization.fx.fuse import fuse + # omitting files that's unlikely to be used right now, for example # the newly added lower_to_fbgemm etc. from torch.ao.quantization.fx.prepare import prepare -from torch.ao.quantization.fx.convert import convert -from torch.ao.quantization.fx.fuse import fuse diff --git a/torch/quantization/fx/_equalize.py b/torch/quantization/fx/_equalize.py index 4cdd9a9adf1b..7acea4f84a2a 100644 --- a/torch/quantization/fx/_equalize.py +++ b/torch/quantization/fx/_equalize.py @@ -7,32 +7,32 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import stat here. """ from torch.ao.quantization.fx._equalize import ( - reshape_scale, + _convert_equalization_ref, _InputEqualizationObserver, _WeightEqualizationObserver, calculate_equalization_scale, - EqualizationQConfig, - input_equalization_observer, - weight_equalization_observer, - default_equalization_qconfig, - fused_module_supports_equalization, - nn_module_supports_equalization, - custom_module_supports_equalization, - node_supports_equalization, - is_equalization_observer, - get_op_node_and_weight_eq_obs, - maybe_get_weight_eq_obs_node, - maybe_get_next_input_eq_obs, - maybe_get_next_equalization_scale, - scale_input_observer, - scale_weight_node, - scale_weight_functional, clear_weight_quant_obs_node, - remove_node, - update_obs_for_equalization, convert_eq_obs, - _convert_equalization_ref, - get_layer_sqnr_dict, - get_equalization_qconfig_dict, CUSTOM_MODULE_SUPP_LIST, + custom_module_supports_equalization, + default_equalization_qconfig, + EqualizationQConfig, + fused_module_supports_equalization, + get_equalization_qconfig_dict, + get_layer_sqnr_dict, + get_op_node_and_weight_eq_obs, + input_equalization_observer, + is_equalization_observer, + maybe_get_next_equalization_scale, + maybe_get_next_input_eq_obs, + maybe_get_weight_eq_obs_node, + nn_module_supports_equalization, + node_supports_equalization, + remove_node, + reshape_scale, + scale_input_observer, + scale_weight_functional, + scale_weight_node, + update_obs_for_equalization, + weight_equalization_observer, ) diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index 36c74bd27722..e29337b3f861 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -6,7 +6,4 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ -from torch.ao.quantization.fx.fuse_handler import ( - FuseHandler, - DefaultFuseHandler, -) +from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler diff --git a/torch/quantization/fx/graph_module.py b/torch/quantization/fx/graph_module.py index 08c90c2165f6..a71e980a57ba 100644 --- a/torch/quantization/fx/graph_module.py +++ b/torch/quantization/fx/graph_module.py @@ -7,11 +7,11 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import stat here. """ from torch.ao.quantization.fx.graph_module import ( - GraphModule, - FusedGraphModule, - ObservedGraphModule, _is_observed_module, - ObservedStandaloneGraphModule, _is_observed_standalone_module, - QuantizedGraphModule + FusedGraphModule, + GraphModule, + ObservedGraphModule, + ObservedStandaloneGraphModule, + QuantizedGraphModule, ) diff --git a/torch/quantization/fx/match_utils.py b/torch/quantization/fx/match_utils.py index d39c141e1ee8..8b49f7c645d8 100644 --- a/torch/quantization/fx/match_utils.py +++ b/torch/quantization/fx/match_utils.py @@ -7,8 +7,8 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import stat here. """ from torch.ao.quantization.fx.match_utils import ( + _find_matches, + _is_match, _MatchResult, MatchAllNode, - _is_match, - _find_matches ) diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index d528f42a4937..26954833bb48 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -7,12 +7,12 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import stat here. """ from torch.ao.quantization.fx.pattern_utils import ( - QuantizeHandler, _register_fusion_pattern, - get_default_fusion_patterns, _register_quant_pattern, + get_default_fusion_patterns, + get_default_output_activation_post_process_map, get_default_quant_patterns, - get_default_output_activation_post_process_map + QuantizeHandler, ) # QuantizeHandler.__module__ = _NAMESPACE @@ -20,7 +20,9 @@ _register_fusion_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" get_default_fusion_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" _register_quant_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" get_default_quant_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" -get_default_output_activation_post_process_map.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_output_activation_post_process_map.__module__ = ( + "torch.ao.quantization.fx.pattern_utils" +) # __all__ = [ # "QuantizeHandler", diff --git a/torch/quantization/fx/prepare.py b/torch/quantization/fx/prepare.py index 770f07751e24..ca65dcc04dd0 100644 --- a/torch/quantization/fx/prepare.py +++ b/torch/quantization/fx/prepare.py @@ -6,6 +6,4 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ -from torch.ao.quantization.fx.prepare import ( - prepare -) +from torch.ao.quantization.fx.prepare import prepare diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 50bfa0bfbe8e..34ee88a4713c 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -7,20 +7,20 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import stat here. """ from torch.ao.quantization.fx.quantize_handler import ( - QuantizeHandler, + BatchNormQuantizeHandler, BinaryOpQuantizeHandler, CatQuantizeHandler, ConvReluQuantizeHandler, - LinearReLUQuantizeHandler, - BatchNormQuantizeHandler, - EmbeddingQuantizeHandler, - RNNDynamicQuantizeHandler, - DefaultNodeQuantizeHandler, - FixedQParamsOpQuantizeHandler, CopyNodeQuantizeHandler, CustomModuleQuantizeHandler, + DefaultNodeQuantizeHandler, + EmbeddingQuantizeHandler, + FixedQParamsOpQuantizeHandler, GeneralTensorShapeOpQuantizeHandler, - StandaloneModuleQuantizeHandler + LinearReLUQuantizeHandler, + QuantizeHandler, + RNNDynamicQuantizeHandler, + StandaloneModuleQuantizeHandler, ) QuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" @@ -32,8 +32,16 @@ BatchNormQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_pat EmbeddingQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" RNNDynamicQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" DefaultNodeQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" -FixedQParamsOpQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +FixedQParamsOpQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) CopyNodeQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" -CustomModuleQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" -GeneralTensorShapeOpQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" -StandaloneModuleQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +CustomModuleQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +GeneralTensorShapeOpQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +StandaloneModuleQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) diff --git a/torch/quantization/fx/quantization_types.py b/torch/quantization/fx/quantization_types.py index f31cdf5ba1c8..a422cdd3142e 100644 --- a/torch/quantization/fx/quantization_types.py +++ b/torch/quantization/fx/quantization_types.py @@ -6,7 +6,4 @@ If you are adding a new entry/functionality, please, add it to the appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ -from torch.ao.quantization.utils import ( - Pattern, - QuantizerCls -) +from torch.ao.quantization.utils import Pattern, QuantizerCls diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index 96f4f68c592b..ef35559884b7 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -7,14 +7,14 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import stat here. """ from torch.ao.quantization.fx.utils import ( - get_custom_module_class_keys, - get_linear_prepack_op_for_dtype, - get_qconv_prepack_op, - get_new_attr_name_with_prefix, - graph_module_from_producer_nodes, + all_node_args_have_no_tensors, assert_and_get_unique_device, create_getattr_from_value, - all_node_args_have_no_tensors, + get_custom_module_class_keys, + get_linear_prepack_op_for_dtype, + get_new_attr_name_with_prefix, get_non_observable_arg_indexes_and_types, - maybe_get_next_module + get_qconv_prepack_op, + graph_module_from_producer_nodes, + maybe_get_next_module, ) diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 50a6894f99a1..6e6c7c1917c8 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -7,30 +7,30 @@ If you are adding a new entry/functionality, please, add it to the here. """ from torch.ao.quantization.observer import ( + _is_activation_post_process, + _is_per_channel_script_obs_instance, + _ObserverBase, _PartialWrapper, _with_args, _with_callable_args, ABC, - ObserverBase, - _ObserverBase, - MinMaxObserver, - MovingAverageMinMaxObserver, - PerChannelMinMaxObserver, - MovingAveragePerChannelMinMaxObserver, - HistogramObserver, - PlaceholderObserver, - RecordingObserver, - NoopObserver, - _is_activation_post_process, - _is_per_channel_script_obs_instance, - get_observer_state_dict, - load_observer_state_dict, - default_observer, - default_placeholder_observer, default_debug_observer, - default_weight_observer, - default_histogram_observer, - default_per_channel_weight_observer, default_dynamic_quant_observer, default_float_qparams_observer, + default_histogram_observer, + default_observer, + default_per_channel_weight_observer, + default_placeholder_observer, + default_weight_observer, + get_observer_state_dict, + HistogramObserver, + load_observer_state_dict, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + NoopObserver, + ObserverBase, + PerChannelMinMaxObserver, + PlaceholderObserver, + RecordingObserver, ) diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py index 9da450abd67b..6bb7e14110cb 100644 --- a/torch/quantization/qconfig.py +++ b/torch/quantization/qconfig.py @@ -7,24 +7,24 @@ If you are adding a new entry/functionality, please, add it to the here. """ from torch.ao.quantization.qconfig import ( - QConfig, - default_qconfig, + _add_module_to_qconfig_obs_ctr, + _assert_valid_qconfig, + default_activation_only_qconfig, default_debug_qconfig, - default_per_channel_qconfig, - QConfigDynamic, default_dynamic_qconfig, + default_per_channel_qconfig, + default_qat_qconfig, + default_qat_qconfig_v2, + default_qconfig, + default_weight_only_qconfig, float16_dynamic_qconfig, float16_static_qconfig, - per_channel_dynamic_qconfig, float_qparams_weight_only_qconfig, - default_qat_qconfig, - default_weight_only_qconfig, - default_activation_only_qconfig, - default_qat_qconfig_v2, - get_default_qconfig, get_default_qat_qconfig, - _assert_valid_qconfig, + get_default_qconfig, + per_channel_dynamic_qconfig, + QConfig, + qconfig_equals, QConfigAny, - _add_module_to_qconfig_obs_ctr, - qconfig_equals + QConfigDynamic, ) diff --git a/torch/quantization/quant_type.py b/torch/quantization/quant_type.py index c7f7cc15dbdd..8555f0379266 100644 --- a/torch/quantization/quant_type.py +++ b/torch/quantization/quant_type.py @@ -7,5 +7,4 @@ If you are adding a new entry/functionality, please, add it to the here. """ -from torch.ao.quantization.quant_type import QuantType -from torch.ao.quantization.quant_type import _get_quant_type_to_str +from torch.ao.quantization.quant_type import _get_quant_type_to_str, QuantType diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index e1a59f88c1e7..8b44a980ce82 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -7,23 +7,23 @@ If you are adding a new entry/functionality, please, add it to the here. """ from torch.ao.quantization.quantization_mappings import ( - DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS, - DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, - DEFAULT_QAT_MODULE_MAPPINGS, - DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, - _INCLUDE_QCONFIG_PROPAGATE_LIST, - DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, - DEFAULT_MODULE_TO_ACT_POST_PROCESS, - no_observer_set, - get_default_static_quant_module_mappings, - get_static_quant_module_class, - get_dynamic_quant_module_class, - get_default_qat_module_mappings, - get_default_dynamic_quant_module_mappings, - get_default_qconfig_propagation_list, - get_default_compare_output_module_list, - get_default_float_to_quantized_operator_mappings, - get_quantized_operator, _get_special_act_post_process, _has_special_act_post_process, + _INCLUDE_QCONFIG_PROPAGATE_LIST, + DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, + DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, + DEFAULT_MODULE_TO_ACT_POST_PROCESS, + DEFAULT_QAT_MODULE_MAPPINGS, + DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS, + DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, + get_default_compare_output_module_list, + get_default_dynamic_quant_module_mappings, + get_default_float_to_quantized_operator_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, + get_default_static_quant_module_mappings, + get_dynamic_quant_module_class, + get_quantized_operator, + get_static_quant_module_class, + no_observer_set, ) diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index e416f85ec5ba..600d3a46fed0 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -7,22 +7,24 @@ If you are adding a new entry/functionality, please, add it to the here. """ -from torch.ao.quantization.quantize import _convert -from torch.ao.quantization.quantize import _observer_forward_hook -from torch.ao.quantization.quantize import _propagate_qconfig_helper -from torch.ao.quantization.quantize import _remove_activation_post_process -from torch.ao.quantization.quantize import _remove_qconfig -from torch.ao.quantization.quantize import _add_observer_ -from torch.ao.quantization.quantize import add_quant_dequant -from torch.ao.quantization.quantize import convert -from torch.ao.quantization.quantize import _get_observer_dict -from torch.ao.quantization.quantize import _get_unique_devices_ -from torch.ao.quantization.quantize import _is_activation_post_process -from torch.ao.quantization.quantize import prepare -from torch.ao.quantization.quantize import prepare_qat -from torch.ao.quantization.quantize import propagate_qconfig_ -from torch.ao.quantization.quantize import quantize -from torch.ao.quantization.quantize import quantize_dynamic -from torch.ao.quantization.quantize import quantize_qat -from torch.ao.quantization.quantize import _register_activation_post_process_hook -from torch.ao.quantization.quantize import swap_module +from torch.ao.quantization.quantize import ( + _add_observer_, + _convert, + _get_observer_dict, + _get_unique_devices_, + _is_activation_post_process, + _observer_forward_hook, + _propagate_qconfig_helper, + _register_activation_post_process_hook, + _remove_activation_post_process, + _remove_qconfig, + add_quant_dequant, + convert, + prepare, + prepare_qat, + propagate_qconfig_, + quantize, + quantize_dynamic, + quantize_qat, + swap_module, +) diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index aad3bc7253e4..649142c7a7ee 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -7,23 +7,20 @@ If you are adding a new entry/functionality, please, add it to the here. """ +from torch.ao.quantization.fx.graph_module import ObservedGraphModule from torch.ao.quantization.quantize_fx import ( _check_is_graph_module, - _swap_ff_with_fxff, + _convert_fx, + _convert_standalone_module_fx, _fuse_fx, - Scope, - ScopeContextManager, - QuantizationTracer, _prepare_fx, _prepare_standalone_module_fx, + _swap_ff_with_fxff, + convert_fx, fuse_fx, prepare_fx, prepare_qat_fx, - _convert_fx, - convert_fx, - _convert_standalone_module_fx, -) - -from torch.ao.quantization.fx.graph_module import ( - ObservedGraphModule, + QuantizationTracer, + Scope, + ScopeContextManager, ) diff --git a/torch/quantization/quantize_jit.py b/torch/quantization/quantize_jit.py index 6228e5ca24c7..aa627dc7bb51 100644 --- a/torch/quantization/quantize_jit.py +++ b/torch/quantization/quantize_jit.py @@ -8,19 +8,19 @@ here. """ from torch.ao.quantization.quantize_jit import ( - _check_is_script_module, _check_forward_method, + _check_is_script_module, + _convert_jit, + _prepare_jit, + _prepare_ondevice_dynamic_jit, + _quantize_jit, + convert_dynamic_jit, + convert_jit, + fuse_conv_bn_jit, + prepare_dynamic_jit, + prepare_jit, + quantize_dynamic_jit, + quantize_jit, script_qconfig, script_qconfig_dict, - fuse_conv_bn_jit, - _prepare_jit, - prepare_jit, - prepare_dynamic_jit, - _prepare_ondevice_dynamic_jit, - _convert_jit, - convert_jit, - convert_dynamic_jit, - _quantize_jit, - quantize_jit, - quantize_dynamic_jit ) diff --git a/torch/quantization/stubs.py b/torch/quantization/stubs.py index 10d297d3c7ce..d3fd5c63683d 100644 --- a/torch/quantization/stubs.py +++ b/torch/quantization/stubs.py @@ -7,8 +7,4 @@ If you are adding a new entry/functionality, please, add it to the here. """ -from torch.ao.quantization.stubs import ( - QuantStub, - DeQuantStub, - QuantWrapper -) +from torch.ao.quantization.stubs import DeQuantStub, QuantStub, QuantWrapper