Files
pytorch/test/test_foreach.py
Jane Xu 5a0926a26e Stop skipping entire foreach tests, just skip the profiler portion (#156871)
Instead of skipping the whole test as the CUPTI team figures out what is wrong, let's temporarily skip the profiler check portion. It is high pri to add it back to ensure foreach ops are actually performant.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156871
Approved by: https://github.com/albanD
ghstack dependencies: #156876
2025-06-27 22:35:34 +00:00

1628 lines
60 KiB
Python

# Owner(s): ["module: mta"]
# ruff: noqa: F841
import itertools
import os
import random
import re
import unittest
import weakref
from contextlib import nullcontext
from numbers import Number
import torch
from torch.testing import make_tensor
from torch.testing._comparison import default_tolerances
from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_MULTIGPU
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
largeTensorTest,
onlyCUDA,
OpDTypes,
ops,
)
from torch.testing._internal.common_dtype import (
all_types_and_complex_and,
floating_types,
floating_types_and,
integral_types_and,
)
from torch.testing._internal.common_methods_invocations import (
foreach_binary_op_db,
foreach_other_op_db,
foreach_pointwise_op_db,
foreach_reduce_op_db,
foreach_unary_op_db,
)
from torch.testing._internal.common_utils import (
gradcheck,
parametrize,
run_tests,
skipIfRocmVersionLessThan,
skipIfTorchDynamo,
TEST_WITH_ROCM,
TestCase,
)
from torch.testing._internal.triton_utils import requires_cuda
_BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator"
class RegularFuncWrapper:
def __init__(self, func):
self.func = func
def __call__(self, inputs, scalars=None, **kwargs):
if scalars is not None:
assert len(inputs) == 3
# We need to distribute each scalar to the regular func and it needs
# special consideration as it is a keyword only argument to the
# regular func. (Strangely, it is not a keyword only argument to the
# foreach func)
return [
self.func(*i, value=scalars[idx], **kwargs)
for idx, i in enumerate(zip(*inputs))
]
if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)):
# binary op with tensorlist and scalar.
inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
return [self.func(*i, **kwargs) for i in zip(*inputs)]
class ForeachFuncWrapper:
def __init__(self, func):
self.func = func
# Some foreach functions don't have in-place implementations.
self.is_inplace = False if func is None else func.__name__.endswith("_")
def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs):
actual = None
zero_size = kwargs.pop("zero_size", False)
# Skip profiler check for CUDA 12.6, 12.8 as the upgrade makes profiler results flaky
# https://github.com/pytorch/pytorch/issues/148681. TODO: ADD IT BACK!!!
skip_profiler_check = _get_torch_cuda_version() in [(12, 6), (12, 8)]
if (
is_cuda
and not skip_profiler_check
and torch.autograd.kineto_available()
and torch.profiler.ProfilerActivity.CUDA
in torch.profiler.supported_activities()
):
with torch.profiler.profile() as p:
actual = self.func(*inputs, **kwargs)
# synchronize within the profiler context to make sure events happen before exiting
torch.cuda.synchronize()
keys = tuple([e.key for e in p.key_averages()])
mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
assert mta_called == (expect_fastpath and (not zero_size)), (
f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}"
)
else:
actual = self.func(*inputs, **kwargs)
if self.is_inplace:
assert id(inputs[0]) == id(actual)
return actual
class InplaceForeachVersionBumpCheck:
def __init__(
self,
testcase: TestCase,
tensorlist: "List[torch.Tensor]", # noqa: F821
) -> None:
self._testcase = testcase
self._tensorlist = tensorlist
self._orig_version_counts = [t._version for t in tensorlist]
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
# note(crcrpar): some methods e.g. `_binary_test` could call the given inplace function multiple times
self._testcase.assertGreaterEqual(
[t._version for t in self._tensorlist], self._orig_version_counts
)
def get_transform_func(num_tensors, dtype, device, is_fastpath):
def transform(t):
if not torch.is_tensor(t):
return t
if torch.is_tensor(t) and t.ndim == 0:
return t
return make_tensor(
(num_tensors, num_tensors),
dtype=dtype,
device=device,
requires_grad=True,
noncontiguous=not is_fastpath,
)
return transform
# note(crcrpar): `zero_size` is `False` unless (dtype, device) == (torch.float32, "cuda")
# as the pair would go through `multi_tensor_apply_kernel` if inputs are not zero size.
@unittest.mock.patch.dict(os.environ, {"KINETO_LOG_LEVEL": "5"})
class TestForeach(TestCase):
@property
def is_cuda(self):
return self.device_type == "cuda"
def _get_funcs(self, op):
return (
ForeachFuncWrapper(op.method_variant),
RegularFuncWrapper(op.ref),
ForeachFuncWrapper(op.inplace_variant),
RegularFuncWrapper(op.ref_inplace),
)
# note(crcrpar): Make sure 0-size tensors are appropriately ignored by `multi_tensor_apply`
# which is originally reported in https://github.com/pytorch/pytorch/issues/94865.
# rel:
# - https://github.com/pytorch/pytorch/pull/94655
# - https://github.com/pytorch/pytorch/issues/100701
# - https://github.com/pytorch/pytorch/pull/100811
@onlyCUDA
@ops(
foreach_unary_op_db
+ foreach_binary_op_db
+ foreach_pointwise_op_db
+ foreach_reduce_op_db
+ foreach_other_op_db,
dtypes=(torch.float32,),
)
def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op):
wrapped_op, _, inplace_op, _ = self._get_funcs(op)
for sample in op.sample_zero_size_inputs(device, dtype):
if op.method_variant is not None:
wrapped_op(
(sample.input, *sample.args),
is_cuda=self.is_cuda,
expect_fastpath=True,
zero_size=True,
)
if op.inplace_variant is not None:
with InplaceForeachVersionBumpCheck(self, sample.input):
inplace_op(
(sample.input, *sample.args),
is_cuda=self.is_cuda,
expect_fastpath=True,
zero_size=True,
)
@skipIfRocmVersionLessThan((6, 0))
@ops(
foreach_unary_op_db
+ foreach_binary_op_db
+ foreach_pointwise_op_db
+ foreach_reduce_op_db
+ foreach_other_op_db,
)
@parametrize(
"noncontiguous,inplace",
[(False, False), (False, True), (True, False), (True, True)],
name_fn=lambda x, y: "{}_{}".format(
"fastpath" if not x else "slowpath", "inplace" if y else "outplace"
),
)
def test_parity(self, device, dtype, op, noncontiguous, inplace):
if inplace:
_, _, func, ref = self._get_funcs(op)
else:
func, ref, _, _ = self._get_funcs(op)
for sample in op.sample_inputs(
device, dtype, noncontiguous=noncontiguous, allow_higher_dtype_scalars=True
):
ref_kwargs = sample.kwargs
# div promotes ints to floats, so we cannot go on the fastpath there
div_slowpath = (
dtype in integral_types_and(torch.bool) and op.name == "_foreach_div"
)
expect_fastpath = not (
noncontiguous or sample.disable_fastpath or div_slowpath
)
ref_input, ctxmgr = sample.input, nullcontext()
if inplace:
with torch.no_grad():
ref_input = [t.detach().clone() for t in sample.input]
ctxmgr = InplaceForeachVersionBumpCheck(self, sample.input)
try:
with ctxmgr:
actual = func(
[sample.input, *sample.args],
self.is_cuda,
expect_fastpath,
**sample.kwargs,
)
except Exception as e:
with self.assertRaises(type(e)):
ref([ref_input, *sample.ref_args], **ref_kwargs)
else:
expected = ref([ref_input, *sample.ref_args], **ref_kwargs)
self.assertEqual(expected, actual)
def _binary_test(
self,
dtype,
op,
ref,
inputs,
is_fastpath,
is_inplace,
*,
alpha,
scalar_self_arg: bool,
):
ref_inputs = (
[[t.detach().clone() for t in inputs[0]], inputs[1]]
if is_inplace
else inputs
)
try:
with (
InplaceForeachVersionBumpCheck(self, inputs[0])
if op.is_inplace
else nullcontext()
):
actual = op(inputs, self.is_cuda, is_fastpath)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
if not scalar_self_arg:
ref(ref_inputs)
else:
[ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
else:
expected = (
ref(ref_inputs)
if not scalar_self_arg
else [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
)
self.assertEqual(actual, expected)
if alpha is not None and not scalar_self_arg:
kwargs = {"alpha": alpha}
ref_inputs = inputs
try:
op_kwargs = {}
op_kwargs.update(kwargs)
with (
InplaceForeachVersionBumpCheck(self, inputs[0])
if op.is_inplace
else nullcontext()
):
actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
ref(ref_inputs, **kwargs)
else:
expected = ref(ref_inputs, **kwargs)
if dtype in (torch.float16, torch.bfloat16) and TEST_WITH_ROCM:
self.assertEqual(
expected, actual, atol=1.0e-3, rtol=default_tolerances(dtype)[0]
)
else:
self.assertEqual(expected, actual)
@ops(filter(lambda op: op.supports_scalar_self_arg, foreach_binary_op_db))
@parametrize("is_fastpath", (True, False))
def test_binary_op_with_scalar_self_support(self, device, dtype, op, is_fastpath):
def clone(arg):
if isinstance(arg, (list, tuple)):
return [clone(a) for a in arg]
if torch.is_tensor(arg):
return arg.detach().clone().requires_grad_()
else:
return arg
scalar_self_arg_test_complete = False
for i, sample in enumerate(
op.sample_inputs(
device,
dtype,
noncontiguous=not is_fastpath,
allow_higher_dtype_scalars=True,
)
):
(rhs_arg,) = sample.args
kwargs = {} or sample.kwargs
alpha = kwargs.pop("alpha", None)
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
if isinstance(rhs_arg, Number) and not scalar_self_arg_test_complete:
scalar_self_arg_test_complete = True
self._binary_test(
dtype,
wrapped_op,
ref,
[rhs_arg, sample.input],
is_fastpath,
False,
alpha=alpha,
scalar_self_arg=True,
)
if op.supports_autograd and dtype == torch.float32:
transformed_sample = sample.transform(
get_transform_func(
len(sample.input), dtype, device, is_fastpath
)
)
tensors = transformed_sample.input
(rhs_arg,) = transformed_sample.args
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
sum(
wrapped_op(
[rhs_arg, tensors], is_cuda=False, expect_fastpath=False
)
).mean().backward()
sum(ref.func(ref_rhs_arg, t) for t in ref_tensors).mean().backward()
self.assertEqual(
[t.grad for t in tensors], [t.grad for t in ref_tensors]
)
@ops(foreach_pointwise_op_db)
@parametrize("is_fastpath", (True, False))
def test_pointwise_op_with_tensor_of_scalarlist_overload(
self, device, dtype, op, is_fastpath
):
for sample in op.sample_inputs(
device,
dtype,
noncontiguous=not is_fastpath,
allow_higher_dtype_scalars=True,
):
assert isinstance(sample.args, tuple)
assert len(sample.args) == 2
inputs = [sample.input, *sample.args]
kwargs = sample.kwargs.copy()
disable_fastpath = sample.disable_fastpath and is_fastpath
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
scalars = kwargs.pop("scalars", None)
if is_fastpath and scalars:
sample = sample.transform(
lambda t: t.detach().clone() if torch.is_tensor(t) else t
)
inputs = [sample.input, *sample.args]
tensor_values = torch.tensor(scalars)
# 1D Tensor of scalars
for is_inplace, op_, ref_ in (
(False, wrapped_op, ref),
(True, inplace_op, inplace_ref),
):
self._pointwise_test(
op_,
ref_,
inputs,
is_fastpath and not disable_fastpath,
is_inplace,
scalars=tensor_values,
**kwargs,
)
self._pointwise_test(
op_,
ref_,
inputs,
is_fastpath and not disable_fastpath,
is_inplace,
scalars=tensor_values[0],
custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.",
**kwargs,
)
if self.is_cuda:
self._pointwise_test(
op_,
ref_,
inputs,
is_fastpath and not disable_fastpath,
is_inplace,
scalars=tensor_values.cuda(),
custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.",
**kwargs,
)
self._pointwise_test(
op_,
ref_,
inputs,
is_fastpath and not disable_fastpath,
is_inplace,
scalars=tensor_values[:2],
custom_values_err=f"Expected length of scalars to match input of length {len(scalars)} but got 2 instead.",
**kwargs,
)
self._pointwise_test(
op_,
ref_,
inputs,
is_fastpath and not disable_fastpath,
is_inplace,
scalars=torch.tensor([[0, 1], [2, 3]])[:, 1],
custom_values_err="Expected scalars to be contiguous.",
**kwargs,
)
# Tests of implicit broadcasting
N = len(sample.input)
inputs = [
[
make_tensor(
(N, N),
device=device,
dtype=dtype,
noncontiguous=not is_fastpath,
)
for _ in range(N)
],
[
make_tensor(
(N - i, 1),
device=device,
dtype=dtype,
noncontiguous=not is_fastpath,
)
for i in range(N)
],
[
make_tensor(
(1, N - i),
device=device,
dtype=dtype,
noncontiguous=not is_fastpath,
)
for i in range(N)
],
]
self._pointwise_test(
wrapped_op,
ref,
inputs,
is_fastpath and disable_fastpath,
is_inplace=False,
scalars=scalars,
**kwargs,
)
self._pointwise_test(
inplace_op,
inplace_ref,
inputs,
is_fastpath and disable_fastpath,
is_inplace=True,
scalars=scalars,
**kwargs,
)
def _pointwise_test(
self,
op,
ref,
inputs,
is_fastpath,
is_inplace,
*,
scalars=None,
custom_values_err=None,
**kwargs,
):
ref_inputs = (
[[t.detach().clone() for t in inputs[0]], inputs[1], inputs[2]]
if is_inplace
else inputs
)
try:
with (
InplaceForeachVersionBumpCheck(self, inputs[0])
if is_inplace
else nullcontext()
):
actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
ref(ref_inputs, **kwargs)
else:
expected = ref(ref_inputs, **kwargs)
self.assertEqual(expected, actual)
if scalars is not None:
kwargs = kwargs.copy()
kwargs["scalars"] = scalars
try:
actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
except RuntimeError as e:
# Match with error messages from regular non-foreach reference if no
# custom error message was provided.
if custom_values_err is None:
with self.assertRaisesRegex(
type(e), re.escape(str(e).splitlines()[0])
):
ref(ref_inputs, **kwargs)
else:
self.assertEqual(re.escape(str(e)), re.escape(custom_values_err))
else:
expected = ref(ref_inputs, **kwargs)
self.assertEqual(expected, actual)
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
# TODO: enable empty list case
for tensors in [
[torch.randn([0], device=device, dtype=dtype)],
[torch.empty_strided((0, 1), (0, 0), dtype=dtype, device=device)],
]:
res = torch._foreach_add(tensors, 1)
self.assertEqual(res, tensors)
torch._foreach_add_(tensors, 1)
self.assertEqual(res, tensors)
# Regression test for https://github.com/pytorch/pytorch/issues/113156
torch._foreach_mul_(tensors, 1)
@onlyCUDA
@dtypes(torch.float32)
def test_foreach_check_stride_ignore_dims_of_one(self, device, dtype):
# default tensor stride is (9, 9, 3, 1).
tensor = torch.ones((2, 1, 3, 3), device=device, dtype=dtype)
strided_tensor = torch.ones(
(2, 1, 3, 3), device=device, dtype=dtype
).as_strided((2, 1, 3, 3), (9, 1, 3, 1))
left_inputs = [tensor, strided_tensor]
right_inputs = [strided_tensor, tensor]
compare_result = tensor + strided_tensor
foreach_add_check_ = ForeachFuncWrapper(torch._foreach_add)
out = foreach_add_check_(
(left_inputs, right_inputs), is_cuda=True, expect_fastpath=True
)
for res in out:
self.assertEqual(res, compare_result)
@ops(
filter(lambda op: op.supports_out, foreach_binary_op_db),
dtypes=OpDTypes.supported,
)
def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
foreach_op, ref = op.method_variant, op.ref
tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
if ref == torch.sub and dtype == torch.bool:
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
[ref(t, 1) for t in tensors]
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
foreach_op(tensors, 1)
return
expected = [ref(t, 1) for t in tensors]
res = foreach_op(tensors, 1)
self.assertEqual(res, expected)
@ops(
filter(lambda op: op.supports_out, foreach_binary_op_db),
allowed_dtypes=[torch.float],
)
def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
foreach_op = op.method_variant
tensors = [
torch.tensor([1.1], dtype=torch.float, device=device),
torch.tensor([1], dtype=torch.long, device=device),
]
runtime_error = None
try:
foreach_op(tensors, 1)
except RuntimeError as e:
runtime_error = e
self.assertIsNone(runtime_error)
@skipIfTorchDynamo("Different error msgs, TODO")
@ops(
filter(lambda op: op.supports_out, foreach_binary_op_db),
dtypes=OpDTypes.supported,
)
def test_binary_op_list_error_cases(self, device, dtype, op):
foreach_op, foreach_op_, ref, ref_ = (
op.method_variant,
op.inplace_variant,
op.ref,
op.ref_inplace,
)
tensors1 = []
tensors2 = []
ops_to_test = [foreach_op, foreach_op_]
# Empty lists
for fop in ops_to_test:
with self.assertRaisesRegex(
RuntimeError, "Tensor list must have at least one tensor."
):
fop(tensors1, tensors2)
# One empty list
tensors1.append(torch.tensor([1], device=device, dtype=dtype))
for fop in ops_to_test:
with self.assertRaisesRegex(
RuntimeError,
"Tensor list must have same number of elements as scalar list.",
):
fop(tensors1, tensors2)
# Lists have different amount of tensors
tensors2.append(torch.tensor([1], device=device))
tensors2.append(torch.tensor([1], device=device))
for fop in ops_to_test:
with self.assertRaisesRegex(
RuntimeError,
"Tensor lists must have the same number of tensors, got 1 and 2",
):
fop(tensors1, tensors2)
with self.assertRaisesRegex(
RuntimeError,
"Tensor lists must have the same number of tensors, got 2 and 1",
):
fop(tensors2, tensors1)
# Corresponding tensors with different sizes that aren't compatible with broadcast
# If sizes are different then foreach chooses slow path, thus error messages are expected
# to be the same as torch regular function.
tensors1 = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
tensors2 = [torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)]
if dtype == torch.bool and foreach_op == torch._foreach_sub:
for fop in ops_to_test:
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
fop(tensors1, tensors2)
return
with self.assertRaisesRegex(
RuntimeError,
r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1",
):
foreach_op(tensors1, tensors2)
with self.assertRaisesRegex(
RuntimeError,
r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1",
):
foreach_op_(tensors1, tensors2)
# different devices
if self.device_type == "cuda" and torch.cuda.device_count() > 1:
tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype)
tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype)
with self.assertRaisesRegex(
RuntimeError, "Expected all tensors to be on the same device"
):
foreach_op([tensor1], [tensor2])
if (
dtype in integral_types_and(torch.bool)
and foreach_op == torch._foreach_div
):
with self.assertRaisesRegex(RuntimeError, "result type"):
foreach_op_([tensor1], [tensor2])
else:
with self.assertRaisesRegex(
RuntimeError, "Expected all tensors to be on the same device"
):
foreach_op_([tensor1], [tensor2])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
@ops(
filter(lambda op: op.supports_out, foreach_binary_op_db),
dtypes=OpDTypes.supported,
)
def test_binary_op_list_slow_path(self, device, dtype, op):
foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op)
# 0-strides
tensor1 = make_tensor((10, 10), dtype=dtype, device=device)
tensor2 = make_tensor((1,), device=device, dtype=dtype).expand_as(tensor1)
inputs = ([tensor1], [tensor2])
self._binary_test(
dtype,
foreach_op,
native_op,
inputs,
is_fastpath=False,
is_inplace=False,
alpha=None,
scalar_self_arg=False,
)
self._binary_test(
dtype,
foreach_op_,
native_op_,
inputs,
is_fastpath=False,
is_inplace=True,
alpha=None,
scalar_self_arg=False,
)
# different strides
tensor1 = torch.zeros(10, 10, device=device, dtype=dtype)
tensor2 = torch.ones(10, 10, device=device, dtype=dtype)
inputs = ([tensor1], [tensor2.t()])
self._binary_test(
dtype,
foreach_op,
native_op,
inputs,
is_fastpath=False,
is_inplace=False,
alpha=None,
scalar_self_arg=False,
)
self._binary_test(
dtype,
foreach_op_,
native_op_,
inputs,
is_fastpath=False,
is_inplace=True,
alpha=None,
scalar_self_arg=False,
)
# non contiguous
tensor1 = make_tensor(
(5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True
)
tensor2 = make_tensor(
(5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True
)
self.assertFalse(tensor1.is_contiguous())
self.assertFalse(tensor2.is_contiguous())
inputs = ([tensor1], [tensor2])
self._binary_test(
dtype,
foreach_op,
native_op,
inputs,
is_fastpath=False,
is_inplace=False,
alpha=None,
scalar_self_arg=False,
)
self._binary_test(
dtype,
foreach_op_,
native_op_,
inputs,
is_fastpath=False,
is_inplace=True,
alpha=None,
scalar_self_arg=False,
)
# sliced tensor
tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype)
tensor2 = make_tensor((5, 2, 1, 3 * 7), device=device, dtype=dtype)[
:, :, :, ::7
]
inputs = ([tensor1], [tensor2])
self._binary_test(
dtype,
foreach_op,
native_op,
inputs,
is_fastpath=False,
is_inplace=False,
alpha=None,
scalar_self_arg=False,
)
self._binary_test(
dtype,
foreach_op_,
native_op_,
inputs,
is_fastpath=False,
is_inplace=True,
alpha=None,
scalar_self_arg=False,
)
@ops(
filter(lambda op: op.supports_out, foreach_binary_op_db),
dtypes=floating_types_and(torch.half, torch.bfloat16),
)
def test_binary_op_float_inf_nan(self, device, dtype, op):
inputs = (
[
torch.tensor([float("inf")], device=device, dtype=dtype),
torch.tensor([-float("inf")], device=device, dtype=dtype),
torch.tensor([float("nan")], device=device, dtype=dtype),
torch.tensor([float("nan")], device=device, dtype=dtype),
],
[
torch.tensor([-float("inf")], device=device, dtype=dtype),
torch.tensor([float("inf")], device=device, dtype=dtype),
torch.tensor([float("inf")], device=device, dtype=dtype),
torch.tensor([float("nan")], device=device, dtype=dtype),
],
)
op, ref, inplace_op, inplace_ref = self._get_funcs(op)
self._binary_test(
dtype, op, ref, inputs, True, False, alpha=None, scalar_self_arg=False
)
self._binary_test(
dtype,
inplace_op,
inplace_ref,
inputs,
True,
True,
alpha=None,
scalar_self_arg=False,
)
# note: Below three tests (postfixed with `_tensors_on_different_devices`)
# checks whether foreach works with lists of tensors on different devices
# but tensors of the same index are on the same device, e.g., ['cuda', 'cpu].
@onlyCUDA
@ops(foreach_unary_op_db)
def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
method, ref, inplace_method, ref_inplace = self._get_funcs(op)
# tensors: ['cuda', 'cpu]
tensors = next(
iter(
op.sample_inputs(
device,
dtype,
num_input_tensors=[2],
allow_higher_dtype_scalars=True,
)
)
).input
tensors[1] = tensors[1].to("cpu")
if not op.supports_out:
try:
actual = method((tensors,), False, False, zero_size=False)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), str(e).splitlines()[0]):
ref((tensors,))
else:
expected = ref((tensors,))
self.assertEqual(expected, actual)
try:
inplace_method((tensors,), False, False, zero_size=False)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), str(e).splitlines()[0]):
ref_inplace((tensors,))
else:
if not op.supports_out:
self.assertEqual(expected, tensors)
else:
self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)
@onlyCUDA
@ops(filter(lambda op: op.supports_out, foreach_binary_op_db))
def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
_cuda_tensors = next(
iter(
op.sample_inputs(
device,
dtype,
num_input_tensors=[2],
same_size=True,
allow_higher_dtype_scalars=True,
)
)
).input
_cpu_tensors = next(
iter(
op.sample_inputs(
"cpu",
dtype,
num_input_tensors=[2],
same_size=True,
allow_higher_dtype_scalars=True,
)
)
).input
tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors))
foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
native_op, native_op_ = op.ref, op.ref_inplace
try:
actual = foreach_op(tensors1, tensors2)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
[native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
else:
expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
self.assertEqual(expected, actual)
try:
foreach_op_(tensors1, tensors2)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
[native_op_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
else:
self.assertEqual(actual, tensors1)
@onlyCUDA
@ops(foreach_pointwise_op_db, allowed_dtypes=floating_types())
def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op):
# tensors1: ['cuda', 'cpu]
# tensors2: ['cuda', 'cpu]
# tensors3: ['cuda', 'cpu]
# first tensorlist is zero-size when float32
_cuda_tensors = list(
op.sample_inputs(
device,
dtype,
num_input_tensors=[3],
same_size=True,
allow_higher_dtype_scalars=True,
)
)[int(dtype == torch.float32)].input
_cpu_tensors = next(
iter(
op.sample_inputs(
"cpu",
dtype,
num_input_tensors=[3],
same_size=True,
allow_higher_dtype_scalars=True,
)
)
).input
tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors))
foreach_op, foreach_op_, native_op = (
op.method_variant,
op.inplace_variant,
op.ref,
)
actual = foreach_op(tensors1, tensors2, tensors3)
expected = [native_op(*_cuda_tensors), native_op(*_cpu_tensors)]
self.assertEqual(expected, actual)
# note(mkozuki): Limiting dtypes to FP32&FP64, we can safely run inplace ops.
foreach_op_(tensors1, tensors2, tensors3)
self.assertEqual(expected, tensors1)
# note: BFloat16 has the same number of exponent bits as FP32
# so if squared L2 norm overflows in BF16, then it also overflows in FP32.
@onlyCUDA
@ops(
[o for o in foreach_reduce_op_db if "norm" in o.name],
allowed_dtypes=(torch.half, torch.bfloat16),
)
def test_foreach_l2_large_value_input(self, device, dtype, op):
ord, N = 2, 10
max_value = torch.finfo(dtype).max
scaler = torch.tensor([max_value]).sqrt().to(device=device, dtype=dtype)
inputs = (
[
t * scaler
for t in next(
iter(
op.sample_inputs(
device,
dtype,
requries_grad=True,
num_input_tensors=[N],
low=1,
)
)
).input
][:-1],
)
# make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
self.assertTrue(scaler * scaler * N > max_value)
fn, ref_fn, *_ = self._get_funcs(op)
actual = fn(
inputs, is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False
)
expect = ref_fn(inputs, ord=ord)
if dtype == torch.float16:
# making sure the reference L2 norm values are in the range of FP16.
self.assertFalse(any(torch.isinf(e) for e in expect))
else:
self.assertTrue(
all(
inputs[0][i].numel() == 0 or torch.isinf(e)
for i, e in enumerate(expect)
)
)
self.assertEqual(expect, actual, equal_nan=False)
@onlyCUDA
@ops(foreach_reduce_op_db, allowed_dtypes=floating_types())
@parametrize("use_cuda_graph", (False, True))
@parametrize("w_empty", (False, True))
def test_big_num_tensors(self, device, dtype, op, use_cuda_graph, w_empty):
# foreach_max cannot handle empty tensors as max requires an identity
intersperse_empty_tensors = w_empty and op.name != "_foreach_max"
N = 600
indices_with_empty_tensors = (
set()
if not intersperse_empty_tensors
else {200, 300, 301, 400, 401, 402, 404, 598}
)
tensorlist = [
make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False)
if i not in indices_with_empty_tensors
else torch.empty(0, dtype=dtype, device=device)
for i in range(N)
]
fn, ref_fn, *_ = self._get_funcs(op)
import math
if op.name == "_foreach_norm":
ords = [1, 2]
if not intersperse_empty_tensors:
# inf norm over an empty tensor is not defined by vector norm as it expects an identity
ords.append(math.inf)
else:
ords = [None]
for ord in ords:
kwargs = {"ord": ord} if ord else {}
if not use_cuda_graph:
actual = fn(
inputs=[tensorlist],
is_cuda=True,
expect_fastpath=True,
zero_size=False,
**kwargs,
)
else:
# When using CUDA graphs and the tensor metadata doesn't fit in
# the static kernel argument space, multi_tensor_apply creates
# the launch arguments once, uses cudaUserObject_t to tie its
# lifetime to the graph, and reuses it throughout replays. This
# test verifies multi_tensor_apply's behavior in the scenario.
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
actual = fn.func(tensorlist, **kwargs)
g.replay()
expect = ref_fn(inputs=[tensorlist], **kwargs)
self.assertEqual(expect, actual, equal_nan=True)
@onlyCUDA
@ops(foreach_reduce_op_db)
@parametrize("w_empty", (False, True))
def test_foreach_reduce_large_input(self, device, dtype, op, w_empty):
# test inputs larger than kChunkSize (65536) * max_num_blocks (320)
N = 65536 * 320 * 2
disable_fastpath = False
kwargs = {}
if op.name == "_foreach_norm":
kwargs["ord"] = 2
disable_fastpath = dtype not in floating_types_and(
torch.half, torch.bfloat16
)
tensorlist = [
make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)
]
# foreach_max cannot handle empty tensors as max over empty is undefined
if w_empty and op.name != "_foreach_max":
tensorlist += [
torch.empty(0, dtype=dtype, device=device),
make_tensor((N,), dtype=dtype, device=device, noncontiguous=False),
]
inputs = (tensorlist,)
wrapped_op, ref, _, _ = self._get_funcs(op)
self.assertEqual(
ref(inputs, **kwargs),
wrapped_op(
inputs, self.is_cuda, not disable_fastpath, zero_size=False, **kwargs
),
)
@onlyCUDA
@ops(
foreach_unary_op_db
+ foreach_binary_op_db
+ foreach_pointwise_op_db
+ foreach_other_op_db,
dtypes=(torch.float,),
)
def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op):
inplace_op = op.inplace_variant
if inplace_op is None:
self.skipTest("no in-place op available")
sample = next(
iter(
op.sample_inputs(
dtype=dtype, device=device, num_input_tensors=[2], same_size=True
)
)
)
sample.input[0].requires_grad_(True)
with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
inplace_op(sample.input, *sample.args)
sample.input[1].requires_grad_(True)
with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
inplace_op(sample.input, *sample.args)
_tensors = [
t.detach().clone().requires_grad_(i == 0)
for i, t in enumerate(sample.input)
]
tensors = [t.clone() for t in _tensors]
inplace_op(tensors, *sample.args)
self.assertIsNotNone(tensors[0].grad_fn)
self.assertIsNone(tensors[1].grad_fn)
@onlyCUDA
@ops(
filter(
lambda op: op.supports_out,
foreach_unary_op_db
+ foreach_binary_op_db
+ foreach_pointwise_op_db
+ foreach_other_op_db,
),
dtypes=(torch.float,),
)
def test_outplace_with_invalid_grads(self, device, dtype, op):
func, *_ = self._get_funcs(op)
sample = next(
iter(
op.sample_inputs(
dtype=dtype,
device=device,
requires_grad=True,
num_input_tensors=[2],
same_size=True,
)
)
)
self.assertTrue(all(t.requires_grad for t in sample.input))
(out1, out2) = func(
[sample.input, *sample.args],
is_cuda=False,
expect_fastpath=False,
**sample.kwargs,
)
out1.backward(torch.ones_like(out1))
self.assertIsNotNone(sample.input[0].grad)
self.assertIsNone(sample.input[1].grad)
@ops(
filter(
lambda op: op.backward_requires_result,
foreach_unary_op_db
+ foreach_binary_op_db
+ foreach_pointwise_op_db
+ foreach_other_op_db,
),
dtypes=(torch.float32,),
)
def test_lifetime_of_grad_fn_when_result_is_saved(self, device, dtype, op):
def get_ref(func, sample):
class Foo:
pass
out = func(
(sample.input, *sample.args),
is_cuda=False,
expect_fastpath=False,
**sample.kwargs,
)
foo = Foo()
meta_dict = out[0].grad_fn.metadata
meta_dict[0] = foo
ref = weakref.ref(foo)
return out, ref
def _test(func, sample):
out, ref = get_ref(func, sample)
self.assertIsNotNone(ref())
del out
self.assertIsNone(ref())
func = self._get_funcs(op)[0]
for sample in op.sample_inputs(
device, dtype, requires_grad=True, num_input_tensors=[1]
):
for key in ("is_fastpath", "disable_fastpath"):
if key in sample.kwargs:
del sample.kwargs[key]
# note: `_foreach_pow.Scalar` and `_foreach_pow.ScalarList` don't depend on `result`
# see: https://github.com/pytorch/pytorch/blob/5403c777/tools/autograd/derivatives.yaml#L3048-L3049
if op.name == "_foreach_pow":
if (
isinstance(sample.args[0], list)
and isinstance(sample.args[0][0], Number)
) or (
isinstance(sample.args[0], Number)
and not isinstance(sample.args[0], float)
):
continue
if isinstance(sample.args[0], float):
new_args = (sample.input,)
sample.input = sample.args[0]
sample.args = new_args
_test(func, sample)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_tensors_grouping(self):
num_tensors_per_list = 10
num_devices = torch.cuda.device_count()
dtypes = (torch.float16, torch.float32, torch.float64)
list1 = [
torch.tensor(
i,
device=torch.device("cuda", random.randint(0, num_devices - 1)),
dtype=dtypes[random.randint(0, 2)],
)
for i in range(num_tensors_per_list)
]
list2 = [None for _ in list1]
list3 = [torch.rand_like(t) for t in list1]
nested_tensorlists = [list1, list2, list3]
grouped_tensors = torch.utils._foreach_utils._group_tensors_by_device_and_dtype(
nested_tensorlists, with_indices=True
)
num_tensors_seen = 0
for (device, dtype), ([l1, l2, l3], indices) in grouped_tensors.items():
for t in itertools.chain(l1, l3):
self.assertEqual(t.device, device)
self.assertEqual(t.dtype, dtype)
num_tensors_seen += 1
self.assertEqual(len(l1), len(l2))
self.assertTrue(all(p is None for p in l2))
for i, index in enumerate(indices):
self.assertEqual(l1[i], list1[index])
self.assertEqual(l2[i], list2[index])
self.assertEqual(l3[i], list3[index])
self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list)
@onlyCUDA
def test_0dim_tensor_overload_cpu_ok(self):
tensors = [torch.ones((), device="cuda", dtype=torch.float32) for _ in range(2)]
scalar_cpu_tensor = torch.tensor(4.0, device="cpu")
# For mul and div, the scalar is allowed to be on CPU too
actual = torch._foreach_mul(tensors, scalar_cpu_tensor)
self.assertEqual(actual, [t.mul(scalar_cpu_tensor) for t in tensors])
actual = torch._foreach_div(tensors, scalar_cpu_tensor)
self.assertEqual(actual, [t.div(scalar_cpu_tensor) for t in tensors])
@onlyCUDA
def test_div_reciprocal(self):
expect_m, expect_e = torch.frexp(
torch.div(torch.tensor(0.1, device="cuda"), 10.0)
)
actual_m, actual_e = torch.frexp(
torch._foreach_div([torch.tensor(0.1, device="cuda")], [10.0])[0]
)
self.assertEqual(expect_m, actual_m)
self.assertEqual(expect_e, actual_e)
@onlyCUDA
def test_0dim_tensor_overload_exception(self):
# check exceptions of fast path
tensors = [
make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)
]
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"):
torch._foreach_add(tensors, torch.tensor(1.0, device="cpu"), alpha=1.0)
tensors = [
make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda")
]
with self.assertRaisesRegex(
RuntimeError, "scalar tensor expected to be 0 dim but"
):
torch._foreach_mul(tensors, torch.tensor([1.0, 1.0], device="cuda"))
with self.assertRaisesRegex(
RuntimeError, "scalar tensor expected to be 0 dim but"
):
torch._foreach_add(tensors, torch.tensor([1.0, 1.0], device="cuda"))
@onlyCUDA
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op):
foreach_copy_ = op.inplace_variant
copy_ = op.ref_inplace
for non_blocking in (False, True):
for sample in op.sample_inputs(
device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
):
with torch.no_grad():
ref_input = [t.detach().clone() for t in sample.input]
foreach_copy_(sample.input, sample.args[0], non_blocking)
for t, s in zip(ref_input, sample.args[0]):
copy_(t, s, non_blocking)
self.assertEqual(sample.input, ref_input)
if torch.cuda.device_count() > 1:
device = torch.device("cuda", 1)
rhs_tensors = [t.to(device) for t in sample.args[0]]
foreach_copy_(sample.input, rhs_tensors, non_blocking)
for t, s in zip(ref_input, rhs_tensors):
copy_(t, s, non_blocking)
self.assertEqual(ref_input, sample.input)
@onlyCUDA
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
# check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_
foreach_copy_ = ForeachFuncWrapper(op.inplace_variant)
for sample in op.sample_inputs(
device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
):
for src_dtype in floating_types_and(torch.half, torch.bfloat16):
if src_dtype == dtype:
continue
self_tensors = [t.clone() for t in sample.input]
src_tensors = [t.to(src_dtype) for t in self_tensors]
out = foreach_copy_(
(self_tensors, src_tensors), is_cuda=True, expect_fastpath=True
)
ref_out = [
torch.empty_like(t).copy_(s)
for t, s in zip(self_tensors, src_tensors)
]
for t, ref_t in zip(out, ref_out):
self.assertTrue(torch.equal(t, ref_t))
@onlyCUDA
@largeTensorTest("40GB", device="cuda")
def test_foreach_copy_with_multi_dtypes_large_input(self):
# see https://github.com/pytorch/pytorch/issues/156261
self_tensor = torch.empty(2**31 + 1, device="cuda", dtype=torch.float32)
src_tensor = torch.ones(2**31 + 1, device="cuda", dtype=torch.bfloat16)
torch._foreach_copy_([self_tensor], [src_tensor])
ref_out = torch.empty_like(self_tensor).copy_(src_tensor)
self.assertEqual(self_tensor, ref_out)
@requires_cuda
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
def test_foreach_copy_with_different_device_inputs(self, device, dtype, op):
if dtype in (torch.complex128, torch.complex64):
self.skipTest("Complex dtype not supported")
# check foreach_copy when self and src tensorList have different device
foreach_copy = op.method_variant
copy_ = op.ref_inplace
def fn(self_tensor, src_tensor, non_blocking):
return foreach_copy(self_tensor, src_tensor, non_blocking)
fn = torch.compile(fn)
for non_blocking in (False,):
for sample in op.sample_inputs(
device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
):
with torch.no_grad():
ref_input = [t.detach().clone() for t in sample.input]
ref_input_cpu = [t.detach().clone().to("cpu") for t in sample.input]
rhs_tensors = [t.detach().clone().to("cpu") for t in sample.args[0]]
self_tensors = [t.detach().clone().to("cpu") for t in sample.input]
output1 = fn(sample.input, rhs_tensors, non_blocking)
for t, s in zip(ref_input, rhs_tensors):
copy_(t, s, non_blocking)
self.assertEqual(output1, ref_input)
output2 = fn(self_tensors, sample.args[0], non_blocking)
for t, s in zip(ref_input_cpu, sample.args[0]):
copy_(t, s, non_blocking)
self.assertEqual(output2, ref_input_cpu)
# Test reverse-mode & forward-mode AD if supported.
@onlyCUDA
@ops(
foreach_unary_op_db
+ foreach_binary_op_db
+ foreach_pointwise_op_db
+ foreach_reduce_op_db
+ foreach_other_op_db,
dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float64, torch.complex128),
)
@parametrize(
"inplace", (False, True), name_fn=lambda x: "inplace" if x else "outplace"
)
def test_autodiff(self, device, dtype, op, inplace):
if (not inplace) and not op.supports_out:
self.skipTest("out-of-place not implemented")
if inplace and op.has_no_in_place:
self.skipTest("in-place not implemented")
if not (
op.supports_autograd
or op.supports_inplace_autograd
or op.supports_forward_ad
):
self.skipTest("neither reverse mode nor forward mode supported")
# note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
if (
(not inplace)
and dtype == torch.float64
and op.name
in (
"_foreach_acos",
"_foreach_asin",
"_foreach_log10",
"_foreach_log1p",
"_foreach_log2",
"_foreach_log",
"_foreach_pow",
"_foreach_sqrt",
"_foreach_rsqrt",
)
):
value_range = {"low": 0.5, "high": 1.0}
else:
value_range = {}
for sample in op.sample_inputs(
device,
dtype,
requires_grad=True,
num_input_tensors=[5],
allow_higher_dtype_scalars=True,
**value_range,
):
# Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])`
if op.name == "_foreach_pow" and isinstance(sample.input, Number):
continue
func = None
if inplace:
# Call `clone` to avoid inplace modifications likewise
# `torch.testing._internal.common_utils.TestGradients._get_safe_inplace`
def inplace_func(*tensorlist):
kwargs = (
{"alpha": sample.kwargs["alpha"]}
if "alpha" in sample.kwargs
else {}
)
op.inplace_variant(
tuple(t.clone() for t in tensorlist), *sample.args, **kwargs
)
return tensorlist
func = inplace_func
else:
def outplace_func(*tensorlist):
kwargs = (
{"alpha": sample.kwargs["alpha"]}
if "alpha" in sample.kwargs
else {}
)
return op.method_variant(tensorlist, *sample.args, **kwargs)
func = outplace_func
working_sample, err_msg_pattern = check_autodiff_sample(
op, sample, dtype, inplace
)
def call_gradcheck():
gradcheck(
func,
sample.input,
raise_exception=True,
check_forward_ad=op.supports_forward_ad,
check_batched_forward_grad=False,
check_backward_ad=op.supports_autograd,
check_batched_grad=False,
)
if not working_sample:
if not err_msg_pattern:
# lhs of float64 and rhs of complex.
continue
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
call_gradcheck()
continue
call_gradcheck()
# Test per-tensor `grad_fn` behavior.
if inplace and op.supports_inplace_autograd:
# per-tensor `grad_fn` check.
hook_buffer = []
def get_grad_fn_hook(i):
def hook(grad_inputs, grad_outputs) -> None:
hook_buffer.append(i)
return hook
_inputs = [t.detach().clone().requires_grad_() for t in sample.input]
inputs = [t.clone() for t in _inputs]
kwargs = (
{"alpha": sample.kwargs["alpha"]}
if "alpha" in sample.kwargs
else {}
)
op.inplace_variant(inputs, *sample.args, **kwargs)
self.assertEqual(len({t.grad_fn for t in inputs}), len(inputs))
for i, t in enumerate(inputs):
t.grad_fn.register_hook(get_grad_fn_hook(i))
torch.autograd.grad(
inputs[0],
inputs=(_inputs[0],),
grad_outputs=(torch.rand_like(inputs[0]),),
retain_graph=True,
)
self.assertEqual(hook_buffer, [0])
hook_buffer.clear()
# tensors have different shapes.
sum_of_cloned_tensors = torch.cat([t.view(-1) for t in inputs]).sum()
grad_output = torch.rand_like(sum_of_cloned_tensors)
torch.autograd.grad(
sum_of_cloned_tensors,
inputs=tuple(_inputs),
grad_outputs=(grad_output,),
retain_graph=False,
)
self.assertEqual(hook_buffer, list(reversed(range(len(inputs)))))
# TODO(crcrpar): Hide this inside torch/testing/_internal.
# would end up adding another layer to `foreach_inputs_sample_func.__call__`
# so that we can use this function as something like the first argument of `filter` function.
# Even after moving this function to testing, I personally think it'd be better to check the error message.
def check_autodiff_sample(op, sample, dtype, is_inplace):
if op.name == "_foreach_abs" and is_inplace and dtype == torch.complex128:
return False, "In-place abs is not supported for complex tensors."
if op.name == "_foreach_sub" and (
(
isinstance(sample.args[-1], list)
and any(isinstance(a, bool) for a in sample.args[-1])
)
or isinstance(sample.args[-1], bool)
):
return False, _BOOL_SUB_ERR_MSG
if op.name == "_foreach_norm" and (not is_inplace):
return (
False,
"Trying to set a forward gradient that has a different size than that of the original Tensor, "
"this is not supported. Tensor is of size [] while the given forward gradient is of size [1",
)
rhs_arg_has_complex_number = sample.args and (
(
isinstance(sample.args[-1], list)
and any(isinstance(a, complex) for a in sample.args[-1])
)
or (isinstance(sample.args[-1], complex))
)
if rhs_arg_has_complex_number and dtype == torch.float64:
if op.name == "_foreach_lerp":
return False, "value cannot be converted to type double without overflow"
if op.name in (
"_foreach_clamp_max",
"_foreach_clamp_min",
"_foreach_maximum",
"_foreach_minimum",
):
return False, "clamp is not supported for complex types"
if op.name == "_foreach_lerp" and is_inplace:
return False, "value cannot be converted to type double without overflow"
if not is_inplace:
return False, ""
elif op.name in (
"_foreach_add",
"_foreach_sub",
"_foreach_mul",
"_foreach_div",
"_foreach_pow",
):
return (
False,
"result type ComplexDouble can't be cast to the desired output type Double",
)
return True, ""
instantiate_device_type_tests(TestForeach, globals())
if __name__ == "__main__":
run_tests()