mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53669 This PR does two things: * Ports `pow` to be structured * Fixes a bug with how pow handles mixed cpu and cuda tensors **bug fix** Pow is a binary op, and all binary ops that use TensorIterator are currently written to handle the case when one of the inputs is a CUDA tensor, and the other is a zero-dimensional cpu tensor. `pow` incidentally only handles one of the two cases: it fails when the CUDA tensor is passed as the exponent, e.g. `at::pow(torch.tensor(2.0, device='cpu'), torch.tensor([2, 2], device='cuda'))`. Porting `pow` to structured happened to change the error that was outputted from a `TORCH_CHECK` in TensorIterator to an `INTERNAL_ASSERT` in loop.cuh, so I ended up trying to fix the error and update the tests. I added more details in a comment on the PR. **notes on the structured port** Pow is a little weird, so I wrote down a couple of issues I noticed during the port: * Multiple independent overloads. `pow` has two overloads that have their own cpu/cuda kernels, meaning one doesn't call the other. I have to update the names of the kernel overloads to make the compiler happy, since the codegen would otherwise try to generate two classes with the same name. `pow` actually has 3 overloads that all have `out` variants, so I ported all 3 to structured- one of them just happens to redispatch one of the others in most cases. * Name propagation. Is name propagation implemented per operator? Or is expected to work for most/all ops by default. Right now it looks like it happens for TensorIterator ops by default. For ops that don't use TensorIterator, we need to explicitly pass the names through to the `set_output()` call in the meta function. This happened to matter for `pow` because it has 3 overloads, but only two of them directly use TensorIterator. I had to pass names directly to `set_output` in the 3rd overload to make tests happy. * Lack of `const Tensor &` in the C++ API. It's a goal to slowly make all `Tensor &` arguments const as part of the structured port, but in this case I needed to explicitly cast constness away because one structured kernel called back into the C++ API, which still has ordinary `Tensor &` arguments. This probably isn't something we'll fix soon, since we have boxing logic that actually relies on the `Tensor &` / `const Tensor &` distinction in some places. Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D27029821 Pulled By: bdhirsh fbshipit-source-id: c1786e770de6e6c2474b9a48210b88057ab1018e
2679 lines
120 KiB
Python
2679 lines
120 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
import itertools
|
|
from itertools import product
|
|
import math
|
|
import random
|
|
import unittest
|
|
import warnings
|
|
import operator
|
|
from functools import partial
|
|
|
|
from torch._six import inf, nan
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase, iter_indices, TEST_WITH_ASAN, run_tests,
|
|
torch_to_numpy_dtype_dict, make_tensor, TEST_SCIPY, set_default_dtype)
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA,
|
|
dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA,
|
|
skipCUDAIfRocm, skipIf)
|
|
|
|
if TEST_SCIPY:
|
|
import scipy.special
|
|
|
|
# TODO: remove this
|
|
def _generate_input(shape, dtype, device, with_extremal):
|
|
if shape == ():
|
|
x = torch.tensor((), dtype=dtype, device=device)
|
|
else:
|
|
if dtype.is_floating_point or dtype.is_complex:
|
|
# work around torch.randn not being implemented for bfloat16
|
|
if dtype == torch.bfloat16:
|
|
x = torch.randn(*shape, device=device) * random.randint(30, 100)
|
|
x = x.to(torch.bfloat16)
|
|
else:
|
|
x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100)
|
|
x[torch.randn(*shape) > 0.5] = 0
|
|
if with_extremal and dtype.is_floating_point:
|
|
# Use extremal values
|
|
x[torch.randn(*shape) > 0.5] = float('nan')
|
|
x[torch.randn(*shape) > 0.5] = float('inf')
|
|
x[torch.randn(*shape) > 0.5] = float('-inf')
|
|
elif with_extremal and dtype.is_complex:
|
|
x[torch.randn(*shape) > 0.5] = complex('nan')
|
|
x[torch.randn(*shape) > 0.5] = complex('inf')
|
|
x[torch.randn(*shape) > 0.5] = complex('-inf')
|
|
elif dtype == torch.bool:
|
|
x = torch.zeros(shape, dtype=dtype, device=device)
|
|
x[torch.randn(*shape) > 0.5] = True
|
|
else:
|
|
x = torch.randint(15, 100, shape, dtype=dtype, device=device)
|
|
|
|
return x
|
|
|
|
# TODO: refactor this out
|
|
# Converts half/bfloat16 dtype to float when device is cpu
|
|
def _convert_t(dtype, device):
|
|
if device == 'cpu' and dtype in {torch.half, torch.bfloat16}:
|
|
return torch.float
|
|
return dtype
|
|
|
|
# TODO: revise the tests to use make_tensor in common_utils.py instead
|
|
# Returns a tensor of the requested shape, dtype, and device
|
|
# Requesting a half CPU tensor returns a float CPU tensor with
|
|
# values representable by a half.
|
|
# Initialization uses randint for non-float types and randn for float types.
|
|
def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor:
|
|
# Returns a tensor filled with ones
|
|
if fill_ones:
|
|
return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device)
|
|
|
|
# Returns a tensor with random integer values
|
|
if not (dtype.is_floating_point or dtype.is_complex):
|
|
t = torch.randint(0, 10, shape, device=device)
|
|
if dtype != torch.uint8:
|
|
t = t - 5 # generate negative values also
|
|
return t.to(_convert_t(dtype, device))
|
|
|
|
# Populates the CPU tensor with floats representable as half/bfloat16
|
|
if dtype == torch.half and device == 'cpu':
|
|
return torch.randn(*shape, dtype=torch.float, device=device).half().float()
|
|
if dtype == torch.bfloat16 and device == 'cpu':
|
|
return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float()
|
|
|
|
# Default: returns a tensor with random float values
|
|
return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype)
|
|
|
|
# TODO: update to use opinfos consistently
|
|
class TestBinaryUfuncs(TestCase):
|
|
|
|
def test_add_broadcast_empty(self, device):
|
|
# empty + empty
|
|
self.assertRaises(RuntimeError, lambda: torch.randn(5, 0, device=device) + torch.randn(0, 5, device=device))
|
|
self.assertEqual(torch.randn(5, 0, device=device), torch.randn(0, device=device) + torch.randn(5, 0, device=device))
|
|
self.assertEqual(torch.randn(5, 0, 0, device=device), torch.randn(0, device=device) + torch.randn(5, 0, 1, device=device))
|
|
|
|
# scalar + empty
|
|
self.assertEqual(torch.randn(5, 0, 6, device=device), torch.randn((), device=device) + torch.randn(5, 0, 6, device=device))
|
|
|
|
# non-empty, empty
|
|
self.assertEqual(torch.randn(0, device=device), torch.randn(0, device=device) + torch.randn(1, device=device))
|
|
self.assertEqual(torch.randn(0, 7, 0, 6, 5, 0, 7, device=device),
|
|
torch.randn(0, 7, 0, 6, 5, 0, 1, device=device) + torch.randn(1, 1, 5, 1, 7, device=device))
|
|
self.assertRaises(RuntimeError, lambda: torch.randn(7, 0, device=device) + torch.randn(2, 1, device=device))
|
|
|
|
def test_addcmul_scalars_as_floats(self, device):
|
|
# zero-dim variables that don't require grad should bind to scalar arguments
|
|
x = torch.tensor(2.)
|
|
y = torch.tensor(3., device=device)
|
|
# 3 + (3 * 3) * 2
|
|
self.assertEqual(y.addcmul(y, y, value=x), 21)
|
|
|
|
x = torch.tensor(2., requires_grad=True)
|
|
self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x))
|
|
|
|
# TODO: update to work on CUDA, too
|
|
@onlyCPU
|
|
def test_comparison_ops(self, device):
|
|
x = torch.randn(5, 5)
|
|
y = torch.randn(5, 5)
|
|
|
|
eq = x == y
|
|
for idx in iter_indices(x):
|
|
self.assertEqual(x[idx] == y[idx], eq[idx] == 1)
|
|
|
|
ne = x != y
|
|
for idx in iter_indices(x):
|
|
self.assertEqual(x[idx] != y[idx], ne[idx] == 1)
|
|
|
|
lt = x < y
|
|
for idx in iter_indices(x):
|
|
self.assertEqual(x[idx] < y[idx], lt[idx] == 1)
|
|
|
|
le = x <= y
|
|
for idx in iter_indices(x):
|
|
self.assertEqual(x[idx] <= y[idx], le[idx] == 1)
|
|
|
|
gt = x > y
|
|
for idx in iter_indices(x):
|
|
self.assertEqual(x[idx] > y[idx], gt[idx] == 1)
|
|
|
|
ge = x >= y
|
|
for idx in iter_indices(x):
|
|
self.assertEqual(x[idx] >= y[idx], ge[idx] == 1)
|
|
|
|
# TODO: update to work on CUDA, too
|
|
@onlyCPU
|
|
def test_comparison_ops_must_take_bool_output(self, device):
|
|
for op in [torch.lt, torch.le, torch.gt, torch.ge, torch.eq, torch.ne,
|
|
torch.logical_and, torch.logical_or, torch.logical_xor]:
|
|
self.assertEqual(op(torch.tensor([True]), torch.tensor([False])).dtype, torch.bool)
|
|
|
|
# TODO: update to work on CUDA, too
|
|
@onlyCPU
|
|
def test_inplace_comparison_ops_require_inputs_have_same_dtype(self, device):
|
|
with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'):
|
|
for op in ['lt_', 'le_', 'gt_', 'ge_', 'eq_', 'ne_', 'logical_xor_', 'logical_and_', 'logical_or_']:
|
|
x = torch.tensor([1], dtype=torch.int)
|
|
y = torch.tensor([2], dtype=torch.long)
|
|
in_place_method = getattr(x, op)
|
|
in_place_method(y)
|
|
|
|
# TODO: update to work on CUDA, too
|
|
@onlyCPU
|
|
def test_comparison_ops_check_for_scalar_overflow(self, device):
|
|
s = 1 << 20
|
|
t = torch.tensor([1 << 5], dtype=torch.uint8)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t < s)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(s < t)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t <= s)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(s <= t)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t > s)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(s > t)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t >= s)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(s >= t)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t == s)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(s == t)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t != s)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(s != t)
|
|
|
|
# TODO: update to work on CUDA, too
|
|
@onlyCPU
|
|
def test_comparison_ops_check_for_zerodim_tensor_overflow(self, device):
|
|
t1 = torch.tensor([1 << 5], dtype=torch.uint8)
|
|
t2 = torch.tensor([1 << 30], dtype=torch.int32)
|
|
ts1 = torch.tensor(1 << 20, dtype=torch.int32)
|
|
ts2 = torch.tensor(1 << 40, dtype=torch.int64)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t1 < ts1)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(ts2 < t2)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t1 <= ts1)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(ts2 <= t2)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t1 > ts1)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(ts2 > t2)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t1 >= ts1)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(ts2 >= t2)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t1 == ts1)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(ts2 == t2)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(t1 != ts1)
|
|
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
|
self.assertTrue(ts2 != t2)
|
|
|
|
# TODO: update to work on CUDA, too
|
|
@onlyCPU
|
|
def test_bitwise_ops(self, device):
|
|
x = torch.randn(5, 5).gt(0)
|
|
y = torch.randn(5, 5).gt(0)
|
|
|
|
and_result = x & y
|
|
for idx in iter_indices(x):
|
|
if and_result[idx]:
|
|
self.assertTrue(x[idx] and y[idx])
|
|
else:
|
|
self.assertFalse(x[idx] and y[idx])
|
|
|
|
or_result = x | y
|
|
for idx in iter_indices(x):
|
|
if or_result[idx]:
|
|
self.assertTrue(x[idx] or y[idx])
|
|
else:
|
|
self.assertFalse(x[idx] or y[idx])
|
|
|
|
xor_result = x ^ y
|
|
for idx in iter_indices(x):
|
|
if xor_result[idx]:
|
|
self.assertTrue(x[idx] ^ y[idx])
|
|
else:
|
|
self.assertFalse(x[idx] ^ y[idx])
|
|
|
|
x_clone = x.clone()
|
|
x_clone &= y
|
|
self.assertEqual(x_clone, and_result)
|
|
|
|
x_clone = x.clone()
|
|
x_clone |= y
|
|
self.assertEqual(x_clone, or_result)
|
|
|
|
x_clone = x.clone()
|
|
x_clone ^= y
|
|
self.assertEqual(x_clone, xor_result)
|
|
|
|
def test_inplace_division(self, device):
|
|
t = torch.rand(5, 5, device=device)
|
|
id_before = id(t)
|
|
t /= 2
|
|
id_after = id(t)
|
|
self.assertEqual(id_before, id_after)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_complex=False))
|
|
def test_div_rounding_modes(self, device, dtype):
|
|
if dtype.is_floating_point:
|
|
low, high = -10.0, 10.0
|
|
else:
|
|
info = torch.iinfo(dtype)
|
|
low, high = info.min, info.max
|
|
|
|
a = make_tensor((100,), device, dtype, low=low, high=high)
|
|
b = make_tensor((100,), device, dtype, low=low, high=high)
|
|
|
|
# Avoid division by zero so we can test (a / b) * b == a
|
|
if dtype.is_floating_point:
|
|
eps = 0.1
|
|
b[(-eps < b) & (b < eps)] = eps
|
|
else:
|
|
b[b == 0] = 1
|
|
|
|
if not dtype.is_floating_point:
|
|
# floor(a / b) * b can be < a, so fixup slightly to avoid underflow
|
|
a = torch.where(a < 0, a + b, a)
|
|
|
|
d_true = torch.divide(a, b, rounding_mode='true')
|
|
self.assertTrue(d_true.is_floating_point())
|
|
self.assertEqual(d_true * b, a.to(d_true.dtype))
|
|
|
|
d_floor = torch.divide(a, b, rounding_mode='floor')
|
|
if dtype not in (torch.bfloat16, torch.half):
|
|
self.assertEqual(d_floor * b + torch.remainder(a, b), a)
|
|
else:
|
|
self.assertEqual(d_floor * b + torch.remainder(a.float(), b.float()), a,
|
|
exact_dtype=False)
|
|
|
|
d_trunc = torch.divide(a, b, rounding_mode='trunc')
|
|
rounding_unsupported = (
|
|
dtype == torch.half and device != 'cuda' or
|
|
dtype == torch.bfloat16 and device != 'cpu')
|
|
d_ref = d_true.float() if rounding_unsupported else d_true
|
|
self.assertEqual(d_trunc, d_ref.trunc().to(dtype))
|
|
|
|
@dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64)
|
|
def test_div_rounding_nonfinite(self, device, dtype):
|
|
|
|
# Compare division of special floating point values against NumPy
|
|
x = torch.tensor([1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan],
|
|
dtype=dtype)
|
|
|
|
a, b = x[None, :].clone(), x[:, None].clone()
|
|
|
|
# Compare bfloat16 against NumPy float
|
|
exact_dtype = dtype != torch.bfloat16
|
|
if exact_dtype:
|
|
an, bn = a.cpu().numpy(), b.cpu().numpy()
|
|
else:
|
|
an, bn = a.float().cpu().numpy(), b.float().cpu().numpy()
|
|
|
|
for mode, np_ref in (("true", np.true_divide), ("floor", np.floor_divide)):
|
|
with np.errstate(all='ignore'):
|
|
expect = np_ref(an, bn)
|
|
with set_default_dtype(torch.double):
|
|
actual = torch.divide(a, b, rounding_mode=mode)
|
|
self.assertEqual(actual, torch.from_numpy(expect),
|
|
exact_device=False, exact_dtype=exact_dtype)
|
|
|
|
# Compare contiguous (likely vectorized) against non-contiguous (not vectorized)
|
|
storage = torch.empty((20, 20), dtype=dtype, device=device)
|
|
storage[::2, ::2] = a
|
|
storage[1::2, 1::2] = b
|
|
|
|
for rounding_mode in ("true", "trunc", "floor"):
|
|
expect = torch.divide(storage[::2, ::2], storage[1::2, 1::2], rounding_mode=rounding_mode)
|
|
actual = torch.divide(a, b, rounding_mode=rounding_mode)
|
|
self.assertEqual(actual, expect)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes(
|
|
include_bool=False, include_complex=False, include_bfloat16=False))
|
|
def test_div_rounding_numpy(self, device, dtype):
|
|
info = (torch.finfo(dtype) if dtype.is_floating_point
|
|
else torch.iinfo(dtype))
|
|
low, high = info.min, info.max
|
|
|
|
# Compare division of random values against NumPy
|
|
a = make_tensor((4096,), device, dtype, low=low, high=high)
|
|
b = make_tensor((4096,), device, dtype, low=low, high=high)
|
|
|
|
# Avoid integer division by zero which raises
|
|
if not dtype.is_floating_point:
|
|
b[b == 0] = 1
|
|
|
|
# Compare bfloat16 against NumPy float
|
|
exact_dtype = dtype != torch.bfloat16
|
|
|
|
if exact_dtype:
|
|
an, bn = a.cpu().numpy(), b.cpu().numpy()
|
|
else:
|
|
an, bn = a.float().cpu().numpy(), b.float().cpu().numpy()
|
|
|
|
for mode, np_ref in (
|
|
("true", np.true_divide),
|
|
("floor", np.floor_divide),
|
|
("trunc", lambda a, b: np.trunc(np.true_divide(a, b)).astype(a.dtype))
|
|
):
|
|
with np.errstate(all='ignore'):
|
|
expect = torch.from_numpy(np_ref(an, bn))
|
|
|
|
# Contiguous (likely vectorized)
|
|
with set_default_dtype(torch.double):
|
|
actual = torch.divide(a, b, rounding_mode=mode)
|
|
self.assertEqual(actual, expect, exact_device=False, exact_dtype=exact_dtype)
|
|
|
|
# Non-contiguous (not vectorized)
|
|
expect = expect[::2]
|
|
with set_default_dtype(torch.double):
|
|
actual = torch.divide(a[::2], b[::2], rounding_mode=mode)
|
|
|
|
self.assertEqual(actual, expect, exact_device=False, exact_dtype=exact_dtype)
|
|
|
|
# Tests that trying to add, inplace, a CUDA tensor to a CPU tensor
|
|
# throws the correct error message
|
|
@onlyCUDA
|
|
def test_cross_device_inplace_error_msg(self, device):
|
|
a = torch.tensor(2.)
|
|
b = torch.tensor(2., device=device)
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"Expected all tensors to be on the same device"):
|
|
a += b
|
|
|
|
# TODO: refactor this test into a more generic one, it's parked here currently
|
|
@onlyOnCPUAndCUDA
|
|
def test_out_resize_warning(self, device):
|
|
a = torch.tensor((1, 2, 3), device=device, dtype=torch.float32)
|
|
b = torch.tensor((4, 5, 6), device=device, dtype=torch.float32)
|
|
|
|
unary_inputs = (a,)
|
|
binary_inputs = (a, b)
|
|
unary_ops = (torch.ceil, torch.exp)
|
|
binary_ops = (torch.add, torch.sub)
|
|
for op in (unary_ops + binary_ops):
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("always")
|
|
inputs = unary_inputs if op in unary_ops else binary_inputs
|
|
|
|
# No warnings
|
|
op(*inputs, out=torch.empty(3, device=device))
|
|
op(*inputs, out=torch.empty(0, device=device))
|
|
self.assertEqual(len(w), 0)
|
|
|
|
# Cases that throw warnings
|
|
op(*inputs, out=torch.empty(2, device=device))
|
|
self.assertEqual(len(w), 1)
|
|
|
|
# Verifies that the inplace dunders (like idiv) actually are in place
|
|
@onlyOnCPUAndCUDA
|
|
def test_inplace_dunders(self, device):
|
|
t = torch.randn((1,), device=device)
|
|
expected = t.data_ptr()
|
|
t += 1
|
|
t -= 1
|
|
t *= 1
|
|
t /= 1
|
|
with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'):
|
|
t //= 1
|
|
t %= 1
|
|
self.assertEqual(expected, t.data_ptr())
|
|
|
|
def check_internal_mem_overlap(self, inplace_op, num_inputs,
|
|
dtype, device,
|
|
expected_failure=False):
|
|
if isinstance(inplace_op, str):
|
|
inplace_op = getattr(torch.Tensor, inplace_op)
|
|
input = torch.randn(1, dtype=dtype, device=device).expand(3, 3)
|
|
inputs = [input] + [torch.randn_like(input)
|
|
for i in range(num_inputs - 1)]
|
|
if not expected_failure:
|
|
with self.assertRaisesRegex(RuntimeError, 'single memory location'):
|
|
inplace_op(*inputs)
|
|
else:
|
|
with self.assertRaises(AssertionError):
|
|
with self.assertRaisesRegex(RuntimeError, 'single memory location'):
|
|
inplace_op(*inputs)
|
|
|
|
def unary_check_input_output_mem_overlap(self, data, sz, op,
|
|
expected_failure=False):
|
|
|
|
def _test(op, output, input):
|
|
output_exp = torch.empty_like(output)
|
|
op(input, out=output_exp)
|
|
self.assertEqual(op(input, out=output), output_exp, msg=op.__name__)
|
|
|
|
# output is identical to input:
|
|
_test(op, output=data[0:sz], input=data[0:sz])
|
|
# output and input are independent:
|
|
_test(op, output=data[0:sz], input=data[sz:2 * sz])
|
|
# output partially overlaps with input:
|
|
if not expected_failure:
|
|
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
|
|
_test(op, data[0:sz], data[1:sz + 1])
|
|
else:
|
|
with self.assertRaises(AssertionError):
|
|
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
|
|
_test(op, data[0:sz], data[1:sz + 1])
|
|
|
|
def binary_check_input_output_mem_overlap(self, op, device,
|
|
expected_failure=False):
|
|
sz = 3
|
|
data = torch.randn(2 * sz, device=device)
|
|
other = torch.randn(sz, device=device)
|
|
|
|
self.unary_check_input_output_mem_overlap(
|
|
data, sz, lambda input, out: op(other, input, out=out),
|
|
expected_failure=expected_failure)
|
|
|
|
self.unary_check_input_output_mem_overlap(
|
|
data, sz, lambda input, out: op(input, other, out=out),
|
|
expected_failure=expected_failure)
|
|
|
|
@dtypes(torch.double)
|
|
def test_binary_op_mem_overlap(self, device, dtype):
|
|
ops = [
|
|
("add", True, True, 'cpu'),
|
|
("add", True, True, 'cuda'),
|
|
("mul", True, True, 'cpu'),
|
|
("mul", True, True, 'cuda'),
|
|
("sub", True, True, 'cpu'),
|
|
("sub", True, True, 'cuda'),
|
|
("div", True, True, 'cpu'),
|
|
("div", True, True, 'cuda'),
|
|
("pow", True, True, 'cpu'),
|
|
("pow", True, True, 'cuda'),
|
|
("fmod", True, True, 'cpu'),
|
|
("fmod", True, True, 'cuda'),
|
|
("atan2", True, True, 'cpu'),
|
|
("atan2", True, True, 'cuda'),
|
|
("hypot", True, True, 'cpu'),
|
|
("hypot", True, True, 'cuda'),
|
|
("igamma", True, True, 'cpu'),
|
|
("igamma", True, True, 'cuda'),
|
|
("igammac", True, True, 'cpu'),
|
|
("igammac", True, True, 'cuda'),
|
|
("nextafter", True, True, 'cpu'),
|
|
("nextafter", True, True, 'cuda'),
|
|
("le", True, True, 'cpu'),
|
|
("le", True, True, 'cuda'),
|
|
("lt", True, True, 'cpu'),
|
|
("lt", True, True, 'cuda'),
|
|
("ge", True, True, 'cpu'),
|
|
("ge", True, True, 'cuda'),
|
|
("gt", True, True, 'cpu'),
|
|
("gt", True, True, 'cuda'),
|
|
("eq", True, True, 'cpu'),
|
|
("eq", True, True, 'cuda'),
|
|
("ne", True, True, 'cpu'),
|
|
("ne", True, True, 'cuda'),
|
|
("logical_and", True, True, 'cpu'),
|
|
("logical_and", True, True, 'cuda'),
|
|
("logical_or", True, True, 'cpu'),
|
|
("logical_or", True, True, 'cuda'),
|
|
("logical_xor", True, True, 'cpu'),
|
|
("logical_xor", True, True, 'cuda'),
|
|
]
|
|
|
|
for (fn, has_input_output_mem_overlap_check,
|
|
has_internal_mem_overlap_check, dev) in ops:
|
|
if dev != device:
|
|
continue
|
|
out_op = getattr(torch, fn)
|
|
inplace_op = getattr(torch.Tensor, fn + '_')
|
|
self.check_internal_mem_overlap(
|
|
inplace_op, 2, dtype, device,
|
|
expected_failure=not has_internal_mem_overlap_check)
|
|
|
|
self.binary_check_input_output_mem_overlap(out_op, device,
|
|
expected_failure=not has_input_output_mem_overlap_check)
|
|
|
|
def _do_pow_for_exponents(self, m1, exponents, pow_fn, atol):
|
|
for num in exponents:
|
|
if isinstance(num, int) and num < 0 and not m1.is_floating_point() and not m1.is_complex():
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r'Integers to negative integer powers are not allowed\.'):
|
|
torch.pow(m1[4], num)
|
|
else:
|
|
# base - tensor, exponent - number
|
|
# contiguous
|
|
res1 = torch.pow(m1[4], num)
|
|
res2 = res1.clone().zero_()
|
|
# `math.pow` has issues with complex exponentiation so we need to resort to normal `pow`.
|
|
for i in range(res2.size(0)):
|
|
res2[i] = pow_fn(m1[4][i], num)
|
|
rtol = 0 if atol is not None else None
|
|
self.assertEqual(res1, res2, atol=atol, rtol=rtol)
|
|
|
|
# non-contiguous
|
|
res1 = torch.pow(m1[:, 4], num)
|
|
res2 = res1.clone().zero_()
|
|
for i in range(res2.size(0)):
|
|
res2[i] = pow_fn(m1[i, 4], num)
|
|
self.assertEqual(res1, res2, atol=atol, rtol=rtol)
|
|
|
|
# scalar ** tensor to enforce correct handling of dtypes for __rpow__().
|
|
expected_dtype = torch.result_type(num, m1)
|
|
res1 = num ** m1[4]
|
|
res2 = torch.tensor(num, dtype=expected_dtype, device=m1.device) ** m1[4]
|
|
self.assertEqual(res1, res2)
|
|
self.assertEqual(res1.dtype, expected_dtype)
|
|
|
|
def test_pow(self, device):
|
|
# [res] torch.pow([res,] x)
|
|
|
|
# pow has dedicated implementation for different exponents
|
|
for dtype in torch.testing.get_all_math_dtypes(device):
|
|
|
|
# This test won't work on torch.half because math.pow will generate a much more accurate result. We skip it
|
|
# for now.
|
|
if dtype == torch.half:
|
|
continue
|
|
|
|
# deferring to https://github.com/pytorch/pytorch/pull/36793
|
|
if dtype.is_complex:
|
|
continue
|
|
|
|
m1 = torch.empty(0, dtype=dtype, device=device)
|
|
if m1.is_floating_point() or m1.is_complex():
|
|
m1 = torch.rand(100, 100, dtype=dtype, device=device) + 0.5
|
|
else:
|
|
# math.pow will overflow and throw exceptions for large integers
|
|
range_high = 4 if dtype in (torch.int8, torch.uint8) else 10
|
|
m1 = torch.randint(1, range_high, (100, 100), dtype=dtype, device=device)
|
|
|
|
exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3]
|
|
complex_exponents = [-2.5j, -1.0j, 0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j]
|
|
if m1.is_complex():
|
|
self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4)
|
|
else:
|
|
self._do_pow_for_exponents(m1, exponents, math.pow, None)
|
|
self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
|
|
|
|
# base - number, exponent - tensor
|
|
# contiguous
|
|
res1 = torch.pow(3, m1[4])
|
|
res2 = res1.clone().zero_()
|
|
for i in range(res2.size(0)):
|
|
res2[i] = math.pow(3, m1[4, i])
|
|
self.assertEqual(res1, res2)
|
|
|
|
# non-contiguous
|
|
res1 = torch.pow(3, m1[:, 4])
|
|
res2 = res1.clone().zero_()
|
|
for i in range(res2.size(0)):
|
|
res2[i] = math.pow(3, m1[i][4])
|
|
self.assertEqual(res1, res2)
|
|
|
|
# resize behavior for exp == 1
|
|
out = torch.zeros(1, dtype=dtype, device=device)
|
|
torch.pow(m1, 1, out=out)
|
|
self.assertEqual(out, m1)
|
|
|
|
# TODO: refactor all these tests using opinfos properly
|
|
def _test_pow(self, base, exponent, np_exponent=None):
|
|
if np_exponent is None:
|
|
np_exponent = exponent
|
|
|
|
def to_np(value):
|
|
if isinstance(value, torch.Tensor):
|
|
return value.cpu().numpy()
|
|
return value
|
|
|
|
try:
|
|
np_res = np.power(to_np(base), to_np(np_exponent))
|
|
expected = torch.from_numpy(np_res) if isinstance(np_res, np.ndarray) else torch.tensor(np_res, dtype=base.dtype)
|
|
except ValueError as e:
|
|
err_msg = "Integers to negative integer powers are not allowed."
|
|
self.assertEqual(str(e), err_msg)
|
|
out = torch.empty_like(base)
|
|
test_cases = [
|
|
lambda: base.pow(exponent),
|
|
lambda: base.pow_(exponent),
|
|
lambda: torch.pow(base, exponent),
|
|
lambda: torch.pow(base, exponent, out=out)
|
|
]
|
|
for test_case in test_cases:
|
|
self.assertRaisesRegex(RuntimeError, err_msg, test_case)
|
|
else:
|
|
if isinstance(base, torch.Tensor):
|
|
actual = base.pow(exponent)
|
|
self.assertEqual(actual, expected.to(actual))
|
|
actual = base.clone()
|
|
# When base is a 0-dim cpu tensor and exp is a cuda tensor, we exp `pow` to work but `pow_` to fail, since
|
|
# `pow` will try to create the output tensor on a cuda device, but `pow_` needs to use the cpu tensor as the output
|
|
if base.dim() == 0 and base.device.type == 'cpu' and exponent.device.type == 'cuda':
|
|
regex = 'Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!'
|
|
elif torch.can_cast(torch.result_type(base, exponent), base.dtype):
|
|
actual2 = actual.pow_(exponent)
|
|
self.assertEqual(actual, expected)
|
|
self.assertEqual(actual2, expected)
|
|
else:
|
|
self.assertRaisesRegex(RuntimeError, "Found dtype \\w+ but expected \\w+", lambda: actual.pow_(exponent))
|
|
|
|
actual = torch.pow(base, exponent)
|
|
self.assertEqual(actual, expected.to(actual))
|
|
|
|
actual2 = torch.pow(base, exponent, out=actual)
|
|
self.assertEqual(actual, expected.to(actual))
|
|
self.assertEqual(actual2, expected.to(actual))
|
|
|
|
def test_int_pow(self, device):
|
|
|
|
def _test_integral_pow(dt, range, dev):
|
|
tensor = torch.tensor((3, 3), dtype=dt, device=dev).random_(*range)
|
|
exps = [0, 1, 2, 4,
|
|
torch.tensor((3, 3), dtype=dt, device=dev).random_(0, 5)]
|
|
for exp in exps:
|
|
self._test_pow(tensor, exp)
|
|
|
|
_test_integral_pow(torch.int8, (-3, 4), device)
|
|
_test_integral_pow(torch.uint8, (0, 4), device)
|
|
_test_integral_pow(torch.int16, (-5, 5), device)
|
|
_test_integral_pow(torch.int64, (-10, 10), device)
|
|
_test_integral_pow(torch.int32, (-10, 10), device)
|
|
|
|
def test_int_tensor_pow_neg_ints(self, device):
|
|
ints = [torch.iinfo(torch.int32).min,
|
|
-3, -2, -1, 0, 1, 2, 3,
|
|
torch.iinfo(torch.int32).max]
|
|
neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1]
|
|
tensor = torch.tensor(ints, dtype=torch.int32, device=device)
|
|
for pow in neg_ints:
|
|
self._test_pow(tensor, pow)
|
|
|
|
def test_long_tensor_pow_floats(self, device):
|
|
ints = [0, 1, 23, 4567]
|
|
floats = [0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0]
|
|
tensor = torch.tensor(ints, dtype=torch.int64, device=device)
|
|
for pow in floats:
|
|
self._test_pow(tensor, pow)
|
|
|
|
def test_float_scalar_pow_float_tensor(self, device):
|
|
floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0,
|
|
1 / 3, 1 / 2, 1.0, 3 / 2, 2.0]
|
|
tensor = torch.tensor(floats, dtype=torch.float32, device=device)
|
|
for base in floats:
|
|
self._test_pow(base, tensor)
|
|
|
|
@onlyCUDA
|
|
def test_cuda_tensor_pow_scalar_tensor(self, device):
|
|
cuda_tensors = [torch.randn((3, 3), device=device), torch.tensor(3.0, device=device)]
|
|
scalar_tensors = [torch.tensor(5.0, device='cpu'), torch.tensor(-3), torch.tensor(1)]
|
|
for base, exp in product(cuda_tensors, scalar_tensors):
|
|
self._test_pow(base, exp)
|
|
|
|
@onlyCUDA
|
|
def test_cpu_tensor_pow_cuda_scalar_tensor(self, device):
|
|
cuda_tensors = [torch.tensor(5.0, device='cuda'), torch.tensor(-3, device='cuda')]
|
|
for exp in cuda_tensors:
|
|
base = torch.randn((3, 3), device='cpu')
|
|
regex = 'Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!'
|
|
self.assertRaisesRegex(RuntimeError, regex, torch.pow, base, exp)
|
|
for exp in cuda_tensors:
|
|
# Binary ops with a cpu + cuda tensor are allowed if the cpu tensor has 0 dimension
|
|
base = torch.tensor(3.0, device='cpu')
|
|
self._test_pow(base, exp)
|
|
|
|
@onlyOnCPUAndCUDA
|
|
@dtypes(*(torch.testing.get_all_dtypes(include_bool=False, include_bfloat16=False)))
|
|
def test_complex_scalar_pow_tensor(self, device, dtype):
|
|
complexes = [0.5j, 1. + 1.j, -1.5j, 2.2 - 1.6j, 1 + 0j]
|
|
exp = make_tensor((100,), device, dtype, low=-2, high=2)
|
|
exp[0] = exp[10] = exp[20] = 0
|
|
for base in complexes:
|
|
self._test_pow(base, exp)
|
|
|
|
def test_tensor_pow_tensor(self, dev):
|
|
def rotate(l, n):
|
|
return l[-n:] + l[:-n]
|
|
|
|
def test_tensor_pow_tensor(values, torch_type, numpy_type):
|
|
vals_tensor = torch.tensor(values, dtype=torch_type, device=dev)
|
|
for i in range(len(values)):
|
|
pows = rotate(values, i)
|
|
pows_tensor = torch.tensor(pows, dtype=torch_type, device=dev)
|
|
self._test_pow(vals_tensor, pows_tensor)
|
|
|
|
ints = [0, 1, 2, 3]
|
|
test_tensor_pow_tensor(ints, torch.int32, np.int32)
|
|
test_tensor_pow_tensor(ints, torch.int64, np.int64)
|
|
|
|
floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3,
|
|
0.0,
|
|
1 / 3, 1 / 2, 1.0, 2.0, 3.0]
|
|
test_tensor_pow_tensor(floats, torch.float32, np.float32)
|
|
test_tensor_pow_tensor(floats, torch.float64, np.float64)
|
|
|
|
def test_logical_xor_with_nontrivial_alignment(self, device):
|
|
# test tensor that is not aligned to multiple of 16 bytes
|
|
size = 128
|
|
a = (torch.randn(size, device=device) > 0)
|
|
b = (torch.randn(size, device=device) > 0)
|
|
c = (torch.randn(size, device=device) > 0)
|
|
non_trivial_alignment = [1, 2, 4, 8, 15]
|
|
for i in non_trivial_alignment:
|
|
for j in non_trivial_alignment:
|
|
for k in non_trivial_alignment:
|
|
a_ = a[i: 100 + i]
|
|
b_ = b[j: 100 + j]
|
|
c_ = c[k: 100 + k]
|
|
torch.logical_xor(a_, b_, out=c_)
|
|
for x, y, z in zip(a_.tolist(), b_.tolist(), c_.tolist()):
|
|
self.assertEqual(x ^ y, z)
|
|
|
|
@dtypes(torch.float)
|
|
def test_add_with_tail(self, device, dtype):
|
|
# test tensor where there is a tail which is not a multiple
|
|
# of GPU warp size
|
|
for tail_size in [1, 63, 67, 130]:
|
|
size = 4096 + tail_size
|
|
a = torch.randn(size, device=device, dtype=dtype)
|
|
b = torch.randn(size, device=device, dtype=dtype)
|
|
c = a + b
|
|
for x, y, z in zip(a.tolist(), b.tolist(), c.tolist()):
|
|
self.assertEqual(x + y, z)
|
|
|
|
# Tests that CUDA tensors on different devices cannot be used in the same
|
|
# binary operation, and that CUDA "scalars" cannot be used in the same
|
|
# binary operation as non-scalar CPU tensors.
|
|
@deviceCountAtLeast(2)
|
|
@onlyCUDA
|
|
def test_cross_device_binary_ops(self, devices):
|
|
vals = (1., (2.,))
|
|
cpu_tensor = torch.randn(2, 2)
|
|
|
|
def do_test(op, a, b):
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
|
|
op(a, b)
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
|
|
op(b, a)
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
|
|
op(a, cpu_tensor)
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
|
|
op(cpu_tensor, a)
|
|
|
|
for op in (operator.add, torch.add,
|
|
operator.sub, torch.sub,
|
|
operator.mul, torch.mul,
|
|
operator.truediv, torch.true_divide,
|
|
operator.floordiv, torch.floor_divide):
|
|
for a, b in product(vals, vals):
|
|
a = torch.tensor(a, device=devices[0])
|
|
b = torch.tensor(b, device=devices[1])
|
|
|
|
if op in (operator.floordiv, torch.floor_divide):
|
|
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
|
|
do_test(op, a, b)
|
|
else:
|
|
do_test(op, a, b)
|
|
|
|
# This test ensures that a scalar Tensor can be safely used
|
|
# in a binary operation in conjunction with a Tensor on all
|
|
# available CUDA devices
|
|
@deviceCountAtLeast(2)
|
|
@onlyCUDA
|
|
def test_binary_op_scalar_device_unspecified(self, devices):
|
|
scalar_val = torch.tensor(1.)
|
|
for default_device in devices:
|
|
with torch.cuda.device(default_device):
|
|
for device in devices:
|
|
device_obj = torch.device(device)
|
|
x = torch.rand(3, device=device)
|
|
y0 = x * scalar_val
|
|
self.assertEqual(y0.device, device_obj)
|
|
y1 = scalar_val * x
|
|
self.assertEqual(y1.device, device_obj)
|
|
self.assertEqual(y0, y1)
|
|
|
|
def test_div_and_floordiv_vs_python(self, device):
|
|
# Tests torch division ops which can handle both arguments being
|
|
# scalars.
|
|
# NOTE: torch.floor_divide currently truncates instead of flooring.
|
|
# the quotient. See https://github.com/pytorch/pytorch/issues/43874.
|
|
def _scalar_helper(python_op, torch_op):
|
|
for a, b in product(range(-10, 10), range(-10, 10)):
|
|
for op in (lambda x: x * .5, lambda x: math.floor(x)):
|
|
a = op(a)
|
|
b = op(b)
|
|
|
|
# Skips zero divisors
|
|
if b == 0:
|
|
continue
|
|
|
|
expected = python_op(a, b)
|
|
|
|
for op in (operator.truediv, torch.true_divide):
|
|
actual_scalar = torch_op(a, b)
|
|
|
|
a_t = torch.tensor(a, device=device)
|
|
b_t = torch.tensor(b, device=device)
|
|
|
|
actual_tensor = torch_op(a_t, b_t)
|
|
actual_first_tensor = torch_op(a_t, b)
|
|
actual_second_tensor = torch_op(a, b_t)
|
|
|
|
self.assertEqual(actual_scalar, expected_div)
|
|
self.assertEqual(actual_tensor.item(), expected_div)
|
|
self.assertEqual(actual_first_tensor, actual_tensor)
|
|
self.assertEqual(actual_second_tensor, actual_tensor)
|
|
|
|
_scalar_helper(operator.truediv, operator.truediv)
|
|
_scalar_helper(operator.truediv, torch.true_divide)
|
|
with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'):
|
|
_scalar_helper(lambda a, b: math.trunc(a / b), operator.floordiv)
|
|
_scalar_helper(lambda a, b: math.trunc(a / b), torch.floor_divide)
|
|
|
|
# NOTE: torch.floor_divide currently truncates instead of flooring.
|
|
# See https://github.com/pytorch/pytorch/issues/43874.
|
|
@onlyOnCPUAndCUDA
|
|
def test_div_and_floordiv_script_vs_python(self, device):
|
|
# Creates jitted functions of two tensors
|
|
def _wrapped_div(a, b):
|
|
return a / b
|
|
|
|
def _wrapped_floordiv(a, b):
|
|
return a // b
|
|
|
|
scripted_div = torch.jit.script(_wrapped_div)
|
|
scripted_floordiv = torch.jit.script(_wrapped_floordiv)
|
|
for a, b in product(range(-10, 10), range(-10, 10)):
|
|
for op in (lambda x: x * .5, lambda x: math.floor(x)):
|
|
a = op(a)
|
|
b = op(b)
|
|
|
|
# Skips zero divisors
|
|
if b == 0:
|
|
continue
|
|
|
|
expected_div = a / b
|
|
expected_truncdiv = math.trunc(a / b)
|
|
a_t = torch.tensor(a, device=device)
|
|
b_t = torch.tensor(b, device=device)
|
|
|
|
self.assertEqual(scripted_div(a_t, b_t), expected_div)
|
|
with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'):
|
|
self.assertEqual(scripted_floordiv(a_t, b_t), expected_truncdiv)
|
|
|
|
# Creates jitted functions of one tensor
|
|
def _wrapped_div_scalar(a):
|
|
return a / 5
|
|
|
|
# NOTE: the JIT implements division as torch.reciprocal(a) * 5
|
|
def _wrapped_rdiv_scalar(a):
|
|
return 5 / a
|
|
|
|
def _wrapped_floordiv_scalar(a):
|
|
return a // 5
|
|
|
|
# NOTE: this fails if the input is not an integer tensor
|
|
# See https://github.com/pytorch/pytorch/issues/45199
|
|
def _wrapped_rfloordiv_scalar(a):
|
|
return 5 // a
|
|
|
|
scripted_div_scalar = torch.jit.script(_wrapped_div_scalar)
|
|
scripted_rdiv_scalar = torch.jit.script(_wrapped_rdiv_scalar)
|
|
scripted_floordiv_scalar = torch.jit.script(_wrapped_floordiv_scalar)
|
|
scripted_rfloordiv_scalar = torch.jit.script(_wrapped_rfloordiv_scalar)
|
|
|
|
for a in range(-10, 10):
|
|
for op in (lambda x: x * .5, lambda x: math.floor(x)):
|
|
a = op(a)
|
|
|
|
a_t = torch.tensor(a, device=device)
|
|
|
|
self.assertEqual(a / 5, scripted_div_scalar(a_t))
|
|
with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'):
|
|
self.assertEqual(math.trunc(a / 5), scripted_floordiv_scalar(a_t))
|
|
|
|
# Skips zero divisors
|
|
if a == 0:
|
|
continue
|
|
|
|
self.assertEqual(5 / a, scripted_rdiv_scalar(a_t))
|
|
|
|
# Handles Issue 45199 (see comment above)
|
|
if a_t.is_floating_point():
|
|
with self.assertRaises(RuntimeError):
|
|
scripted_rfloordiv_scalar(a_t)
|
|
else:
|
|
# This should emit a UserWarning, why doesn't it?
|
|
# See issue gh-52387
|
|
self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))
|
|
|
|
# NOTE: torch.floor_divide currently truncates instead of flooring
|
|
# the quotient. See https://github.com/pytorch/pytorch/issues/43874.
|
|
@onlyOnCPUAndCUDA
|
|
def test_idiv_and_ifloordiv_vs_python(self, device):
|
|
def _wrapped_idiv_tensor(a, b):
|
|
a /= b
|
|
return a
|
|
|
|
def _wrapped_idiv_scalar(a):
|
|
a /= 5
|
|
return a
|
|
|
|
def _wrapped_true_divide__tensor(a, b):
|
|
a.true_divide_(b)
|
|
return a
|
|
|
|
def _wrapped_true_divide__scalar(a):
|
|
a.true_divide_(5)
|
|
return a
|
|
|
|
def _wrapped_floor_divide__tensor(a, b):
|
|
a.floor_divide_(b)
|
|
return a
|
|
|
|
def _wrapped_floor_divide__scalar(a):
|
|
a.floor_divide_(5)
|
|
return a
|
|
|
|
# The following functions are unsupported by the JIT
|
|
def _wrapped_ifloordiv_tensor(a, b):
|
|
a //= b
|
|
return a
|
|
|
|
def _wrapped_ifloordiv_scalar(a):
|
|
a //= 5
|
|
return a
|
|
|
|
with self.assertRaises(torch.jit.frontend.NotSupportedError):
|
|
scripted_ifloordiv_tensor = torch.jit.script(_wrapped_ifloordiv_tensor)
|
|
|
|
with self.assertRaises(torch.jit.frontend.NotSupportedError):
|
|
scripted_ifloordiv_scalar = torch.jit.script(_wrapped_ifloordiv_scalar)
|
|
|
|
scripted_idiv_tensor = torch.jit.script(_wrapped_idiv_tensor)
|
|
scripted_idiv_scalar = torch.jit.script(_wrapped_idiv_scalar)
|
|
scripted_true_divide__tensor = torch.jit.script(_wrapped_true_divide__tensor)
|
|
scripted_true_divide__scalar = torch.jit.script(_wrapped_true_divide__scalar)
|
|
scripted_floor_divide__tensor = torch.jit.script(_wrapped_floor_divide__tensor)
|
|
scripted_floor_divide__scalar = torch.jit.script(_wrapped_floor_divide__scalar)
|
|
|
|
for a, b in product(range(-10, 10), range(-10, 10)):
|
|
for op in (lambda x: x * .5, lambda x: math.floor(x)):
|
|
a = op(a)
|
|
b = op(b)
|
|
|
|
# Skips zero divisors
|
|
if b == 0:
|
|
continue
|
|
|
|
expected_idiv = a / b
|
|
expected_ifloordiv = a // b
|
|
expected_itruncdiv = math.trunc(a / b)
|
|
|
|
a_t = torch.tensor(a, device=device)
|
|
b_t = torch.tensor(b, device=device)
|
|
|
|
if a_t.is_floating_point():
|
|
tmp0 = a_t.clone()
|
|
tmp0 /= b
|
|
|
|
tmp1 = a_t.clone()
|
|
tmp1 /= b_t
|
|
|
|
self.assertEqual(tmp0.item(), expected_idiv)
|
|
self.assertEqual(tmp1.item(), expected_idiv)
|
|
self.assertEqual(scripted_true_divide__tensor(a_t.clone(), b_t).item(), expected_idiv)
|
|
self.assertEqual(scripted_true_divide__scalar(a_t.clone()).item(), a / 5)
|
|
else:
|
|
tmp = a_t.clone()
|
|
with self.assertRaises(RuntimeError):
|
|
tmp /= b
|
|
with self.assertRaises(RuntimeError):
|
|
tmp /= b_t
|
|
with self.assertRaises(RuntimeError):
|
|
scripted_true_divide__tensor(tmp, b_t)
|
|
with self.assertRaises(RuntimeError):
|
|
scripted_true_divide__scalar(tmp)
|
|
|
|
|
|
if not a_t.is_floating_point() and b_t.is_floating_point():
|
|
# Inplace modification fails because a float tensor is required
|
|
# if the divisor is a float tensor
|
|
with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
|
|
a_t.clone().floor_divide_(b_t)
|
|
with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
|
|
scripted_floor_divide_tensor(a_t.clone(), b_t)
|
|
tmp = a_t.clone()
|
|
with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
|
|
tmp //= b_t
|
|
else:
|
|
# Inplace modification is OK when both or neither tensor is
|
|
# a float tensor
|
|
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
|
|
self.assertEqual(a_t.clone().floor_divide_(b_t).item(), expected_itruncdiv)
|
|
self.assertEqual(scripted_floor_divide__tensor(a_t.clone(), b_t).item(), expected_itruncdiv)
|
|
tmp = a_t.clone()
|
|
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
|
|
tmp //= b_t
|
|
self.assertEqual(tmp.item(), expected_itruncdiv)
|
|
|
|
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
|
|
self.assertEqual(scripted_floor_divide__scalar(a_t), math.trunc(a / 5))
|
|
|
|
# Tests binary op equivalence with Python builtin ops
|
|
# Also tests that reverse operations are equivalent to forward ops
|
|
# NOTE: division ops are tested separately above
|
|
def test_binary_ops_with_scalars(self, device):
|
|
for ops in ((operator.add, torch.add),
|
|
(operator.sub, torch.sub),
|
|
(operator.mul, torch.mul),
|
|
(operator.truediv, torch.div)):
|
|
python_op, torch_op = ops
|
|
|
|
for a, b in product(range(-10, 10), range(-10, 10)):
|
|
for op in (lambda x: x * .5, lambda x: math.floor(x)):
|
|
a = op(a)
|
|
b = op(b)
|
|
|
|
# Skips zero divisors
|
|
if b == 0 or a == 0:
|
|
continue
|
|
|
|
a_tensor = torch.tensor(a, device=device)
|
|
b_tensor = torch.tensor(b, device=device)
|
|
a_tensor_cpu = a_tensor.cpu()
|
|
b_tensor_cpu = b_tensor.cpu()
|
|
vals = (a, b, a_tensor, b_tensor, a_tensor_cpu, b_tensor_cpu)
|
|
|
|
for args in product(vals, vals):
|
|
first, second = args
|
|
|
|
first_scalar = first if not isinstance(first, torch.Tensor) else first.item()
|
|
second_scalar = second if not isinstance(second, torch.Tensor) else second.item()
|
|
expected = python_op(first_scalar, second_scalar)
|
|
|
|
self.assertEqual(expected, python_op(first, second))
|
|
self.assertEqual(expected, torch_op(first, second))
|
|
|
|
@dtypes(*product(torch.testing.get_all_dtypes(include_complex=False), torch.testing.get_all_dtypes(include_complex=False)))
|
|
def test_maximum_minimum_type_promotion(self, device, dtypes):
|
|
a = torch.tensor((0, 1), device=device, dtype=dtypes[0])
|
|
b = torch.tensor((1, 0), device=device, dtype=dtypes[1])
|
|
for op in (torch.maximum, torch.max, torch.fmax, torch.minimum, torch.min, torch.fmin):
|
|
result = op(a, b)
|
|
self.assertEqual(result.dtype, torch.result_type(a, b))
|
|
|
|
@dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool]))
|
|
def test_maximum_minimum_int_and_bool(self, device, dtype):
|
|
ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum),
|
|
(torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin))
|
|
rng = np.random.default_rng()
|
|
a_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype])
|
|
b_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype])
|
|
|
|
for torch_op, alias, numpy_op in ops:
|
|
a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
|
|
b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
|
|
tensor_result = torch_op(a_tensor, b_tensor)
|
|
|
|
out = torch.empty_like(a_tensor)
|
|
torch_op(a_tensor, b_tensor, out=out)
|
|
|
|
numpy_result = numpy_op(a_np, b_np)
|
|
|
|
if alias is not None:
|
|
alias_result = alias(a_tensor, b_tensor)
|
|
self.assertEqual(alias_result, tensor_result)
|
|
|
|
self.assertEqual(tensor_result, numpy_result)
|
|
self.assertEqual(out, numpy_result)
|
|
|
|
@precisionOverride({torch.bfloat16: 1e-2})
|
|
@dtypes(*(torch.testing.get_all_fp_dtypes()))
|
|
def test_maximum_minimum_float(self, device, dtype):
|
|
ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum),
|
|
(torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin))
|
|
|
|
if dtype == torch.bfloat16:
|
|
a_np = np.random.randn(10).astype(np.float64)
|
|
b_np = np.random.randn(10).astype(np.float64)
|
|
else:
|
|
a_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype])
|
|
b_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype])
|
|
|
|
for torch_op, alias, numpy_op in ops:
|
|
numpy_result = numpy_op(a_np, b_np)
|
|
|
|
a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
|
|
b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
|
|
tensor_result = torch_op(a_tensor, b_tensor)
|
|
out = torch.empty_like(a_tensor)
|
|
torch_op(a_tensor, b_tensor, out=out)
|
|
|
|
if alias is not None:
|
|
alias_result = alias(a_tensor, b_tensor)
|
|
self.assertEqual(alias_result, tensor_result)
|
|
|
|
self.assertEqual(tensor_result, numpy_result)
|
|
self.assertEqual(out, numpy_result)
|
|
|
|
@dtypes(*(torch.testing.get_all_fp_dtypes()))
|
|
def test_maximum_minimum_float_nan_and_inf(self, device, dtype):
|
|
# np.maximum and np.minimum functions compare input arrays element-wisely.
|
|
# if one of the elements being compared is a NaN, then that element is returned.
|
|
ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum),
|
|
(torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin))
|
|
a_vals = (float('inf'), -float('inf'), float('nan'), float('inf'), float('nan'), float('nan'), 1, float('nan'))
|
|
b_vals = (-float('inf'), float('inf'), float('inf'), float('nan'), float('nan'), 0, float('nan'), -5)
|
|
if dtype == torch.bfloat16:
|
|
a_np = np.array(a_vals, dtype=np.float64)
|
|
b_np = np.array(b_vals, dtype=np.float64)
|
|
else:
|
|
a_np = np.array(a_vals, dtype=torch_to_numpy_dtype_dict[dtype])
|
|
b_np = np.array(b_vals, dtype=torch_to_numpy_dtype_dict[dtype])
|
|
|
|
for torch_op, alias, numpy_op in ops:
|
|
numpy_result = numpy_op(a_np, b_np)
|
|
|
|
a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
|
|
b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
|
|
tensor_result = torch_op(a_tensor, b_tensor)
|
|
|
|
out = torch.empty_like(a_tensor)
|
|
torch_op(a_tensor, b_tensor, out=out)
|
|
|
|
if alias is not None:
|
|
alias_result = alias(a_tensor, b_tensor)
|
|
self.assertEqual(alias_result, tensor_result)
|
|
|
|
if dtype == torch.bfloat16:
|
|
self.assertEqual(tensor_result, numpy_result, exact_dtype=False)
|
|
self.assertEqual(out, numpy_result, exact_dtype=False)
|
|
else:
|
|
self.assertEqual(tensor_result, numpy_result)
|
|
self.assertEqual(out, numpy_result)
|
|
|
|
@dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes()))
|
|
def test_maximum_minimum_complex(self, device, dtypes):
|
|
for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min, torch.fmax, torch.fmin):
|
|
with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'):
|
|
torch_op(torch.ones(1, device=device, dtype=dtypes[0]),
|
|
torch.ones(1, device=device, dtype=dtypes[1]))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'):
|
|
torch_op(torch.ones(1, device=device, dtype=dtypes[1]),
|
|
torch.ones(1, device=device, dtype=dtypes[0]))
|
|
|
|
@onlyCUDA
|
|
def test_maximum_minimum_cross_device(self, device):
|
|
a = torch.tensor((1, 2, -1))
|
|
b = torch.tensor((3, 0, 4), device=device)
|
|
ops = (torch.maximum, torch.minimum)
|
|
|
|
for torch_op in ops:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"Expected all tensors to be on the same device"):
|
|
torch_op(a, b)
|
|
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"Expected all tensors to be on the same device"):
|
|
torch_op(b, a)
|
|
|
|
# test cuda tensor and cpu scalar
|
|
ops = ((torch.maximum, np.maximum), (torch.minimum, np.minimum))
|
|
a_np = np.array(1)
|
|
b_np = np.array([3, 0, 4])
|
|
|
|
for torch_op, numpy_op in ops:
|
|
a_tensor = torch.from_numpy(a_np)
|
|
b_tensor = torch.from_numpy(b_np).to(device=device)
|
|
tensor_result_1 = torch_op(a_tensor, b_tensor)
|
|
numpy_result_1 = numpy_op(a_np, b_np)
|
|
tensor_result_2 = torch_op(b_tensor, a_tensor)
|
|
numpy_result_2 = numpy_op(b_np, a_np)
|
|
|
|
self.assertEqual(tensor_result_1, numpy_result_1)
|
|
self.assertEqual(tensor_result_2, numpy_result_2)
|
|
|
|
# TODO: tests like this should be generic
|
|
@dtypesIfCUDA(torch.half, torch.float, torch.double)
|
|
@dtypes(torch.float, torch.double)
|
|
def test_mul_intertype_scalar(self, device, dtype):
|
|
x = torch.tensor(1.5, dtype=dtype, device=device)
|
|
y = torch.tensor(3, dtype=torch.int32, device=device)
|
|
|
|
self.assertEqual(x * y, 4.5)
|
|
self.assertEqual(y * x, 4.5)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
|
|
y *= x
|
|
x *= y
|
|
self.assertEqual(x, 4.5)
|
|
|
|
@onlyCPU
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_sub(self, device, dtype):
|
|
m1 = torch.tensor([2.34, 4.44], dtype=dtype, device=device)
|
|
m2 = torch.tensor([1.23, 2.33], dtype=dtype, device=device)
|
|
|
|
if dtype == torch.bool:
|
|
self.assertRaises(RuntimeError, lambda: m1 - m2)
|
|
elif (dtype == torch.bfloat16 or dtype == torch.half):
|
|
# bfloat16 has a lower precision so we have to have a separate check for it
|
|
self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype), atol=0.01, rtol=0)
|
|
else:
|
|
self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype))
|
|
|
|
# TODO: what is this test testing?
|
|
@onlyCPU
|
|
@dtypes(torch.float)
|
|
def test_csub(self, device, dtype):
|
|
# with a tensor
|
|
a = torch.randn(100, 90, dtype=dtype, device=device)
|
|
b = a.clone().normal_()
|
|
|
|
res_add = torch.add(a, b, alpha=-1)
|
|
res_csub = a.clone()
|
|
res_csub.sub_(b)
|
|
self.assertEqual(res_add, res_csub)
|
|
|
|
# with a scalar
|
|
a = torch.randn(100, 100, dtype=dtype, device=device)
|
|
|
|
scalar = 123.5
|
|
res_add = torch.add(a, -scalar)
|
|
res_csub = a.clone()
|
|
res_csub.sub_(scalar)
|
|
self.assertEqual(res_add, res_csub)
|
|
|
|
# TODO: reconcile with minimum/maximum tests
|
|
@dtypesIfCUDA(torch.half, torch.float, torch.double)
|
|
@dtypes(torch.float, torch.double)
|
|
def test_min_max_binary_op_nan(self, device, dtype):
|
|
a = torch.rand(1000, dtype=dtype, device=device)
|
|
b = torch.rand(1000, dtype=dtype, device=device)
|
|
|
|
# 0:250: a -- nan, b -- not nan
|
|
a[:250] = float('nan')
|
|
# 250:500: a -- not nan, b -- nan
|
|
b[250:500] = float('nan')
|
|
# 500:750: a and b both nan
|
|
a[500:750] = float('nan')
|
|
b[500:750] = float('nan')
|
|
# 750:1000: neither nan
|
|
|
|
ma = torch.max(a, b)
|
|
mi = torch.min(a, b)
|
|
|
|
for i in range(750):
|
|
self.assertTrue(torch.isnan(ma[i]), "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i]))
|
|
self.assertTrue(torch.isnan(mi[i]), "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i]))
|
|
|
|
for i in range(750, 1000):
|
|
self.assertFalse(torch.isnan(ma[i]), "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i]))
|
|
self.assertFalse(torch.isnan(mi[i]), "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i]))
|
|
|
|
@dtypes(*product(torch.testing.get_all_dtypes(include_complex=False),
|
|
torch.testing.get_all_dtypes(include_complex=False)))
|
|
def test_copysign(self, device, dtypes):
|
|
def _test_copysign_numpy(a, b):
|
|
torch_result = torch.copysign(a, b)
|
|
|
|
if a.dtype == torch.bfloat16:
|
|
np_a = a.to(torch.float).cpu().numpy()
|
|
else:
|
|
np_a = a.cpu().numpy()
|
|
|
|
if b.dtype == torch.bfloat16:
|
|
np_b = b.to(torch.float).cpu().numpy()
|
|
else:
|
|
np_b = b.cpu().numpy()
|
|
expected = torch.from_numpy(np.copysign(np_a, np_b))
|
|
# To handle inconsistencies of type promotion between PyTorch and Numpy
|
|
# Applied for both arguments having integral precision and bfloat16
|
|
types = [torch.bool, torch.bfloat16] + torch.testing.get_all_int_dtypes()
|
|
if a.dtype in types or b.dtype in types:
|
|
promoted_type = torch.promote_types(torch_result.dtype, expected.dtype)
|
|
torch_result = torch_result.to(promoted_type)
|
|
expected = expected.to(promoted_type)
|
|
|
|
# Verify Value
|
|
self.assertEqual(torch_result, expected)
|
|
# Verify Sign
|
|
# Use double copysign to verify the correctnes of 0.0 and -0.0, since
|
|
# it always True for self.assertEqual(0.0 == -0.0). So, we use 1 as the
|
|
# magnitude to verify the sign between torch and numpy results, elementwise.
|
|
# Special case: NaN conversions between FP32 and FP16 is not bitwise
|
|
# equivalent to pass this assertion.
|
|
if a.dtype != torch.float16 and b.dtype != torch.float16:
|
|
self.assertEqual(torch.copysign(torch.tensor(1.0), torch_result),
|
|
torch.copysign(torch.tensor(1.0), expected))
|
|
|
|
# Compare Result with NumPy
|
|
# Type promotion
|
|
a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
|
|
b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
|
|
_test_copysign_numpy(a, b)
|
|
|
|
# Broadcast
|
|
a = make_tensor((10, 1, 10), device=device, dtype=dtypes[0], low=-9, high=9)
|
|
b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
|
|
_test_copysign_numpy(a, b)
|
|
|
|
a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
|
|
b = make_tensor((10, 1, 10), device=device, dtype=dtypes[1], low=-9, high=9)
|
|
_test_copysign_numpy(a, b)
|
|
|
|
# 0.0/-0.0/inf/-inf/nan
|
|
cases = [0.0, -0.0, float('inf'), float('-inf'), float('nan')]
|
|
# torch.bfloat16 can not hold '-nan'
|
|
# torch.half can not hold '-nan' on CUDA
|
|
types = [torch.float32, torch.float64]
|
|
if device == 'cpu':
|
|
types.append(torch.float16)
|
|
if dtypes[0] in types:
|
|
b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
|
|
for case in cases:
|
|
_test_copysign_numpy(torch.tensor([case], device=device, dtype=dtypes[0]), b)
|
|
|
|
if dtypes[1] in torch.testing.get_all_fp_dtypes():
|
|
a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
|
|
for case in cases:
|
|
_test_copysign_numpy(a, torch.tensor([case], device=device, dtype=dtypes[1]))
|
|
|
|
@dtypes(torch.bfloat16, torch.float)
|
|
def test_div(self, device, dtype):
|
|
for op, method, inplace in ((torch.div, torch.Tensor.div, torch.Tensor.div_),
|
|
(torch.true_divide, torch.Tensor.true_divide,
|
|
torch.Tensor.true_divide_)):
|
|
m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype)
|
|
res1 = m1.clone()
|
|
inplace(res1[:, 3], 2)
|
|
res2 = m1.clone()
|
|
for i in range(m1.size(0)):
|
|
res2[i, 3] = res2[i, 3] / 2
|
|
self.assertEqual(res1, res2)
|
|
|
|
if dtype == torch.bfloat16:
|
|
a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
|
|
a2 = torch.tensor([2., 2.], dtype=dtype, device=device)
|
|
self.assertEqual(op(a1, a2),
|
|
torch.tensor([2.1, 3.1], dtype=dtype, device=device),
|
|
atol=0.01, rtol=0)
|
|
self.assertEqual(method(a1, a2), op(a1, a2))
|
|
|
|
@dtypes(torch.bfloat16, torch.float)
|
|
def test_true_divide_out(self, device, dtype):
|
|
a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
|
|
a2 = torch.tensor([2., 2.], dtype=dtype, device=device)
|
|
res = torch.empty_like(a1)
|
|
self.assertEqual(torch.true_divide(a1, a2, out=res),
|
|
torch.tensor([2.1, 3.1], dtype=dtype, device=device),
|
|
atol=0.01, rtol=0)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.half)
|
|
def test_divmul_scalar(self, device, dtype):
|
|
x = torch.tensor(100., device=device, dtype=dtype)
|
|
x_ref = x.float()
|
|
scale = 1e5
|
|
res = x.div(scale)
|
|
expected = x_ref.div(scale)
|
|
self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.)
|
|
x = torch.tensor(1e-5, device=device, dtype=dtype)
|
|
x_ref = x.float()
|
|
res = x.mul(scale)
|
|
expected = x_ref.mul(scale)
|
|
self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.)
|
|
res = scale * x
|
|
self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.)
|
|
|
|
@dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128})
|
|
@dtypes(*set(torch.testing.get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128})
|
|
def test_floor_divide_tensor(self, device, dtype):
|
|
x = torch.randn(10, device=device).mul(30).to(dtype)
|
|
y = torch.arange(1, 11, dtype=dtype, device=device)
|
|
|
|
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
|
|
z = x // y
|
|
z_alt = torch.trunc(x.double() / y.double()).to(dtype)
|
|
|
|
self.assertEqual(z.dtype, x.dtype)
|
|
self.assertEqual(z, z_alt)
|
|
|
|
@dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128})
|
|
@dtypes(*set(torch.testing.get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128})
|
|
def test_floor_divide_scalar(self, device, dtype):
|
|
x = torch.randn(100, device=device).mul(10).to(dtype)
|
|
|
|
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
|
|
z = x // 3
|
|
z_alt = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=x.dtype, device=device)
|
|
|
|
self.assertEqual(z.dtype, x.dtype)
|
|
self.assertEqual(z, z_alt)
|
|
|
|
# Note: this tests fails on XLA
|
|
@onlyOnCPUAndCUDA
|
|
@dtypes(torch.float, torch.long)
|
|
def test_floor_divide_out(self, device, dtype):
|
|
x = torch.randn(10, device=device).mul(10).to(dtype)
|
|
y = torch.arange(1, 11, dtype=dtype, device=device)
|
|
o = torch.empty(10, dtype=dtype, device=device)
|
|
|
|
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
|
|
torch.floor_divide(x, y, out=o)
|
|
self.assertEqual(o, x // y)
|
|
|
|
# Tests scalar with out
|
|
torch.floor_divide(x, 2, out=o)
|
|
self.assertEqual(o, x // 2)
|
|
|
|
if dtype == torch.int:
|
|
o = torch.empty(10, dtype=torch.float, device=device)
|
|
torch.floor_divide(x, y, out=o)
|
|
self.assertEqual(o, torch.floor_divide(x.float(), y.float()))
|
|
|
|
@onlyCPU
|
|
@dtypes(*torch.testing.get_all_math_dtypes('cpu'))
|
|
def test_rdiv(self, device, dtype):
|
|
if dtype is torch.float16:
|
|
return
|
|
elif dtype.is_complex:
|
|
x = torch.rand(100, dtype=dtype, device=device).add(1).mul(4)
|
|
else:
|
|
x = torch.rand(100, device=device).add(1).mul(4).to(dtype)
|
|
y = 30 / x
|
|
z = torch.tensor([30 / v.item() for v in x], device=device)
|
|
self.assertEqual(y, z, exact_dtype=False)
|
|
|
|
@dtypes(*torch.testing.get_all_fp_dtypes(include_bfloat16=False))
|
|
def test_fmod_remainder_by_zero_float(self, device, dtype):
|
|
fn_list = (torch.fmod, torch.remainder)
|
|
for fn in fn_list:
|
|
# check floating-point tensor fmod/remainder to zero is nan on both CPU and GPU
|
|
x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
|
|
zero = torch.zeros_like(x)
|
|
self.assertTrue(torch.all(fn(x, 0.0).isnan()))
|
|
self.assertTrue(torch.all(fn(x, zero).isnan()))
|
|
|
|
@onlyOnCPUAndCUDA # Check Issue https://github.com/pytorch/pytorch/issues/48130
|
|
@skipCUDAIfRocm # Error happens on both ROCM and XLA
|
|
@dtypes(*torch.testing.get_all_int_dtypes())
|
|
def test_fmod_remainder_by_zero_integral(self, device, dtype):
|
|
fn_list = (torch.fmod, torch.remainder)
|
|
for fn in fn_list:
|
|
# check integral tensor fmod/remainder to zero
|
|
x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
|
|
zero = torch.zeros_like(x)
|
|
# RuntimeError on CPU
|
|
if self.device_type == 'cpu':
|
|
with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"):
|
|
fn(x, zero)
|
|
# Different value for different dtype on CUDA:
|
|
# Due to it's an undefined behavior, CUDA returns a pattern of all 1s
|
|
# for integral dividend (other than int64) divided by zero. For int64,
|
|
# CUDA returns all 1s for negative dividend, half 1s for positive dividend.
|
|
# uint8: 0xff -> 255
|
|
# int32: 0xffffffff -> -1
|
|
else:
|
|
if dtype == torch.int64:
|
|
self.assertEqual(fn(x, zero) == 4294967295, x >= 0)
|
|
self.assertEqual(fn(x, zero) == -1, x < 0)
|
|
else:
|
|
value = 255 if dtype == torch.uint8 else -1
|
|
self.assertTrue(torch.all(fn(x, zero) == value))
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False))
|
|
def test_fmod_remainder(self, device, dtype):
|
|
# Use numpy as reference
|
|
def _helper(x, mod):
|
|
fns_list = ((torch.fmod, torch.Tensor.fmod_, np.fmod),
|
|
(torch.remainder, torch.Tensor.remainder_, np.remainder))
|
|
for fn, inplace_fn, ref_fn in fns_list:
|
|
np_x = x.cpu().numpy()
|
|
np_mod = mod.cpu().numpy() if torch.is_tensor(mod) else mod
|
|
exp = ref_fn(np_x, np_mod)
|
|
exp = torch.from_numpy(exp)
|
|
res = fn(x, mod)
|
|
|
|
self.assertEqual(res, exp, exact_dtype=False)
|
|
# out
|
|
out = torch.empty(0, device=device, dtype=res.dtype)
|
|
fn(x, mod, out=out)
|
|
self.assertEqual(out, exp, exact_dtype=False)
|
|
self.assertEqual(out.size(), torch.Size([10, 10]))
|
|
# in-place (Type cast runtime error)
|
|
try:
|
|
inplace_fn(x, mod)
|
|
self.assertEqual(x, exp, exact_dtype=False)
|
|
except RuntimeError as e:
|
|
self.assertRegex(str(e), "result type (Half|Float|Double) "
|
|
"can't be cast to the desired output "
|
|
"type (Byte|Char|Short|Int|Long)")
|
|
|
|
x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
|
|
# mod with same dtype as x
|
|
mod = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
|
|
# Exclude 0
|
|
mod[mod == 0] = 1
|
|
|
|
# Mods: Integer, Float, Tensor, Non-contiguous Tensor
|
|
mods = [3, 2.3, mod, mod.t()]
|
|
# mod with floating-point dtype
|
|
if dtype in torch.testing.get_all_int_dtypes():
|
|
mod_float = make_tensor((10, 10), device=device, dtype=torch.float, low=-9, high=9)
|
|
mod[mod == 0] = 1
|
|
mods.append(mod_float)
|
|
|
|
for dividend, mod in product([x, x.t()], mods):
|
|
_helper(dividend, mod)
|
|
|
|
@dtypes(torch.float, torch.double)
|
|
def test_remainder_fmod_large_dividend(self, device, dtype):
|
|
alarge = 1e9
|
|
pi = 3.14159265358979
|
|
for avalue in [alarge, -alarge]:
|
|
for bvalue in [pi, -pi]:
|
|
a = torch.tensor([avalue], dtype=dtype, device=device)
|
|
b = torch.tensor([bvalue], dtype=dtype, device=device)
|
|
c = torch.remainder(a, b)
|
|
d = torch.fmod(a, b)
|
|
self.assertTrue((b[0] > 0) == (c[0] > 0)) # remainder has same sign as divisor
|
|
self.assertTrue((a[0] > 0) == (d[0] > 0)) # fmod has same sign as dividend
|
|
self.assertTrue(abs(c[0]) < abs(b[0])) # remainder is within range of divisor
|
|
self.assertTrue(abs(d[0]) < abs(b[0])) # fmod is within range of divisor
|
|
if ((a[0] > 0) == (b[0] > 0)):
|
|
self.assertTrue(c[0] == d[0]) # remainder is same as fmod
|
|
else:
|
|
self.assertTrue(abs(c[0] - d[0]) == abs(b[0])) # differ by one divisor
|
|
|
|
@dtypesIfCPU(torch.bfloat16, torch.float32, torch.float64)
|
|
@dtypes(torch.float32, torch.float64)
|
|
def test_hypot(self, device, dtype):
|
|
inputs = [
|
|
(torch.randn(10, device=device).to(dtype), torch.randn(10, device=device).to(dtype)),
|
|
(torch.randn((3, 3, 3), device=device).to(dtype), torch.randn((3, 3, 3), device=device).to(dtype)),
|
|
(torch.randn((10, 1), device=device).to(dtype), torch.randn((10, 1), device=device).to(dtype).transpose(0, 1)),
|
|
(torch.randint(100, (10, ), device=device, dtype=torch.long), torch.randn(10, device=device).to(dtype))
|
|
]
|
|
for input in inputs:
|
|
actual = torch.hypot(input[0], input[1])
|
|
if dtype == torch.bfloat16:
|
|
expected = torch.sqrt(input[0] * input[0] + input[1] * input[1])
|
|
else:
|
|
expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy())
|
|
self.assertEqual(actual, expected)
|
|
|
|
@onlyOnCPUAndCUDA
|
|
@dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
|
|
def test_gcd(self, device, dtype):
|
|
# Tests gcd(0, 0), gcd(0, a) cases
|
|
t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device)
|
|
t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device)
|
|
actual = torch.gcd(t1, t2)
|
|
expected = np.gcd([0, 10, 0], [0, 0, 10])
|
|
self.assertEqual(actual, expected)
|
|
|
|
if dtype == torch.uint8:
|
|
# Test unsigned integers with potential sign issues (i.e., uint8 with value >= 128)
|
|
a = torch.tensor([190, 210], device=device, dtype=dtype)
|
|
b = torch.tensor([190, 220], device=device, dtype=dtype)
|
|
actual = torch.gcd(a, b)
|
|
expected = torch.tensor([190, 10], device=device, dtype=dtype)
|
|
else:
|
|
# Compares with NumPy
|
|
a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
|
|
b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
|
|
actual = torch.gcd(a, b)
|
|
expected = np.gcd(a.cpu().numpy(), b.cpu().numpy())
|
|
self.assertEqual(actual, expected)
|
|
|
|
@onlyOnCPUAndCUDA
|
|
@dtypes(torch.int16, torch.int32, torch.int64)
|
|
def test_lcm(self, device, dtype):
|
|
# Tests lcm(0, 0), lcm(0, a) cases
|
|
t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device)
|
|
t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device)
|
|
actual = torch.lcm(t1, t2)
|
|
expected = np.lcm([0, 10, 0], [0, 0, 10])
|
|
self.assertEqual(actual, expected)
|
|
|
|
# Compares with NumPy
|
|
a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
|
|
b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype)
|
|
actual = torch.lcm(a, b)
|
|
expected = np.lcm(a.cpu().numpy(), b.cpu().numpy())
|
|
self.assertEqual(actual, expected)
|
|
|
|
@onlyOnCPUAndCUDA
|
|
@dtypes(torch.float32, torch.float64)
|
|
def test_nextafter(self, device, dtype):
|
|
# Test special cases
|
|
t1 = torch.tensor([0, 0, 10], device=device, dtype=dtype)
|
|
t2 = torch.tensor([inf, -inf, 10], device=device, dtype=dtype)
|
|
actual = torch.nextafter(t1, t2)
|
|
expected = np.nextafter(t1.cpu().numpy(), t2.cpu().numpy())
|
|
self.assertEqual(actual, expected, atol=0, rtol=0)
|
|
|
|
actual = torch.nextafter(t2, t1)
|
|
expected = np.nextafter(t2.cpu().numpy(), t1.cpu().numpy())
|
|
self.assertEqual(actual, expected, atol=0, rtol=0)
|
|
|
|
t1 = torch.tensor([0, nan], device=device, dtype=dtype)
|
|
t2 = torch.tensor([nan, 0], device=device, dtype=dtype)
|
|
self.assertTrue(torch.nextafter(t1, t2).isnan().all())
|
|
|
|
a = torch.randn(100, device=device, dtype=dtype)
|
|
b = torch.randn(100, device=device, dtype=dtype)
|
|
actual = torch.nextafter(a, b)
|
|
expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy())
|
|
self.assertEqual(actual, expected, atol=0, rtol=0)
|
|
|
|
def _test_cop(self, torchfn, mathfn, dtype, device):
|
|
def reference_implementation(res2):
|
|
for i, j in iter_indices(sm1):
|
|
idx1d = i * sm1.size(0) + j
|
|
res2[i, j] = mathfn(sm1[i, j], sm2[idx1d])
|
|
return res2
|
|
|
|
# contiguous
|
|
m1 = torch.randn(10, 10, 10, dtype=dtype, device=device)
|
|
m2 = torch.randn(10, 10 * 10, dtype=dtype, device=device)
|
|
sm1 = m1[4]
|
|
sm2 = m2[4]
|
|
|
|
res1 = torchfn(sm1, sm2.view(10, 10))
|
|
res2 = reference_implementation(res1.clone())
|
|
self.assertEqual(res1, res2)
|
|
|
|
# non-contiguous
|
|
m1 = torch.randn(10, 10, 10, dtype=dtype, device=device)
|
|
m2 = torch.randn(10 * 10, 10 * 10, dtype=dtype, device=device)
|
|
sm1 = m1[:, 4]
|
|
sm2 = m2[:, 4]
|
|
# view as sm1.size()
|
|
sm2.set_(sm2.storage(), sm2.storage_offset(), sm1.size(), (sm2.stride()[0] * 10, sm2.stride()[0]))
|
|
res1 = torchfn(sm1, sm2)
|
|
# reference_implementation assumes 1-d sm2
|
|
sm2.set_(sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride())
|
|
res2 = reference_implementation(res1.clone())
|
|
self.assertEqual(res1, res2)
|
|
|
|
@onlyCPU
|
|
@dtypes(torch.float)
|
|
def test_cdiv(self, device, dtype):
|
|
self._test_cop(torch.div, lambda x, y: x / y, dtype, device)
|
|
|
|
@onlyCPU
|
|
@dtypes(torch.float)
|
|
def test_cremainder(self, device, dtype):
|
|
self._test_cop(torch.remainder, lambda x, y: x % y, dtype, device)
|
|
|
|
@onlyCPU
|
|
@dtypes(torch.float)
|
|
def test_cmul(self, device, dtype):
|
|
self._test_cop(torch.mul, lambda x, y: x * y, dtype, device)
|
|
|
|
@onlyCPU
|
|
@dtypes(torch.float)
|
|
def test_cpow(self, device, dtype):
|
|
self._test_cop(torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y), dtype, device)
|
|
|
|
@onlyCPU
|
|
@dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
|
|
def test_floor_divide_zero(self, device, dtype):
|
|
a = torch.tensor([0, 1], dtype=dtype, device=device)
|
|
b = torch.tensor([0, 1], dtype=dtype, device=device)
|
|
with self.assertRaisesRegex(RuntimeError, 'ZeroDivisionError'):
|
|
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
|
|
a // b
|
|
|
|
@unittest.skipIf(TEST_WITH_ASAN, "Integer overflows are not allowed under ASAN")
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_muldiv_scalar(self, device, dtype):
|
|
x = make_tensor((10, 3), device, dtype, low=None, high=None)
|
|
s = make_tensor((1,), 'cpu', dtype, low=None, high=None).item()
|
|
y = torch.full_like(x, s)
|
|
self.assertEqual(x * s, x * y)
|
|
self.assertEqual(s * x, y * x)
|
|
self.assertEqual(x / s, x / y)
|
|
self.assertEqual(s / x, y / x)
|
|
|
|
@dtypes(*tuple(itertools.combinations_with_replacement(torch.testing.get_all_dtypes(), 2)))
|
|
def test_comparison_ops_type_promotion_and_broadcasting(self, device, dtypes):
|
|
# issue #42660
|
|
# testing all combinations of broadcasting and type promotion
|
|
# with a range of dtypes and input shapes, and with extremal values
|
|
def compare_with_numpy_bin_op(torch_fn, np_fn, x, y, out=None):
|
|
# working around the fact that numpy doesn't support bfloat16
|
|
# by letting numpy treat them as float32's
|
|
x_np = x if x.dtype != torch.bfloat16 else x.to(torch.float32)
|
|
y_np = y.cpu().numpy() if y.dtype != torch.bfloat16 else y.to(torch.float32).cpu().numpy()
|
|
self.compare_with_numpy(lambda inp: torch_fn(inp, y, out=out) if out else torch_fn(inp, y),
|
|
lambda inp: np_fn(inp, y_np, out=out) if out else np_fn(inp, y_np),
|
|
x_np)
|
|
|
|
complex_op_denylist = [torch.lt, torch.le, torch.gt, torch.ge] # complex not supported
|
|
input_sizes = [
|
|
(1,),
|
|
(10,),
|
|
(10, 1),
|
|
(1, 10),
|
|
(4, 10),
|
|
(64, 10),
|
|
(12, 3)]
|
|
op_pairs = [(torch.lt, np.less),
|
|
(torch.le, np.less_equal),
|
|
(torch.gt, np.greater),
|
|
(torch.ge, np.greater_equal),
|
|
(torch.eq, np.equal),
|
|
(torch.ne, np.not_equal),
|
|
(torch.logical_and, np.logical_and),
|
|
(torch.logical_or, np.logical_or),
|
|
(torch.logical_xor, np.logical_xor)]
|
|
|
|
for size1 in input_sizes:
|
|
size2 = (2,) + size1 # perform broadcasting
|
|
for with_extremal in [False, True]:
|
|
a = _generate_input(size1, dtypes[0], device, with_extremal)
|
|
b = _generate_input(size2, dtypes[1], device, with_extremal)
|
|
for torch_op, numpy_op in op_pairs:
|
|
if (dtypes[0].is_complex or dtypes[1].is_complex) and torch_op in complex_op_denylist:
|
|
continue
|
|
# functional version of op
|
|
compare_with_numpy_bin_op(torch_op, numpy_op, a, b)
|
|
|
|
# functional comparison ops always return bool tensors
|
|
self.assertEqual(torch_op(a, b).dtype, torch.bool)
|
|
|
|
# out version of op
|
|
out = torch.zeros(1, dtype=torch.complex128) # all casts to complex128 are safe
|
|
compare_with_numpy_bin_op(torch_op, numpy_op, a, b, out=out)
|
|
|
|
@onlyOnCPUAndCUDA
|
|
@dtypes(torch.int8, torch.int16, torch.int32, torch.int64)
|
|
def test_signed_shift(self, device, dtype):
|
|
"Ensure that signed integer bit shifting works as expected."
|
|
a = torch.tensor([-10, 10], device=device, dtype=dtype) # [11...1110110, 1010]
|
|
expected_l = torch.tensor([-40, 40], device=device, dtype=dtype) # [11...11011000, 101000]
|
|
self.assertEqual(a << 2, expected_l)
|
|
self.compare_with_numpy(lambda x: x << 2, lambda x: np.left_shift(x, 2), a)
|
|
expected_r = torch.tensor([-5, 5], device=device, dtype=dtype) # [1111...111011, 101]
|
|
self.assertEqual(a >> 1, expected_r)
|
|
self.compare_with_numpy(lambda x: x >> 1, lambda x: np.right_shift(x, 1), a)
|
|
|
|
def test_bitwise_and(self, device):
|
|
for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
|
|
a = torch.tensor([1, -2, 3], dtype=dtype, device=device)
|
|
b = torch.tensor([2, 1, 3], dtype=dtype, device=device)
|
|
expected_res = torch.tensor([0, 0, 3], dtype=dtype, device=device)
|
|
b_scalar = 2
|
|
expected_res_scalar = torch.tensor([0, 2, 2], dtype=dtype, device=device)
|
|
|
|
# standard version
|
|
self.assertEqual(torch.bitwise_and(a, b), expected_res)
|
|
self.assertEqual(torch.bitwise_and(a, b_scalar), expected_res_scalar)
|
|
|
|
# out
|
|
c = torch.empty(0, dtype=dtype, device=device)
|
|
torch.bitwise_and(a, b, out=c)
|
|
self.assertEqual(c, expected_res)
|
|
torch.bitwise_and(a, b_scalar, out=c)
|
|
self.assertEqual(c, expected_res_scalar)
|
|
|
|
# in-place
|
|
a1 = a.clone()
|
|
a1.bitwise_and_(b)
|
|
self.assertEqual(a1, expected_res)
|
|
a.bitwise_and_(b_scalar)
|
|
self.assertEqual(a, expected_res_scalar)
|
|
|
|
self.assertEqual(torch.tensor([False, True, False], device=device),
|
|
torch.bitwise_and(torch.tensor([True, True, False], device=device),
|
|
torch.tensor([False, True, False], device=device)))
|
|
|
|
def test_bitwise_or(self, device):
|
|
for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
|
|
a = torch.tensor([1, -2, 3], dtype=dtype, device=device)
|
|
b = torch.tensor([2, 1, 3], dtype=dtype, device=device)
|
|
expected_res = torch.tensor([3, -1, 3], dtype=dtype, device=device)
|
|
b_scalar = 2
|
|
expected_res_scalar = torch.tensor([3, -2, 3], dtype=dtype, device=device)
|
|
|
|
# standard version
|
|
self.assertEqual(torch.bitwise_or(a, b), expected_res)
|
|
self.assertEqual(torch.bitwise_or(a, b_scalar), expected_res_scalar)
|
|
|
|
# out
|
|
c = torch.empty(0, dtype=dtype, device=device)
|
|
torch.bitwise_or(a, b, out=c)
|
|
self.assertEqual(c, expected_res)
|
|
torch.bitwise_or(a, b_scalar, out=c)
|
|
self.assertEqual(c, expected_res_scalar)
|
|
|
|
# in-place
|
|
a1 = a.clone()
|
|
a1.bitwise_or_(b)
|
|
self.assertEqual(a1, expected_res)
|
|
a.bitwise_or_(b_scalar)
|
|
self.assertEqual(a, expected_res_scalar)
|
|
|
|
self.assertEqual(torch.tensor([True, True, False], device=device),
|
|
torch.bitwise_or(torch.tensor([True, True, False], device=device),
|
|
torch.tensor([False, True, False], device=device)))
|
|
|
|
def test_bitwise_xor(self, device):
|
|
for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
|
|
a = torch.tensor([1, -2, 3], dtype=dtype, device=device)
|
|
b = torch.tensor([2, 1, 3], dtype=dtype, device=device)
|
|
expected_res = torch.tensor([3, -1, 0], dtype=dtype, device=device)
|
|
b_scalar = 2
|
|
expected_res_scalar = torch.tensor([3, -4, 1], dtype=dtype, device=device)
|
|
|
|
# standard version
|
|
self.assertEqual(torch.bitwise_xor(a, b), expected_res)
|
|
self.assertEqual(torch.bitwise_xor(a, b_scalar), expected_res_scalar)
|
|
|
|
# out
|
|
c = torch.empty(0, dtype=dtype, device=device)
|
|
torch.bitwise_xor(a, b, out=c)
|
|
self.assertEqual(c, expected_res)
|
|
torch.bitwise_xor(a, b_scalar, out=c)
|
|
self.assertEqual(c, expected_res_scalar)
|
|
|
|
# in-place
|
|
a1 = a.clone()
|
|
a1.bitwise_xor_(b)
|
|
self.assertEqual(a1, expected_res)
|
|
a.bitwise_xor_(b_scalar)
|
|
self.assertEqual(a, expected_res_scalar)
|
|
|
|
self.assertEqual(torch.tensor([True, False, False], device=device),
|
|
torch.bitwise_xor(torch.tensor([True, True, False], device=device),
|
|
torch.tensor([False, True, False], device=device)))
|
|
|
|
@onlyOnCPUAndCUDA
|
|
@dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False),
|
|
torch.testing.get_all_dtypes(include_complex=False))))
|
|
def test_heaviside(self, device, dtypes):
|
|
input_dtype = dtypes[0]
|
|
values_dtype = dtypes[1]
|
|
|
|
rng = np.random.default_rng()
|
|
input = np.array(rng.integers(-10, 10, size=10),
|
|
dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64])
|
|
input[0] = input[3] = input[7] = 0
|
|
values = np.array(rng.integers(-10, 10, size=10),
|
|
dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64])
|
|
np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype)
|
|
|
|
input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
|
|
values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
|
|
out = torch.empty_like(input)
|
|
|
|
if input_dtype == values_dtype:
|
|
torch_result = torch.heaviside(input, values)
|
|
self.assertEqual(np_result, torch_result)
|
|
|
|
torch_result = input.heaviside(values)
|
|
self.assertEqual(np_result, torch_result)
|
|
|
|
torch.heaviside(input, values, out=out)
|
|
self.assertEqual(np_result, out)
|
|
|
|
input.heaviside_(values)
|
|
self.assertEqual(np_result, input)
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
|
|
torch.heaviside(input, values)
|
|
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
|
|
input.heaviside(values)
|
|
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
|
|
torch.heaviside(input, values, out=out)
|
|
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
|
|
input.heaviside_(values)
|
|
|
|
@onlyCUDA
|
|
def test_heaviside_cross_device(self, device):
|
|
x = torch.tensor([-9, 5, 0, 6, -2, 2], device='cuda')
|
|
y = torch.tensor(0)
|
|
result = torch.heaviside(x, y)
|
|
expect = torch.tensor([0, 1, 0, 1, 0, 1], device='cuda')
|
|
self.assertEqual(result, expect)
|
|
|
|
result = torch.heaviside(y, x)
|
|
expect = torch.tensor([-9, 5, 0, 6, -2, 2], device='cuda')
|
|
self.assertEqual(result, expect)
|
|
|
|
x = torch.tensor([-9, 5, 0, 6, -2, 2])
|
|
y = torch.tensor(0, device='cuda')
|
|
with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'):
|
|
torch.heaviside(x, y)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'):
|
|
torch.heaviside(y, x)
|
|
|
|
@dtypes(*list(product(torch.testing.get_all_complex_dtypes(),
|
|
torch.testing.get_all_complex_dtypes())))
|
|
def test_heaviside_complex(self, device, dtypes):
|
|
input_dtype = dtypes[0]
|
|
values_dtype = dtypes[1]
|
|
|
|
data = (complex(0, -6), complex(-1, 3), complex(1, 1))
|
|
input = torch.tensor(data, device=device, dtype=input_dtype)
|
|
values = torch.tensor(data, device=device, dtype=values_dtype)
|
|
out = torch.empty_like(input)
|
|
real = input.real
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
|
|
torch.heaviside(input, real)
|
|
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
|
|
real.heaviside(values)
|
|
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
|
|
input.heaviside_(values)
|
|
with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
|
|
torch.heaviside(real, real, out=out)
|
|
|
|
def _test_logical(self, device, dtypes, op, a_, b_, expected_res_):
|
|
expected_res = torch.tensor(expected_res_, dtype=dtypes[0], device=device)
|
|
a = torch.tensor(a_, dtype=dtypes[0], device=device)
|
|
b = torch.tensor(b_, dtype=dtypes[1], device=device)
|
|
|
|
# new tensor
|
|
self.assertEqual(expected_res.bool(), getattr(a, op)(b))
|
|
# out
|
|
c = torch.empty(0, dtype=torch.bool, device=device)
|
|
getattr(torch, op)(a, b, out=c)
|
|
self.assertEqual(expected_res.bool(), c)
|
|
|
|
# in-place
|
|
# TODO: remove when different dtypes as operands are supported
|
|
if dtypes[0] != dtypes[1]:
|
|
with self.assertRaises(RuntimeError):
|
|
getattr(a, op + '_')(b)
|
|
return
|
|
|
|
getattr(a, op + '_')(b)
|
|
self.assertEqual(expected_res, a)
|
|
|
|
@dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes()))
|
|
def test_logical_xor(self, device, dtypes):
|
|
self._test_logical(device, dtypes, 'logical_xor', [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1])
|
|
|
|
@dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes()))
|
|
def test_logical_and(self, device, dtypes):
|
|
self._test_logical(device, dtypes, 'logical_and', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0])
|
|
|
|
@dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes()))
|
|
def test_logical_or(self, device, dtypes):
|
|
self._test_logical(device, dtypes, 'logical_or', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1])
|
|
|
|
def test_remainder_overflow(self, device):
|
|
# Check Integer Overflows
|
|
x = torch.tensor(23500, dtype=torch.int64, device=device)
|
|
q = 392486996410368
|
|
self.assertEqual(x % q, x)
|
|
self.assertEqual(-x % q, q - x)
|
|
self.assertEqual(x % -q, x - q)
|
|
self.assertEqual(-x % -q, -x)
|
|
|
|
def test_rpow(self, device):
|
|
m = torch.randn(10, 10, device=device)
|
|
self.assertEqual(torch.pow(2, m), 2**m)
|
|
|
|
# test with scalar
|
|
m = torch.randn(1, device=device).squeeze()
|
|
assert m.dim() == 0, "m is intentionally a scalar"
|
|
self.assertEqual(torch.pow(2, m), 2**m)
|
|
|
|
@onlyCPU
|
|
def test_ldexp(self, device):
|
|
# random values
|
|
mantissas = torch.randn(64, device=device)
|
|
exponents = torch.randint(-31, 31, (64,), device=device, dtype=torch.int32)
|
|
|
|
# basic test
|
|
np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy())
|
|
pt_outcome_1 = torch.ldexp(mantissas, exponents)
|
|
pt_outcome_2 = mantissas.ldexp(exponents)
|
|
self.assertEqual(np_outcome, pt_outcome_1)
|
|
self.assertEqual(np_outcome, pt_outcome_2)
|
|
mantissas.ldexp_(exponents)
|
|
self.assertEqual(np_outcome, mantissas)
|
|
|
|
# test bounds
|
|
mantissas = torch.tensor([float('inf'), float('-inf'), float('inf'), float('nan')], device=device)
|
|
exponents = torch.randint(0, 31, (4,), device=device, dtype=torch.int32)
|
|
np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy())
|
|
pt_outcome = torch.ldexp(mantissas, exponents)
|
|
self.assertEqual(np_outcome, pt_outcome)
|
|
|
|
def test_lerp(self, device):
|
|
start_end_weight_shapes = [(), (5,), (5, 5)]
|
|
for shapes in product(start_end_weight_shapes, start_end_weight_shapes, start_end_weight_shapes):
|
|
start = torch.randn(shapes[0], device=device)
|
|
end = torch.randn(shapes[1], device=device)
|
|
|
|
# Tensor weights
|
|
for weight in [torch.randn(shapes[2], device=device), random.random()]:
|
|
actual = torch.lerp(start, end, weight)
|
|
actual_method = start.lerp(end, weight)
|
|
self.assertEqual(actual, actual_method)
|
|
actual_out = torch.Tensor().to(device)
|
|
torch.lerp(start, end, weight, out=actual_out)
|
|
self.assertEqual(actual, actual_out)
|
|
expected = start + weight * (end - start)
|
|
self.assertEqual(expected, actual)
|
|
|
|
def _test_logaddexp(self, device, dtype, base2):
|
|
if base2:
|
|
ref_func = np.logaddexp2
|
|
our_func = torch.logaddexp2
|
|
else:
|
|
ref_func = np.logaddexp
|
|
our_func = torch.logaddexp
|
|
|
|
def _test_helper(a, b):
|
|
ref = ref_func(a.cpu().numpy(), b.cpu().numpy())
|
|
v = our_func(a, b)
|
|
self.assertEqual(ref, v)
|
|
|
|
# simple test
|
|
a = torch.randn(64, 2, dtype=dtype, device=device) - 0.5
|
|
b = torch.randn(64, 2, dtype=dtype, device=device) - 0.5
|
|
_test_helper(a, b)
|
|
_test_helper(a[:3], b[:3])
|
|
|
|
# large value test for numerical stability
|
|
a *= 10000
|
|
b *= 10000
|
|
_test_helper(a, b)
|
|
_test_helper(a[:3], b[:3])
|
|
|
|
a = torch.tensor([float('inf'), float('-inf'), float('inf'), float("nan")], dtype=dtype, device=device)
|
|
b = torch.tensor([float('inf'), float('-inf'), float('-inf'), float("nan")], dtype=dtype, device=device)
|
|
_test_helper(a, b)
|
|
|
|
@dtypes(torch.float32, torch.float64)
|
|
def test_logaddexp(self, device, dtype):
|
|
self._test_logaddexp(device, dtype, base2=False)
|
|
|
|
@dtypes(torch.float32, torch.float64)
|
|
def test_logaddexp2(self, device, dtype):
|
|
self._test_logaddexp(device, dtype, base2=True)
|
|
|
|
def test_add(self, device):
|
|
dtypes = [torch.float, torch.double] + torch.testing.get_all_complex_dtypes()
|
|
for dtype in dtypes:
|
|
# [res] torch.add([res,] tensor1, tensor2)
|
|
m1 = torch.randn(100, 100, dtype=dtype, device=device)
|
|
v1 = torch.randn(100, dtype=dtype, device=device)
|
|
|
|
# contiguous
|
|
res1 = torch.add(m1[4], v1)
|
|
res2 = res1.clone().zero_()
|
|
for i in range(m1.size(1)):
|
|
res2[i] = m1[4, i] + v1[i]
|
|
self.assertEqual(res1, res2)
|
|
|
|
m1 = torch.randn(100, 100, device=device)
|
|
v1 = torch.randn(100, device=device)
|
|
|
|
# non-contiguous
|
|
res1 = torch.add(m1[:, 4], v1)
|
|
res2 = res1.clone().zero_()
|
|
for i in range(m1.size(0)):
|
|
res2[i] = m1[i, 4] + v1[i]
|
|
self.assertEqual(res1, res2)
|
|
|
|
# [res] torch.add([res,] tensor, value)
|
|
m1 = torch.randn(10, 10, device=device)
|
|
|
|
# contiguous
|
|
res1 = m1.clone()
|
|
res1[3].add_(2)
|
|
res2 = m1.clone()
|
|
for i in range(m1.size(1)):
|
|
res2[3, i] = res2[3, i] + 2
|
|
self.assertEqual(res1, res2)
|
|
|
|
# non-contiguous
|
|
m1 = torch.randn(10, 10, device=device)
|
|
res1 = m1.clone()
|
|
res1[:, 3].add_(2)
|
|
res2 = m1.clone()
|
|
for i in range(m1.size(0)):
|
|
res2[i, 3] = res2[i, 3] + 2
|
|
self.assertEqual(res1, res2)
|
|
|
|
# inter-type
|
|
m1 = torch.randn(10, 10, dtype=dtype, device=device)
|
|
self.assertEqual(m1 + 3, m1 + torch.tensor(3))
|
|
self.assertEqual(3 + m1, torch.tensor(3) + m1)
|
|
|
|
# contiguous + non-contiguous
|
|
m1 = torch.randn(10, 10, dtype=dtype, device=device)
|
|
m2 = torch.randn(10, 10, dtype=dtype, device=device).t()
|
|
res = m1 + m2
|
|
self.assertTrue(res.is_contiguous())
|
|
self.assertEqual(res, m1 + m2.contiguous())
|
|
|
|
# 1d + empty
|
|
m1 = torch.tensor([1.0], dtype=dtype, device=device)
|
|
m2 = torch.tensor([], dtype=dtype, device=device)
|
|
self.assertEqual(m1 + m2, [])
|
|
|
|
# inter-type unint8
|
|
one = torch.tensor(1, dtype=torch.uint8, device=device)
|
|
self.assertEqual(torch.add(one, 1), 2)
|
|
self.assertEqual(torch.add(one, 1).dtype, torch.uint8)
|
|
|
|
# bool
|
|
m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device)
|
|
m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device)
|
|
expected = torch.tensor([True, True, False, True, False, True], dtype=torch.bool, device=device)
|
|
self.assertEqual(m1 + m2, expected)
|
|
|
|
# fused multiply add
|
|
a = torch.zeros(2, 3, dtype=torch.bool, device=device)
|
|
res = torch.add(a, a, alpha=0)
|
|
expected = torch.zeros(2, 3, device=device).bool()
|
|
self.assertEqual(res, expected)
|
|
|
|
# bfloat16
|
|
m1 = torch.tensor([1., 2.], dtype=torch.bfloat16)
|
|
m2 = torch.tensor([3., 4.], dtype=torch.bfloat16)
|
|
self.assertEqual(m1 + m2, torch.tensor([4., 6.], dtype=torch.bfloat16))
|
|
|
|
# different alpha types
|
|
m1 = torch.tensor([2 + 3j, 4 + 5j], dtype=torch.complex64, device=device)
|
|
m2 = torch.tensor([4 + 5j, 2 + 3j], dtype=torch.complex64, device=device)
|
|
# add complex numbers with float alpha
|
|
res = torch.add(m1, m2, alpha=0.1)
|
|
expected = torch.tensor([2.4000 + 3.5000j, 4.2000 + 5.3000j], dtype=torch.complex64, device=device)
|
|
self.assertEqual(res, expected)
|
|
|
|
# add complex numbers with complex alpha
|
|
res = torch.add(m1, m2, alpha=complex(0.1, 0.2))
|
|
expected = torch.tensor([1.4000 + 4.3000j, 3.6000 + 5.7000j], dtype=torch.complex64, device=device)
|
|
self.assertEqual(res, expected)
|
|
|
|
# add complex numbers with integer alpha
|
|
res = torch.add(m1, m2, alpha=2)
|
|
expected = torch.tensor([10. + 13.j, 8. + 11.j], dtype=torch.complex64, device=device)
|
|
self.assertEqual(res, expected)
|
|
|
|
# mismatched alpha
|
|
m1 = torch.tensor([1], dtype=torch.int8, device=device)
|
|
m2 = torch.tensor([2], dtype=torch.int8, device=device)
|
|
self.assertRaisesRegex(RuntimeError,
|
|
r"Boolean alpha only supported for Boolean results\.",
|
|
lambda: torch.add(m1, m2, alpha=True))
|
|
self.assertRaisesRegex(RuntimeError,
|
|
r"For integral input tensors, argument alpha must not be a floating point number\.",
|
|
lambda: torch.add(m1, m2, alpha=1.0))
|
|
|
|
# mismatched alpha, float / double tensor and complex alpha
|
|
m1 = torch.tensor([3., 4.], device=device)
|
|
m2 = torch.tensor([4., 3.], device=device)
|
|
self.assertRaises(RuntimeError, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2)))
|
|
|
|
m1 = torch.tensor([3., 4.], dtype=torch.double, device=device)
|
|
m2 = torch.tensor([4., 3.], dtype=torch.double, device=device)
|
|
self.assertRaises(RuntimeError, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2)))
|
|
|
|
# complex
|
|
m1 = torch.tensor((4.0000 + 4.0000j), dtype=torch.complex64)
|
|
m2 = torch.tensor(4., dtype=torch.float64)
|
|
self.assertRaisesRegex(RuntimeError, r"result type ComplexFloat can't be cast to the desired output type Double",
|
|
lambda: torch.add(m1, m1, out=m2))
|
|
|
|
|
|
def test_sub_typing(self, device):
|
|
m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device)
|
|
m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device)
|
|
self.assertRaisesRegex(RuntimeError,
|
|
r"Subtraction, the `\-` operator, with two bool tensors is not supported. "
|
|
r"Use the `\^` or `logical_xor\(\)` operator instead.",
|
|
lambda: m1 - m2)
|
|
self.assertRaisesRegex(RuntimeError,
|
|
r"Subtraction, the `\-` operator, with a bool tensor is not supported. "
|
|
r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.",
|
|
lambda: 1 - m1)
|
|
self.assertRaisesRegex(RuntimeError,
|
|
r"Subtraction, the `\-` operator, with a bool tensor is not supported. "
|
|
r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.",
|
|
lambda: m2 - 1)
|
|
|
|
# mismatched alpha
|
|
m1 = torch.tensor([1], dtype=torch.int8, device=device)
|
|
m2 = torch.tensor([2], dtype=torch.int8, device=device)
|
|
self.assertRaisesRegex(RuntimeError,
|
|
r"Boolean alpha only supported for Boolean results\.",
|
|
lambda: torch.sub(m1, m2, alpha=True))
|
|
self.assertRaisesRegex(RuntimeError,
|
|
r"For integral input tensors, argument alpha must not be a floating point number\.",
|
|
lambda: torch.sub(m1, m2, alpha=1.0))
|
|
|
|
def test_mul(self, device):
|
|
m1 = torch.randn(10, 10, device=device)
|
|
res1 = m1.clone()
|
|
res1[:, 3].mul_(2)
|
|
res2 = m1.clone()
|
|
for i in range(res1.size(0)):
|
|
res2[i, 3] = res2[i, 3] * 2
|
|
self.assertEqual(res1, res2)
|
|
|
|
a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device)
|
|
a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device)
|
|
self.assertEqual(a1 * a2, torch.tensor([True, False, False, False], dtype=torch.bool, device=device))
|
|
|
|
if device == 'cpu':
|
|
a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device)
|
|
a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device)
|
|
self.assertEqual(a1 * a2, torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device), atol=0.01, rtol=0)
|
|
self.assertEqual(a1.mul(a2), a1 * a2)
|
|
|
|
def test_bool_tensor_comparison_ops(self, device):
|
|
a = torch.tensor([True, False, True, False, True, False], dtype=torch.bool, device=device)
|
|
b = torch.tensor([True, False, True, True, True, True], dtype=torch.bool, device=device)
|
|
self.assertEqual(a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device))
|
|
self.assertEqual(a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device))
|
|
self.assertEqual(a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device))
|
|
self.assertEqual(a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device))
|
|
self.assertEqual(a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device))
|
|
self.assertEqual(a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device))
|
|
self.assertEqual(a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device))
|
|
self.assertEqual(a == torch.tensor(True, dtype=torch.bool, device=device),
|
|
torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device))
|
|
self.assertEqual(a == torch.tensor(0, dtype=torch.bool, device=device),
|
|
torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device))
|
|
self.assertFalse(a.equal(b))
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes(include_complex=False))
|
|
def test_logical(self, device, dtype):
|
|
if dtype != torch.bool:
|
|
x = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype)
|
|
b = torch.tensor([2], device=device, dtype=dtype)
|
|
self.assertEqual(x.lt(2), torch.tensor([True, False, False, False]))
|
|
self.assertEqual(x.le(2), torch.tensor([True, True, False, False]))
|
|
self.assertEqual(x.ge(2), torch.tensor([False, True, True, True]))
|
|
self.assertEqual(x.gt(2), torch.tensor([False, False, True, True]))
|
|
self.assertEqual(x.eq(2), torch.tensor([False, True, False, False]))
|
|
self.assertEqual(x.ne(2), torch.tensor([True, False, True, True]))
|
|
|
|
self.assertEqual(x.lt(b), torch.tensor([True, False, False, False]))
|
|
self.assertEqual(x.le(b), torch.tensor([True, True, False, False]))
|
|
self.assertEqual(x.ge(b), torch.tensor([False, True, True, True]))
|
|
self.assertEqual(x.gt(b), torch.tensor([False, False, True, True]))
|
|
self.assertEqual(x.eq(b), torch.tensor([False, True, False, False]))
|
|
self.assertEqual(x.ne(b), torch.tensor([True, False, True, True]))
|
|
else:
|
|
x = torch.tensor([True, False, True, False], device=device)
|
|
self.assertEqual(x.lt(True), torch.tensor([False, True, False, True]))
|
|
self.assertEqual(x.le(True), torch.tensor([True, True, True, True]))
|
|
self.assertEqual(x.ge(True), torch.tensor([True, False, True, False]))
|
|
self.assertEqual(x.gt(True), torch.tensor([False, False, False, False]))
|
|
self.assertEqual(x.eq(True), torch.tensor([True, False, True, False]))
|
|
self.assertEqual(x.ne(True), torch.tensor([False, True, False, True]))
|
|
|
|
def test_atan2(self, device):
|
|
def _test_atan2_with_size(size, device):
|
|
a = torch.rand(size=size, device=device, dtype=torch.double)
|
|
b = torch.rand(size=size, device=device, dtype=torch.double)
|
|
actual = a.atan2(b)
|
|
x = a.view(-1)
|
|
y = b.view(-1)
|
|
expected = torch.tensor([math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())],
|
|
device=device, dtype=torch.double)
|
|
self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02)
|
|
|
|
_test_atan2_with_size((2, 2), device)
|
|
_test_atan2_with_size((3, 3), device)
|
|
_test_atan2_with_size((5, 5), device)
|
|
|
|
def test_atan2_edgecases(self, device):
|
|
def _test_atan2(x, y, expected, device, dtype):
|
|
expected_tensor = torch.tensor([expected], dtype=dtype, device=device)
|
|
x_tensor = torch.tensor([x], dtype=dtype, device=device)
|
|
y_tensor = torch.tensor([y], dtype=dtype, device=device)
|
|
actual = torch.atan2(y_tensor, x_tensor)
|
|
self.assertEqual(expected_tensor, actual, rtol=0, atol=0.02)
|
|
|
|
for dtype in [torch.float, torch.double]:
|
|
_test_atan2(0, 0, 0, device, dtype)
|
|
_test_atan2(0, 1, math.pi / 2, device, dtype)
|
|
_test_atan2(0, -1, math.pi / -2, device, dtype)
|
|
_test_atan2(-1, 0, math.pi, device, dtype)
|
|
_test_atan2(1, 0, 0, device, dtype)
|
|
_test_atan2(-1, -1, math.pi * -3 / 4 , device, dtype)
|
|
_test_atan2(1, 1, math.pi / 4 , device, dtype)
|
|
_test_atan2(1, -1, math.pi / -4 , device, dtype)
|
|
_test_atan2(-1, 1, math.pi * 3 / 4 , device, dtype)
|
|
|
|
def test_trapz(self, device):
|
|
def test_dx(sizes, dim, dx, device):
|
|
t = torch.randn(sizes, device=device)
|
|
actual = torch.trapz(t, dx=dx, dim=dim)
|
|
expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim)
|
|
self.assertEqual(expected.shape, actual.shape)
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_x(sizes, dim, x, device):
|
|
t = torch.randn(sizes, device=device)
|
|
actual = torch.trapz(t, x=torch.tensor(x, device=device), dim=dim)
|
|
expected = np.trapz(t.cpu().numpy(), x=x, axis=dim)
|
|
self.assertEqual(expected.shape, actual.shape)
|
|
self.assertEqual(expected, actual.cpu())
|
|
|
|
test_dx((2, 3, 4), 1, 1, device)
|
|
test_dx((10, 2), 0, 0.1, device)
|
|
test_dx((1, 10), 0, 2.3, device)
|
|
test_dx((0, 2), 0, 1.0, device)
|
|
test_dx((0, 2), 1, 1.0, device)
|
|
test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
|
|
test_x((10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device)
|
|
test_x((1, 10), 0, [1.0], device)
|
|
test_x((0, 2), 0, [], device)
|
|
test_x((0, 2), 1, [1.0, 2.0], device)
|
|
with self.assertRaisesRegex(
|
|
IndexError,
|
|
'Dimension out of range'):
|
|
test_x((2, 3), 2, [], device)
|
|
test_dx((2, 3), 2, 1.0, device)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
'There must be one `x` value for each sample point'):
|
|
test_x((2, 3), 1, [1.0, 2.0], device)
|
|
test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device)
|
|
|
|
@dtypes(torch.double)
|
|
def test_pow_scalar_overloads_mem_overlap(self, device, dtype):
|
|
sz = 3
|
|
doubles = torch.randn(2 * sz, dtype=dtype, device=device)
|
|
self.check_internal_mem_overlap(
|
|
lambda t: t.pow_(42), 1, dtype, device)
|
|
self.unary_check_input_output_mem_overlap(
|
|
doubles, sz, lambda input, out: torch.pow(input, 42, out=out))
|
|
self.unary_check_input_output_mem_overlap(
|
|
doubles, sz, lambda input, out: torch.pow(42, input, out=out))
|
|
|
|
@dtypes(*list(product(torch.testing.get_all_dtypes(include_bool=False),
|
|
torch.testing.get_all_dtypes(include_bool=False))))
|
|
def test_float_power(self, device, dtypes):
|
|
def to_np(value):
|
|
if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16:
|
|
return value.to(torch.float).cpu().numpy()
|
|
return value.cpu().numpy() if isinstance(value, torch.Tensor) else value
|
|
|
|
base_dtype = dtypes[0]
|
|
exp_dtype = dtypes[1]
|
|
out_dtype = torch.complex128 if base_dtype.is_complex or exp_dtype.is_complex else torch.float64
|
|
|
|
base = make_tensor((30,), device, base_dtype, low=1, high=100)
|
|
# Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0
|
|
# Related: https://github.com/pytorch/pytorch/issues/48000
|
|
# base[0] = base[3] = base[7] = 0
|
|
exp = make_tensor((30,), device, exp_dtype, low=-2, high=2)
|
|
exp[0] = exp[4] = exp[6] = 0
|
|
|
|
expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp)))
|
|
|
|
exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2]
|
|
complex_exponents = exponents + [-2.5j, -1.0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j]
|
|
|
|
for op in (torch.float_power, torch.Tensor.float_power, torch.Tensor.float_power_):
|
|
|
|
# Case of Tensor x Tensor
|
|
if op is torch.Tensor.float_power_ and base_dtype != out_dtype:
|
|
with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"):
|
|
op(base.clone(), exp)
|
|
else:
|
|
result = op(base.clone(), exp)
|
|
self.assertEqual(expected, result)
|
|
|
|
if op is torch.float_power:
|
|
out = torch.empty_like(base).to(device=device, dtype=out_dtype)
|
|
op(base, exp, out=out)
|
|
self.assertEqual(expected, out)
|
|
|
|
# Case of Tensor x Scalar
|
|
for i in complex_exponents if exp_dtype.is_complex else exponents:
|
|
out_dtype_scalar_exp = torch.complex128 if base_dtype.is_complex or type(i) == complex else torch.float64
|
|
expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))
|
|
|
|
if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp:
|
|
with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"):
|
|
op(base.clone(), i)
|
|
else:
|
|
result = op(base.clone(), i)
|
|
self.assertEqual(expected_scalar_exp, result)
|
|
|
|
if op is torch.float_power:
|
|
out = torch.empty_like(base).to(device=device, dtype=out_dtype_scalar_exp)
|
|
op(base, i, out=out)
|
|
self.assertEqual(expected_scalar_exp, out)
|
|
|
|
# Case of Scalar x Tensor
|
|
for i in complex_exponents if base_dtype.is_complex else exponents:
|
|
out_dtype_scalar_base = torch.complex128 if exp_dtype.is_complex or type(i) == complex else torch.float64
|
|
expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp)))
|
|
|
|
result = torch.float_power(i, exp)
|
|
self.assertEqual(expected_scalar_base, result)
|
|
|
|
out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base)
|
|
torch.float_power(i, exp, out=out)
|
|
self.assertEqual(expected_scalar_base, out)
|
|
|
|
def test_float_power_exceptions(self, device):
|
|
def _promo_helper(x, y):
|
|
for i in (x, y):
|
|
if type(i) == complex:
|
|
return torch.complex128
|
|
elif type(i) == torch.Tensor and i.is_complex():
|
|
return torch.complex128
|
|
return torch.double
|
|
|
|
test_cases = ((torch.tensor([-2, -1, 0, 1, 2], device=device), -.25),
|
|
(torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), 2.))
|
|
for base, exp in test_cases:
|
|
for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble):
|
|
out = torch.empty(1, device=device, dtype=out_dtype)
|
|
required_dtype = _promo_helper(base, exp)
|
|
|
|
if out.dtype == required_dtype:
|
|
torch.float_power(base, exp, out=out)
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"):
|
|
torch.float_power(base, exp, out=out)
|
|
|
|
if base.dtype == required_dtype:
|
|
torch.Tensor.float_power_(base.clone(), exp)
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"):
|
|
torch.Tensor.float_power_(base.clone(), exp)
|
|
|
|
@skipIf(not TEST_SCIPY, "Scipy required for the test.")
|
|
@dtypes(*product(torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False),
|
|
torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False)))
|
|
def test_xlogy(self, device, dtypes):
|
|
def out_variant_helper(torch_fn, x, y):
|
|
expected = torch_fn(x, y)
|
|
out = torch.empty_like(expected)
|
|
torch_fn(x, y, out=out)
|
|
self.assertEqual(expected, out)
|
|
|
|
def inplace_variant_helper(x, y):
|
|
if x.dtype in torch.testing.get_all_int_dtypes() + [torch.bool]:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"can't be cast to the desired output type"):
|
|
x.clone().xlogy_(y)
|
|
else:
|
|
expected = torch.empty_like(x)
|
|
torch.xlogy(x, y, out=expected)
|
|
inplace_out = x.clone().xlogy_(y)
|
|
self.assertEqual(expected, inplace_out)
|
|
|
|
x_dtype, y_dtype = dtypes
|
|
|
|
# Tensor-Tensor Test (tensor of same and different shape)
|
|
x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000)
|
|
y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000)
|
|
z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000)
|
|
|
|
torch_fn = partial(torch.xlogy, x)
|
|
reference_fn = partial(scipy.special.xlogy, x.cpu().numpy())
|
|
|
|
self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False)
|
|
self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False)
|
|
self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False)
|
|
out_variant_helper(torch.xlogy, x, x)
|
|
out_variant_helper(torch.xlogy, x, y)
|
|
out_variant_helper(torch.xlogy, x, z)
|
|
inplace_variant_helper(x, x)
|
|
inplace_variant_helper(x, y)
|
|
inplace_variant_helper(x, z)
|
|
|
|
# Scalar-Tensor Test
|
|
torch_fn = partial(torch.xlogy, 3.14)
|
|
reference_fn = partial(scipy.special.xlogy, 3.14)
|
|
|
|
self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False)
|
|
self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False)
|
|
self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False)
|
|
out_variant_helper(torch.xlogy, 3.14, x)
|
|
out_variant_helper(torch.xlogy, 3.14, y)
|
|
out_variant_helper(torch.xlogy, 3.14, z)
|
|
|
|
# Special Values Tensor-Tensor
|
|
t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device)
|
|
zeros = torch.zeros(6, dtype=y_dtype, device=device)
|
|
|
|
torch_fn = partial(torch.xlogy, zeros)
|
|
reference_fn = partial(scipy.special.xlogy, zeros.cpu().numpy())
|
|
self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False)
|
|
out_variant_helper(torch.xlogy, zeros, t)
|
|
inplace_variant_helper(zeros, t)
|
|
|
|
# Special Values Scalar-Tensor
|
|
torch_fn = partial(torch.xlogy, 0)
|
|
reference_fn = partial(scipy.special.xlogy, 0)
|
|
self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False)
|
|
out_variant_helper(torch.xlogy, 0, t)
|
|
|
|
def test_xlogy_scalar_type_promotion(self, device):
|
|
# Test that python numbers don't participate in type promotion at the same
|
|
# priority level as 0-dim tensors
|
|
t = torch.randn((), dtype=torch.float32, device=device)
|
|
|
|
self.assertEqual(t.dtype, torch.xlogy(t, 5).dtype)
|
|
self.assertEqual(t.dtype, torch.xlogy(t, 5.).dtype)
|
|
|
|
self.assertEqual(t.dtype, torch.xlogy(5, t).dtype)
|
|
self.assertEqual(t.dtype, torch.xlogy(5., t).dtype)
|
|
|
|
@skipIf(not TEST_SCIPY, "Scipy required for the test.")
|
|
def test_xlogy_bfloat16(self, device):
|
|
def _compare_helper(x, y):
|
|
x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy()
|
|
y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy()
|
|
expected = torch.from_numpy(scipy.special.xlogy(x_np, y_np))
|
|
actual = torch.xlogy(x, y)
|
|
self.assertEqual(expected, actual, exact_dtype=False)
|
|
|
|
x_dtype, y_dtype = torch.bfloat16, torch.bfloat16
|
|
|
|
# Tensor-Tensor Test (tensor of same and different shape)
|
|
x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000)
|
|
y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000)
|
|
z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000)
|
|
|
|
_compare_helper(x, x)
|
|
_compare_helper(x, y)
|
|
_compare_helper(x, z)
|
|
|
|
_compare_helper(x, 3.14)
|
|
_compare_helper(y, 3.14)
|
|
_compare_helper(z, 3.14)
|
|
|
|
# Special Values Tensor-Tensor
|
|
t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device)
|
|
zeros = torch.tensor(5, dtype=y_dtype, device=device)
|
|
_compare_helper(t, zeros)
|
|
_compare_helper(t, 0.)
|
|
|
|
tensor_binary_ops = [
|
|
'__lt__', '__le__',
|
|
'__gt__', '__ge__',
|
|
'__eq__', '__ne__',
|
|
|
|
'__add__', '__radd__', '__iadd__',
|
|
'__sub__', '__rsub__', '__isub__',
|
|
'__mul__', '__rmul__', '__imul__',
|
|
'__matmul__', '__rmatmul__', '__imatmul__',
|
|
'__truediv__', '__rtruediv__', '__itruediv__',
|
|
'__floordiv__', '__rfloordiv__', '__ifloordiv__',
|
|
'__mod__', '__rmod__', '__imod__',
|
|
'__divmod__', '__rdivmod__', '__idivmod__',
|
|
'__pow__', '__rpow__', '__ipow__',
|
|
'__lshift__', '__rlshift__', '__ilshift__',
|
|
'__rshift__', '__rrshift__', '__irshift__',
|
|
'__and__', '__rand__', '__iand__',
|
|
'__xor__', '__rxor__', '__ixor__',
|
|
'__or__', '__ror__', '__ior__',
|
|
]
|
|
|
|
# Test that binary math operations return NotImplemented for unknown types.
|
|
def generate_not_implemented_tests(cls):
|
|
class UnknownType:
|
|
pass
|
|
|
|
# TODO: refactor to inline these
|
|
_types = [
|
|
torch.half, torch.float, torch.double,
|
|
torch.int8, torch.short, torch.int, torch.long,
|
|
torch.uint8
|
|
]
|
|
|
|
# TODO: refactor to use make_tensor
|
|
def _small_2d(dtype, device, has_zeros=True, fill_ones=False, oneish=False):
|
|
t = _make_tensor((5, 5), dtype, device, fill_ones=fill_ones)
|
|
if oneish:
|
|
return t.clamp(min=_number(.99, 1, dtype), max=1.01)
|
|
if not has_zeros:
|
|
return t.clamp(min=(_number(_div_min, 1, dtype)))
|
|
return t
|
|
|
|
for op in tensor_binary_ops:
|
|
@dtypes(*_types)
|
|
def test(self, device, dtype):
|
|
# Generate the inputs
|
|
tensor = _small_2d(dtype, device)
|
|
|
|
# Runs the tensor op on the device
|
|
result = getattr(tensor, op)(UnknownType())
|
|
self.assertEqual(result, NotImplemented)
|
|
|
|
test_name = "test_{}_not_implemented".format(op)
|
|
assert not hasattr(cls, test_name), "{0} already in {1}".format(
|
|
test_name, cls.__name__)
|
|
|
|
setattr(cls, test_name, test)
|
|
|
|
|
|
generate_not_implemented_tests(TestBinaryUfuncs)
|
|
instantiate_device_type_tests(TestBinaryUfuncs, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|