mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Addresses several of the challenges described in https://github.com/pytorch/pytorch/issues/49468. This PR builds on https://github.com/pytorch/pytorch/pull/50741 and https://github.com/pytorch/pytorch/issues/53105 to extend OpInfo out= testing. It covers the following cases for ops that produce a single tensor: - out= values don't affect computation - out= noncontiguous produces the correct output and preserves strides - out= with the wrong shape throws a warning - out= with an empty tensor throws no warning - out= with the wrong device throws an error - out= with a dtype the computation's result can't be "safely" cast to throws an error It works with operations that produce a single tensor and operations that produce an iterable of tensors (the latter is tested with operations like torch.svd). In addition to the new out= test, the OpInfos have been updated. "supports_tensor_out" is replaced with the more general and straightforward "supports_out" metadata, and many operations which previously had to skip out= testing with an explicit SkipInfo no longer need to. A couple redundant tests in test_unary_ufuncs.py have been removed, too. One other perk of these tests is that once all operations have OpInfos this will allow us to validate that we've universally deprecated incorrectly sized tensors passed to out=, and give us the option to actually disable the behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/53259 Reviewed By: mrshenli Differential Revision: D26894723 Pulled By: mruberry fbshipit-source-id: 2b536e9baf126f36386a35f2f806dd88c58690b3
609 lines
27 KiB
Python
609 lines
27 KiB
Python
from functools import partial, wraps, reduce
|
|
import warnings
|
|
|
|
import torch
|
|
|
|
from torch.testing import \
|
|
(FileCheck, floating_and_complex_types_and)
|
|
from torch.testing._internal.common_utils import \
|
|
(TestCase, run_tests, IS_SANDCASTLE, clone_input_helper, make_tensor)
|
|
from torch.testing._internal.common_methods_invocations import \
|
|
(op_db)
|
|
from torch.testing._internal.common_device_type import \
|
|
(instantiate_device_type_tests, ops, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes)
|
|
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
|
|
from torch.autograd.gradcheck import gradcheck, gradgradcheck
|
|
|
|
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
|
|
check_alias_annotation
|
|
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining
|
|
|
|
|
|
# Tests that apply to all operators
|
|
|
|
class TestOpInfo(TestCase):
|
|
exact_dtype = True
|
|
|
|
# Verifies that ops have their unsupported dtypes
|
|
# registered correctly by testing that each claimed unsupported dtype
|
|
# throws a runtime error
|
|
@skipCUDAIfRocm
|
|
@onlyOnCPUAndCUDA
|
|
@ops(op_db, dtypes=OpDTypes.unsupported)
|
|
def test_unsupported_dtypes(self, device, dtype, op):
|
|
# sample_inputs can have a function for generating the input that doesn't work for specified dtype
|
|
# https://github.com/pytorch/pytorch/issues/49024
|
|
with self.assertRaises(RuntimeError):
|
|
samples = op.sample_inputs(device, dtype)
|
|
if len(samples) == 0:
|
|
self.skipTest("Skipped! No sample inputs!")
|
|
|
|
# NOTE: only tests on first sample
|
|
sample = samples[0]
|
|
op(*sample.input, *sample.args, **sample.kwargs)
|
|
|
|
# Verifies that ops have their supported dtypes
|
|
# registered correctly by testing that each claimed supported dtype
|
|
# does NOT throw a runtime error
|
|
@onlyOnCPUAndCUDA
|
|
@ops(op_db, dtypes=OpDTypes.supported)
|
|
def test_supported_dtypes(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype)
|
|
if len(samples) == 0:
|
|
self.skipTest("Skipped! No sample inputs!")
|
|
|
|
# NOTE: only tests on first sample
|
|
sample = samples[0]
|
|
op(*sample.input, *sample.args, **sample.kwargs)
|
|
|
|
|
|
# gradcheck requires double precision
|
|
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
|
allowed_dtypes=[torch.double, torch.cdouble])
|
|
|
|
|
|
class TestGradients(TestCase):
|
|
exact_dtype = True
|
|
|
|
# Copies inputs to inplace operations to avoid inplace modifications
|
|
# to leaves requiring gradient
|
|
def _get_safe_inplace(self, inplace_variant):
|
|
@wraps(inplace_variant)
|
|
def _fn(t, *args, **kwargs):
|
|
return inplace_variant(t.clone(), *args, **kwargs)
|
|
|
|
return _fn
|
|
|
|
def _check_helper(self, device, dtype, op, variant, check):
|
|
if variant is None:
|
|
self.skipTest("Skipped! Variant not implemented.")
|
|
if not op.supports_dtype(dtype, torch.device(device).type):
|
|
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
for sample in samples:
|
|
if sample.output_process_fn_grad is not None:
|
|
out_fn = sample.output_process_fn_grad
|
|
|
|
def variant_out_fn(*args, **kwargs):
|
|
return out_fn(variant(*args, **kwargs))
|
|
else:
|
|
variant_out_fn = variant
|
|
|
|
def fn(*inputs):
|
|
output = variant_out_fn(*inputs, **sample.kwargs)
|
|
return op.output_func(output)
|
|
|
|
if check == 'gradcheck':
|
|
self.assertTrue(gradcheck(fn, (*sample.input,) + sample.args,
|
|
check_batched_grad=op.check_batched_grad,
|
|
check_grad_dtypes=True))
|
|
elif check == 'gradgradcheck':
|
|
self.assertTrue(gradgradcheck(fn, (*sample.input,) + sample.args,
|
|
gen_non_contig_grad_outputs=False,
|
|
check_batched_grad=op.check_batched_gradgrad,
|
|
check_grad_dtypes=True))
|
|
self.assertTrue(gradgradcheck(fn, (*sample.input,) + sample.args,
|
|
gen_non_contig_grad_outputs=True,
|
|
check_batched_grad=op.check_batched_gradgrad,
|
|
check_grad_dtypes=True))
|
|
else:
|
|
self.assertTrue(False, msg="Unknown check requested!")
|
|
|
|
def _grad_test_helper(self, device, dtype, op, variant):
|
|
return self._check_helper(device, dtype, op, variant, 'gradcheck')
|
|
|
|
def _gradgrad_test_helper(self, device, dtype, op, variant):
|
|
return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
|
|
|
|
def _skip_helper(self, op, dtype):
|
|
if not op.supports_autograd:
|
|
self.skipTest("Skipped! autograd not supported")
|
|
if not op.test_complex_grad and dtype.is_complex:
|
|
self.skipTest("Skipped! complex grad tests marked to skip.")
|
|
|
|
# Tests that gradients are computed correctly
|
|
@_gradcheck_ops(op_db)
|
|
def test_fn_grad(self, device, dtype, op):
|
|
self._skip_helper(op, dtype)
|
|
self._grad_test_helper(device, dtype, op, op.get_op())
|
|
|
|
# Method grad (and gradgrad, see below) tests are disabled since they're
|
|
# costly and redundant with function grad (and gradgad) tests
|
|
# @_gradcheck_ops(op_db)
|
|
# def test_method_grad(self, device, dtype, op):
|
|
# self._skip_helper(op, dtype)
|
|
# self._grad_test_helper(device, dtype, op, op.get_method())
|
|
|
|
@_gradcheck_ops(op_db)
|
|
def test_inplace_grad(self, device, dtype, op):
|
|
self._skip_helper(op, dtype)
|
|
if not op.test_inplace_grad:
|
|
self.skipTest("Skipped! Inplace gradcheck marked to skip.")
|
|
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
|
|
|
# Test that gradients of gradients are computed correctly
|
|
@_gradcheck_ops(op_db)
|
|
def test_fn_gradgrad(self, device, dtype, op):
|
|
self._skip_helper(op, dtype)
|
|
self._gradgrad_test_helper(device, dtype, op, op.get_op())
|
|
|
|
# Method gradgrad (and grad, see above) tests are disabled since they're
|
|
# costly and redundant with function gradgrad (and grad) tests
|
|
# @_gradcheck_ops(op_db)
|
|
# def test_method_gradgrad(self, device, dtype, op):
|
|
# self._skip_helper(op, dtype)
|
|
# self._gradgrad_test_helper(device, dtype, op, op.get_method())
|
|
|
|
@_gradcheck_ops(op_db)
|
|
def test_inplace_gradgrad(self, device, dtype, op):
|
|
self._skip_helper(op, dtype)
|
|
if not op.test_inplace_grad:
|
|
self.skipTest("Skipped! Inplace gradgradcheck marked to skip.")
|
|
self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
|
|
|
|
|
# Tests operators for consistency between JIT and eager, also checks
|
|
# correctness of JIT specific alias schemas and intended
|
|
# autodifferentiation behavior.
|
|
# Inherits from JitCommonTestCase instead of TestCase directly to share
|
|
# functionality with original test_jit.py method operator tests
|
|
class TestCommon(JitCommonTestCase):
|
|
exact_dtype = True
|
|
|
|
# Compares variant's backward
|
|
# NOTE: verifies it fails when the forward fails
|
|
def check_variant_backward(self, input, forward_result, expected_grad, expected_exception):
|
|
variant_exception_during_backwards = False
|
|
try:
|
|
forward_result.sum().backward()
|
|
variant_grad = input.grad
|
|
input.grad = None
|
|
except Exception as e:
|
|
if not expected_exception:
|
|
self.fail("Unexpected exception during backwards!")
|
|
variant_exception_during_backwards = True
|
|
|
|
if expected_exception != variant_exception_during_backwards:
|
|
self.fail("Unexpected success during backwards!")
|
|
|
|
if not expected_exception:
|
|
self.assertEqual(variant_grad, expected_grad)
|
|
|
|
# Tests that the forward and backward passes of operations produce the
|
|
# same values for the cross-product of op variants (method, inplace)
|
|
# against eager's gold standard op function variant
|
|
@ops(op_db)
|
|
def test_variant_consistency_eager(self, device, dtype, op):
|
|
test_backward = op.supports_autograd and (op.test_complex_grad or not dtype.is_complex)
|
|
samples = op.sample_inputs(device, dtype, requires_grad=test_backward)
|
|
if len(samples) == 0:
|
|
self.skipTest("Skipped! No sample inputs!")
|
|
|
|
for sample in samples:
|
|
# Acquires variants to test
|
|
method = op.get_method()
|
|
inplace = op.get_inplace()
|
|
inplace_ops = [inplace, ] # list of all inplace ops: inplace variant + alias inplace variants if exist
|
|
aliases = []
|
|
for a_op in op.aliases:
|
|
aliases.append(a_op.op)
|
|
aliases.append(a_op.method_variant)
|
|
aliases.append(a_op.inplace_variant)
|
|
inplace_ops.append(a_op.inplace_variant)
|
|
aliases = tuple(aliases)
|
|
|
|
inplace_ops = tuple(v for v in inplace_ops if v is not None)
|
|
variants = (v for v in (method, inplace) + aliases if v is not None)
|
|
# Computes expected forward
|
|
|
|
# below calls op's function variant
|
|
expected_forward = op(*sample.input, *sample.args, **sample.kwargs)
|
|
|
|
# Computes expected backward
|
|
# NOTE: backward may fail for some dtypes
|
|
exception_during_backwards = False
|
|
expected_grad = None
|
|
try:
|
|
expected_forward.sum().backward()
|
|
expected_grad = sample.input.grad
|
|
sample.input.grad = None
|
|
except Exception as e:
|
|
exception_during_backwards = True
|
|
|
|
# Test eager consistency
|
|
for variant in variants:
|
|
# Verifies that inplace operations that promote int->float fail
|
|
# on tensors with integer dtypes.
|
|
if (variant in inplace_ops and not torch.can_cast(expected_forward.dtype, dtype)):
|
|
try:
|
|
variant_forward = variant(*(clone_input_helper(input) for input in sample.input),
|
|
*sample.args,
|
|
**sample.kwargs)
|
|
except Exception as e:
|
|
continue
|
|
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
|
|
# Compares variant's forward
|
|
# Note: copy the tensor-type inputs when testing inplace operation
|
|
variant_forward = variant(*(clone_input_helper(input) if variant in inplace_ops else input
|
|
for input in sample.input),
|
|
*sample.args,
|
|
**sample.kwargs)
|
|
self.assertEqual(variant_forward, expected_forward)
|
|
|
|
# Compares variant's backward
|
|
if test_backward and (variant not in inplace_ops or op.test_inplace_grad):
|
|
self.check_variant_backward(sample.input, variant_forward,
|
|
expected_grad, exception_during_backwards)
|
|
|
|
# Tests that the forward and backward passes of operations produce the
|
|
# same values for the cross-product of op variants (function, method, inplace)
|
|
# and runtimes (eager, traced, scripted).
|
|
# TODO WARNING: inplace x {traced, scripted} not currently tested
|
|
@ops(op_db)
|
|
def test_variant_consistency_jit(self, device, dtype, op):
|
|
test_backward = op.supports_autograd and (
|
|
(dtype.is_complex and op.test_complex_grad) or
|
|
(dtype.is_floating_point and (not op.skip_bfloat16_grad or dtype != torch.bfloat16)))
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=test_backward)
|
|
if len(samples) == 0:
|
|
self.skipTest("Skipped! No sample inputs!")
|
|
|
|
for sample in samples:
|
|
|
|
# Acquires variants to test
|
|
func = op.get_op()
|
|
method = op.get_method()
|
|
inplace = op.get_inplace()
|
|
variants = {
|
|
'function': func, 'method': method,
|
|
# TODO: inplace tests currently fail
|
|
# 'inplace': inplace,
|
|
}
|
|
|
|
# Test traced and scripted consistency
|
|
for func_type, variant in variants.items():
|
|
if variant is None:
|
|
continue
|
|
|
|
# Create accessor for script function variant
|
|
name = op.name + '_' if func_type == 'inplace' else op.name
|
|
|
|
# run with disable_autodiff_subgraph_inlining(True) to test
|
|
# autodiff support. Context manager forces the graph to contain
|
|
# DifferentiableGraph nodes if they are present
|
|
with disable_autodiff_subgraph_inlining():
|
|
|
|
|
|
# Check scripted forward, grad, and grad grad
|
|
script_fn = create_script_fn(self, name, func_type)
|
|
|
|
check_against_reference(self,
|
|
script_fn,
|
|
func,
|
|
op.output_func,
|
|
(*sample.input,) + sample.args,
|
|
sample.kwargs,
|
|
no_grad=not test_backward)
|
|
|
|
# Check traced forward, grad, and grad grad
|
|
traced_fn = create_traced_fn(self, variant)
|
|
check_against_reference(self,
|
|
traced_fn,
|
|
func,
|
|
op.output_func,
|
|
(*sample.input,) + sample.args,
|
|
sample.kwargs,
|
|
no_grad=not test_backward)
|
|
|
|
# Check alias annotation schema for correctness (make
|
|
# sure inputs that aren't supposed to be modified aren't)
|
|
# Note: only runs in float32 and int64 because schema isn't affected by dtype,
|
|
# so running it on all dtypes is would be excessive
|
|
if dtype in [torch.float32, torch.int32]:
|
|
check_alias_annotation(name, (*sample.input,) + sample.args, sample.kwargs,
|
|
func_type=func_type, aten_name=op.aten_name)
|
|
|
|
# Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
|
|
if dtype is torch.float32:
|
|
# Sandcastle doesn't fuse nodes
|
|
if IS_SANDCASTLE:
|
|
# fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
|
|
nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
|
|
fusible_nodes = []
|
|
else:
|
|
nonfusible_nodes = op.autodiff_nonfusible_nodes
|
|
fusible_nodes = op.autodiff_fusible_nodes
|
|
|
|
self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
|
|
self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
|
|
|
|
@ops([op for op in op_db if op.aliases])
|
|
def test_jit_alias_remapping(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
if len(samples) == 0:
|
|
self.skipTest("Skipped! No sample inputs!")
|
|
|
|
# NOTE: only tests on first sample
|
|
sample = samples[0]
|
|
|
|
# Prepare data for test scripting
|
|
# Below we prepare strings of args/kwargs with and without type annotations.
|
|
# These strings are inserted into function template strings which is then torch scripted.
|
|
# - args string is ["t0", "t1", ...] corresponds to the input tensors required by the op
|
|
# - args_annot_kw is the string for the template function signature, for example,
|
|
# ["t0", "t1", "s0: float", "s1: bool", "max: float = 1.0", "min: float = 0.0"] ->
|
|
# def fn(t0, t1, s0: float, s1: bool, max: float = 1.0, min: float = 0.0)
|
|
# - args_kw is the string of args/kwargs used to call the op, same as args_annot_kw but
|
|
# without type annotations
|
|
args = [f"t{i}" for i in range(len(sample.input))]
|
|
args_annot_kw = args + \
|
|
[f"s{i}: {type(v).__name__}" for i, v in enumerate(sample.args)] + \
|
|
[f"{k}: {type(v).__name__} = {v}" for k, v in sample.kwargs.items()]
|
|
args_kw = args + \
|
|
[f"s{i}" for i in range(len(sample.args))] + \
|
|
[f"{k}={v}" for k, v in sample.kwargs.items()]
|
|
|
|
# Prepare data for test tracing
|
|
sample_args_kwargs = ()
|
|
if len(sample.args) > 0:
|
|
sample_args_kwargs += (sample.args, )
|
|
if len(sample.kwargs) > 0:
|
|
sample_args_kwargs += (sample.kwargs, )
|
|
|
|
original_name = op.name
|
|
original_name_inplace = original_name + "_"
|
|
expected_dtype = op(*sample.input, *sample.args, **sample.kwargs).dtype
|
|
|
|
for a_op in op.aliases:
|
|
inplace = a_op.inplace_variant
|
|
method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
|
|
variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)
|
|
|
|
# Test scripting:
|
|
for variant in variants:
|
|
variant_name = variant.__name__
|
|
op_name = original_name_inplace if variant is inplace else original_name
|
|
|
|
if variant in method_or_inplace:
|
|
fn_template = '''
|
|
def _fn(t0{c}{args_annot_kw}):
|
|
return t0.{alias_name}({args_kw})
|
|
'''
|
|
# remove the first input tensor
|
|
script = fn_template.format(
|
|
c=", " if len(args_kw[1:]) > 1 else "",
|
|
args_annot_kw=", ".join(args_annot_kw[1:]),
|
|
args_kw=", ".join(args_kw[1:]),
|
|
alias_name=variant_name,
|
|
)
|
|
else:
|
|
fn_template = '''
|
|
def _fn({args_annot_kw}):
|
|
return variant({args_kw})
|
|
'''
|
|
script = fn_template.format(
|
|
args_annot_kw=", ".join(args_annot_kw),
|
|
args_kw=", ".join(args_kw),
|
|
)
|
|
scripted = torch.jit.CompilationUnit(script)._fn
|
|
|
|
if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
|
|
try:
|
|
inp = (clone_input_helper(input) for input in sample.input)
|
|
scripted(*inp, *sample.args, **sample.kwargs)
|
|
except Exception as e:
|
|
continue
|
|
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
|
|
|
|
inp = (clone_input_helper(input) for input in sample.input)
|
|
scripted(*inp, *sample.args, **sample.kwargs)
|
|
inp = (clone_input_helper(input) for input in sample.input)
|
|
graph = scripted.graph_for(*inp, *sample.args, **sample.kwargs)
|
|
FileCheck().check(op_name).check_not(variant_name).run(graph)
|
|
|
|
# Test tracing:
|
|
for variant in variants:
|
|
variant_name = variant.__name__
|
|
op_name = original_name_inplace if variant is inplace else original_name
|
|
|
|
def _fn(*sample_args, **sample_kwargs):
|
|
return variant(*sample_args, **sample_kwargs)
|
|
|
|
inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs
|
|
traced = torch.jit.trace(_fn, *inp)
|
|
inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs
|
|
traced(*inp)
|
|
inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs
|
|
graph = traced.graph_for(*inp)
|
|
FileCheck().check(op_name).check_not(variant_name).run(graph)
|
|
|
|
# Validates ops implement the correct out= behavior
|
|
# See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
|
|
# for a description of the correct behavior
|
|
# TODO: operations that support out= but don't support float
|
|
# are not covered by this test.
|
|
@ops(op_db, allowed_dtypes=(torch.float,))
|
|
def test_out(self, device, dtype, op):
|
|
# TODO: verify the op doesn't support the out= kwarg
|
|
if not op.supports_out:
|
|
self.skipTest("Skipped! Op doesn't support out= kwarg.")
|
|
|
|
# NOTE: only tests on first sample
|
|
samples = op.sample_inputs(device, dtype)
|
|
sample = samples[0]
|
|
|
|
# calls it normally to get the expected result
|
|
expected = op(*sample.input, *sample.args, **sample.kwargs)
|
|
op_out = partial(op, *sample.input, *sample.args, **sample.kwargs)
|
|
|
|
# Short-circuits if output is not a single tensor or an
|
|
# iterable of tensors
|
|
|
|
# Returns True if iterable is an iterable of tensors (includes empty iterables)
|
|
# and False o.w.
|
|
def _is_iterable_of_tensors(iterable):
|
|
try:
|
|
for t in iter(iterable):
|
|
if not isinstance(t, torch.Tensor):
|
|
return False
|
|
except TypeError as te:
|
|
return False
|
|
|
|
return True
|
|
|
|
if not isinstance(expected, torch.Tensor) and not _is_iterable_of_tensors(expected):
|
|
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
|
|
|
|
# A wrapper around map that works with single tensors and always
|
|
# instantiates the map. Used below to apply transforms to
|
|
# single tensor and iterable tensor outputs.
|
|
def _apply_out_transform(fn, out):
|
|
if isinstance(out, torch.Tensor):
|
|
return fn(out)
|
|
|
|
# assumes (see above) that out is an iterable of tensors
|
|
return tuple(map(fn, out))
|
|
|
|
# Case 0: out= with the correct shape, dtype, and device
|
|
# but NaN values for floating point and complex tensors, and
|
|
# maximum values for integer tensors.
|
|
# Expected behavior: out= values have no effect on the computation.
|
|
def _case_zero_transform(t):
|
|
try:
|
|
info = torch.iinfo(t.dtype)
|
|
return torch.full_like(t, info.max)
|
|
except TypeError as te:
|
|
# for non-integer types fills with NaN
|
|
return torch.full_like(t, float('nan'))
|
|
|
|
out = _apply_out_transform(_case_zero_transform, expected)
|
|
op_out(out=out)
|
|
self.assertEqual(expected, out)
|
|
|
|
# Case 1: out= with the correct shape, dtype, and device,
|
|
# but noncontiguous.
|
|
# Expected behavior: strides are respected.
|
|
def _case_one_transform(t):
|
|
return make_tensor(t.shape,
|
|
dtype=t.dtype,
|
|
device=t.device,
|
|
discontiguous=True)
|
|
|
|
# Extracts strides from a tensor or iterable of tensors into a tuple
|
|
def _extract_strides(out):
|
|
if isinstance(out, torch.Tensor):
|
|
return (out.stride(),)
|
|
|
|
# assumes (see above) that out is an iterable of tensors
|
|
return tuple(map(lambda t: t.stride(), out))
|
|
|
|
out = _apply_out_transform(_case_one_transform, expected)
|
|
original_strides = _extract_strides(out)
|
|
|
|
op_out(out=out)
|
|
final_strides = _extract_strides(out)
|
|
|
|
self.assertEqual(expected, out)
|
|
self.assertEqual(original_strides, final_strides)
|
|
|
|
# Case 2: out= with the correct dtype and device, but the wrong shape
|
|
# Expected behavior: resize with a warning.
|
|
def _case_two_transform(t):
|
|
wrong_shape = list(t.shape)
|
|
|
|
if len(wrong_shape) == 0:
|
|
# Handles scalar tensor case (empty list)
|
|
wrong_shape = [2]
|
|
else:
|
|
wrong_shape[-1] = wrong_shape[-1] + 1
|
|
return make_tensor(wrong_shape, dtype=t.dtype, device=t.device)
|
|
|
|
out = _apply_out_transform(_case_two_transform, expected)
|
|
with self.assertWarnsRegex(UserWarning, "An output with one or more elements"):
|
|
op_out(out=out)
|
|
self.assertEqual(expected, out)
|
|
|
|
# Case 3: out= with the correct dtype and device, but an empty
|
|
# tensor.
|
|
# Expected behavior: resize without warning.
|
|
def _case_three_transform(t):
|
|
return make_tensor((0,),
|
|
dtype=t.dtype,
|
|
device=t.device)
|
|
|
|
out = _apply_out_transform(_case_three_transform, expected)
|
|
with warnings.catch_warnings(record=True) as caught:
|
|
warnings.simplefilter("always")
|
|
op_out(out=out)
|
|
|
|
# Verifies no warning is a resize warning
|
|
for w in caught:
|
|
if "An output with one or more elements" in str(w.message):
|
|
self.fail("Resizing an out= argument with no elements threw a resize warning!")
|
|
|
|
self.assertEqual(expected, out)
|
|
|
|
# Case 4: out= with correct shape and dtype, but wrong device.
|
|
wrong_device = None
|
|
if torch.device(device).type != 'cpu':
|
|
wrong_device = 'cpu'
|
|
elif torch.cuda.is_available():
|
|
wrong_device = 'cuda'
|
|
|
|
if wrong_device is not None:
|
|
def _case_four_transform(t):
|
|
return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
|
|
|
|
out = _apply_out_transform(_case_four_transform, expected)
|
|
with self.assertRaises(RuntimeError):
|
|
op_out(out=out)
|
|
|
|
# Case 5: out= with correct shape and device, but a dtype
|
|
# that output cannot be "safely" cast to (long).
|
|
# Expected behavior: error.
|
|
# NOTE: this case is filtered by dtype since some ops produce
|
|
# bool tensors, for example, which can be safely cast to any
|
|
# dtype. It is applied when single tensors are floating point or complex
|
|
# dtypes, or if an op returns multiple tensors when at least one such
|
|
# tensor is a floating point or complex dtype.
|
|
_dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
|
|
if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or
|
|
(not isinstance(expected, torch.Tensor) and
|
|
reduce(lambda cur, t: cur or t.dtype in _dtypes, expected, False))):
|
|
def _case_five_transform(t):
|
|
return make_tensor(t.shape, dtype=torch.long, device=t.device)
|
|
|
|
out = out = _apply_out_transform(_case_five_transform, expected)
|
|
with self.assertRaises(RuntimeError):
|
|
op_out(out=out)
|
|
|
|
|
|
instantiate_device_type_tests(TestOpInfo, globals())
|
|
instantiate_device_type_tests(TestGradients, globals())
|
|
instantiate_device_type_tests(TestCommon, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|