mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/110798 Approved by: https://github.com/Skylion007
		
			
				
	
	
		
			2197 lines
		
	
	
		
			91 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			2197 lines
		
	
	
		
			91 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["module: unknown"]
 | 
						|
 | 
						|
import copy
 | 
						|
from collections.abc import Sequence
 | 
						|
from functools import partial
 | 
						|
import warnings
 | 
						|
import unittest
 | 
						|
import inspect
 | 
						|
import itertools
 | 
						|
import torch
 | 
						|
import contextlib
 | 
						|
import re
 | 
						|
import os
 | 
						|
 | 
						|
from collections import defaultdict
 | 
						|
from importlib import import_module
 | 
						|
from torch.utils._pytree import tree_map
 | 
						|
from typing import Dict
 | 
						|
from torch.testing import make_tensor
 | 
						|
from torch.testing._internal.common_dtype import (
 | 
						|
    floating_and_complex_types_and,
 | 
						|
    all_types_and_complex_and,
 | 
						|
)
 | 
						|
 | 
						|
from torch.testing._internal.common_utils import (
 | 
						|
    TestCase,
 | 
						|
    is_iterable_of_tensors,
 | 
						|
    run_tests,
 | 
						|
    IS_SANDCASTLE,
 | 
						|
    clone_input_helper,
 | 
						|
    IS_CI,
 | 
						|
    set_default_dtype,
 | 
						|
    suppress_warnings,
 | 
						|
    noncontiguous_like,
 | 
						|
    TEST_WITH_ASAN,
 | 
						|
    TEST_WITH_UBSAN,
 | 
						|
    IS_WINDOWS,
 | 
						|
    IS_FBCODE,
 | 
						|
    first_sample,
 | 
						|
    parametrize,
 | 
						|
    skipIfTorchInductor,
 | 
						|
    slowTest,
 | 
						|
)
 | 
						|
from torch.testing._internal.common_methods_invocations import (
 | 
						|
    op_db,
 | 
						|
    UnaryUfuncInfo,
 | 
						|
    ReductionOpInfo,
 | 
						|
    ReductionPythonRefInfo,
 | 
						|
    SpectralFuncInfo,
 | 
						|
    ops_and_refs,
 | 
						|
    python_ref_db,
 | 
						|
    BinaryUfuncInfo,
 | 
						|
    xfail,
 | 
						|
    skip,
 | 
						|
    skipOps
 | 
						|
)
 | 
						|
from torch.testing._internal.common_device_type import (
 | 
						|
    deviceCountAtLeast,
 | 
						|
    instantiate_device_type_tests,
 | 
						|
    ops,
 | 
						|
    onlyCUDA,
 | 
						|
    onlyCPU,
 | 
						|
    onlyNativeDeviceTypes,
 | 
						|
    OpDTypes,
 | 
						|
    skipMeta,
 | 
						|
)
 | 
						|
from torch._subclasses.fake_tensor import (
 | 
						|
    FakeTensor,
 | 
						|
    FakeTensorMode,
 | 
						|
)
 | 
						|
from torch._subclasses.fake_utils import outputs_alias_inputs
 | 
						|
 | 
						|
import torch._prims as prims
 | 
						|
from torch._prims.context import TorchRefsMode
 | 
						|
from torch._prims_common.wrappers import _maybe_remove_out_wrapper
 | 
						|
 | 
						|
from torch.testing._internal import opinfo
 | 
						|
from torch.testing._internal import composite_compliance
 | 
						|
 | 
						|
from torch.utils._pytree import tree_flatten
 | 
						|
from torch.utils._python_dispatch import TorchDispatchMode
 | 
						|
 | 
						|
assert torch.get_default_dtype() == torch.float32
 | 
						|
 | 
						|
# variant testing is only done with torch.float and torch.cfloat to avoid
 | 
						|
#   excessive test times and maximize signal to noise ratio
 | 
						|
_variant_ops = partial(
 | 
						|
    ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
 | 
						|
)
 | 
						|
 | 
						|
# Get names of all the operators which have ref in their entry in OpInfo (testing infra)
 | 
						|
#   except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py),
 | 
						|
#   elementwise binary operators (separately implemented in test_binary_ufuncs.py),
 | 
						|
#   reduction operations (separately impelemented in test_reductions.py),
 | 
						|
#   and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py)
 | 
						|
_ref_test_ops = tuple(
 | 
						|
    filter(
 | 
						|
        lambda op: not isinstance(
 | 
						|
            op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
 | 
						|
        )
 | 
						|
        and op.ref is not None,
 | 
						|
        op_db,
 | 
						|
    )
 | 
						|
)
 | 
						|
_ops_and_refs = op_db + python_ref_db
 | 
						|
 | 
						|
def reduction_dtype_filter(op):
 | 
						|
    if(not isinstance(op, ReductionPythonRefInfo) or not op.supports_out
 | 
						|
       or torch.int16 not in op.dtypes):
 | 
						|
        return False
 | 
						|
 | 
						|
    argspec = inspect.getfullargspec(op.op)
 | 
						|
    if 'dtype' not in argspec.kwonlyargs:
 | 
						|
        return False
 | 
						|
    return True
 | 
						|
 | 
						|
# Create a list of operators that are a subset of _ref_test_ops but don't have a
 | 
						|
# numpy ref to compare them too, If both CPU and CUDA are compared to numpy
 | 
						|
# then they do not need to be compared to each other
 | 
						|
_ops_and_refs_with_no_numpy_ref = [op for op in _ops_and_refs if op.ref is None]
 | 
						|
 | 
						|
aten = torch.ops.aten
 | 
						|
 | 
						|
# Tests that apply to all operators and aren't related to any particular
 | 
						|
#   system
 | 
						|
class TestCommon(TestCase):
 | 
						|
    exact_dtype = True
 | 
						|
 | 
						|
    # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI
 | 
						|
    @classmethod
 | 
						|
    def tearDownClass(cls):
 | 
						|
        super().tearDownClass()
 | 
						|
 | 
						|
        if IS_CI:
 | 
						|
            err_msg = (
 | 
						|
                "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
 | 
						|
                "This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
 | 
						|
            )
 | 
						|
            # Assure no opinfo entry has dynamic_dtypes
 | 
						|
            filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db))
 | 
						|
            for op in filtered_ops:
 | 
						|
                fmt_str = opinfo.utils.str_format_dynamic_dtype(op)
 | 
						|
                err_msg += "\n" + fmt_str
 | 
						|
 | 
						|
            assert len(filtered_ops) == 0, err_msg
 | 
						|
 | 
						|
    # Validates that each OpInfo works correctly on different CUDA devices
 | 
						|
    @onlyCUDA
 | 
						|
    @deviceCountAtLeast(2)
 | 
						|
    @ops(op_db, allowed_dtypes=(torch.float32, torch.long))
 | 
						|
    def test_multiple_devices(self, devices, dtype, op):
 | 
						|
        for cuda_device_str in devices:
 | 
						|
            cuda_device = torch.device(cuda_device_str)
 | 
						|
            # NOTE: only tests on first sample
 | 
						|
            samples = op.sample_inputs(cuda_device, dtype)
 | 
						|
            sample = first_sample(self, samples)
 | 
						|
            result = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
 | 
						|
            if isinstance(result, torch.Tensor):
 | 
						|
                self.assertTrue(result.device == cuda_device)
 | 
						|
            elif is_iterable_of_tensors(result):
 | 
						|
                self.assertTrue(all(t.device == cuda_device for t in result))
 | 
						|
            else:
 | 
						|
                self.skipTest(
 | 
						|
                    "Skipped! Only supports single tensor or iterable of tensor outputs."
 | 
						|
                )
 | 
						|
 | 
						|
    def test_pointwise_tag_coverage(self):
 | 
						|
 | 
						|
        pytorch_dir = os.path.abspath(__file__ + "/../../")
 | 
						|
        files = [
 | 
						|
            "aten/src/ATen/native/UnaryOps.cpp",
 | 
						|
            "aten/src/ATen/native/BinaryOps.cpp",
 | 
						|
            "aten/src/ATen/native/PointwiseOps.cpp",
 | 
						|
            "aten/src/ATen/native/TensorCompare.cpp",
 | 
						|
        ]
 | 
						|
 | 
						|
        allowed_functions = (
 | 
						|
            # reduction version of these operators
 | 
						|
            "aten.max.default",
 | 
						|
            "aten.max.dim",
 | 
						|
            "aten.max.dim_max",
 | 
						|
            "aten.max.names_dim",
 | 
						|
            "aten.max.names_dim_max",
 | 
						|
            "aten.max.unary_out",
 | 
						|
            "aten.min.default",
 | 
						|
            "aten.min.dim",
 | 
						|
            "aten.min.dim_min",
 | 
						|
            "aten.min.names_dim",
 | 
						|
            "aten.min.names_dim_min",
 | 
						|
            "aten.min.unary_out",
 | 
						|
            # not pointwise
 | 
						|
            "aten.isin.Tensor_Tensor",
 | 
						|
            "aten.isin.Tensor_Tensor_out",
 | 
						|
            "aten.isin.Tensor_Scalar",
 | 
						|
            "aten.isin.Tensor_Scalar_out",
 | 
						|
            "aten.isin.Scalar_Tensor",
 | 
						|
            "aten.isin.Scalar_Tensor_out",
 | 
						|
            "aten.mode.default",
 | 
						|
            "aten.mode.dimname",
 | 
						|
            "aten.mode.dimname_out",
 | 
						|
            "aten.mode.values",
 | 
						|
        )
 | 
						|
 | 
						|
        regex = re.compile(r"DEFINE_DISPATCH\(.*_stub")
 | 
						|
 | 
						|
        def get_opoverloadpacket_from_dispatch(kernel):
 | 
						|
            if hasattr(torch.ops.aten, kernel):
 | 
						|
                return kernel
 | 
						|
            if hasattr(torch.ops.aten, f"__{kernel}__"):
 | 
						|
                return f"__{kernel}__"
 | 
						|
            if hasattr(torch.ops.aten, f"special_{kernel}"):
 | 
						|
                return f"special_{kernel}"
 | 
						|
            if "_" in kernel:
 | 
						|
                kernel_split = kernel.split("_")
 | 
						|
                new_kernel = "_".join(kernel_split[:-1])
 | 
						|
                if hasattr(torch.ops.aten, new_kernel):
 | 
						|
                    return new_kernel
 | 
						|
 | 
						|
            # could not find op from kernel dispatch string
 | 
						|
            self.assertTrue(False)
 | 
						|
 | 
						|
        for file_name in files:
 | 
						|
            with open(os.path.join(pytorch_dir, file_name)) as f:
 | 
						|
                lines = f.read()
 | 
						|
                matches = regex.findall(lines)
 | 
						|
                for match in matches:
 | 
						|
                    kernel = match[len("DEFINE_DISPATCH("):-len("_stub")]
 | 
						|
 | 
						|
                    # no op definition for it, but defined with DEFINE_DISPATCH ?
 | 
						|
                    if kernel == "trigamma":
 | 
						|
                        continue
 | 
						|
 | 
						|
                    kernel = get_opoverloadpacket_from_dispatch(kernel)
 | 
						|
                    overloadpacket = getattr(torch.ops.aten, kernel)
 | 
						|
 | 
						|
                    for overload_name in overloadpacket.overloads():
 | 
						|
                        overload = getattr(overloadpacket, overload_name)
 | 
						|
 | 
						|
                        if not torch._C._dispatch_has_kernel(overload.name()):
 | 
						|
                            continue
 | 
						|
 | 
						|
                        # TODO: tags are not propagated to generated overload,
 | 
						|
                        # and there's no way of specifying them
 | 
						|
                        if torch.Tag.generated in overload.tags:
 | 
						|
                            continue
 | 
						|
 | 
						|
                        if str(overload) in allowed_functions:
 | 
						|
                            continue
 | 
						|
 | 
						|
                        self.assertTrue(torch.Tag.pointwise in overload.tags)
 | 
						|
 | 
						|
    # Tests that the function and its (ndarray-accepting) reference produce the same
 | 
						|
    #   values on the tensors from sample_inputs func for the corresponding op.
 | 
						|
    # This test runs in double and complex double precision because
 | 
						|
    # NumPy does computation internally using double precision for many functions
 | 
						|
    # resulting in possible equality check failures.
 | 
						|
    @onlyNativeDeviceTypes
 | 
						|
    @suppress_warnings
 | 
						|
    @ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128))
 | 
						|
    def test_numpy_ref(self, device, dtype, op):
 | 
						|
        # Sets the default dtype to NumPy's default dtype of double
 | 
						|
        with set_default_dtype(torch.double):
 | 
						|
            for sample_input in op.reference_inputs(device, dtype):
 | 
						|
                self.compare_with_reference(
 | 
						|
                    op, op.ref, sample_input, exact_dtype=(dtype is not torch.long)
 | 
						|
                )
 | 
						|
 | 
						|
    # Tests that the cpu and gpu results are consistent
 | 
						|
    @onlyCUDA
 | 
						|
    @suppress_warnings
 | 
						|
    @slowTest
 | 
						|
    @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one)
 | 
						|
    def test_compare_cpu(self, device, dtype, op):
 | 
						|
 | 
						|
        def to_cpu(arg):
 | 
						|
            if isinstance(arg, torch.Tensor):
 | 
						|
                return arg.to(device='cpu')
 | 
						|
            return arg
 | 
						|
 | 
						|
        samples = op.reference_inputs(device, dtype)
 | 
						|
 | 
						|
        for sample in samples:
 | 
						|
            cpu_sample = sample.transform(to_cpu)
 | 
						|
            cuda_results = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
            cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
 | 
						|
 | 
						|
            # output_process_fn_grad has a very unfortunate name
 | 
						|
            # We use this function in linalg extensively to postprocess the inputs of functions
 | 
						|
            # that are not completely well-defined. Think svd and muliplying the singular vectors by -1.
 | 
						|
            # CPU and CUDA implementations of the SVD can return valid SVDs that are different.
 | 
						|
            # We use this function to compare them.
 | 
						|
            cuda_results = sample.output_process_fn_grad(cuda_results)
 | 
						|
            cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
 | 
						|
 | 
						|
            # Lower tolerance because we are running this as a `@slowTest`
 | 
						|
            # Don't want the periodic tests to fail frequently
 | 
						|
            self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3)
 | 
						|
 | 
						|
    # Tests that experimental Python References can propagate shape, dtype,
 | 
						|
    # and device metadata properly.
 | 
						|
    # See https://github.com/pytorch/pytorch/issues/78050 for a discussion of stride propagation.
 | 
						|
    @onlyNativeDeviceTypes
 | 
						|
    @ops(python_ref_db)
 | 
						|
    @skipIfTorchInductor("Takes too long for inductor")
 | 
						|
    def test_python_ref_meta(self, device, dtype, op):
 | 
						|
        with FakeTensorMode() as mode:
 | 
						|
            pass
 | 
						|
 | 
						|
        def _to_tensormeta(x):
 | 
						|
            if isinstance(x, torch.Tensor):
 | 
						|
                out = FakeTensor.from_tensor(x, mode)
 | 
						|
                return out
 | 
						|
            return x
 | 
						|
 | 
						|
        # TODO: iterate over requires_grad true/false
 | 
						|
        for sample in op.reference_inputs(device, dtype, requires_grad=False):
 | 
						|
            result = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
 | 
						|
            meta_sample = sample.transform(_to_tensormeta)
 | 
						|
            try:
 | 
						|
                with mode:
 | 
						|
                    meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
 | 
						|
            except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
 | 
						|
                continue
 | 
						|
            except torch._subclasses.fake_tensor.DataDependentOutputException:
 | 
						|
                continue
 | 
						|
            except torch._subclasses.fake_tensor.UnsupportedOperatorException:
 | 
						|
                continue
 | 
						|
 | 
						|
            if isinstance(result, torch.Tensor):
 | 
						|
                self.assertTrue(isinstance(meta_result, FakeTensor))
 | 
						|
                prims.utils.compare_tensor_meta(result, meta_result)
 | 
						|
            elif isinstance(result, Sequence):
 | 
						|
                for a, b in zip(result, meta_result):
 | 
						|
                    if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
 | 
						|
                        self.assertTrue(isinstance(b, FakeTensor))
 | 
						|
                        prims.utils.compare_tensor_meta(a, b)
 | 
						|
 | 
						|
    def _ref_test_helper(
 | 
						|
        self,
 | 
						|
        ctx,
 | 
						|
        device,
 | 
						|
        dtype,
 | 
						|
        op,
 | 
						|
        skip_zero_numel=False,
 | 
						|
        skip_zero_dim=False,
 | 
						|
        skip_bfloat=False,
 | 
						|
        skip_view_consistency=False,
 | 
						|
    ):
 | 
						|
        # NOTE: this test works by comparing the reference
 | 
						|
        ex = None
 | 
						|
        for sample in op.reference_inputs(device, dtype, requires_grad=False):
 | 
						|
            if isinstance(sample.input, torch.Tensor) and sample.input.numel() == 0 and skip_zero_numel:
 | 
						|
                continue
 | 
						|
            if isinstance(sample.input, torch.Tensor) and sample.input.ndim == 0 and skip_zero_dim:
 | 
						|
                continue
 | 
						|
 | 
						|
            if (
 | 
						|
                skip_bfloat
 | 
						|
                and (
 | 
						|
                    (
 | 
						|
                        isinstance(sample.input, torch.Tensor)
 | 
						|
                        and sample.input.dtype == torch.bfloat16
 | 
						|
                    )
 | 
						|
                    or any(
 | 
						|
                        isinstance(arg, torch.Tensor) and arg.dtype == torch.bfloat16
 | 
						|
                        for arg in sample.args
 | 
						|
                    )
 | 
						|
                )
 | 
						|
            ):
 | 
						|
                continue
 | 
						|
            with ctx():
 | 
						|
                ref_result = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
            torch_result = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
 | 
						|
 | 
						|
            for a, b in zip(tree_flatten(ref_result)[0], tree_flatten(torch_result)[0]):
 | 
						|
                if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
 | 
						|
                    prims.utils.compare_tensor_meta(a, b)
 | 
						|
                    if getattr(op, 'validate_view_consistency', True) and not skip_view_consistency:
 | 
						|
                        msg = (f"The torch implementation {'returns' if b._is_view() else 'does not return'} "
 | 
						|
                               f"a view, while the reference {'does' if a._is_view() else 'does not'}")
 | 
						|
                        self.assertEqual(a._is_view(), b._is_view(), msg)
 | 
						|
 | 
						|
            # Computes the dtype the more precise computatino would occur in
 | 
						|
            precise_dtype = torch.bool
 | 
						|
            if prims.utils.is_integer_dtype(dtype):
 | 
						|
                # Note: bool and integer dtypes do not have more
 | 
						|
                # precise dtypes -- they simply must be close
 | 
						|
                precise_dtype = dtype
 | 
						|
            if prims.utils.is_float_dtype(dtype):
 | 
						|
                precise_dtype = torch.double
 | 
						|
            if prims.utils.is_complex_dtype(dtype):
 | 
						|
                precise_dtype = torch.cdouble
 | 
						|
 | 
						|
            # Checks if the results are close
 | 
						|
            try:
 | 
						|
                self.assertEqual(
 | 
						|
                    ref_result,
 | 
						|
                    torch_result,
 | 
						|
                    exact_stride=False,
 | 
						|
                    exact_device=True,
 | 
						|
                    exact_layout=True,
 | 
						|
                    exact_is_coalesced=True,
 | 
						|
                )
 | 
						|
            except AssertionError as e:
 | 
						|
                # Raises the error if the precise dtype comparison wouldn't be
 | 
						|
                # different
 | 
						|
                if dtype is precise_dtype:
 | 
						|
                    raise e
 | 
						|
 | 
						|
                ex = e
 | 
						|
 | 
						|
 | 
						|
            # Goes to next sample if these results are close
 | 
						|
            if not ex:
 | 
						|
                continue
 | 
						|
 | 
						|
            # If the results are not close, checks that the
 | 
						|
            # reference is more accurate than the torch op
 | 
						|
            def _make_precise(x):
 | 
						|
                if isinstance(x, torch.dtype):
 | 
						|
                    return precise_dtype
 | 
						|
                if isinstance(x, torch.Tensor) and x.dtype is dtype:
 | 
						|
                    return x.to(precise_dtype)
 | 
						|
                return x
 | 
						|
 | 
						|
            precise_sample = sample.transform(_make_precise)
 | 
						|
            precise_result = op.torch_opinfo(precise_sample.input, *precise_sample.args, **precise_sample.kwargs)
 | 
						|
 | 
						|
            def _distance(a, b):
 | 
						|
                # Special-cases boolean comparisons
 | 
						|
                if prims.utils.is_boolean_dtype(a.dtype):
 | 
						|
                    assert b.dtype is torch.bool
 | 
						|
                    return (a ^ b).sum()
 | 
						|
 | 
						|
                same = (a == b)
 | 
						|
                if prims.utils.is_float_dtype(a.dtype) or prims.utils.is_complex_dtype(a.dtype):
 | 
						|
                    same = torch.logical_or(same, torch.logical_and(torch.isnan(a), torch.isnan(b)))
 | 
						|
 | 
						|
                actual_error = torch.where(same, 0, torch.abs(a - b)).sum()
 | 
						|
                return actual_error
 | 
						|
 | 
						|
            ref_distance = 0
 | 
						|
            for a, b in zip(tree_flatten(ref_result)[0], tree_flatten(precise_result)[0]):
 | 
						|
                ref_distance = ref_distance + _distance(a, b)
 | 
						|
 | 
						|
            torch_distance = 0
 | 
						|
            for a, b in zip(tree_flatten(torch_result)[0], tree_flatten(precise_result)[0]):
 | 
						|
                torch_distance = torch_distance + _distance(a, b)
 | 
						|
 | 
						|
            # TODO: consider adding some tolerance to this comparison
 | 
						|
            msg = f"Reference result was farther ({ref_distance}) from the precise " \
 | 
						|
                  f"computation than the torch result was ({torch_distance})!"
 | 
						|
            self.assertTrue(ref_distance <= torch_distance, msg=msg)
 | 
						|
 | 
						|
        # Reports numerical accuracy discrepancies
 | 
						|
        if ex is not None:
 | 
						|
            msg = "Test passed because the reference was more accurate than the torch operator."
 | 
						|
            warnings.warn(msg)
 | 
						|
 | 
						|
    # Tests that experimental Python References perform the same computation
 | 
						|
    # as the operators they reference, when operator calls in the torch
 | 
						|
    # namesapce are remapped to the refs namespace (torch.foo becomes refs.foo).
 | 
						|
    @onlyNativeDeviceTypes
 | 
						|
    @ops(python_ref_db)
 | 
						|
    @skipIfTorchInductor("Takes too long for inductor")
 | 
						|
    def test_python_ref(self, device, dtype, op):
 | 
						|
        # In this test, primTorch refs call into the refs namespace
 | 
						|
        # For example, a ref with torch.foo in it will calls refs.foo instead
 | 
						|
        # Direct calls to refs and prims are not affected
 | 
						|
        self._ref_test_helper(lambda: TorchRefsMode(strict=True), device, dtype, op)
 | 
						|
 | 
						|
    # Tests that experimental Python References perform the same computation
 | 
						|
    # as the operators they reference, when operator calls in the torch
 | 
						|
    # namespace are preserved (torch.foo remains torch.foo).
 | 
						|
    @onlyNativeDeviceTypes
 | 
						|
    @ops(python_ref_db)
 | 
						|
    @skipIfTorchInductor("Takes too long for inductor")
 | 
						|
    def test_python_ref_torch_fallback(self, device, dtype, op):
 | 
						|
        # In this test, refs call into the torch namespace (after the initial invocation)
 | 
						|
        # For example, a ref with torch.foo in it will call torch.foo instead of refs.foo
 | 
						|
        # Direct calls to refs and prims are not translated
 | 
						|
        self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
 | 
						|
 | 
						|
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
 | 
						|
    @onlyCUDA
 | 
						|
    @ops(python_ref_db)
 | 
						|
    @parametrize('executor', ['aten',])
 | 
						|
    @skipIfTorchInductor("Takes too long for inductor")
 | 
						|
    def test_python_ref_executor(self, device, dtype, op, executor):
 | 
						|
        # skip zero-dim tensors for some composites of reduction operations and view
 | 
						|
        skip_zero_dim_ops = [
 | 
						|
            "_refs.logsumexp",
 | 
						|
            "_refs.log_softmax",
 | 
						|
            "_refs.native_group_norm",
 | 
						|
            "_refs.softmax",
 | 
						|
            "_refs.sum_to_size",
 | 
						|
            "ops.nvprims.view",
 | 
						|
        ]
 | 
						|
 | 
						|
        from torch._prims.executor import make_traced
 | 
						|
        from copy import copy
 | 
						|
        op = copy(op)
 | 
						|
        op.op = partial(make_traced(op.op), executor=executor)
 | 
						|
        self._ref_test_helper(
 | 
						|
            contextlib.nullcontext,
 | 
						|
            device,
 | 
						|
            dtype,
 | 
						|
            op,
 | 
						|
        )
 | 
						|
 | 
						|
    @skipMeta
 | 
						|
    @onlyNativeDeviceTypes
 | 
						|
    @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
 | 
						|
    def test_errors(self, device, op):
 | 
						|
        error_inputs = op.error_inputs(device)
 | 
						|
        for ei in error_inputs:
 | 
						|
            si = ei.sample_input
 | 
						|
            with self.assertRaisesRegex(ei.error_type, ei.error_regex):
 | 
						|
                out = op(si.input, *si.args, **si.kwargs)
 | 
						|
                self.assertFalse(isinstance(out, type(NotImplemented)))
 | 
						|
 | 
						|
    @skipMeta
 | 
						|
    @onlyNativeDeviceTypes
 | 
						|
    @ops([op for op in op_db if op.error_inputs_sparse_func is not None], dtypes=OpDTypes.none)
 | 
						|
    @parametrize("layout", (torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc, torch.sparse_coo))
 | 
						|
    def test_errors_sparse(self, device, op, layout):
 | 
						|
        for ei in op.error_inputs_sparse(device, layout):
 | 
						|
            si = ei.sample_input
 | 
						|
            with self.assertRaisesRegex(ei.error_type, ei.error_regex):
 | 
						|
                out = op(si.input, *si.args, **si.kwargs)
 | 
						|
                self.assertFalse(isinstance(out, type(NotImplemented)))
 | 
						|
 | 
						|
    @skipMeta
 | 
						|
    @onlyNativeDeviceTypes
 | 
						|
    @ops([op for op in python_ref_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
 | 
						|
    @skipIfTorchInductor("Takes too long for inductor")
 | 
						|
    def test_python_ref_errors(self, device, op):
 | 
						|
        mode = FakeTensorMode()
 | 
						|
        with mode:
 | 
						|
            pass
 | 
						|
 | 
						|
        def _to_tensormeta(x):
 | 
						|
            if isinstance(x, torch.Tensor):
 | 
						|
                return FakeTensor.from_tensor(x, mode)
 | 
						|
            return x
 | 
						|
 | 
						|
        error_inputs = op.error_inputs(device)
 | 
						|
        for ei in error_inputs:
 | 
						|
            si = ei.sample_input
 | 
						|
            meta_sample = si.transform(_to_tensormeta)
 | 
						|
            with self.assertRaisesRegex(ei.error_type, ei.error_regex):
 | 
						|
                op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
 | 
						|
 | 
						|
    # Tests that the function produces the same result when called with
 | 
						|
    #   noncontiguous tensors.
 | 
						|
    # TODO: get working with Windows by addressing failing operators
 | 
						|
    # TODO: get working with ASAN by addressing failing operators
 | 
						|
    @unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
 | 
						|
    @onlyNativeDeviceTypes
 | 
						|
    @suppress_warnings
 | 
						|
    @ops(op_db, allowed_dtypes=(torch.float32, torch.long, torch.complex64))
 | 
						|
    def test_noncontiguous_samples(self, device, dtype, op):
 | 
						|
        test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
 | 
						|
        sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad)
 | 
						|
        for sample_input in sample_inputs:
 | 
						|
            t_inp, t_args, t_kwargs = (
 | 
						|
                sample_input.input,
 | 
						|
                sample_input.args,
 | 
						|
                sample_input.kwargs,
 | 
						|
            )
 | 
						|
            noncontig_sample = sample_input.noncontiguous()
 | 
						|
            n_inp, n_args, n_kwargs = (
 | 
						|
                noncontig_sample.input,
 | 
						|
                noncontig_sample.args,
 | 
						|
                noncontig_sample.kwargs,
 | 
						|
            )
 | 
						|
 | 
						|
            # validates forward
 | 
						|
            expected = op(t_inp, *t_args, **t_kwargs)
 | 
						|
            actual = op(n_inp, *n_args, **n_kwargs)
 | 
						|
 | 
						|
            self.assertEqual(actual, expected)
 | 
						|
 | 
						|
            # Validate backward
 | 
						|
            # Short-circuits if the op doesn't support grad in this device x dtype
 | 
						|
            if not test_grad:
 | 
						|
                continue
 | 
						|
 | 
						|
            expected = sample_input.output_process_fn_grad(expected)
 | 
						|
            actual = sample_input.output_process_fn_grad(actual)
 | 
						|
 | 
						|
            if isinstance(expected, torch.Tensor):
 | 
						|
                grad_for_expected = torch.randn_like(expected)
 | 
						|
                grad_for_actual = noncontiguous_like(grad_for_expected)
 | 
						|
            elif isinstance(expected, Sequence):
 | 
						|
                # Filter output elements that do not require grad
 | 
						|
                expected = [
 | 
						|
                    t
 | 
						|
                    for t in expected
 | 
						|
                    if isinstance(t, torch.Tensor) and t.requires_grad
 | 
						|
                ]
 | 
						|
                actual = [
 | 
						|
                    n for n in actual if isinstance(n, torch.Tensor) and n.requires_grad
 | 
						|
                ]
 | 
						|
                grad_for_expected = [torch.randn_like(t) for t in expected]
 | 
						|
                grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected]
 | 
						|
            else:
 | 
						|
                # Nothing to do if it returns a scalar or things like that
 | 
						|
                continue
 | 
						|
 | 
						|
            # Concatenate inputs into a tuple
 | 
						|
            t_inputs = (
 | 
						|
                (t_inp,) + t_args
 | 
						|
                if isinstance(t_inp, torch.Tensor)
 | 
						|
                else tuple(t_inp) + t_args
 | 
						|
            )
 | 
						|
            n_inputs = (
 | 
						|
                (n_inp,) + n_args
 | 
						|
                if isinstance(n_inp, torch.Tensor)
 | 
						|
                else tuple(n_inp) + n_args
 | 
						|
            )
 | 
						|
 | 
						|
            # Filter the elemnts that are tensors that require grad
 | 
						|
            t_input_tensors = [
 | 
						|
                t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad
 | 
						|
            ]
 | 
						|
            n_input_tensors = [
 | 
						|
                n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad
 | 
						|
            ]
 | 
						|
 | 
						|
            self.assertEqual(len(t_input_tensors), len(n_input_tensors))
 | 
						|
 | 
						|
            # Some functions may not use all the inputs to generate gradients. One of the
 | 
						|
            # few examples of this "odd" behaviour is F.hinge_embedding_loss
 | 
						|
            t_grads = torch.autograd.grad(
 | 
						|
                expected, t_input_tensors, grad_for_expected, allow_unused=True
 | 
						|
            )
 | 
						|
            n_grads = torch.autograd.grad(
 | 
						|
                actual, n_input_tensors, grad_for_actual, allow_unused=True
 | 
						|
            )
 | 
						|
 | 
						|
            msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}."
 | 
						|
            for i, (t, n) in enumerate(zip(t_grads, n_grads)):
 | 
						|
                self.assertEqual(t, n, msg=msg.format(i))
 | 
						|
 | 
						|
    # Separates one case from the following test_out because many ops don't properly implement the
 | 
						|
    #   incorrectly sized out parameter warning properly yet
 | 
						|
    # Cases test here:
 | 
						|
    #   - out= with the correct dtype and device, but the wrong shape
 | 
						|
    @ops(_ops_and_refs, dtypes=OpDTypes.none)
 | 
						|
    def test_out_warning(self, device, op):
 | 
						|
        # Prefers running in float32 but has a fallback for the first listed supported dtype
 | 
						|
        supported_dtypes = op.supported_dtypes(self.device_type)
 | 
						|
        if len(supported_dtypes) == 0:
 | 
						|
            self.skipTest("Skipped! Op has not supported dtypes on this device.")
 | 
						|
        dtype = (
 | 
						|
            torch.float32
 | 
						|
            if torch.float32 in supported_dtypes
 | 
						|
            else list(supported_dtypes)[0]
 | 
						|
        )
 | 
						|
 | 
						|
        # Ops from python_ref_db point to python decomps that are potentially
 | 
						|
        # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these
 | 
						|
        # ops before testing to avoid clashing with OpInfo.supports_out
 | 
						|
        if not op.supports_out:
 | 
						|
            op = copy.copy(op)
 | 
						|
            op.op = _maybe_remove_out_wrapper(op.op)
 | 
						|
 | 
						|
        samples = op.sample_inputs(device, dtype)
 | 
						|
        for sample in samples:
 | 
						|
            # calls it normally to get the expected result
 | 
						|
            expected = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
            op_out = partial(op, sample.input, *sample.args, **sample.kwargs)
 | 
						|
 | 
						|
            # Short-circuits if output is not a single tensor or an
 | 
						|
            #   iterable of tensors
 | 
						|
            if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
 | 
						|
                expected, include_empty=True
 | 
						|
            ):
 | 
						|
                self.skipTest(
 | 
						|
                    "Skipped! Only supports single tensor or iterable of tensor outputs."
 | 
						|
                )
 | 
						|
 | 
						|
            # Validates the op doesn't support out if it claims not to
 | 
						|
            if not op.supports_out:
 | 
						|
                with self.assertRaises(Exception):
 | 
						|
                    assert op_out(out=expected) != NotImplemented
 | 
						|
                return
 | 
						|
 | 
						|
            # A wrapper around map that works with single tensors and always
 | 
						|
            #   instantiates the map. Used below to apply transforms to
 | 
						|
            #   single tensor and iterable tensor outputs.
 | 
						|
            def _apply_out_transform(fn, out):
 | 
						|
                if isinstance(out, torch.Tensor):
 | 
						|
                    return fn(out)
 | 
						|
 | 
						|
                # assumes (see above) that out is an iterable of tensors
 | 
						|
                return tuple(map(fn, out))
 | 
						|
 | 
						|
            # Extracts strides from a tensor or iterable of tensors into a tuple
 | 
						|
            def _extract_strides(out):
 | 
						|
                if isinstance(out, torch.Tensor):
 | 
						|
                    return (out.stride(),)
 | 
						|
 | 
						|
                # assumes (see above) that out is an iterable of tensors
 | 
						|
                return tuple(t.stride() for t in out)
 | 
						|
 | 
						|
            # Extracts data pointers from a tensor or iterable of tensors into a tuple
 | 
						|
            # NOTE: only extracts on the CPU and CUDA device types since some
 | 
						|
            #   device types don't have storage
 | 
						|
            def _extract_data_ptrs(out):
 | 
						|
                if self.device_type != "cpu" and self.device_type != "cuda":
 | 
						|
                    return ()
 | 
						|
 | 
						|
                if isinstance(out, torch.Tensor):
 | 
						|
                    return (out.data_ptr(),)
 | 
						|
 | 
						|
                # assumes (see above) that out is an iterable of tensors
 | 
						|
                return tuple(t.data_ptr() for t in out)
 | 
						|
 | 
						|
            @suppress_warnings
 | 
						|
            def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
 | 
						|
                out = _apply_out_transform(transform, expected)
 | 
						|
                original_strides = _extract_strides(out)
 | 
						|
                original_ptrs = _extract_data_ptrs(out)
 | 
						|
 | 
						|
                op_out(out=out)
 | 
						|
                final_strides = _extract_strides(out)
 | 
						|
                final_ptrs = _extract_data_ptrs(out)
 | 
						|
 | 
						|
                self.assertEqual(expected, out)
 | 
						|
 | 
						|
                if compare_strides_and_data_ptrs:
 | 
						|
                    stride_msg = "Strides are not the same! Original strides were {} and strides are now {}".format(
 | 
						|
                        original_strides, final_strides
 | 
						|
                    )
 | 
						|
                    self.assertEqual(original_strides, final_strides, msg=stride_msg)
 | 
						|
                    self.assertEqual(original_ptrs, final_ptrs)
 | 
						|
 | 
						|
            # Case Zero: out= with the correct dtype and device, but the wrong shape
 | 
						|
            #   Expected behavior: if nonempty, resize with a warning.
 | 
						|
            def _case_zero_transform(t):
 | 
						|
                wrong_shape = list(t.shape)
 | 
						|
 | 
						|
                if len(wrong_shape) == 0:
 | 
						|
                    # Handles scalar tensor case (empty list)
 | 
						|
                    wrong_shape = [2]
 | 
						|
                else:
 | 
						|
                    wrong_shape[-1] = wrong_shape[-1] + 1
 | 
						|
                return make_tensor(wrong_shape, dtype=t.dtype, device=t.device)
 | 
						|
 | 
						|
            # Verifies the out values are correct
 | 
						|
            _compare_out(_case_zero_transform, compare_strides_and_data_ptrs=False)
 | 
						|
 | 
						|
            # Additionally validates that the appropriate warning is thrown if a nonempty
 | 
						|
            #   tensor is resized.
 | 
						|
            def _any_nonempty(out):
 | 
						|
                if isinstance(out, torch.Tensor):
 | 
						|
                    return out.numel() > 0
 | 
						|
 | 
						|
                return any(x.numel() > 0 for x in out)
 | 
						|
 | 
						|
            out = _apply_out_transform(_case_zero_transform, expected)
 | 
						|
            msg_fail = "Resized a non-empty tensor but did not warn about it."
 | 
						|
            if _any_nonempty(out):
 | 
						|
                with self.assertWarnsRegex(
 | 
						|
                    UserWarning, "An output with one or more elements", msg=msg_fail
 | 
						|
                ):
 | 
						|
                    op_out(out=out)
 | 
						|
 | 
						|
    # Validates ops implement the correct out= behavior
 | 
						|
    # See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
 | 
						|
    #   for a description of the correct behavior
 | 
						|
    # Validates the following cases:
 | 
						|
    #   - Case 0: out has the correct shape, dtype, and device but is full of extremal values
 | 
						|
    #   - Case 1: out has the correct shape, dtype, and device but is noncontiguous
 | 
						|
    #   - Case 2: out has the correct dtype and device, but is zero elements
 | 
						|
    #   - Case 3: out has the correct shape and dtype, but is on a different device type
 | 
						|
    #   - Case 4: out has the correct shape and device, but a dtype that cannot
 | 
						|
    #       "safely" cast to
 | 
						|
    #
 | 
						|
    # Case 3 and 4 are slightly different when the op is a factory function:
 | 
						|
    #   - if device, dtype are NOT passed, any combination of dtype/device should be OK for out
 | 
						|
    #   - if device, dtype are passed, device and dtype should match
 | 
						|
    @ops(_ops_and_refs, dtypes=OpDTypes.any_one)
 | 
						|
    def test_out(self, device, dtype, op):
 | 
						|
        # Prefers running in float32 but has a fallback for the first listed supported dtype
 | 
						|
        samples = op.sample_inputs(device, dtype)
 | 
						|
 | 
						|
        # Ops from python_ref_db point to python decomps that are potentially
 | 
						|
        # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these
 | 
						|
        # ops before testing to avoid clashing with OpInfo.supports_out
 | 
						|
        if not op.supports_out:
 | 
						|
            op = copy.copy(op)
 | 
						|
            op.op = _maybe_remove_out_wrapper(op.op)
 | 
						|
 | 
						|
        for sample in samples:
 | 
						|
            # calls it normally to get the expected result
 | 
						|
            expected = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
            op_out = partial(op, sample.input, *sample.args, **sample.kwargs)
 | 
						|
 | 
						|
            # Short-circuits if output is not a single tensor or an
 | 
						|
            #   iterable of tensors
 | 
						|
            if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
 | 
						|
                expected, include_empty=True
 | 
						|
            ):
 | 
						|
                self.skipTest(
 | 
						|
                    "Skipped! Only supports single tensor or iterable of tensor outputs."
 | 
						|
                )
 | 
						|
 | 
						|
            # Validates the op doesn't support out if it claims not to
 | 
						|
            if not op.supports_out:
 | 
						|
                with self.assertRaises(Exception):
 | 
						|
                    assert op_out(out=expected) != NotImplemented
 | 
						|
                return
 | 
						|
 | 
						|
            # A wrapper around map that works with single tensors and always
 | 
						|
            #   instantiates the map. Used below to apply transforms to
 | 
						|
            #   single tensor and iterable tensor outputs.
 | 
						|
            def _apply_out_transform(fn, out):
 | 
						|
                if isinstance(out, torch.Tensor):
 | 
						|
                    return fn(out)
 | 
						|
 | 
						|
                # assumes (see above) that out is an iterable of tensors
 | 
						|
                return tuple(map(fn, out))
 | 
						|
 | 
						|
            # Extracts strides from a tensor or iterable of tensors into a tuple
 | 
						|
            def _extract_strides(out):
 | 
						|
                if isinstance(out, torch.Tensor):
 | 
						|
                    return (out.stride(),)
 | 
						|
 | 
						|
                # assumes (see above) that out is an iterable of tensors
 | 
						|
                return tuple(t.stride() for t in out)
 | 
						|
 | 
						|
            # Extracts data pointers from a tensor or iterable of tensors into a tuple
 | 
						|
            # NOTE: only extracts on the CPU and CUDA device types since some
 | 
						|
            #   device types don't have storage
 | 
						|
            def _extract_data_ptrs(out):
 | 
						|
                if self.device_type != "cpu" and self.device_type != "cuda":
 | 
						|
                    return ()
 | 
						|
 | 
						|
                if isinstance(out, torch.Tensor):
 | 
						|
                    return (out.data_ptr(),)
 | 
						|
 | 
						|
                # assumes (see above) that out is an iterable of tensors
 | 
						|
                return tuple(t.data_ptr() for t in out)
 | 
						|
 | 
						|
            def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
 | 
						|
                out = _apply_out_transform(transform, expected)
 | 
						|
                original_strides = _extract_strides(out)
 | 
						|
                original_ptrs = _extract_data_ptrs(out)
 | 
						|
 | 
						|
                op_out(out=out)
 | 
						|
                final_strides = _extract_strides(out)
 | 
						|
                final_ptrs = _extract_data_ptrs(out)
 | 
						|
                self.assertEqual(expected, out)
 | 
						|
 | 
						|
                if compare_strides_and_data_ptrs:
 | 
						|
                    stride_msg = "Strides are not the same! Original strides were {} and strides are now {}".format(
 | 
						|
                        original_strides, final_strides
 | 
						|
                    )
 | 
						|
                    self.assertEqual(original_strides, final_strides, msg=stride_msg)
 | 
						|
                    self.assertEqual(original_ptrs, final_ptrs)
 | 
						|
 | 
						|
            # Case 0: out= with the correct shape, dtype, and device
 | 
						|
            #   but NaN values for floating point and complex tensors, and
 | 
						|
            #   maximum values for integer tensors.
 | 
						|
            #   Expected behavior: out= values have no effect on the computation.
 | 
						|
            def _case_zero_transform(t):
 | 
						|
                try:
 | 
						|
                    info = torch.iinfo(t.dtype)
 | 
						|
                    return torch.full_like(t, info.max)
 | 
						|
                except TypeError as te:
 | 
						|
                    # for non-integer types fills with NaN
 | 
						|
                    return torch.full_like(t, float("nan"))
 | 
						|
 | 
						|
 | 
						|
            _compare_out(_case_zero_transform)
 | 
						|
 | 
						|
            # Case 1: out= with the correct shape, dtype, and device,
 | 
						|
            #   but noncontiguous.
 | 
						|
            #   Expected behavior: strides are respected and `out` storage is not changed.
 | 
						|
            def _case_one_transform(t):
 | 
						|
                return make_tensor(
 | 
						|
                    t.shape, dtype=t.dtype, device=t.device, noncontiguous=True
 | 
						|
                )
 | 
						|
 | 
						|
            _compare_out(_case_one_transform)
 | 
						|
 | 
						|
            # Case 2: out= with the correct dtype and device, but has no elements.
 | 
						|
            #   Expected behavior: resize without warning.
 | 
						|
            def _case_two_transform(t):
 | 
						|
                return make_tensor((0,), dtype=t.dtype, device=t.device)
 | 
						|
 | 
						|
            _compare_out(_case_two_transform, compare_strides_and_data_ptrs=False)
 | 
						|
 | 
						|
            # Also validates that no warning is thrown when this out is resized
 | 
						|
            out = _apply_out_transform(_case_two_transform, expected)
 | 
						|
            with warnings.catch_warnings(record=True) as caught:
 | 
						|
                warnings.simplefilter("always")
 | 
						|
                op_out(out=out)
 | 
						|
 | 
						|
            # Verifies no warning is a resize warning
 | 
						|
            for w in caught:
 | 
						|
                if "An output with one or more elements" in str(w.message):
 | 
						|
                    self.fail(
 | 
						|
                        "Resizing an out= argument with no elements threw a resize warning!"
 | 
						|
                    )
 | 
						|
 | 
						|
            # Case 3: out= with correct shape and dtype, but wrong device.
 | 
						|
            wrong_device = None
 | 
						|
            if torch.device(device).type != "cpu":
 | 
						|
                wrong_device = "cpu"
 | 
						|
            elif torch.cuda.is_available():
 | 
						|
                wrong_device = "cuda"
 | 
						|
 | 
						|
 | 
						|
            factory_fn_msg = (
 | 
						|
                "\n\nNOTE: If your op is a factory function (i.e., it accepts TensorOptions) you should mark its "
 | 
						|
                "OpInfo with `is_factory_function=True`."
 | 
						|
            )
 | 
						|
            if wrong_device is not None:
 | 
						|
 | 
						|
                def _case_three_transform(t):
 | 
						|
                    return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
 | 
						|
 | 
						|
                out = _apply_out_transform(_case_three_transform, expected)
 | 
						|
 | 
						|
                if op.is_factory_function and sample.kwargs.get("device", None) is None:
 | 
						|
                    op_out(out=out)
 | 
						|
                else:
 | 
						|
                    msg_fail = (
 | 
						|
                        f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}."
 | 
						|
                    ) + factory_fn_msg
 | 
						|
                    with self.assertRaises(RuntimeError, msg=msg_fail):
 | 
						|
                        op_out(out=out)
 | 
						|
 | 
						|
            # Case 4: out= with correct shape and device, but a dtype
 | 
						|
            #   that output cannot be "safely" cast to (long).
 | 
						|
            #   Expected behavior: error.
 | 
						|
            # NOTE: this case is filtered by dtype since some ops produce
 | 
						|
            #   bool tensors, for example, which can be safely cast to any
 | 
						|
            #   dtype. It is applied when single tensors are floating point or complex
 | 
						|
            #   dtypes, or if an op returns multiple tensors when at least one such
 | 
						|
            #   tensor is a floating point or complex dtype.
 | 
						|
            _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
 | 
						|
            if (
 | 
						|
                isinstance(expected, torch.Tensor)
 | 
						|
                and expected.dtype in _dtypes
 | 
						|
                or (
 | 
						|
                    not isinstance(expected, torch.Tensor)
 | 
						|
                    and any(t.dtype in _dtypes for t in expected)
 | 
						|
                )
 | 
						|
            ):
 | 
						|
 | 
						|
                def _case_four_transform(t):
 | 
						|
                    return make_tensor(t.shape, dtype=torch.long, device=t.device)
 | 
						|
 | 
						|
                out = _apply_out_transform(_case_four_transform, expected)
 | 
						|
                msg_fail = "Expected RuntimeError when doing an unsafe cast!"
 | 
						|
                msg_fail = (
 | 
						|
                    msg_fail
 | 
						|
                    if not isinstance(expected, torch.Tensor)
 | 
						|
                    else (
 | 
						|
                        "Expected RuntimeError when doing an unsafe cast from a result of dtype "
 | 
						|
                        f"{expected.dtype} into an out= with dtype torch.long"
 | 
						|
                    )
 | 
						|
                ) + factory_fn_msg
 | 
						|
 | 
						|
                if op.is_factory_function and sample.kwargs.get("dtype", None) is None:
 | 
						|
                    op_out(out=out)
 | 
						|
                else:
 | 
						|
                    with self.assertRaises(RuntimeError, msg=msg_fail):
 | 
						|
                        op_out(out=out)
 | 
						|
 | 
						|
 | 
						|
    @ops(filter(reduction_dtype_filter, _ops_and_refs), dtypes=(torch.int16,))
 | 
						|
    def test_out_integral_dtype(self, device, dtype, op):
 | 
						|
        def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs):
 | 
						|
            out = None
 | 
						|
            try:
 | 
						|
                if with_out:
 | 
						|
                    out = torch.empty(0, dtype=torch.int32, device=device)
 | 
						|
                    op_to_test(inputs, out=out, *args, **kwargs)
 | 
						|
                else:
 | 
						|
                    out = op_to_test(inputs, *args, **kwargs)
 | 
						|
                self.assertFalse(expectFail)
 | 
						|
            except RuntimeError as err:
 | 
						|
                self.assertEqual(
 | 
						|
                    str(err), "dtype argument and out dtype must match in reduction")
 | 
						|
                self.assertTrue(expectFail)
 | 
						|
            return out
 | 
						|
        samples = op.sample_inputs(device, dtype)
 | 
						|
        for sample in samples:
 | 
						|
            if 'dtype' not in sample.kwargs:
 | 
						|
                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
 | 
						|
                helper(True, False, op, sample.input, *sample.args, **sample.kwargs)
 | 
						|
                sample.kwargs['dtype'] = torch.int16
 | 
						|
                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
 | 
						|
                helper(True, True, op, sample.input, *sample.args, **sample.kwargs)
 | 
						|
                sample.kwargs['dtype'] = torch.int32
 | 
						|
                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
 | 
						|
                helper(True, False, op, sample.input, *sample.args, **sample.kwargs)
 | 
						|
            else:
 | 
						|
                helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
 | 
						|
                helper(True, sample.kwargs['dtype'] != torch.int32, op, sample.input,
 | 
						|
                       *sample.args, **sample.kwargs)
 | 
						|
 | 
						|
    # Tests that the forward and backward passes of operations produce the
 | 
						|
    #   same values for the cross-product of op variants (method, inplace)
 | 
						|
    #   against eager's gold standard op function variant
 | 
						|
    @_variant_ops(op_db)
 | 
						|
    def test_variant_consistency_eager(self, device, dtype, op):
 | 
						|
        # Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases)
 | 
						|
 | 
						|
        method = op.method_variant
 | 
						|
        inplace = op.inplace_variant
 | 
						|
        operator = op.operator_variant
 | 
						|
        inplace_operator = op.inplace_operator_variant
 | 
						|
 | 
						|
 | 
						|
        # list of all inplace ops: inplace variant + alias inplace variants if exist
 | 
						|
        inplace_ops = [inplace, inplace_operator]
 | 
						|
        variants = [method, inplace, operator, inplace_operator]
 | 
						|
        operators = [operator, inplace_operator]
 | 
						|
 | 
						|
        for a_op in op.aliases:
 | 
						|
            variants.append(a_op.op)
 | 
						|
            variants.append(a_op.method_variant)
 | 
						|
            variants.append(a_op.inplace_variant)
 | 
						|
            inplace_ops.append(a_op.inplace_variant)
 | 
						|
 | 
						|
        inplace_variants = tuple(filter(None, inplace_ops))
 | 
						|
        variants = tuple(filter(None, variants))
 | 
						|
        operators = tuple(filter(None, operators))
 | 
						|
 | 
						|
        _requires_grad = dtype in op.supported_backward_dtypes(
 | 
						|
            torch.device(device).type
 | 
						|
        )
 | 
						|
 | 
						|
        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
 | 
						|
        samples = op.sample_inputs(
 | 
						|
            device,
 | 
						|
            dtype,
 | 
						|
            requires_grad=_requires_grad,
 | 
						|
            include_conjugated_inputs=include_conjugated_inputs,
 | 
						|
        )
 | 
						|
        samples = list(samples)
 | 
						|
 | 
						|
        def _test_consistency_helper(samples, variants):
 | 
						|
            for sample in samples:
 | 
						|
                # TODO: Check grad for all Tensors requiring grad if sample.input is TensorList
 | 
						|
                tensor = (
 | 
						|
                    sample.input
 | 
						|
                    if isinstance(sample.input, torch.Tensor)
 | 
						|
                    else sample.input[0]
 | 
						|
                )
 | 
						|
 | 
						|
                # Computes function forward and backward values
 | 
						|
                tensor.grad = None
 | 
						|
                expected_forward = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
                expected_grad = None
 | 
						|
 | 
						|
                output_process_fn_grad = (
 | 
						|
                    sample.output_process_fn_grad
 | 
						|
                    if sample.output_process_fn_grad
 | 
						|
                    else lambda x: x
 | 
						|
                )
 | 
						|
 | 
						|
                # Skips inplace variants if the output dtype is not the same as
 | 
						|
                #   the input dtype
 | 
						|
                skip_inplace = False
 | 
						|
                if (
 | 
						|
                    isinstance(expected_forward, torch.Tensor)
 | 
						|
                    and expected_forward.dtype is not tensor.dtype
 | 
						|
                ):
 | 
						|
                    skip_inplace = True
 | 
						|
 | 
						|
                # TODO: backward consistency only supported for single tensor outputs
 | 
						|
                # TODO: backward consistency only checked on sample.input, not all
 | 
						|
                #   tensor inputs
 | 
						|
                # TODO: update to handle checking grads of all tensor inputs as
 | 
						|
                #   derived from each tensor output
 | 
						|
                if isinstance(
 | 
						|
                    expected_forward, torch.Tensor
 | 
						|
                ) and dtype in op.supported_backward_dtypes(torch.device(device).type):
 | 
						|
                    out = output_process_fn_grad(expected_forward).sum()
 | 
						|
                    if out.dtype.is_complex:
 | 
						|
                        out = out.abs()
 | 
						|
                    out.backward()
 | 
						|
                    expected_grad = tensor.grad
 | 
						|
 | 
						|
                # Test eager consistency
 | 
						|
                for variant in variants:
 | 
						|
                    # Skips inplace ops
 | 
						|
                    if variant in inplace_ops and skip_inplace:
 | 
						|
                        continue
 | 
						|
 | 
						|
                    # Compares variant's forward
 | 
						|
                    # Note: copies the to-be-modified input when testing the inplace variant
 | 
						|
                    tensor.grad = None
 | 
						|
                    cloned = (
 | 
						|
                        clone_input_helper(sample.input)
 | 
						|
                        if variant in inplace_ops
 | 
						|
                        else sample.input
 | 
						|
                    )
 | 
						|
 | 
						|
                    if variant in inplace_ops and sample.broadcasts_input:
 | 
						|
                        with self.assertRaises(
 | 
						|
                            RuntimeError,
 | 
						|
                            msg=(
 | 
						|
                                "inplace variant either incorrectly allowed "
 | 
						|
                                f"resizing or you have marked the sample {sample.summary()}"
 | 
						|
                                " incorrectly with `broadcasts_self=True"
 | 
						|
                            ),
 | 
						|
                        ):
 | 
						|
                            variant_forward = variant(
 | 
						|
                                cloned, *sample.args, **sample.kwargs
 | 
						|
                            )
 | 
						|
                        continue
 | 
						|
 | 
						|
                    if variant in operators and sample.kwargs:
 | 
						|
                        # skip samples with kwargs for operator variants
 | 
						|
                        continue
 | 
						|
 | 
						|
                    variant_forward = variant(cloned, *sample.args, **sample.kwargs)
 | 
						|
                    self.assertEqual(expected_forward, variant_forward)
 | 
						|
 | 
						|
                    # Compares variant's backward
 | 
						|
                    if expected_grad is not None and (
 | 
						|
                        variant not in inplace_ops or op.supports_inplace_autograd
 | 
						|
                    ):
 | 
						|
                        out = output_process_fn_grad(variant_forward).sum()
 | 
						|
                        if out.dtype.is_complex:
 | 
						|
                            out = out.abs()
 | 
						|
                        out.backward()
 | 
						|
                        self.assertEqual(expected_grad, tensor.grad)
 | 
						|
 | 
						|
        _test_consistency_helper(samples, variants)
 | 
						|
 | 
						|
        def _test_inplace_preserve_storage(samples, variants):
 | 
						|
            for sample in samples:
 | 
						|
                # Skips inplace variants if the output dtype is not the same as
 | 
						|
                #   the input dtype
 | 
						|
                expected_forward = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
                tensor = (
 | 
						|
                    sample.input
 | 
						|
                    if isinstance(sample.input, torch.Tensor)
 | 
						|
                    else sample.input[0]
 | 
						|
                )
 | 
						|
                skip_inplace = False
 | 
						|
                if (
 | 
						|
                    isinstance(expected_forward, torch.Tensor)
 | 
						|
                    and expected_forward.dtype is not tensor.dtype
 | 
						|
                ):
 | 
						|
                    skip_inplace = True
 | 
						|
                if skip_inplace:
 | 
						|
                    return
 | 
						|
                for variant in variants:
 | 
						|
                    cloned = (
 | 
						|
                        clone_input_helper(sample.input)
 | 
						|
                        if variant in inplace_ops
 | 
						|
                        else sample.input
 | 
						|
                    )
 | 
						|
                    inp_tensor = (
 | 
						|
                        cloned if isinstance(cloned, torch.Tensor) else cloned[0]
 | 
						|
                    )
 | 
						|
                    data_ptr = inp_tensor.data_ptr()
 | 
						|
                    if variant in operators and sample.kwargs:
 | 
						|
                        # skip samples with kwargs for operator variants
 | 
						|
                        continue
 | 
						|
 | 
						|
                    variant_forward = variant(cloned, *sample.args, **sample.kwargs)
 | 
						|
                    # TODO Support non-tensor outputs if they exist for inplace ops
 | 
						|
                    if isinstance(variant_forward, torch.Tensor):
 | 
						|
                        self.assertEqual(
 | 
						|
                            data_ptr, variant_forward.data_ptr(), atol=0, rtol=0
 | 
						|
                        )
 | 
						|
                    else:
 | 
						|
                        self.assertTrue(
 | 
						|
                            False,
 | 
						|
                            "Non-tensor outputs for inplace ops are not supported",
 | 
						|
                        )
 | 
						|
 | 
						|
        if len(inplace_ops) > 0:
 | 
						|
            inplace_samples = list(
 | 
						|
                filter(lambda sample: not sample.broadcasts_input, samples)
 | 
						|
            )
 | 
						|
            _test_inplace_preserve_storage(inplace_samples, inplace_variants)
 | 
						|
 | 
						|
    # Reference testing for operations in complex32 against complex64.
 | 
						|
    # NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype.
 | 
						|
    @ops(op_db, allowed_dtypes=(torch.complex32,))
 | 
						|
    def test_complex_half_reference_testing(self, device, dtype, op):
 | 
						|
        if not op.supports_dtype(torch.complex32, device):
 | 
						|
            unittest.skip("Does not support complex32")
 | 
						|
 | 
						|
        for sample in op.sample_inputs(device, dtype):
 | 
						|
            actual = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
            # sample.transform applies the lambda to torch.Tensor and torch.dtype.
 | 
						|
            # However, we only want to apply it to Tensors with dtype `torch.complex32`..
 | 
						|
            transformed_sample = sample.transform(lambda x: x.to(torch.complex64) if isinstance(
 | 
						|
                x, torch.Tensor) and x.dtype is torch.complex32 else x)
 | 
						|
            expected = op(
 | 
						|
                transformed_sample.input,
 | 
						|
                *transformed_sample.args,
 | 
						|
                **transformed_sample.kwargs,
 | 
						|
            )
 | 
						|
            # Since range of chalf is much less compared to cfloat,
 | 
						|
            # we get `inf`s easily (eg. with `pow`, `exp`),
 | 
						|
            # so we cast `cfloat` back to `chalf`.
 | 
						|
            expected = tree_map(lambda x: x.to(torch.complex32) if isinstance(
 | 
						|
                x, torch.Tensor) and x.dtype is torch.complex64 else x, expected)
 | 
						|
 | 
						|
            # `exact_dtype` is False because for ops like real, imag
 | 
						|
            # we get different dtypes for `actual` and `expected`
 | 
						|
            # `chalf` input -> `half` output
 | 
						|
            # `cfloat` input -> `float` output
 | 
						|
            self.assertEqual(actual, expected, exact_dtype=False)
 | 
						|
 | 
						|
 | 
						|
    @ops(op_db, allowed_dtypes=(torch.bool,))
 | 
						|
    @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior")
 | 
						|
    def test_non_standard_bool_values(self, device, dtype, op):
 | 
						|
        # Test boolean values other than 0x00 and 0x01 (gh-54789)
 | 
						|
        def convert_boolean_tensors(x):
 | 
						|
            if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
 | 
						|
                return x
 | 
						|
 | 
						|
            # Map False -> 0 and True -> Random value in [2, 255]
 | 
						|
            true_vals = torch.randint(2, 255, x.shape, dtype=torch.uint8, device=x.device)
 | 
						|
            false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
 | 
						|
            x_int = torch.where(x, true_vals, false_vals)
 | 
						|
 | 
						|
            ret = x_int.view(torch.bool)
 | 
						|
            self.assertEqual(ret, x)
 | 
						|
            return ret
 | 
						|
 | 
						|
        for sample in op.sample_inputs(device, dtype):
 | 
						|
            expect = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
 | 
						|
            transformed = sample.transform(convert_boolean_tensors)
 | 
						|
            actual = op(transformed.input, *transformed.args, **transformed.kwargs)
 | 
						|
 | 
						|
            self.assertEqual(expect, actual)
 | 
						|
 | 
						|
    # Validates that each OpInfo specifies its forward and backward dtypes
 | 
						|
    #   correctly for CPU and CUDA devices
 | 
						|
    @skipMeta
 | 
						|
    @onlyNativeDeviceTypes
 | 
						|
    @ops(ops_and_refs, dtypes=OpDTypes.none)
 | 
						|
    def test_dtypes(self, device, op):
 | 
						|
        # Check complex32 support only if the op claims.
 | 
						|
        # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally.
 | 
						|
        device_type = torch.device(device).type
 | 
						|
        include_complex32 = (
 | 
						|
            (torch.complex32,)
 | 
						|
            if op.supports_dtype(torch.complex32, device_type)
 | 
						|
            else ()
 | 
						|
        )
 | 
						|
 | 
						|
        # dtypes to try to backward in
 | 
						|
        allowed_backward_dtypes = floating_and_complex_types_and(
 | 
						|
            *((torch.half, torch.bfloat16) + include_complex32)
 | 
						|
        )
 | 
						|
 | 
						|
        # lists for (un)supported dtypes
 | 
						|
        supported_dtypes = set()
 | 
						|
        unsupported_dtypes = set()
 | 
						|
        supported_backward_dtypes = set()
 | 
						|
        unsupported_backward_dtypes = set()
 | 
						|
        dtype_error: Dict[torch.dtype, Exception] = dict()
 | 
						|
 | 
						|
        def unsupported(dtype, e):
 | 
						|
            dtype_error[dtype] = e
 | 
						|
            unsupported_dtypes.add(dtype)
 | 
						|
            if dtype in allowed_backward_dtypes:
 | 
						|
                unsupported_backward_dtypes.add(dtype)
 | 
						|
 | 
						|
        for dtype in all_types_and_complex_and(
 | 
						|
            *((torch.half, torch.bfloat16, torch.bool) + include_complex32)
 | 
						|
        ):
 | 
						|
            # tries to acquire samples - failure indicates lack of support
 | 
						|
            requires_grad = dtype in allowed_backward_dtypes
 | 
						|
            try:
 | 
						|
                samples = tuple(
 | 
						|
                    op.sample_inputs(device, dtype, requires_grad=requires_grad)
 | 
						|
                )
 | 
						|
            except Exception as e:
 | 
						|
                unsupported(dtype, e)
 | 
						|
                continue
 | 
						|
 | 
						|
            for sample in samples:
 | 
						|
                # tries to call operator with the sample - failure indicates
 | 
						|
                #   lack of support
 | 
						|
                try:
 | 
						|
                    result = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
                    supported_dtypes.add(dtype)
 | 
						|
                except Exception as e:
 | 
						|
                    # NOTE: some ops will fail in forward if their inputs
 | 
						|
                    #   require grad but they don't support computing the gradient
 | 
						|
                    #   in that type! This is a bug in the op!
 | 
						|
                    unsupported(dtype, e)
 | 
						|
                    continue
 | 
						|
 | 
						|
                # Checks for backward support in the same dtype, if the input has
 | 
						|
                # one or more tensors requiring grad
 | 
						|
                def _tensor_requires_grad(x):
 | 
						|
                    if isinstance(x, dict):
 | 
						|
                        for v in x.values():
 | 
						|
                            if _tensor_requires_grad(v):
 | 
						|
                                return True
 | 
						|
                    if isinstance(x, (list, tuple)):
 | 
						|
                        for a in x:
 | 
						|
                            if _tensor_requires_grad(a):
 | 
						|
                                return True
 | 
						|
                    if isinstance(x, torch.Tensor) and x.requires_grad:
 | 
						|
                        return True
 | 
						|
 | 
						|
                    return False
 | 
						|
 | 
						|
                requires_grad = _tensor_requires_grad(sample.input) \
 | 
						|
                    or _tensor_requires_grad(sample.args) or _tensor_requires_grad(sample.kwargs)
 | 
						|
                if not requires_grad:
 | 
						|
                    continue
 | 
						|
 | 
						|
                try:
 | 
						|
                    result = sample.output_process_fn_grad(result)
 | 
						|
                    if isinstance(result, torch.Tensor):
 | 
						|
                        backward_tensor = result
 | 
						|
                    elif isinstance(result, Sequence) and isinstance(
 | 
						|
                        result[0], torch.Tensor
 | 
						|
                    ):
 | 
						|
                        backward_tensor = result[0]
 | 
						|
                    else:
 | 
						|
                        continue
 | 
						|
 | 
						|
                    # Note: this grad may not have the same dtype as dtype
 | 
						|
                    # For functions like complex (float -> complex) or abs
 | 
						|
                    #   (complex -> float) the grad tensor will have a
 | 
						|
                    #   different dtype than the input.
 | 
						|
                    #   For simplicity, this is still modeled as these ops
 | 
						|
                    #   supporting grad in the input dtype.
 | 
						|
                    grad = torch.randn_like(backward_tensor)
 | 
						|
                    backward_tensor.backward(grad)
 | 
						|
                    supported_backward_dtypes.add(dtype)
 | 
						|
                except Exception as e:
 | 
						|
                    dtype_error[dtype] = e
 | 
						|
                    unsupported_backward_dtypes.add(dtype)
 | 
						|
 | 
						|
        # Checks that dtypes are listed correctly and generates an informative
 | 
						|
        #   error message
 | 
						|
 | 
						|
        supported_forward = supported_dtypes - unsupported_dtypes
 | 
						|
        partially_supported_forward = supported_dtypes & unsupported_dtypes
 | 
						|
        unsupported_forward = unsupported_dtypes - supported_dtypes
 | 
						|
        supported_backward = supported_backward_dtypes - unsupported_backward_dtypes
 | 
						|
        partially_supported_backward = (
 | 
						|
            supported_backward_dtypes & unsupported_backward_dtypes
 | 
						|
        )
 | 
						|
        unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes
 | 
						|
 | 
						|
        device_type = torch.device(device).type
 | 
						|
 | 
						|
        claimed_forward = set(op.supported_dtypes(device_type))
 | 
						|
        supported_but_unclaimed_forward = supported_forward - claimed_forward
 | 
						|
        claimed_but_unsupported_forward = claimed_forward & unsupported_forward
 | 
						|
 | 
						|
        claimed_backward = set(op.supported_backward_dtypes(device_type))
 | 
						|
        supported_but_unclaimed_backward = supported_backward - claimed_backward
 | 
						|
        claimed_but_unsupported_backward = claimed_backward & unsupported_backward
 | 
						|
 | 
						|
        # Partially supporting a dtype is not an error, but we print a warning
 | 
						|
        if (len(partially_supported_forward) + len(partially_supported_backward)) > 0:
 | 
						|
            msg = f"Some dtypes for {op.name} on device type {device_type} are only partially supported!\n"
 | 
						|
            if len(partially_supported_forward) > 0:
 | 
						|
                msg = (
 | 
						|
                    msg
 | 
						|
                    + "The following dtypes only worked on some samples during forward: {}.\n".format(
 | 
						|
                        partially_supported_forward
 | 
						|
                    )
 | 
						|
                )
 | 
						|
            if len(partially_supported_backward) > 0:
 | 
						|
                msg = (
 | 
						|
                    msg
 | 
						|
                    + "The following dtypes only worked on some samples during backward: {}.\n".format(
 | 
						|
                        partially_supported_backward
 | 
						|
                    )
 | 
						|
                )
 | 
						|
            print(msg)
 | 
						|
 | 
						|
        if (
 | 
						|
            len(supported_but_unclaimed_forward)
 | 
						|
            + len(claimed_but_unsupported_forward)
 | 
						|
            + len(supported_but_unclaimed_backward)
 | 
						|
            + len(claimed_but_unsupported_backward)
 | 
						|
        ) == 0:
 | 
						|
            return
 | 
						|
 | 
						|
        # Reference operators often support additional dtypes, and that's OK
 | 
						|
        if op in python_ref_db:
 | 
						|
            if (
 | 
						|
                len(claimed_but_unsupported_forward)
 | 
						|
                + len(claimed_but_unsupported_backward)
 | 
						|
            ) == 0:
 | 
						|
                return
 | 
						|
 | 
						|
        # Generates error msg
 | 
						|
        msg = f"The supported dtypes for {op.name} on device type {device_type} are incorrect!\n"
 | 
						|
        if len(supported_but_unclaimed_forward) > 0:
 | 
						|
            msg = (
 | 
						|
                msg
 | 
						|
                + "The following dtypes worked in forward but are not listed by the OpInfo: {}.\n".format(
 | 
						|
                    supported_but_unclaimed_forward
 | 
						|
                )
 | 
						|
            )
 | 
						|
        if len(supported_but_unclaimed_backward) > 0:
 | 
						|
            msg = (
 | 
						|
                msg
 | 
						|
                + "The following dtypes worked in backward but are not listed by the OpInfo: {}.\n".format(
 | 
						|
                    supported_but_unclaimed_backward
 | 
						|
                )
 | 
						|
            )
 | 
						|
        if len(claimed_but_unsupported_forward) > 0:
 | 
						|
            msg = (
 | 
						|
                msg
 | 
						|
                + "The following dtypes did not work in forward but are listed by the OpInfo: {}.\n".format(
 | 
						|
                    claimed_but_unsupported_forward
 | 
						|
                )
 | 
						|
            )
 | 
						|
        if len(claimed_but_unsupported_backward) > 0:
 | 
						|
            msg = (
 | 
						|
                msg
 | 
						|
                + "The following dtypes did not work in backward but are listed by the OpInfo: {}.\n".format(
 | 
						|
                    claimed_but_unsupported_backward
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
        all_claimed_but_unsupported = set.union(claimed_but_unsupported_backward, claimed_but_unsupported_forward)
 | 
						|
        if all_claimed_but_unsupported:
 | 
						|
            msg += "Unexpected failures raised the following errors:\n"
 | 
						|
            for dtype in all_claimed_but_unsupported:
 | 
						|
                msg += f"{dtype} - {dtype_error[dtype]}\n"
 | 
						|
 | 
						|
        self.fail(msg)
 | 
						|
 | 
						|
 | 
						|
class TestCompositeCompliance(TestCase):
 | 
						|
    # Checks if the operator (if it is composite) is written to support most
 | 
						|
    # backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance"
 | 
						|
    # in aten/src/ATen/native/README.md for more details
 | 
						|
    @unittest.skipIf(
 | 
						|
        IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
 | 
						|
    )
 | 
						|
    @ops(op_db, allowed_dtypes=(torch.float,))
 | 
						|
    def test_operator(self, device, dtype, op):
 | 
						|
        samples = op.sample_inputs(device, dtype, requires_grad=False)
 | 
						|
 | 
						|
        for sample in samples:
 | 
						|
            args = [sample.input] + list(sample.args)
 | 
						|
            kwargs = sample.kwargs
 | 
						|
            composite_compliance.check_with_mode(op, args, kwargs, self.assertEqual)
 | 
						|
            composite_compliance.check_all_permutations(op, args, kwargs, self.assertEqual)
 | 
						|
 | 
						|
    @unittest.skipIf(
 | 
						|
        IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
 | 
						|
    )
 | 
						|
    @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
 | 
						|
    def test_backward(self, device, dtype, op):
 | 
						|
        samples = op.sample_inputs(device, dtype, requires_grad=True)
 | 
						|
 | 
						|
        for sample in samples:
 | 
						|
            args = [sample.input] + list(sample.args)
 | 
						|
            kwargs = sample.kwargs
 | 
						|
            # We pass assertEqual so that decorators like `toleranceOverride`
 | 
						|
            # actually work (otherwise they silently do nothing!)
 | 
						|
            composite_compliance.check_backward_formula(
 | 
						|
                op.get_op(), args, kwargs,
 | 
						|
                sample.output_process_fn_grad,
 | 
						|
                op.gradcheck_wrapper, self.assertEqual)
 | 
						|
 | 
						|
    @unittest.skipIf(
 | 
						|
        IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
 | 
						|
    )
 | 
						|
    @ops(op_db, allowed_dtypes=(torch.float,))
 | 
						|
    def test_forward_ad(self, device, dtype, op):
 | 
						|
        if torch.float not in op.supported_backward_dtypes(device):
 | 
						|
            raise unittest.SkipTest("Does not support autograd")
 | 
						|
 | 
						|
        if not op.supports_forward_ad:
 | 
						|
            raise unittest.SkipTest("Does not support forward_ad")
 | 
						|
 | 
						|
        samples = op.sample_inputs(device, dtype, requires_grad=True)
 | 
						|
 | 
						|
        for sample in samples:
 | 
						|
            args = [sample.input] + list(sample.args)
 | 
						|
            kwargs = sample.kwargs
 | 
						|
            # We pass assertEqual so that decorators like `toleranceOverride`
 | 
						|
            # actually work (otherwise they silently do nothing!)
 | 
						|
            composite_compliance.check_forward_ad_formula(
 | 
						|
                op.get_op(), args, kwargs, op.gradcheck_wrapper, self.assertEqual)
 | 
						|
 | 
						|
 | 
						|
class TestMathBits(TestCase):
 | 
						|
    # Tests that
 | 
						|
    # 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors
 | 
						|
    # produces the same value
 | 
						|
    # 2. The gradients are same in both cases mentioned in (1)
 | 
						|
    # 3. If the operator's inplace variant is supported, tests that the inplace operation
 | 
						|
    #    produces the correct value when called on a conjugate/negative view tensor and that the output
 | 
						|
    #    has its conj/neg bit set to true
 | 
						|
    # This test only runs for C -> R and C -> C functions
 | 
						|
    # TODO: add tests for `R->C` functions
 | 
						|
    # Note: This test runs for functions that take both tensors and tensorlists as input.
 | 
						|
    def _test_math_view(
 | 
						|
        self,
 | 
						|
        device,
 | 
						|
        dtype,
 | 
						|
        op,
 | 
						|
        samples,
 | 
						|
        math_op_physical,
 | 
						|
        math_op_view,
 | 
						|
        is_bit_set,
 | 
						|
        out_type,
 | 
						|
    ):
 | 
						|
        inplace_variant = op.inplace_variant
 | 
						|
 | 
						|
        # helper function to clone and conjugate/negate the input if its a tensor
 | 
						|
        # else clone the sequence and conjugate/negate the first element in the sequence
 | 
						|
        # If a requires_grad argument is provided the tensor being conjugated/negated will
 | 
						|
        # have its requires_grad set to that value.
 | 
						|
        def clone_and_perform_view(input, **kwargs):
 | 
						|
            if isinstance(input, torch.Tensor):
 | 
						|
                requires_grad = kwargs.get("requires_grad", input.requires_grad)
 | 
						|
                with torch.no_grad():
 | 
						|
                    # Ensure view represents the original sample input
 | 
						|
                    input = math_op_physical(input)
 | 
						|
                # Note: .conj() is not called under no_grad mode since it's not allowed to modify a
 | 
						|
                # view created in no_grad mode. Here it's ok to do so, so as a workaround we call conj
 | 
						|
                # before resetting the requires_grad field for input
 | 
						|
                input = math_op_view(input)
 | 
						|
                assert input.is_leaf
 | 
						|
                return input.requires_grad_(requires_grad)
 | 
						|
 | 
						|
            if isinstance(input, Sequence):
 | 
						|
                out = list(map(clone_input_helper, input))
 | 
						|
                out[0] = clone_and_perform_view(out[0])
 | 
						|
                return tuple(out)
 | 
						|
 | 
						|
        for sample in samples:
 | 
						|
            tensor = (
 | 
						|
                sample.input
 | 
						|
                if isinstance(sample.input, torch.Tensor)
 | 
						|
                else sample.input[0]
 | 
						|
            )
 | 
						|
            cloned1 = clone_and_perform_view(sample.input)
 | 
						|
 | 
						|
            # Computes function forward value with a physically conjugated/negated tensor and
 | 
						|
            # a conj/neg view tensor and verifies that the output in both case are equal.
 | 
						|
            expected_forward = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
            forward_with_mathview = op(cloned1, *sample.args, **sample.kwargs)
 | 
						|
            self.assertEqual(expected_forward, forward_with_mathview)
 | 
						|
 | 
						|
            # If the op has an inplace variant, and the input doesn't require broadcasting
 | 
						|
            # and has the same dtype as output, verify that the inplace operation on a conjugated/negated
 | 
						|
            # input produces correct output, and the output tensor has the conj/neg bit set to True
 | 
						|
            if inplace_variant is not None and not sample.broadcasts_input:
 | 
						|
                cloned2 = clone_and_perform_view(tensor, requires_grad=False)
 | 
						|
                if (
 | 
						|
                    isinstance(expected_forward, torch.Tensor)
 | 
						|
                    and expected_forward.dtype is tensor.dtype
 | 
						|
                ):
 | 
						|
                    inplace_forward = inplace_variant(
 | 
						|
                        cloned2, *sample.args, **sample.kwargs
 | 
						|
                    )
 | 
						|
                    self.assertTrue(is_bit_set(inplace_forward))
 | 
						|
                    self.assertEqual(inplace_forward, expected_forward)
 | 
						|
 | 
						|
            # TODO: backward consistency only supported for single tensor outputs
 | 
						|
            # TODO: backward consistency only checked on sample.input, not all
 | 
						|
            #   tensor inputs
 | 
						|
            # TODO: update to handle checking grads of all tensor inputs as
 | 
						|
            #   derived from each tensor output
 | 
						|
            if (
 | 
						|
                isinstance(expected_forward, torch.Tensor)
 | 
						|
                and expected_forward.requires_grad
 | 
						|
            ):
 | 
						|
                output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x)
 | 
						|
                expected_forward = output_process_fn_grad(expected_forward)
 | 
						|
                forward_with_mathview = output_process_fn_grad(forward_with_mathview)
 | 
						|
 | 
						|
                tensor = (
 | 
						|
                    sample.input
 | 
						|
                    if isinstance(sample.input, torch.Tensor)
 | 
						|
                    else sample.input[0]
 | 
						|
                )
 | 
						|
                expected_forward.sum().abs().backward(retain_graph=True)
 | 
						|
                forward_with_mathview.sum().abs().backward(retain_graph=True)
 | 
						|
                if tensor.grad is not None:
 | 
						|
                    cloned1_tensor = (
 | 
						|
                        cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
 | 
						|
                    )
 | 
						|
                    self.assertEqual(tensor.grad, cloned1_tensor.grad)
 | 
						|
 | 
						|
                    tensor.grad, cloned1_tensor.grad = None, None
 | 
						|
 | 
						|
                    # a repeat of the above test if output is not complex valued
 | 
						|
                    if out_type(expected_forward):
 | 
						|
                        grad = torch.randn_like(expected_forward)
 | 
						|
                        expected_forward.backward(grad)
 | 
						|
                        forward_with_mathview.backward(
 | 
						|
                            math_op_view(math_op_physical(grad))
 | 
						|
                        )
 | 
						|
 | 
						|
                        self.assertEqual(tensor.grad, cloned1_tensor.grad)
 | 
						|
 | 
						|
    @ops(ops_and_refs, allowed_dtypes=(torch.cfloat,))
 | 
						|
    def test_conj_view(self, device, dtype, op):
 | 
						|
        if not op.test_conjugated_samples:
 | 
						|
            self.skipTest("Operation doesn't support conjugated inputs.")
 | 
						|
        math_op_physical = torch.conj_physical
 | 
						|
        math_op_view = torch.conj
 | 
						|
        _requires_grad = torch.cfloat in op.supported_backward_dtypes(
 | 
						|
            torch.device(device).type
 | 
						|
        )
 | 
						|
        is_bit_set = torch.is_conj
 | 
						|
        samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
 | 
						|
        self._test_math_view(
 | 
						|
            device,
 | 
						|
            dtype,
 | 
						|
            op,
 | 
						|
            samples,
 | 
						|
            math_op_physical,
 | 
						|
            math_op_view,
 | 
						|
            is_bit_set,
 | 
						|
            torch.is_complex,
 | 
						|
        )
 | 
						|
 | 
						|
    @ops(ops_and_refs, allowed_dtypes=(torch.double,))
 | 
						|
    def test_neg_view(self, device, dtype, op):
 | 
						|
        if not op.test_neg_view:
 | 
						|
            self.skipTest("Operation not tested with tensors with negative bit.")
 | 
						|
        math_op_physical = torch.neg
 | 
						|
        math_op_view = torch._neg_view
 | 
						|
        is_bit_set = torch.is_neg
 | 
						|
        samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
 | 
						|
        self._test_math_view(
 | 
						|
            device,
 | 
						|
            dtype,
 | 
						|
            op,
 | 
						|
            samples,
 | 
						|
            math_op_physical,
 | 
						|
            math_op_view,
 | 
						|
            is_bit_set,
 | 
						|
            lambda x: True,
 | 
						|
        )
 | 
						|
 | 
						|
    @ops(ops_and_refs, allowed_dtypes=(torch.cdouble,))
 | 
						|
    def test_neg_conj_view(self, device, dtype, op):
 | 
						|
        if not op.test_neg_view:
 | 
						|
            self.skipTest("Operation not tested with tensors with negative bit.")
 | 
						|
        if not op.test_conjugated_samples:
 | 
						|
            self.skipTest("Operation doesn't support conjugated inputs.")
 | 
						|
 | 
						|
        def math_op_physical(x):
 | 
						|
            return -x.conj_physical()
 | 
						|
 | 
						|
        def math_op_view(x):
 | 
						|
            return torch._neg_view(x).conj()
 | 
						|
 | 
						|
        def is_bit_set(x):
 | 
						|
            return torch.is_neg(x) and torch.is_conj(x)
 | 
						|
 | 
						|
        _requires_grad = dtype in op.supported_backward_dtypes(
 | 
						|
            torch.device(device).type
 | 
						|
        )
 | 
						|
        samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
 | 
						|
        # Only test one sample
 | 
						|
        samples = itertools.islice(samples, 1)
 | 
						|
        self._test_math_view(
 | 
						|
            device,
 | 
						|
            dtype,
 | 
						|
            op,
 | 
						|
            samples,
 | 
						|
            math_op_physical,
 | 
						|
            math_op_view,
 | 
						|
            is_bit_set,
 | 
						|
            torch.is_complex,
 | 
						|
        )
 | 
						|
 | 
						|
# input strides and size may have been altered due to the result of an inplace op
 | 
						|
def check_inplace_view(func, input, rs, input_size, input_strides):
 | 
						|
    if func is None:
 | 
						|
        return
 | 
						|
    # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm(_legit).out
 | 
						|
    # which mutate not necessarily the first input.
 | 
						|
    if isinstance(rs, torch.Tensor) and rs is input:
 | 
						|
        unequal_size = rs.size() != input_size
 | 
						|
        unequal_strides = rs.stride() != input_strides
 | 
						|
        # resize_ should probably have inplace_view tag. Not adding the tag since it
 | 
						|
        # breaks some codegen logic
 | 
						|
        if (unequal_size or unequal_strides):
 | 
						|
            if isinstance(func, torch._ops.OpOverloadPacket):
 | 
						|
                func = func.default
 | 
						|
            # Reference: https://github.com/pytorch/pytorch/issues/78759
 | 
						|
            if func is not torch.ops.aten.resize_.default:
 | 
						|
                # TODO: use self.assertIn when we have separate tests for each tag
 | 
						|
                assert torch.Tag.inplace_view in func.tags
 | 
						|
 | 
						|
# A mode that when enabled runs correctness checks to ensure
 | 
						|
# that operators have expected tags based on their input and
 | 
						|
# ouput tensor properties
 | 
						|
class TestTagsMode(TorchDispatchMode):
 | 
						|
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
 | 
						|
        if isinstance(args[0], torch.Tensor):
 | 
						|
            old_size = args[0].size()
 | 
						|
            old_stride = args[0].stride()
 | 
						|
            rs = func(*args, **kwargs)
 | 
						|
            check_inplace_view(func, args[0], rs, old_size, old_stride)
 | 
						|
        else:
 | 
						|
            rs = func(*args, **kwargs)
 | 
						|
        return rs
 | 
						|
 | 
						|
# Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags`
 | 
						|
class TestTags(TestCase):
 | 
						|
    @onlyCPU
 | 
						|
    @ops(ops_and_refs, dtypes=OpDTypes.any_one)
 | 
						|
    def test_tags(self, device, dtype, op):
 | 
						|
        samples = op.sample_inputs(device, dtype, requires_grad=False)
 | 
						|
        for sample in samples:
 | 
						|
            # TODO: Test tags for ops that return a list of tensors
 | 
						|
            input = sample.input
 | 
						|
            if isinstance(input, torch.Tensor):
 | 
						|
                old_size = input.size()
 | 
						|
                old_stride = input.stride()
 | 
						|
                with TestTagsMode():
 | 
						|
                    rs = op(input, *sample.args, **sample.kwargs)
 | 
						|
                # TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
 | 
						|
                aten_name = op.aten_name if op.aten_name is not None else op.name
 | 
						|
                opoverloadpacket = getattr(torch.ops.aten, aten_name, None)
 | 
						|
                check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
 | 
						|
 | 
						|
 | 
						|
class TestRefsOpsInfo(TestCase):
 | 
						|
 | 
						|
    import_paths = ["_refs", "_refs.special", "_refs.nn.functional", "_refs.fft", "_refs._conversions"]
 | 
						|
    module_alls = [(path, import_module(f"torch.{path}").__all__) for path in import_paths]
 | 
						|
    ref_ops_names = tuple(itertools.chain.from_iterable(
 | 
						|
        [f"{path}.{op}" for op in module_all] for path, module_all in module_alls))
 | 
						|
    ref_db_names = {ref_op.name for ref_op in python_ref_db}
 | 
						|
 | 
						|
    # TODO: References that do not have an entry in python_ref_db
 | 
						|
    skip_ref_ops = {
 | 
						|
        '_refs.alias',
 | 
						|
        '_refs.bitwise_right_shift',
 | 
						|
        '_refs.copy_to',
 | 
						|
        '_refs.empty_permuted',
 | 
						|
        '_refs.empty_strided',
 | 
						|
        '_refs.equal',
 | 
						|
        '_refs.full',
 | 
						|
        '_refs.full_like',
 | 
						|
        '_refs.is_complex',
 | 
						|
        '_refs.to',
 | 
						|
        '_refs.mvlgamma',
 | 
						|
        '_refs.ones',
 | 
						|
        '_refs.ones_like',
 | 
						|
        '_refs.special.expit',
 | 
						|
        '_refs.std_var',
 | 
						|
        '_refs.swap_axes',
 | 
						|
        '_refs.uniform',
 | 
						|
        '_refs.scalar_tensor',
 | 
						|
        '_refs.trunc_divide',
 | 
						|
        '_refs.zero',
 | 
						|
        '_refs.zeros',
 | 
						|
        '_refs.zeros_like',
 | 
						|
        '_refs.rfloordiv',
 | 
						|
        '_refs.rtruediv',
 | 
						|
        '_refs.rpow',
 | 
						|
        # These should be tested with their out-of-place counterparts
 | 
						|
        '_refs.index_add_',
 | 
						|
        '_refs.index_copy_',
 | 
						|
        '_refs.index_fill_',
 | 
						|
        '_refs.native_group_norm',
 | 
						|
    }
 | 
						|
 | 
						|
    not_in_decomp_table = {
 | 
						|
        # duplicated in _decomp and _refs
 | 
						|
        '_refs.nn.functional.group_norm',
 | 
						|
        '_refs.nn.functional.mse_loss',
 | 
						|
        '_refs.floor_divide',
 | 
						|
        '_refs.rsub',
 | 
						|
        # duplicated as refs do not have decent support for advanced indexing
 | 
						|
        '_refs.index_copy',
 | 
						|
        '_refs.index_copy_',
 | 
						|
        '_refs.index_add',
 | 
						|
        '_refs.index_add_',
 | 
						|
        # these are not aten ops?
 | 
						|
        '_refs._conversions.bfloat16',
 | 
						|
        '_refs._conversions.bool',
 | 
						|
        '_refs._conversions.byte',
 | 
						|
        '_refs._conversions.char',
 | 
						|
        '_refs._conversions.double',
 | 
						|
        '_refs._conversions.float',
 | 
						|
        '_refs._conversions.half',
 | 
						|
        '_refs._conversions.int',
 | 
						|
        '_refs._conversions.long',
 | 
						|
        '_refs._conversions.short',
 | 
						|
        '_refs._conversions.chalf',
 | 
						|
        '_refs._conversions.cfloat',
 | 
						|
        '_refs._conversions.cdouble',
 | 
						|
        '_refs.broadcast_shapes',
 | 
						|
        '_refs.broadcast_tensors',
 | 
						|
        '_refs.mvlgamma',
 | 
						|
        '_refs.nn.functional.layer_norm',
 | 
						|
        '_refs.nn.functional.tanhshrink',
 | 
						|
        '_refs.nn.functional.triplet_margin_loss',
 | 
						|
        '_refs.rfloordiv',
 | 
						|
        '_refs.rtruediv',
 | 
						|
        '_refs.rpow',
 | 
						|
        # CompositeImplicitAutograd
 | 
						|
        '_refs.allclose',
 | 
						|
        '_refs.atleast_1d',
 | 
						|
        '_refs.atleast_2d',
 | 
						|
        '_refs.atleast_3d',
 | 
						|
        '_refs.broadcast_to',
 | 
						|
        '_refs.chunk',
 | 
						|
        '_refs.column_stack',
 | 
						|
        '_refs.contiguous',
 | 
						|
        '_refs.dsplit',
 | 
						|
        '_refs.dstack',
 | 
						|
        '_refs.fill',
 | 
						|
        '_refs.fill_',
 | 
						|
        '_refs.flatten',
 | 
						|
        '_refs.fliplr',
 | 
						|
        '_refs.flipud',
 | 
						|
        '_refs.float_power',
 | 
						|
        '_refs.hsplit',
 | 
						|
        '_refs.hstack',
 | 
						|
        '_refs.isclose',
 | 
						|
        '_refs.isfinite',
 | 
						|
        '_refs.isreal',
 | 
						|
        '_refs.istft',
 | 
						|
        '_refs.log_softmax',
 | 
						|
        '_refs.movedim',
 | 
						|
        '_refs.narrow',
 | 
						|
        '_refs.nn.functional.dropout',
 | 
						|
        '_refs.nn.functional.l1_loss',
 | 
						|
        '_refs.nn.functional.smooth_l1_loss',
 | 
						|
        '_refs.nn.functional.log_softmax',
 | 
						|
        '_refs.nn.functional.poisson_nll_loss',
 | 
						|
        '_refs.nn.functional.softmax',
 | 
						|
        '_refs.nn.functional.softmin',
 | 
						|
        '_refs.positive',
 | 
						|
        '_refs.ravel',
 | 
						|
        '_refs.reshape',
 | 
						|
        '_refs.softmax',
 | 
						|
        '_refs.special.expit',
 | 
						|
        '_refs.special.log_softmax',
 | 
						|
        '_refs.special.softmax',
 | 
						|
        '_refs.square',
 | 
						|
        '_refs.stft',
 | 
						|
        '_refs.T',
 | 
						|
        '_refs.take_along_dim',
 | 
						|
        '_refs.tensor_split',
 | 
						|
        '_refs.to',
 | 
						|
        '_refs.true_divide',
 | 
						|
        '_refs.trunc_divide',
 | 
						|
        '_refs.vsplit',
 | 
						|
        '_refs.vstack',
 | 
						|
        '_refs.linalg.matrix_norm',
 | 
						|
        '_refs.linalg.norm',
 | 
						|
        '_refs.linalg.svd',
 | 
						|
        '_refs.linalg.svdvals',
 | 
						|
        '_refs.unflatten',
 | 
						|
        '_refs.sum_to_size',
 | 
						|
        # ref implementation missing kwargs
 | 
						|
        '_refs.full_like',  # missing "layout"
 | 
						|
        '_refs.round',  # missing "decimals"
 | 
						|
        '_refs.scalar_tensor',  # missing "layout"
 | 
						|
        # other
 | 
						|
        '_refs.empty',  # intentional; direct empty is faster and has less guards
 | 
						|
        '_refs.empty_permuted',  # intentional; direct empty is faster and has less guards
 | 
						|
        '_refs.expand_as',
 | 
						|
        '_refs.as_strided',  # _prims._as_strided_meta: "reduce() of empty sequence with no initial value"
 | 
						|
        '_refs.copy_to',  # torch._C._jit_get_operation: No such operator aten::copy_to
 | 
						|
        '_refs.equal',  # 'bool' object has no attribute 'dtype'
 | 
						|
        '_refs.conj',  # Calls _prims.conj
 | 
						|
        '_refs.real',
 | 
						|
        '_refs.imag',
 | 
						|
        '_refs.reshape_as',
 | 
						|
        '_refs.view_as',
 | 
						|
        '_refs.view_as_complex'  # TorchInductor does not support complex at the moment.
 | 
						|
    }
 | 
						|
 | 
						|
    @parametrize("op", ref_ops_names)
 | 
						|
    def test_refs_are_in_python_ref_db(self, op):
 | 
						|
        inplace = op[-1] == "_"
 | 
						|
        if op in self.skip_ref_ops:
 | 
						|
            raise unittest.SkipTest(f"{op} does not have an entry in python_ref_db")
 | 
						|
        elif inplace:
 | 
						|
            self.assertNotIn(op, self.ref_db_names, msg=f"{op} is an in-place operation and should not have an OpInfo")
 | 
						|
        else:
 | 
						|
            # Intentionally don't use assertIn to avoid printing the
 | 
						|
            # (very large) container
 | 
						|
            self.assertTrue(op in self.ref_db_names, msg=f"{op} not in ref_db_names")
 | 
						|
 | 
						|
    @parametrize("op", ref_ops_names)
 | 
						|
    def test_refs_are_in_decomp_table(self, op):
 | 
						|
        path = op.split('.')
 | 
						|
        module_path = '.'.join(path[:-1])
 | 
						|
        op_name = path[-1]
 | 
						|
        op_impl = getattr(import_module(f"torch.{module_path}"), op_name)
 | 
						|
 | 
						|
        if op in self.not_in_decomp_table:
 | 
						|
            self.assertNotIn(op_impl, torch._decomp.decomposition_table.values(),
 | 
						|
                             f"Unexpectedly found {op} in torch._decomp.decomposition_table.values()")
 | 
						|
        else:
 | 
						|
            self.assertIn(op_impl, torch._decomp.decomposition_table.values(),
 | 
						|
                          f"Did not find {op} in torch._decomp.decomposition_table.values()")
 | 
						|
 | 
						|
 | 
						|
fake_skips = (
 | 
						|
    "aminmax",  # failing input
 | 
						|
    "cov",  # aweights cannot be negtaive
 | 
						|
    "istft",  # window overlap add min: 0
 | 
						|
    "linalg.eigvals",  # The tensor has a non-zero number of elements, but its data is not allocated yet
 | 
						|
    "linalg.eigvalsh",  # aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend
 | 
						|
    "linalg.matrix_power",  # Could not run 'aten::eye.m_out' with arguments from the 'Meta' backend
 | 
						|
    # "linalg.pinv",  # Could not run 'aten::pinv.out' with arguments from the 'Meta' backen
 | 
						|
    "linalg.matrix_rank.hermitian",  # Could not run 'aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend
 | 
						|
    "linalg.pinv.hermitian",  # tensor.mH is only supported on matrices or batches of matrices. Got 1-D tensor
 | 
						|
    "linalg.solve",  # Could not run 'aten::linalg_solve' with arguments from the 'Meta' backend
 | 
						|
    "linalg.tensorsolve",  # Could not run 'aten::linalg_solve' with arguments from the 'Meta'
 | 
						|
    "lu_solve",  # MALLOC ERROR: debug
 | 
						|
    "multinomial",  # Could not run 'aten::multinomial' with arguments from the 'Meta' backend
 | 
						|
    "mvlgamma.mvlgamma_p_1",  # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
 | 
						|
    "mvlgamma.mvlgamma_p_3",  # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
 | 
						|
    "mvlgamma.mvlgamma_p_5",  # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend
 | 
						|
    "nanmean",  # logical_not() got an unexpected keyword argument 'out'
 | 
						|
    "quantile",  # quantile() q values must be in the range [0, 1]
 | 
						|
    "nanquantile",  # quantile() q values must be in the range [0, 1]
 | 
						|
    "nn.functional.ctc_loss",  # The tensor has a non-zero number of elements, but its data is not allocated yet
 | 
						|
    "nn.functional.embedding_bag",  # sometimes errors
 | 
						|
    "nn.functional.nll_loss",  # sometimes errors
 | 
						|
    "nn.functional.max_pool1d",  # The tensor has a non-zero number of elements
 | 
						|
    "to_sparse",  # Could not run 'aten::_to_sparse' with arguments from the 'Meta' backend
 | 
						|
    "tensor_split",  # The tensor has a non-zero number of elements, but its data is not allocated yet
 | 
						|
    "repeat_interleave",  # cannot repeat_interleave a meta tensor without output_size
 | 
						|
    "sparse.sampled.addmm",  # sparsity not supported
 | 
						|
    # Can not infer total number of classes from meta. no way at present to throw DynamicOutputShapeException
 | 
						|
    "nn.functional.one_hot",
 | 
						|
    "narrow",  # Fails only for one overload with DataDependentOutputException (hence skip).
 | 
						|
)
 | 
						|
 | 
						|
fake_autocast_device_skips = defaultdict(dict)
 | 
						|
 | 
						|
# TODO: investigate/fix
 | 
						|
fake_autocast_device_skips["cpu"] = {"linalg.pinv"}
 | 
						|
 | 
						|
 | 
						|
dynamic_output_op_tests = (
 | 
						|
    "argwhere",
 | 
						|
    "bincount",
 | 
						|
    "combinations",
 | 
						|
    "linalg.lstsq",
 | 
						|
    "masked_select",
 | 
						|
    "nonzero",
 | 
						|
    "unique_consecutive",
 | 
						|
    "unique",
 | 
						|
    "linalg.lstsq.grad_oriented",
 | 
						|
)
 | 
						|
 | 
						|
# some inputs invoke dynamic output shape operators, some do not
 | 
						|
sometimes_dynamic_output_op_test = (
 | 
						|
    "__getitem__",
 | 
						|
    "index_select",
 | 
						|
)
 | 
						|
 | 
						|
data_dependent_op_tests = (
 | 
						|
    "equal",
 | 
						|
    "corrcoef",
 | 
						|
    "nn.functional.gaussian_nll_loss",
 | 
						|
    "allclose",
 | 
						|
)
 | 
						|
 | 
						|
aliasing_failures = (
 | 
						|
    "histogramdd",
 | 
						|
)
 | 
						|
 | 
						|
fake_backward_skips = {
 | 
						|
    "linalg.cond",
 | 
						|
    "linalg.matrix_norm",
 | 
						|
    "linalg.norm",
 | 
						|
    "linalg.svd",
 | 
						|
    "linalg.svdvals",
 | 
						|
    "pca_lowrank",
 | 
						|
    "roll",
 | 
						|
    "svd_lowrank",
 | 
						|
    "sgn",
 | 
						|
}
 | 
						|
 | 
						|
fake_backward_xfails = {skip(s) for s in fake_backward_skips} | {
 | 
						|
    xfail("fft.ihfftn"),  # Mismatch in aten._conj_physical.default
 | 
						|
    xfail("fft.ihfft2"),  # Mismatch in aten._conj_physical.default
 | 
						|
    skip('nn.functional.ctc_loss'),
 | 
						|
}
 | 
						|
 | 
						|
fake_autocast_backward_xfails = {
 | 
						|
    skip("nn.functional.binary_cross_entropy"),
 | 
						|
    skip("sparse.sampled_addmm"),
 | 
						|
    skip("linalg.pinv"),
 | 
						|
    skip("linalg.pinv", "hermitian"),
 | 
						|
    skip("linalg.pinv", "singular"),
 | 
						|
    skip('pinverse'),
 | 
						|
}
 | 
						|
 | 
						|
class TestFakeTensor(TestCase):
 | 
						|
    def _test_fake_helper(self, device, dtype, op, context):
 | 
						|
        name = op.name
 | 
						|
        if op.variant_test_name:
 | 
						|
            name += "." + op.variant_test_name
 | 
						|
        if name in fake_skips or "sparse" in name or "jiterator" in name:
 | 
						|
            self.skipTest("Skip failing test")
 | 
						|
 | 
						|
        samples = op.sample_inputs(device, dtype, requires_grad=False)
 | 
						|
        for sample in samples:
 | 
						|
            try:
 | 
						|
                mode = FakeTensorMode()
 | 
						|
 | 
						|
                def map_to_fake(e):
 | 
						|
                    if isinstance(e, torch.Tensor):
 | 
						|
                        return mode.from_tensor(e)
 | 
						|
                    else:
 | 
						|
                        return e
 | 
						|
 | 
						|
                input = tree_map(map_to_fake, sample.input)
 | 
						|
                args = tree_map(map_to_fake, sample.args)
 | 
						|
                kwargs = tree_map(map_to_fake, sample.kwargs)
 | 
						|
 | 
						|
                try:
 | 
						|
                    with context():
 | 
						|
                        res = op(sample.input, *sample.args, **sample.kwargs)
 | 
						|
                except Exception as e:
 | 
						|
                    continue
 | 
						|
 | 
						|
                with context():
 | 
						|
                    with mode:
 | 
						|
                        res_fake = op(input, *args, **kwargs)
 | 
						|
 | 
						|
 | 
						|
                for fake_out, real_out in zip(
 | 
						|
                    tree_flatten(res_fake)[0], tree_flatten(res)[0]
 | 
						|
                ):
 | 
						|
                    if not isinstance(fake_out, torch.Tensor):
 | 
						|
                        self.assertTrue(not isinstance(real_out, torch.Tensor))
 | 
						|
                        continue
 | 
						|
 | 
						|
                    self.assertTrue(isinstance(fake_out, FakeTensor))
 | 
						|
                    # if you see a shape exception here, you may need to add
 | 
						|
                    # a `dynamic_output_shape` tag to an operator
 | 
						|
 | 
						|
                    # prims/decomps must correctly model strides,
 | 
						|
                    # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
 | 
						|
                    prims.utils.compare_tensor_meta(fake_out, real_out, True)
 | 
						|
 | 
						|
                    if name not in aliasing_failures:
 | 
						|
                        fake_aliasing = outputs_alias_inputs((input, args, kwargs), res_fake)
 | 
						|
                        real_aliasing = outputs_alias_inputs((sample.input, sample, args, sample.kwargs), res)
 | 
						|
                        self.assertEqual(fake_aliasing, real_aliasing)
 | 
						|
 | 
						|
                self.assertTrue(name not in dynamic_output_op_tests and name not in data_dependent_op_tests)
 | 
						|
 | 
						|
            except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
 | 
						|
                pass
 | 
						|
            except torch._subclasses.fake_tensor.UnsupportedOperatorException:
 | 
						|
                pass
 | 
						|
            except torch._subclasses.fake_tensor.DynamicOutputShapeException:
 | 
						|
                self.assertTrue(name in dynamic_output_op_tests or name in sometimes_dynamic_output_op_test)
 | 
						|
            except torch._subclasses.fake_tensor.DataDependentOutputException:
 | 
						|
                self.assertTrue(name in data_dependent_op_tests)
 | 
						|
 | 
						|
    @ops(op_db, dtypes=OpDTypes.any_one)
 | 
						|
    def test_pointwise_ops(self, device, dtype, op):
 | 
						|
        name = op.name
 | 
						|
        if op.variant_test_name:
 | 
						|
            name += "." + op.variant_test_name
 | 
						|
        if name in fake_skips or "sparse" in name or "jiterator" in name:
 | 
						|
            self.skipTest("Skip failing test")
 | 
						|
 | 
						|
        test_self = self
 | 
						|
 | 
						|
        class TestPointwiseMode(TorchDispatchMode):
 | 
						|
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
 | 
						|
                kwargs = kwargs or {}
 | 
						|
 | 
						|
                out = func(*args, **kwargs)
 | 
						|
 | 
						|
                if torch.Tag.pointwise in func.tags:
 | 
						|
                    shapes = []
 | 
						|
                    for inp in tree_flatten((args, kwargs)):
 | 
						|
                        if isinstance(inp, torch.Tensor):
 | 
						|
                            shapes.append(inp.shape)
 | 
						|
 | 
						|
                    out_shape = torch._refs._broadcast_shapes(*shapes)
 | 
						|
 | 
						|
                    for out_elem in tree_flatten(out):
 | 
						|
                        if isinstance(out_elem, torch.Tensor):
 | 
						|
                            test_self.assertEqual(out_elem.shape, out_shape)
 | 
						|
 | 
						|
                return out
 | 
						|
 | 
						|
        samples = op.sample_inputs(device, dtype, requires_grad=False)
 | 
						|
        for sample in samples:
 | 
						|
            mode = FakeTensorMode()
 | 
						|
 | 
						|
            def map_to_fake(e):
 | 
						|
                if isinstance(e, torch.Tensor):
 | 
						|
                    return mode.from_tensor(e)
 | 
						|
                else:
 | 
						|
                    return e
 | 
						|
 | 
						|
            input = tree_map(map_to_fake, sample.input)
 | 
						|
            args = tree_map(map_to_fake, sample.args)
 | 
						|
            kwargs = tree_map(map_to_fake, sample.kwargs)
 | 
						|
 | 
						|
            try:
 | 
						|
                op(input, *args, **kwargs)
 | 
						|
            except Exception as e:
 | 
						|
                continue
 | 
						|
 | 
						|
            with TestPointwiseMode():
 | 
						|
                with mode:
 | 
						|
                    op(input, *args, **kwargs)
 | 
						|
 | 
						|
    @ops(op_db, dtypes=OpDTypes.any_one)
 | 
						|
    def test_fake(self, device, dtype, op):
 | 
						|
        self._test_fake_helper(device, dtype, op, contextlib.nullcontext)
 | 
						|
 | 
						|
    @ops(op_db, dtypes=OpDTypes.any_one)
 | 
						|
    def test_fake_autocast(self, device, dtype, op):
 | 
						|
        if op.name in fake_autocast_device_skips[device]:
 | 
						|
            self.skipTest("Skip failing test")
 | 
						|
        context = torch.cuda.amp.autocast if device == "cuda" else torch.cpu.amp.autocast
 | 
						|
        self._test_fake_helper(device, dtype, op, context)
 | 
						|
 | 
						|
    def _test_fake_crossref_helper(self, device, dtype, op, context):
 | 
						|
        samples = op.sample_inputs(device, dtype, requires_grad=True)
 | 
						|
 | 
						|
        for iter, sample in enumerate(samples):
 | 
						|
            args = [sample.input] + list(sample.args)
 | 
						|
            kwargs = sample.kwargs
 | 
						|
 | 
						|
            # skip these to speed up tests
 | 
						|
            common_skip_ops = (
 | 
						|
                aten.detach.default,
 | 
						|
                aten.empty_strided.default,
 | 
						|
                aten.copy_.default,
 | 
						|
                aten.is_same_size.default,
 | 
						|
            )
 | 
						|
 | 
						|
            # TODO: enable check_aliasing, batch norm fails
 | 
						|
            try:
 | 
						|
                with torch._subclasses.CrossRefFakeMode(ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True):
 | 
						|
                    with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled(False):
 | 
						|
                        composite_compliance.compute_expected_grads(
 | 
						|
                            op.get_op(), args, kwargs,
 | 
						|
                            sample.output_process_fn_grad,
 | 
						|
                            op.gradcheck_wrapper)
 | 
						|
            except torch._subclasses.fake_tensor.UnsupportedOperatorException:
 | 
						|
                pass
 | 
						|
 | 
						|
    @onlyCUDA
 | 
						|
    @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
 | 
						|
    @skipOps('TestFakeTensor', 'test_fake_crossref_backward_no_amp', fake_backward_xfails)
 | 
						|
    def test_fake_crossref_backward_no_amp(self, device, dtype, op):
 | 
						|
        self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext)
 | 
						|
 | 
						|
    @onlyCUDA
 | 
						|
    @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
 | 
						|
    @skipOps('TestFakeTensor', 'test_fake_crossref_backward_amp', fake_backward_xfails | fake_autocast_backward_xfails)
 | 
						|
    def test_fake_crossref_backward_amp(self, device, dtype, op):
 | 
						|
        self._test_fake_crossref_helper(device, dtype, op, torch.cuda.amp.autocast)
 | 
						|
 | 
						|
 | 
						|
instantiate_device_type_tests(TestCommon, globals())
 | 
						|
instantiate_device_type_tests(TestCompositeCompliance, globals())
 | 
						|
instantiate_device_type_tests(TestMathBits, globals())
 | 
						|
instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")
 | 
						|
instantiate_device_type_tests(TestFakeTensor, globals())
 | 
						|
instantiate_device_type_tests(TestTags, globals())
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    TestCase._default_dtype_check_enabled = True
 | 
						|
    run_tests()
 |