mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-27 00:54:52 +08:00 
			
		
		
		
	https://github.com/pytorch/pytorch/issues/145570 breaking https://github.com/pytorch/pytorch/pull/140793/ into eager and inductor benchmarks to unblock Pull Request resolved: https://github.com/pytorch/pytorch/pull/148602 Approved by: https://github.com/atalman, https://github.com/malfet Co-authored-by: atalman <atalman@fb.com>
		
			
				
	
	
		
			1575 lines
		
	
	
		
			58 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1575 lines
		
	
	
		
			58 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: mta"]
 | |
| # ruff: noqa: F841
 | |
| import itertools
 | |
| import os
 | |
| import random
 | |
| import re
 | |
| import unittest
 | |
| import weakref
 | |
| from contextlib import nullcontext
 | |
| from numbers import Number
 | |
| 
 | |
| import torch
 | |
| from torch.testing import make_tensor
 | |
| from torch.testing._comparison import default_tolerances
 | |
| from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_MULTIGPU
 | |
| from torch.testing._internal.common_device_type import (
 | |
|     dtypes,
 | |
|     instantiate_device_type_tests,
 | |
|     onlyCUDA,
 | |
|     OpDTypes,
 | |
|     ops,
 | |
| )
 | |
| from torch.testing._internal.common_dtype import (
 | |
|     all_types_and_complex_and,
 | |
|     floating_types,
 | |
|     floating_types_and,
 | |
|     integral_types_and,
 | |
| )
 | |
| from torch.testing._internal.common_methods_invocations import (
 | |
|     foreach_binary_op_db,
 | |
|     foreach_other_op_db,
 | |
|     foreach_pointwise_op_db,
 | |
|     foreach_reduce_op_db,
 | |
|     foreach_unary_op_db,
 | |
| )
 | |
| from torch.testing._internal.common_utils import (
 | |
|     gradcheck,
 | |
|     parametrize,
 | |
|     run_tests,
 | |
|     skipIfRocmVersionLessThan,
 | |
|     skipIfTorchDynamo,
 | |
|     TEST_WITH_ROCM,
 | |
|     TestCase,
 | |
| )
 | |
| 
 | |
| 
 | |
| _BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator"
 | |
| 
 | |
| 
 | |
| class RegularFuncWrapper:
 | |
|     def __init__(self, func):
 | |
|         self.func = func
 | |
| 
 | |
|     def __call__(self, inputs, scalars=None, **kwargs):
 | |
|         if scalars is not None:
 | |
|             assert len(inputs) == 3
 | |
|             # We need to distribute each scalar to the regular func and it needs
 | |
|             # special consideration as it is a keyword only argument to the
 | |
|             # regular func. (Strangely, it is not a keyword only argument to the
 | |
|             # foreach func)
 | |
|             return [
 | |
|                 self.func(*i, value=scalars[idx], **kwargs)
 | |
|                 for idx, i in enumerate(zip(*inputs))
 | |
|             ]
 | |
|         if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)):
 | |
|             # binary op with tensorlist and scalar.
 | |
|             inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
 | |
|         return [self.func(*i, **kwargs) for i in zip(*inputs)]
 | |
| 
 | |
| 
 | |
| class ForeachFuncWrapper:
 | |
|     def __init__(self, func):
 | |
|         self.func = func
 | |
|         # Some foreach functions don't have in-place implementations.
 | |
|         self.is_inplace = False if func is None else func.__name__.endswith("_")
 | |
| 
 | |
|     def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs):
 | |
|         actual = None
 | |
|         zero_size = kwargs.pop("zero_size", False)
 | |
|         if (
 | |
|             is_cuda
 | |
|             and torch.autograd.kineto_available()
 | |
|             and torch.profiler.ProfilerActivity.CUDA
 | |
|             in torch.profiler.supported_activities()
 | |
|         ):
 | |
|             with torch.profiler.profile() as p:
 | |
|                 actual = self.func(*inputs, **kwargs)
 | |
|             keys = tuple([e.key for e in p.key_averages()])
 | |
|             mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
 | |
|             assert mta_called == (expect_fastpath and (not zero_size)), (
 | |
|                 f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}"
 | |
|             )
 | |
|         else:
 | |
|             actual = self.func(*inputs, **kwargs)
 | |
|         if self.is_inplace:
 | |
|             assert id(inputs[0]) == id(actual)
 | |
|         return actual
 | |
| 
 | |
| 
 | |
| class InplaceForeachVersionBumpCheck:
 | |
|     def __init__(
 | |
|         self,
 | |
|         testcase: TestCase,
 | |
|         tensorlist: "List[torch.Tensor]",  # noqa: F821
 | |
|     ) -> None:
 | |
|         self._testcase = testcase
 | |
|         self._tensorlist = tensorlist
 | |
|         self._orig_version_counts = [t._version for t in tensorlist]
 | |
| 
 | |
|     def __enter__(self):
 | |
|         pass
 | |
| 
 | |
|     def __exit__(self, exc_type, exc_value, traceback):
 | |
|         # note(crcrpar): some methods e.g. `_binary_test` could call the given inplace function multiple times
 | |
|         self._testcase.assertGreaterEqual(
 | |
|             [t._version for t in self._tensorlist], self._orig_version_counts
 | |
|         )
 | |
| 
 | |
| 
 | |
| def get_transform_func(num_tensors, dtype, device, is_fastpath):
 | |
|     def transform(t):
 | |
|         if not torch.is_tensor(t):
 | |
|             return t
 | |
|         if torch.is_tensor(t) and t.ndim == 0:
 | |
|             return t
 | |
|         return make_tensor(
 | |
|             (num_tensors, num_tensors),
 | |
|             dtype=dtype,
 | |
|             device=device,
 | |
|             requires_grad=True,
 | |
|             noncontiguous=not is_fastpath,
 | |
|         )
 | |
| 
 | |
|     return transform
 | |
| 
 | |
| 
 | |
| # note(crcrpar): `zero_size` is `False` unless (dtype, device) == (torch.float32, "cuda")
 | |
| # as the pair would go through `multi_tensor_apply_kernel` if inputs are not zero size.
 | |
| @unittest.mock.patch.dict(os.environ, {"KINETO_LOG_LEVEL": "5"})
 | |
| class TestForeach(TestCase):
 | |
|     @property
 | |
|     def is_cuda(self):
 | |
|         return self.device_type == "cuda"
 | |
| 
 | |
|     def _get_funcs(self, op):
 | |
|         return (
 | |
|             ForeachFuncWrapper(op.method_variant),
 | |
|             RegularFuncWrapper(op.ref),
 | |
|             ForeachFuncWrapper(op.inplace_variant),
 | |
|             RegularFuncWrapper(op.ref_inplace),
 | |
|         )
 | |
| 
 | |
|     # note(crcrpar): Make sure 0-size tensors are appropriately ignored by `multi_tensor_apply`
 | |
|     # which is originally reported in https://github.com/pytorch/pytorch/issues/94865.
 | |
|     # rel:
 | |
|     #   - https://github.com/pytorch/pytorch/pull/94655
 | |
|     #   - https://github.com/pytorch/pytorch/issues/100701
 | |
|     #   - https://github.com/pytorch/pytorch/pull/100811
 | |
|     @onlyCUDA
 | |
|     @ops(
 | |
|         foreach_unary_op_db
 | |
|         + foreach_binary_op_db
 | |
|         + foreach_pointwise_op_db
 | |
|         + foreach_reduce_op_db
 | |
|         + foreach_other_op_db,
 | |
|         dtypes=(torch.float32,),
 | |
|     )
 | |
|     def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op):
 | |
|         wrapped_op, _, inplace_op, _ = self._get_funcs(op)
 | |
| 
 | |
|         for sample in op.sample_zero_size_inputs(device, dtype):
 | |
|             if op.method_variant is not None:
 | |
|                 wrapped_op(
 | |
|                     (sample.input, *sample.args),
 | |
|                     is_cuda=self.is_cuda,
 | |
|                     expect_fastpath=True,
 | |
|                     zero_size=True,
 | |
|                 )
 | |
| 
 | |
|             if op.inplace_variant is not None:
 | |
|                 with InplaceForeachVersionBumpCheck(self, sample.input):
 | |
|                     inplace_op(
 | |
|                         (sample.input, *sample.args),
 | |
|                         is_cuda=self.is_cuda,
 | |
|                         expect_fastpath=True,
 | |
|                         zero_size=True,
 | |
|                     )
 | |
| 
 | |
|     @skipIfRocmVersionLessThan((6, 0))
 | |
|     @ops(
 | |
|         foreach_unary_op_db
 | |
|         + foreach_binary_op_db
 | |
|         + foreach_pointwise_op_db
 | |
|         + foreach_reduce_op_db
 | |
|         + foreach_other_op_db,
 | |
|     )
 | |
|     @parametrize(
 | |
|         "noncontiguous,inplace",
 | |
|         [(False, False), (False, True), (True, False), (True, True)],
 | |
|         name_fn=lambda x, y: "{}_{}".format(
 | |
|             "fastpath" if not x else "slowpath", "inplace" if y else "outplace"
 | |
|         ),
 | |
|     )
 | |
|     def test_parity(self, device, dtype, op, noncontiguous, inplace):
 | |
|         if inplace:
 | |
|             _, _, func, ref = self._get_funcs(op)
 | |
|         else:
 | |
|             func, ref, _, _ = self._get_funcs(op)
 | |
|         for sample in op.sample_inputs(
 | |
|             device, dtype, noncontiguous=noncontiguous, allow_higher_dtype_scalars=True
 | |
|         ):
 | |
|             ref_kwargs = sample.kwargs
 | |
|             # div promotes ints to floats, so we cannot go on the fastpath there
 | |
|             div_slowpath = (
 | |
|                 dtype in integral_types_and(torch.bool) and op.name == "_foreach_div"
 | |
|             )
 | |
|             expect_fastpath = not (
 | |
|                 noncontiguous or sample.disable_fastpath or div_slowpath
 | |
|             )
 | |
|             ref_input, ctxmgr = sample.input, nullcontext()
 | |
|             if inplace:
 | |
|                 with torch.no_grad():
 | |
|                     ref_input = [t.detach().clone() for t in sample.input]
 | |
|                 ctxmgr = InplaceForeachVersionBumpCheck(self, sample.input)
 | |
|             try:
 | |
|                 with ctxmgr:
 | |
|                     actual = func(
 | |
|                         [sample.input, *sample.args],
 | |
|                         self.is_cuda,
 | |
|                         expect_fastpath,
 | |
|                         **sample.kwargs,
 | |
|                     )
 | |
|             except Exception as e:
 | |
|                 with self.assertRaises(type(e)):
 | |
|                     ref([ref_input, *sample.ref_args], **ref_kwargs)
 | |
|             else:
 | |
|                 expected = ref([ref_input, *sample.ref_args], **ref_kwargs)
 | |
|                 self.assertEqual(expected, actual)
 | |
| 
 | |
|     def _binary_test(
 | |
|         self,
 | |
|         dtype,
 | |
|         op,
 | |
|         ref,
 | |
|         inputs,
 | |
|         is_fastpath,
 | |
|         is_inplace,
 | |
|         *,
 | |
|         alpha,
 | |
|         scalar_self_arg: bool,
 | |
|     ):
 | |
|         ref_inputs = (
 | |
|             [[t.detach().clone() for t in inputs[0]], inputs[1]]
 | |
|             if is_inplace
 | |
|             else inputs
 | |
|         )
 | |
|         try:
 | |
|             with (
 | |
|                 InplaceForeachVersionBumpCheck(self, inputs[0])
 | |
|                 if op.is_inplace
 | |
|                 else nullcontext()
 | |
|             ):
 | |
|                 actual = op(inputs, self.is_cuda, is_fastpath)
 | |
|         except RuntimeError as e:
 | |
|             with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
 | |
|                 if not scalar_self_arg:
 | |
|                     ref(ref_inputs)
 | |
|                 else:
 | |
|                     [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
 | |
|         else:
 | |
|             expected = (
 | |
|                 ref(ref_inputs)
 | |
|                 if not scalar_self_arg
 | |
|                 else [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
 | |
|             )
 | |
|             self.assertEqual(actual, expected)
 | |
|         if alpha is not None and not scalar_self_arg:
 | |
|             kwargs = {"alpha": alpha}
 | |
|             ref_inputs = inputs
 | |
|             try:
 | |
|                 op_kwargs = {}
 | |
|                 op_kwargs.update(kwargs)
 | |
|                 with (
 | |
|                     InplaceForeachVersionBumpCheck(self, inputs[0])
 | |
|                     if op.is_inplace
 | |
|                     else nullcontext()
 | |
|                 ):
 | |
|                     actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs)
 | |
|             except RuntimeError as e:
 | |
|                 with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
 | |
|                     ref(ref_inputs, **kwargs)
 | |
|             else:
 | |
|                 expected = ref(ref_inputs, **kwargs)
 | |
|                 if dtype in (torch.float16, torch.bfloat16) and TEST_WITH_ROCM:
 | |
|                     self.assertEqual(
 | |
|                         expected, actual, atol=1.0e-3, rtol=default_tolerances(dtype)[0]
 | |
|                     )
 | |
|                 else:
 | |
|                     self.assertEqual(expected, actual)
 | |
| 
 | |
|     @ops(filter(lambda op: op.supports_scalar_self_arg, foreach_binary_op_db))
 | |
|     @parametrize("is_fastpath", (True, False))
 | |
|     def test_binary_op_with_scalar_self_support(self, device, dtype, op, is_fastpath):
 | |
|         def clone(arg):
 | |
|             if isinstance(arg, (list, tuple)):
 | |
|                 return [clone(a) for a in arg]
 | |
|             if torch.is_tensor(arg):
 | |
|                 return arg.detach().clone().requires_grad_()
 | |
|             else:
 | |
|                 return arg
 | |
| 
 | |
|         scalar_self_arg_test_complete = False
 | |
|         for i, sample in enumerate(
 | |
|             op.sample_inputs(
 | |
|                 device,
 | |
|                 dtype,
 | |
|                 noncontiguous=not is_fastpath,
 | |
|                 allow_higher_dtype_scalars=True,
 | |
|             )
 | |
|         ):
 | |
|             (rhs_arg,) = sample.args
 | |
|             kwargs = {} or sample.kwargs
 | |
|             alpha = kwargs.pop("alpha", None)
 | |
|             wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
 | |
|             if isinstance(rhs_arg, Number) and not scalar_self_arg_test_complete:
 | |
|                 scalar_self_arg_test_complete = True
 | |
|                 self._binary_test(
 | |
|                     dtype,
 | |
|                     wrapped_op,
 | |
|                     ref,
 | |
|                     [rhs_arg, sample.input],
 | |
|                     is_fastpath,
 | |
|                     False,
 | |
|                     alpha=alpha,
 | |
|                     scalar_self_arg=True,
 | |
|                 )
 | |
|                 if op.supports_autograd and dtype == torch.float32:
 | |
|                     transformed_sample = sample.transform(
 | |
|                         get_transform_func(
 | |
|                             len(sample.input), dtype, device, is_fastpath
 | |
|                         )
 | |
|                     )
 | |
|                     tensors = transformed_sample.input
 | |
|                     (rhs_arg,) = transformed_sample.args
 | |
|                     ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
 | |
|                     sum(
 | |
|                         wrapped_op(
 | |
|                             [rhs_arg, tensors], is_cuda=False, expect_fastpath=False
 | |
|                         )
 | |
|                     ).mean().backward()
 | |
|                     sum(ref.func(ref_rhs_arg, t) for t in ref_tensors).mean().backward()
 | |
|                     self.assertEqual(
 | |
|                         [t.grad for t in tensors], [t.grad for t in ref_tensors]
 | |
|                     )
 | |
| 
 | |
|     @ops(foreach_pointwise_op_db)
 | |
|     @parametrize("is_fastpath", (True, False))
 | |
|     # TODO: Remove skip CUDA 12.6 once resolved: https://github.com/pytorch/pytorch/issues/148681
 | |
|     @unittest.skipIf(_get_torch_cuda_version() >= (12, 6), "Failure on CUDA 12.6")
 | |
|     def test_pointwise_op_with_tensor_of_scalarlist_overload(
 | |
|         self, device, dtype, op, is_fastpath
 | |
|     ):
 | |
|         for sample in op.sample_inputs(
 | |
|             device,
 | |
|             dtype,
 | |
|             noncontiguous=not is_fastpath,
 | |
|             allow_higher_dtype_scalars=True,
 | |
|         ):
 | |
|             assert isinstance(sample.args, tuple)
 | |
|             assert len(sample.args) == 2
 | |
|             inputs = [sample.input, *sample.args]
 | |
|             kwargs = sample.kwargs.copy()
 | |
|             disable_fastpath = sample.disable_fastpath and is_fastpath
 | |
|             wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
 | |
|             scalars = kwargs.pop("scalars", None)
 | |
| 
 | |
|             if is_fastpath and scalars:
 | |
|                 sample = sample.transform(
 | |
|                     lambda t: t.detach().clone() if torch.is_tensor(t) else t
 | |
|                 )
 | |
|                 inputs = [sample.input, *sample.args]
 | |
|                 tensor_values = torch.tensor(scalars)
 | |
|                 # 1D Tensor of scalars
 | |
|                 for is_inplace, op_, ref_ in (
 | |
|                     (False, wrapped_op, ref),
 | |
|                     (True, inplace_op, inplace_ref),
 | |
|                 ):
 | |
|                     self._pointwise_test(
 | |
|                         op_,
 | |
|                         ref_,
 | |
|                         inputs,
 | |
|                         is_fastpath and not disable_fastpath,
 | |
|                         is_inplace,
 | |
|                         scalars=tensor_values,
 | |
|                         **kwargs,
 | |
|                     )
 | |
|                     self._pointwise_test(
 | |
|                         op_,
 | |
|                         ref_,
 | |
|                         inputs,
 | |
|                         is_fastpath and not disable_fastpath,
 | |
|                         is_inplace,
 | |
|                         scalars=tensor_values[0],
 | |
|                         custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.",
 | |
|                         **kwargs,
 | |
|                     )
 | |
|                     if self.is_cuda:
 | |
|                         self._pointwise_test(
 | |
|                             op_,
 | |
|                             ref_,
 | |
|                             inputs,
 | |
|                             is_fastpath and not disable_fastpath,
 | |
|                             is_inplace,
 | |
|                             scalars=tensor_values.cuda(),
 | |
|                             custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.",
 | |
|                             **kwargs,
 | |
|                         )
 | |
|                     self._pointwise_test(
 | |
|                         op_,
 | |
|                         ref_,
 | |
|                         inputs,
 | |
|                         is_fastpath and not disable_fastpath,
 | |
|                         is_inplace,
 | |
|                         scalars=tensor_values[:2],
 | |
|                         custom_values_err=f"Expected length of scalars to match input of length {len(scalars)} but got 2 instead.",
 | |
|                         **kwargs,
 | |
|                     )
 | |
|                     self._pointwise_test(
 | |
|                         op_,
 | |
|                         ref_,
 | |
|                         inputs,
 | |
|                         is_fastpath and not disable_fastpath,
 | |
|                         is_inplace,
 | |
|                         scalars=torch.tensor([[0, 1], [2, 3]])[:, 1],
 | |
|                         custom_values_err="Expected scalars to be contiguous.",
 | |
|                         **kwargs,
 | |
|                     )
 | |
| 
 | |
|             # Tests of implicit broadcasting
 | |
|             N = len(sample.input)
 | |
|             inputs = [
 | |
|                 [
 | |
|                     make_tensor(
 | |
|                         (N, N),
 | |
|                         device=device,
 | |
|                         dtype=dtype,
 | |
|                         noncontiguous=not is_fastpath,
 | |
|                     )
 | |
|                     for _ in range(N)
 | |
|                 ],
 | |
|                 [
 | |
|                     make_tensor(
 | |
|                         (N - i, 1),
 | |
|                         device=device,
 | |
|                         dtype=dtype,
 | |
|                         noncontiguous=not is_fastpath,
 | |
|                     )
 | |
|                     for i in range(N)
 | |
|                 ],
 | |
|                 [
 | |
|                     make_tensor(
 | |
|                         (1, N - i),
 | |
|                         device=device,
 | |
|                         dtype=dtype,
 | |
|                         noncontiguous=not is_fastpath,
 | |
|                     )
 | |
|                     for i in range(N)
 | |
|                 ],
 | |
|             ]
 | |
|             self._pointwise_test(
 | |
|                 wrapped_op,
 | |
|                 ref,
 | |
|                 inputs,
 | |
|                 is_fastpath and disable_fastpath,
 | |
|                 is_inplace=False,
 | |
|                 scalars=scalars,
 | |
|                 **kwargs,
 | |
|             )
 | |
|             self._pointwise_test(
 | |
|                 inplace_op,
 | |
|                 inplace_ref,
 | |
|                 inputs,
 | |
|                 is_fastpath and disable_fastpath,
 | |
|                 is_inplace=True,
 | |
|                 scalars=scalars,
 | |
|                 **kwargs,
 | |
|             )
 | |
| 
 | |
|     def _pointwise_test(
 | |
|         self,
 | |
|         op,
 | |
|         ref,
 | |
|         inputs,
 | |
|         is_fastpath,
 | |
|         is_inplace,
 | |
|         *,
 | |
|         scalars=None,
 | |
|         custom_values_err=None,
 | |
|         **kwargs,
 | |
|     ):
 | |
|         ref_inputs = (
 | |
|             [[t.detach().clone() for t in inputs[0]], inputs[1], inputs[2]]
 | |
|             if is_inplace
 | |
|             else inputs
 | |
|         )
 | |
|         try:
 | |
|             with (
 | |
|                 InplaceForeachVersionBumpCheck(self, inputs[0])
 | |
|                 if is_inplace
 | |
|                 else nullcontext()
 | |
|             ):
 | |
|                 actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
 | |
|         except RuntimeError as e:
 | |
|             with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
 | |
|                 ref(ref_inputs, **kwargs)
 | |
|         else:
 | |
|             expected = ref(ref_inputs, **kwargs)
 | |
|             self.assertEqual(expected, actual)
 | |
|         if scalars is not None:
 | |
|             kwargs = kwargs.copy()
 | |
|             kwargs["scalars"] = scalars
 | |
|             try:
 | |
|                 actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
 | |
|             except RuntimeError as e:
 | |
|                 # Match with error messages from regular non-foreach reference if no
 | |
|                 # custom error message was provided.
 | |
|                 if custom_values_err is None:
 | |
|                     with self.assertRaisesRegex(
 | |
|                         type(e), re.escape(str(e).splitlines()[0])
 | |
|                     ):
 | |
|                         ref(ref_inputs, **kwargs)
 | |
|                 else:
 | |
|                     self.assertEqual(re.escape(str(e)), re.escape(custom_values_err))
 | |
|             else:
 | |
|                 expected = ref(ref_inputs, **kwargs)
 | |
|                 self.assertEqual(expected, actual)
 | |
| 
 | |
|     @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
 | |
|     def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
 | |
|         # TODO: enable empty list case
 | |
|         for tensors in [
 | |
|             [torch.randn([0], device=device, dtype=dtype)],
 | |
|             [torch.empty_strided((0, 1), (0, 0), dtype=dtype, device=device)],
 | |
|         ]:
 | |
|             res = torch._foreach_add(tensors, 1)
 | |
|             self.assertEqual(res, tensors)
 | |
| 
 | |
|             torch._foreach_add_(tensors, 1)
 | |
|             self.assertEqual(res, tensors)
 | |
| 
 | |
|             # Regression test for https://github.com/pytorch/pytorch/issues/113156
 | |
|             torch._foreach_mul_(tensors, 1)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     @dtypes(torch.float32)
 | |
|     def test_foreach_check_stride_ignore_dims_of_one(self, device, dtype):
 | |
|         # default tensor stride is (9, 9, 3, 1).
 | |
|         tensor = torch.ones((2, 1, 3, 3), device=device, dtype=dtype)
 | |
|         strided_tensor = torch.ones(
 | |
|             (2, 1, 3, 3), device=device, dtype=dtype
 | |
|         ).as_strided((2, 1, 3, 3), (9, 1, 3, 1))
 | |
|         left_inputs = [tensor, strided_tensor]
 | |
|         right_inputs = [strided_tensor, tensor]
 | |
|         compare_result = tensor + strided_tensor
 | |
|         foreach_add_check_ = ForeachFuncWrapper(torch._foreach_add)
 | |
|         out = foreach_add_check_(
 | |
|             (left_inputs, right_inputs), is_cuda=True, expect_fastpath=True
 | |
|         )
 | |
|         for res in out:
 | |
|             self.assertEqual(res, compare_result)
 | |
| 
 | |
|     @ops(
 | |
|         filter(lambda op: op.supports_out, foreach_binary_op_db),
 | |
|         dtypes=OpDTypes.supported,
 | |
|     )
 | |
|     def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
 | |
|         foreach_op, ref = op.method_variant, op.ref
 | |
|         tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
 | |
| 
 | |
|         if ref == torch.sub and dtype == torch.bool:
 | |
|             with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
 | |
|                 [ref(t, 1) for t in tensors]
 | |
|             with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
 | |
|                 foreach_op(tensors, 1)
 | |
|             return
 | |
| 
 | |
|         expected = [ref(t, 1) for t in tensors]
 | |
|         res = foreach_op(tensors, 1)
 | |
|         self.assertEqual(res, expected)
 | |
| 
 | |
|     @ops(
 | |
|         filter(lambda op: op.supports_out, foreach_binary_op_db),
 | |
|         allowed_dtypes=[torch.float],
 | |
|     )
 | |
|     def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
 | |
|         foreach_op = op.method_variant
 | |
|         tensors = [
 | |
|             torch.tensor([1.1], dtype=torch.float, device=device),
 | |
|             torch.tensor([1], dtype=torch.long, device=device),
 | |
|         ]
 | |
|         runtime_error = None
 | |
|         try:
 | |
|             foreach_op(tensors, 1)
 | |
|         except RuntimeError as e:
 | |
|             runtime_error = e
 | |
|         self.assertIsNone(runtime_error)
 | |
| 
 | |
|     @skipIfTorchDynamo("Different error msgs, TODO")
 | |
|     @ops(
 | |
|         filter(lambda op: op.supports_out, foreach_binary_op_db),
 | |
|         dtypes=OpDTypes.supported,
 | |
|     )
 | |
|     def test_binary_op_list_error_cases(self, device, dtype, op):
 | |
|         foreach_op, foreach_op_, ref, ref_ = (
 | |
|             op.method_variant,
 | |
|             op.inplace_variant,
 | |
|             op.ref,
 | |
|             op.ref_inplace,
 | |
|         )
 | |
|         tensors1 = []
 | |
|         tensors2 = []
 | |
|         ops_to_test = [foreach_op, foreach_op_]
 | |
| 
 | |
|         # Empty lists
 | |
|         for fop in ops_to_test:
 | |
|             with self.assertRaisesRegex(
 | |
|                 RuntimeError, "Tensor list must have at least one tensor."
 | |
|             ):
 | |
|                 fop(tensors1, tensors2)
 | |
| 
 | |
|         # One empty list
 | |
|         tensors1.append(torch.tensor([1], device=device, dtype=dtype))
 | |
|         for fop in ops_to_test:
 | |
|             with self.assertRaisesRegex(
 | |
|                 RuntimeError,
 | |
|                 "Tensor list must have same number of elements as scalar list.",
 | |
|             ):
 | |
|                 fop(tensors1, tensors2)
 | |
| 
 | |
|         # Lists have different amount of tensors
 | |
|         tensors2.append(torch.tensor([1], device=device))
 | |
|         tensors2.append(torch.tensor([1], device=device))
 | |
|         for fop in ops_to_test:
 | |
|             with self.assertRaisesRegex(
 | |
|                 RuntimeError,
 | |
|                 "Tensor lists must have the same number of tensors, got 1 and 2",
 | |
|             ):
 | |
|                 fop(tensors1, tensors2)
 | |
|             with self.assertRaisesRegex(
 | |
|                 RuntimeError,
 | |
|                 "Tensor lists must have the same number of tensors, got 2 and 1",
 | |
|             ):
 | |
|                 fop(tensors2, tensors1)
 | |
| 
 | |
|         # Corresponding tensors with different sizes that aren't compatible with broadcast
 | |
|         # If sizes are different then foreach chooses slow path, thus error messages are expected
 | |
|         # to be the same as torch regular function.
 | |
|         tensors1 = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
 | |
|         tensors2 = [torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)]
 | |
| 
 | |
|         if dtype == torch.bool and foreach_op == torch._foreach_sub:
 | |
|             for fop in ops_to_test:
 | |
|                 with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
 | |
|                     fop(tensors1, tensors2)
 | |
|             return
 | |
|         with self.assertRaisesRegex(
 | |
|             RuntimeError,
 | |
|             r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1",
 | |
|         ):
 | |
|             foreach_op(tensors1, tensors2)
 | |
|         with self.assertRaisesRegex(
 | |
|             RuntimeError,
 | |
|             r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1",
 | |
|         ):
 | |
|             foreach_op_(tensors1, tensors2)
 | |
| 
 | |
|         # different devices
 | |
|         if self.device_type == "cuda" and torch.cuda.device_count() > 1:
 | |
|             tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype)
 | |
|             tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype)
 | |
|             with self.assertRaisesRegex(
 | |
|                 RuntimeError, "Expected all tensors to be on the same device"
 | |
|             ):
 | |
|                 foreach_op([tensor1], [tensor2])
 | |
|             if (
 | |
|                 dtype in integral_types_and(torch.bool)
 | |
|                 and foreach_op == torch._foreach_div
 | |
|             ):
 | |
|                 with self.assertRaisesRegex(RuntimeError, "result type"):
 | |
|                     foreach_op_([tensor1], [tensor2])
 | |
|             else:
 | |
|                 with self.assertRaisesRegex(
 | |
|                     RuntimeError, "Expected all tensors to be on the same device"
 | |
|                 ):
 | |
|                     foreach_op_([tensor1], [tensor2])
 | |
| 
 | |
|     @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
 | |
|     @ops(
 | |
|         filter(lambda op: op.supports_out, foreach_binary_op_db),
 | |
|         dtypes=OpDTypes.supported,
 | |
|     )
 | |
|     def test_binary_op_list_slow_path(self, device, dtype, op):
 | |
|         foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op)
 | |
|         # 0-strides
 | |
|         tensor1 = make_tensor((10, 10), dtype=dtype, device=device)
 | |
|         tensor2 = make_tensor((1,), device=device, dtype=dtype).expand_as(tensor1)
 | |
|         inputs = ([tensor1], [tensor2])
 | |
|         self._binary_test(
 | |
|             dtype,
 | |
|             foreach_op,
 | |
|             native_op,
 | |
|             inputs,
 | |
|             is_fastpath=False,
 | |
|             is_inplace=False,
 | |
|             alpha=None,
 | |
|             scalar_self_arg=False,
 | |
|         )
 | |
|         self._binary_test(
 | |
|             dtype,
 | |
|             foreach_op_,
 | |
|             native_op_,
 | |
|             inputs,
 | |
|             is_fastpath=False,
 | |
|             is_inplace=True,
 | |
|             alpha=None,
 | |
|             scalar_self_arg=False,
 | |
|         )
 | |
| 
 | |
|         # different strides
 | |
|         tensor1 = torch.zeros(10, 10, device=device, dtype=dtype)
 | |
|         tensor2 = torch.ones(10, 10, device=device, dtype=dtype)
 | |
|         inputs = ([tensor1], [tensor2.t()])
 | |
|         self._binary_test(
 | |
|             dtype,
 | |
|             foreach_op,
 | |
|             native_op,
 | |
|             inputs,
 | |
|             is_fastpath=False,
 | |
|             is_inplace=False,
 | |
|             alpha=None,
 | |
|             scalar_self_arg=False,
 | |
|         )
 | |
|         self._binary_test(
 | |
|             dtype,
 | |
|             foreach_op_,
 | |
|             native_op_,
 | |
|             inputs,
 | |
|             is_fastpath=False,
 | |
|             is_inplace=True,
 | |
|             alpha=None,
 | |
|             scalar_self_arg=False,
 | |
|         )
 | |
| 
 | |
|         # non contiguous
 | |
|         tensor1 = make_tensor(
 | |
|             (5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True
 | |
|         )
 | |
|         tensor2 = make_tensor(
 | |
|             (5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True
 | |
|         )
 | |
|         self.assertFalse(tensor1.is_contiguous())
 | |
|         self.assertFalse(tensor2.is_contiguous())
 | |
|         inputs = ([tensor1], [tensor2])
 | |
|         self._binary_test(
 | |
|             dtype,
 | |
|             foreach_op,
 | |
|             native_op,
 | |
|             inputs,
 | |
|             is_fastpath=False,
 | |
|             is_inplace=False,
 | |
|             alpha=None,
 | |
|             scalar_self_arg=False,
 | |
|         )
 | |
|         self._binary_test(
 | |
|             dtype,
 | |
|             foreach_op_,
 | |
|             native_op_,
 | |
|             inputs,
 | |
|             is_fastpath=False,
 | |
|             is_inplace=True,
 | |
|             alpha=None,
 | |
|             scalar_self_arg=False,
 | |
|         )
 | |
| 
 | |
|         # sliced tensor
 | |
|         tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype)
 | |
|         tensor2 = make_tensor((5, 2, 1, 3 * 7), device=device, dtype=dtype)[
 | |
|             :, :, :, ::7
 | |
|         ]
 | |
|         inputs = ([tensor1], [tensor2])
 | |
|         self._binary_test(
 | |
|             dtype,
 | |
|             foreach_op,
 | |
|             native_op,
 | |
|             inputs,
 | |
|             is_fastpath=False,
 | |
|             is_inplace=False,
 | |
|             alpha=None,
 | |
|             scalar_self_arg=False,
 | |
|         )
 | |
|         self._binary_test(
 | |
|             dtype,
 | |
|             foreach_op_,
 | |
|             native_op_,
 | |
|             inputs,
 | |
|             is_fastpath=False,
 | |
|             is_inplace=True,
 | |
|             alpha=None,
 | |
|             scalar_self_arg=False,
 | |
|         )
 | |
| 
 | |
|     @ops(
 | |
|         filter(lambda op: op.supports_out, foreach_binary_op_db),
 | |
|         dtypes=floating_types_and(torch.half, torch.bfloat16),
 | |
|     )
 | |
|     def test_binary_op_float_inf_nan(self, device, dtype, op):
 | |
|         inputs = (
 | |
|             [
 | |
|                 torch.tensor([float("inf")], device=device, dtype=dtype),
 | |
|                 torch.tensor([-float("inf")], device=device, dtype=dtype),
 | |
|                 torch.tensor([float("nan")], device=device, dtype=dtype),
 | |
|                 torch.tensor([float("nan")], device=device, dtype=dtype),
 | |
|             ],
 | |
|             [
 | |
|                 torch.tensor([-float("inf")], device=device, dtype=dtype),
 | |
|                 torch.tensor([float("inf")], device=device, dtype=dtype),
 | |
|                 torch.tensor([float("inf")], device=device, dtype=dtype),
 | |
|                 torch.tensor([float("nan")], device=device, dtype=dtype),
 | |
|             ],
 | |
|         )
 | |
|         op, ref, inplace_op, inplace_ref = self._get_funcs(op)
 | |
|         self._binary_test(
 | |
|             dtype, op, ref, inputs, True, False, alpha=None, scalar_self_arg=False
 | |
|         )
 | |
|         self._binary_test(
 | |
|             dtype,
 | |
|             inplace_op,
 | |
|             inplace_ref,
 | |
|             inputs,
 | |
|             True,
 | |
|             True,
 | |
|             alpha=None,
 | |
|             scalar_self_arg=False,
 | |
|         )
 | |
| 
 | |
|     # note: Below three tests (postfixed with `_tensors_on_different_devices`)
 | |
|     # checks whether foreach works with lists of tensors on different devices
 | |
|     # but tensors of the same index are on the same device, e.g., ['cuda', 'cpu].
 | |
|     @onlyCUDA
 | |
|     @ops(foreach_unary_op_db)
 | |
|     def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
 | |
|         method, ref, inplace_method, ref_inplace = self._get_funcs(op)
 | |
|         # tensors: ['cuda', 'cpu]
 | |
|         tensors = next(
 | |
|             iter(
 | |
|                 op.sample_inputs(
 | |
|                     device,
 | |
|                     dtype,
 | |
|                     num_input_tensors=[2],
 | |
|                     allow_higher_dtype_scalars=True,
 | |
|                 )
 | |
|             )
 | |
|         ).input
 | |
|         tensors[1] = tensors[1].to("cpu")
 | |
|         if not op.supports_out:
 | |
|             try:
 | |
|                 actual = method((tensors,), False, False, zero_size=False)
 | |
|             except RuntimeError as e:
 | |
|                 with self.assertRaisesRegex(type(e), str(e).splitlines()[0]):
 | |
|                     ref((tensors,))
 | |
|             else:
 | |
|                 expected = ref((tensors,))
 | |
|                 self.assertEqual(expected, actual)
 | |
| 
 | |
|         try:
 | |
|             inplace_method((tensors,), False, False, zero_size=False)
 | |
|         except RuntimeError as e:
 | |
|             with self.assertRaisesRegex(type(e), str(e).splitlines()[0]):
 | |
|                 ref_inplace((tensors,))
 | |
|         else:
 | |
|             if not op.supports_out:
 | |
|                 self.assertEqual(expected, tensors)
 | |
|             else:
 | |
|                 self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     @ops(filter(lambda op: op.supports_out, foreach_binary_op_db))
 | |
|     def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
 | |
|         _cuda_tensors = next(
 | |
|             iter(
 | |
|                 op.sample_inputs(
 | |
|                     device,
 | |
|                     dtype,
 | |
|                     num_input_tensors=[2],
 | |
|                     same_size=True,
 | |
|                     allow_higher_dtype_scalars=True,
 | |
|                 )
 | |
|             )
 | |
|         ).input
 | |
|         _cpu_tensors = next(
 | |
|             iter(
 | |
|                 op.sample_inputs(
 | |
|                     "cpu",
 | |
|                     dtype,
 | |
|                     num_input_tensors=[2],
 | |
|                     same_size=True,
 | |
|                     allow_higher_dtype_scalars=True,
 | |
|                 )
 | |
|             )
 | |
|         ).input
 | |
|         tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors))
 | |
| 
 | |
|         foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
 | |
|         native_op, native_op_ = op.ref, op.ref_inplace
 | |
|         try:
 | |
|             actual = foreach_op(tensors1, tensors2)
 | |
|         except RuntimeError as e:
 | |
|             with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
 | |
|                 [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
 | |
|         else:
 | |
|             expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
 | |
|             self.assertEqual(expected, actual)
 | |
|         try:
 | |
|             foreach_op_(tensors1, tensors2)
 | |
|         except RuntimeError as e:
 | |
|             with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
 | |
|                 [native_op_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
 | |
|         else:
 | |
|             self.assertEqual(actual, tensors1)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     @ops(foreach_pointwise_op_db, allowed_dtypes=floating_types())
 | |
|     def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op):
 | |
|         # tensors1: ['cuda', 'cpu]
 | |
|         # tensors2: ['cuda', 'cpu]
 | |
|         # tensors3: ['cuda', 'cpu]
 | |
|         # first tensorlist is zero-size when float32
 | |
|         _cuda_tensors = list(
 | |
|             op.sample_inputs(
 | |
|                 device,
 | |
|                 dtype,
 | |
|                 num_input_tensors=[3],
 | |
|                 same_size=True,
 | |
|                 allow_higher_dtype_scalars=True,
 | |
|             )
 | |
|         )[int(dtype == torch.float32)].input
 | |
|         _cpu_tensors = next(
 | |
|             iter(
 | |
|                 op.sample_inputs(
 | |
|                     "cpu",
 | |
|                     dtype,
 | |
|                     num_input_tensors=[3],
 | |
|                     same_size=True,
 | |
|                     allow_higher_dtype_scalars=True,
 | |
|                 )
 | |
|             )
 | |
|         ).input
 | |
|         tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors))
 | |
| 
 | |
|         foreach_op, foreach_op_, native_op = (
 | |
|             op.method_variant,
 | |
|             op.inplace_variant,
 | |
|             op.ref,
 | |
|         )
 | |
|         actual = foreach_op(tensors1, tensors2, tensors3)
 | |
|         expected = [native_op(*_cuda_tensors), native_op(*_cpu_tensors)]
 | |
|         self.assertEqual(expected, actual)
 | |
| 
 | |
|         # note(mkozuki): Limiting dtypes to FP32&FP64, we can safely run inplace ops.
 | |
|         foreach_op_(tensors1, tensors2, tensors3)
 | |
|         self.assertEqual(expected, tensors1)
 | |
| 
 | |
|     # note: BFloat16 has the same number of exponent bits as FP32
 | |
|     # so if squared L2 norm overflows in BF16, then it also overflows in FP32.
 | |
|     @onlyCUDA
 | |
|     @ops(
 | |
|         [o for o in foreach_reduce_op_db if "norm" in o.name],
 | |
|         allowed_dtypes=(torch.half, torch.bfloat16),
 | |
|     )
 | |
|     def test_foreach_l2_large_value_input(self, device, dtype, op):
 | |
|         ord, N = 2, 10
 | |
|         max_value = torch.finfo(dtype).max
 | |
|         scaler = torch.tensor([max_value]).sqrt().to(device=device, dtype=dtype)
 | |
|         inputs = (
 | |
|             [
 | |
|                 t * scaler
 | |
|                 for t in next(
 | |
|                     iter(
 | |
|                         op.sample_inputs(
 | |
|                             device,
 | |
|                             dtype,
 | |
|                             requries_grad=True,
 | |
|                             num_input_tensors=[N],
 | |
|                             low=1,
 | |
|                         )
 | |
|                     )
 | |
|                 ).input
 | |
|             ][:-1],
 | |
|         )
 | |
|         # make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
 | |
|         self.assertTrue(scaler * scaler * N > max_value)
 | |
|         fn, ref_fn, *_ = self._get_funcs(op)
 | |
|         actual = fn(
 | |
|             inputs, is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False
 | |
|         )
 | |
|         expect = ref_fn(inputs, ord=ord)
 | |
| 
 | |
|         if dtype == torch.float16:
 | |
|             # making sure the reference L2 norm values are in the range of FP16.
 | |
|             self.assertFalse(any(torch.isinf(e) for e in expect))
 | |
|         else:
 | |
|             self.assertTrue(
 | |
|                 all(
 | |
|                     inputs[0][i].numel() == 0 or torch.isinf(e)
 | |
|                     for i, e in enumerate(expect)
 | |
|                 )
 | |
|             )
 | |
|         self.assertEqual(expect, actual, equal_nan=False)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     @ops(foreach_reduce_op_db, allowed_dtypes=floating_types())
 | |
|     @parametrize("use_cuda_graph", (False, True))
 | |
|     @parametrize("w_empty", (False, True))
 | |
|     def test_big_num_tensors(self, device, dtype, op, use_cuda_graph, w_empty):
 | |
|         # foreach_max cannot handle empty tensors as max requires an identity
 | |
|         intersperse_empty_tensors = w_empty and op.name != "_foreach_max"
 | |
| 
 | |
|         N = 600
 | |
|         indices_with_empty_tensors = (
 | |
|             set()
 | |
|             if not intersperse_empty_tensors
 | |
|             else {200, 300, 301, 400, 401, 402, 404, 598}
 | |
|         )
 | |
|         tensorlist = [
 | |
|             make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False)
 | |
|             if i not in indices_with_empty_tensors
 | |
|             else torch.empty(0, dtype=dtype, device=device)
 | |
|             for i in range(N)
 | |
|         ]
 | |
|         fn, ref_fn, *_ = self._get_funcs(op)
 | |
| 
 | |
|         import math
 | |
| 
 | |
|         if op.name == "_foreach_norm":
 | |
|             ords = [1, 2]
 | |
|             if not intersperse_empty_tensors:
 | |
|                 # inf norm over an empty tensor is not defined by vector norm as it expects an identity
 | |
|                 ords.append(math.inf)
 | |
|         else:
 | |
|             ords = [None]
 | |
| 
 | |
|         for ord in ords:
 | |
|             kwargs = {"ord": ord} if ord else {}
 | |
|             if not use_cuda_graph:
 | |
|                 actual = fn(
 | |
|                     inputs=[tensorlist],
 | |
|                     is_cuda=True,
 | |
|                     expect_fastpath=True,
 | |
|                     zero_size=False,
 | |
|                     **kwargs,
 | |
|                 )
 | |
|             else:
 | |
|                 # When using CUDA graphs and the tensor metadata doesn't fit in
 | |
|                 # the static kernel argument space, multi_tensor_apply creates
 | |
|                 # the launch arguments once, uses cudaUserObject_t to tie its
 | |
|                 # lifetime to the graph, and reuses it throughout replays. This
 | |
|                 # test verifies multi_tensor_apply's behavior in the scenario.
 | |
|                 g = torch.cuda.CUDAGraph()
 | |
|                 with torch.cuda.graph(g):
 | |
|                     actual = fn.func(tensorlist, **kwargs)
 | |
|                 g.replay()
 | |
|             expect = ref_fn(inputs=[tensorlist], **kwargs)
 | |
| 
 | |
|             self.assertEqual(expect, actual, equal_nan=True)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     @ops(foreach_reduce_op_db)
 | |
|     @parametrize("w_empty", (False, True))
 | |
|     def test_foreach_reduce_large_input(self, device, dtype, op, w_empty):
 | |
|         # test inputs larger than kChunkSize (65536) * max_num_blocks (320)
 | |
|         N = 65536 * 320 * 2
 | |
|         disable_fastpath = False
 | |
|         kwargs = {}
 | |
|         if op.name == "_foreach_norm":
 | |
|             kwargs["ord"] = 2
 | |
|             disable_fastpath = dtype not in floating_types_and(
 | |
|                 torch.half, torch.bfloat16
 | |
|             )
 | |
| 
 | |
|         tensorlist = [
 | |
|             make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)
 | |
|         ]
 | |
|         # foreach_max cannot handle empty tensors as max over empty is undefined
 | |
|         if w_empty and op.name != "_foreach_max":
 | |
|             tensorlist += [
 | |
|                 torch.empty(0, dtype=dtype, device=device),
 | |
|                 make_tensor((N,), dtype=dtype, device=device, noncontiguous=False),
 | |
|             ]
 | |
|         inputs = (tensorlist,)
 | |
|         wrapped_op, ref, _, _ = self._get_funcs(op)
 | |
|         self.assertEqual(
 | |
|             ref(inputs, **kwargs),
 | |
|             wrapped_op(
 | |
|                 inputs, self.is_cuda, not disable_fastpath, zero_size=False, **kwargs
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|     @onlyCUDA
 | |
|     @ops(
 | |
|         foreach_unary_op_db
 | |
|         + foreach_binary_op_db
 | |
|         + foreach_pointwise_op_db
 | |
|         + foreach_other_op_db,
 | |
|         dtypes=(torch.float,),
 | |
|     )
 | |
|     def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op):
 | |
|         inplace_op = op.inplace_variant
 | |
|         if inplace_op is None:
 | |
|             self.skipTest("no in-place op available")
 | |
| 
 | |
|         sample = next(
 | |
|             iter(
 | |
|                 op.sample_inputs(
 | |
|                     dtype=dtype, device=device, num_input_tensors=[2], same_size=True
 | |
|                 )
 | |
|             )
 | |
|         )
 | |
|         sample.input[0].requires_grad_(True)
 | |
|         with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
 | |
|             inplace_op(sample.input, *sample.args)
 | |
|         sample.input[1].requires_grad_(True)
 | |
|         with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
 | |
|             inplace_op(sample.input, *sample.args)
 | |
| 
 | |
|         _tensors = [
 | |
|             t.detach().clone().requires_grad_(i == 0)
 | |
|             for i, t in enumerate(sample.input)
 | |
|         ]
 | |
|         tensors = [t.clone() for t in _tensors]
 | |
|         inplace_op(tensors, *sample.args)
 | |
|         self.assertIsNotNone(tensors[0].grad_fn)
 | |
|         self.assertIsNone(tensors[1].grad_fn)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     @ops(
 | |
|         filter(
 | |
|             lambda op: op.supports_out,
 | |
|             foreach_unary_op_db
 | |
|             + foreach_binary_op_db
 | |
|             + foreach_pointwise_op_db
 | |
|             + foreach_other_op_db,
 | |
|         ),
 | |
|         dtypes=(torch.float,),
 | |
|     )
 | |
|     def test_outplace_with_invalid_grads(self, device, dtype, op):
 | |
|         func, *_ = self._get_funcs(op)
 | |
|         sample = next(
 | |
|             iter(
 | |
|                 op.sample_inputs(
 | |
|                     dtype=dtype,
 | |
|                     device=device,
 | |
|                     requires_grad=True,
 | |
|                     num_input_tensors=[2],
 | |
|                     same_size=True,
 | |
|                 )
 | |
|             )
 | |
|         )
 | |
|         self.assertTrue(all(t.requires_grad for t in sample.input))
 | |
|         (out1, out2) = func(
 | |
|             [sample.input, *sample.args],
 | |
|             is_cuda=False,
 | |
|             expect_fastpath=False,
 | |
|             **sample.kwargs,
 | |
|         )
 | |
|         out1.backward(torch.ones_like(out1))
 | |
|         self.assertIsNotNone(sample.input[0].grad)
 | |
|         self.assertIsNone(sample.input[1].grad)
 | |
| 
 | |
|     @ops(
 | |
|         filter(
 | |
|             lambda op: op.backward_requires_result,
 | |
|             foreach_unary_op_db
 | |
|             + foreach_binary_op_db
 | |
|             + foreach_pointwise_op_db
 | |
|             + foreach_other_op_db,
 | |
|         ),
 | |
|         dtypes=(torch.float32,),
 | |
|     )
 | |
|     def test_lifetime_of_grad_fn_when_result_is_saved(self, device, dtype, op):
 | |
|         def get_ref(func, sample):
 | |
|             class Foo:
 | |
|                 pass
 | |
| 
 | |
|             out = func(
 | |
|                 (sample.input, *sample.args),
 | |
|                 is_cuda=False,
 | |
|                 expect_fastpath=False,
 | |
|                 **sample.kwargs,
 | |
|             )
 | |
|             foo = Foo()
 | |
|             meta_dict = out[0].grad_fn.metadata
 | |
|             meta_dict[0] = foo
 | |
|             ref = weakref.ref(foo)
 | |
|             return out, ref
 | |
| 
 | |
|         def _test(func, sample):
 | |
|             out, ref = get_ref(func, sample)
 | |
|             self.assertIsNotNone(ref())
 | |
|             del out
 | |
|             self.assertIsNone(ref())
 | |
| 
 | |
|         func = self._get_funcs(op)[0]
 | |
|         for sample in op.sample_inputs(
 | |
|             device, dtype, requires_grad=True, num_input_tensors=[1]
 | |
|         ):
 | |
|             for key in ("is_fastpath", "disable_fastpath"):
 | |
|                 if key in sample.kwargs:
 | |
|                     del sample.kwargs[key]
 | |
|             # note: `_foreach_pow.Scalar` and `_foreach_pow.ScalarList` don't depend on `result`
 | |
|             # see: https://github.com/pytorch/pytorch/blob/5403c777/tools/autograd/derivatives.yaml#L3048-L3049
 | |
|             if op.name == "_foreach_pow":
 | |
|                 if (
 | |
|                     isinstance(sample.args[0], list)
 | |
|                     and isinstance(sample.args[0][0], Number)
 | |
|                 ) or (
 | |
|                     isinstance(sample.args[0], Number)
 | |
|                     and not isinstance(sample.args[0], float)
 | |
|                 ):
 | |
|                     continue
 | |
|                 if isinstance(sample.args[0], float):
 | |
|                     new_args = (sample.input,)
 | |
|                     sample.input = sample.args[0]
 | |
|                     sample.args = new_args
 | |
|             _test(func, sample)
 | |
| 
 | |
|     @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
 | |
|     def test_tensors_grouping(self):
 | |
|         num_tensors_per_list = 10
 | |
|         num_devices = torch.cuda.device_count()
 | |
|         dtypes = (torch.float16, torch.float32, torch.float64)
 | |
|         list1 = [
 | |
|             torch.tensor(
 | |
|                 i,
 | |
|                 device=torch.device("cuda", random.randint(0, num_devices - 1)),
 | |
|                 dtype=dtypes[random.randint(0, 2)],
 | |
|             )
 | |
|             for i in range(num_tensors_per_list)
 | |
|         ]
 | |
|         list2 = [None for _ in list1]
 | |
|         list3 = [torch.rand_like(t) for t in list1]
 | |
|         nested_tensorlists = [list1, list2, list3]
 | |
|         grouped_tensors = torch.utils._foreach_utils._group_tensors_by_device_and_dtype(
 | |
|             nested_tensorlists, with_indices=True
 | |
|         )
 | |
|         num_tensors_seen = 0
 | |
|         for (device, dtype), ([l1, l2, l3], indices) in grouped_tensors.items():
 | |
|             for t in itertools.chain(l1, l3):
 | |
|                 self.assertEqual(t.device, device)
 | |
|                 self.assertEqual(t.dtype, dtype)
 | |
|                 num_tensors_seen += 1
 | |
|             self.assertEqual(len(l1), len(l2))
 | |
|             self.assertTrue(all(p is None for p in l2))
 | |
|             for i, index in enumerate(indices):
 | |
|                 self.assertEqual(l1[i], list1[index])
 | |
|                 self.assertEqual(l2[i], list2[index])
 | |
|                 self.assertEqual(l3[i], list3[index])
 | |
|         self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_0dim_tensor_overload_cpu_ok(self):
 | |
|         tensors = [torch.ones((), device="cuda", dtype=torch.float32) for _ in range(2)]
 | |
|         scalar_cpu_tensor = torch.tensor(4.0, device="cpu")
 | |
| 
 | |
|         # For mul and div, the scalar is allowed to be on CPU too
 | |
|         actual = torch._foreach_mul(tensors, scalar_cpu_tensor)
 | |
|         self.assertEqual(actual, [t.mul(scalar_cpu_tensor) for t in tensors])
 | |
|         actual = torch._foreach_div(tensors, scalar_cpu_tensor)
 | |
|         self.assertEqual(actual, [t.div(scalar_cpu_tensor) for t in tensors])
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_div_reciprocal(self):
 | |
|         expect_m, expect_e = torch.frexp(
 | |
|             torch.div(torch.tensor(0.1, device="cuda"), 10.0)
 | |
|         )
 | |
|         actual_m, actual_e = torch.frexp(
 | |
|             torch._foreach_div([torch.tensor(0.1, device="cuda")], [10.0])[0]
 | |
|         )
 | |
|         self.assertEqual(expect_m, actual_m)
 | |
|         self.assertEqual(expect_e, actual_e)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     def test_0dim_tensor_overload_exception(self):
 | |
|         # check exceptions of fast path
 | |
|         tensors = [
 | |
|             make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)
 | |
|         ]
 | |
|         with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"):
 | |
|             torch._foreach_add(tensors, torch.tensor(1.0, device="cpu"), alpha=1.0)
 | |
| 
 | |
|         tensors = [
 | |
|             make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda")
 | |
|         ]
 | |
|         with self.assertRaisesRegex(
 | |
|             RuntimeError, "scalar tensor expected to be 0 dim but"
 | |
|         ):
 | |
|             torch._foreach_mul(tensors, torch.tensor([1.0, 1.0], device="cuda"))
 | |
|         with self.assertRaisesRegex(
 | |
|             RuntimeError, "scalar tensor expected to be 0 dim but"
 | |
|         ):
 | |
|             torch._foreach_add(tensors, torch.tensor([1.0, 1.0], device="cuda"))
 | |
| 
 | |
|     @onlyCUDA
 | |
|     @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
 | |
|     def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op):
 | |
|         foreach_copy_ = op.inplace_variant
 | |
|         copy_ = op.ref_inplace
 | |
|         for non_blocking in (False, True):
 | |
|             for sample in op.sample_inputs(
 | |
|                 device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
 | |
|             ):
 | |
|                 with torch.no_grad():
 | |
|                     ref_input = [t.detach().clone() for t in sample.input]
 | |
|                 foreach_copy_(sample.input, sample.args[0], non_blocking)
 | |
|                 for t, s in zip(ref_input, sample.args[0]):
 | |
|                     copy_(t, s, non_blocking)
 | |
|                 self.assertEqual(sample.input, ref_input)
 | |
|                 if torch.cuda.device_count() > 1:
 | |
|                     device = torch.device("cuda", 1)
 | |
|                     rhs_tensors = [t.to(device) for t in sample.args[0]]
 | |
|                     foreach_copy_(sample.input, rhs_tensors, non_blocking)
 | |
|                     for t, s in zip(ref_input, rhs_tensors):
 | |
|                         copy_(t, s, non_blocking)
 | |
|                     self.assertEqual(ref_input, sample.input)
 | |
| 
 | |
|     @onlyCUDA
 | |
|     @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
 | |
|     def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
 | |
|         # check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_
 | |
|         foreach_copy_ = ForeachFuncWrapper(op.inplace_variant)
 | |
|         for sample in op.sample_inputs(
 | |
|             device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
 | |
|         ):
 | |
|             for src_dtype in floating_types_and(torch.half, torch.bfloat16):
 | |
|                 if src_dtype == dtype:
 | |
|                     continue
 | |
|                 self_tensors = [t.clone() for t in sample.input]
 | |
|                 src_tensors = [t.to(src_dtype) for t in self_tensors]
 | |
|                 out = foreach_copy_(
 | |
|                     (self_tensors, src_tensors), is_cuda=True, expect_fastpath=True
 | |
|                 )
 | |
|                 ref_out = [
 | |
|                     torch.empty_like(t).copy_(s)
 | |
|                     for t, s in zip(self_tensors, src_tensors)
 | |
|                 ]
 | |
|                 for t, ref_t in zip(out, ref_out):
 | |
|                     self.assertTrue(torch.equal(t, ref_t))
 | |
| 
 | |
|     # Test reverse-mode & forward-mode AD if supported.
 | |
|     @onlyCUDA
 | |
|     @ops(
 | |
|         foreach_unary_op_db
 | |
|         + foreach_binary_op_db
 | |
|         + foreach_pointwise_op_db
 | |
|         + foreach_reduce_op_db
 | |
|         + foreach_other_op_db,
 | |
|         dtypes=OpDTypes.supported,
 | |
|         allowed_dtypes=(torch.float64, torch.complex128),
 | |
|     )
 | |
|     @parametrize(
 | |
|         "inplace", (False, True), name_fn=lambda x: "inplace" if x else "outplace"
 | |
|     )
 | |
|     def test_autodiff(self, device, dtype, op, inplace):
 | |
|         if (not inplace) and not op.supports_out:
 | |
|             self.skipTest("out-of-place not implemented")
 | |
|         if inplace and op.has_no_in_place:
 | |
|             self.skipTest("in-place not implemented")
 | |
|         if not (
 | |
|             op.supports_autograd
 | |
|             or op.supports_inplace_autograd
 | |
|             or op.supports_forward_ad
 | |
|         ):
 | |
|             self.skipTest("neither reverse mode nor forward mode supported")
 | |
| 
 | |
|         # note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
 | |
|         if (
 | |
|             (not inplace)
 | |
|             and dtype == torch.float64
 | |
|             and op.name
 | |
|             in (
 | |
|                 "_foreach_acos",
 | |
|                 "_foreach_asin",
 | |
|                 "_foreach_log10",
 | |
|                 "_foreach_log1p",
 | |
|                 "_foreach_log2",
 | |
|                 "_foreach_log",
 | |
|                 "_foreach_pow",
 | |
|                 "_foreach_sqrt",
 | |
|                 "_foreach_rsqrt",
 | |
|             )
 | |
|         ):
 | |
|             value_range = {"low": 0.5, "high": 1.0}
 | |
|         else:
 | |
|             value_range = {}
 | |
|         for sample in op.sample_inputs(
 | |
|             device,
 | |
|             dtype,
 | |
|             requires_grad=True,
 | |
|             num_input_tensors=[5],
 | |
|             allow_higher_dtype_scalars=True,
 | |
|             **value_range,
 | |
|         ):
 | |
|             # Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])`
 | |
|             if op.name == "_foreach_pow" and isinstance(sample.input, Number):
 | |
|                 continue
 | |
| 
 | |
|             func = None
 | |
|             if inplace:
 | |
|                 # Call `clone` to avoid inplace modifications likewise
 | |
|                 # `torch.testing._internal.common_utils.TestGradients._get_safe_inplace`
 | |
|                 def inplace_func(*tensorlist):
 | |
|                     kwargs = (
 | |
|                         {"alpha": sample.kwargs["alpha"]}
 | |
|                         if "alpha" in sample.kwargs
 | |
|                         else {}
 | |
|                     )
 | |
|                     op.inplace_variant(
 | |
|                         tuple(t.clone() for t in tensorlist), *sample.args, **kwargs
 | |
|                     )
 | |
|                     return tensorlist
 | |
| 
 | |
|                 func = inplace_func
 | |
|             else:
 | |
| 
 | |
|                 def outplace_func(*tensorlist):
 | |
|                     kwargs = (
 | |
|                         {"alpha": sample.kwargs["alpha"]}
 | |
|                         if "alpha" in sample.kwargs
 | |
|                         else {}
 | |
|                     )
 | |
|                     return op.method_variant(tensorlist, *sample.args, **kwargs)
 | |
| 
 | |
|                 func = outplace_func
 | |
| 
 | |
|             working_sample, err_msg_pattern = check_autodiff_sample(
 | |
|                 op, sample, dtype, inplace
 | |
|             )
 | |
| 
 | |
|             def call_gradcheck():
 | |
|                 gradcheck(
 | |
|                     func,
 | |
|                     sample.input,
 | |
|                     raise_exception=True,
 | |
|                     check_forward_ad=op.supports_forward_ad,
 | |
|                     check_batched_forward_grad=False,
 | |
|                     check_backward_ad=op.supports_autograd,
 | |
|                     check_batched_grad=False,
 | |
|                 )
 | |
| 
 | |
|             if not working_sample:
 | |
|                 if not err_msg_pattern:
 | |
|                     # lhs of float64 and rhs of complex.
 | |
|                     continue
 | |
|                 with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
 | |
|                     call_gradcheck()
 | |
|                 continue
 | |
|             call_gradcheck()
 | |
| 
 | |
|             # Test per-tensor `grad_fn` behavior.
 | |
|             if inplace and op.supports_inplace_autograd:
 | |
|                 # per-tensor `grad_fn` check.
 | |
|                 hook_buffer = []
 | |
| 
 | |
|                 def get_grad_fn_hook(i):
 | |
|                     def hook(grad_inputs, grad_outputs) -> None:
 | |
|                         hook_buffer.append(i)
 | |
| 
 | |
|                     return hook
 | |
| 
 | |
|                 _inputs = [t.detach().clone().requires_grad_() for t in sample.input]
 | |
|                 inputs = [t.clone() for t in _inputs]
 | |
|                 kwargs = (
 | |
|                     {"alpha": sample.kwargs["alpha"]}
 | |
|                     if "alpha" in sample.kwargs
 | |
|                     else {}
 | |
|                 )
 | |
|                 op.inplace_variant(inputs, *sample.args, **kwargs)
 | |
| 
 | |
|                 self.assertEqual(len({t.grad_fn for t in inputs}), len(inputs))
 | |
| 
 | |
|                 for i, t in enumerate(inputs):
 | |
|                     t.grad_fn.register_hook(get_grad_fn_hook(i))
 | |
| 
 | |
|                 torch.autograd.grad(
 | |
|                     inputs[0],
 | |
|                     inputs=(_inputs[0],),
 | |
|                     grad_outputs=(torch.rand_like(inputs[0]),),
 | |
|                     retain_graph=True,
 | |
|                 )
 | |
|                 self.assertEqual(hook_buffer, [0])
 | |
|                 hook_buffer.clear()
 | |
| 
 | |
|                 # tensors have different shapes.
 | |
|                 sum_of_cloned_tensors = torch.cat([t.view(-1) for t in inputs]).sum()
 | |
|                 grad_output = torch.rand_like(sum_of_cloned_tensors)
 | |
|                 torch.autograd.grad(
 | |
|                     sum_of_cloned_tensors,
 | |
|                     inputs=tuple(_inputs),
 | |
|                     grad_outputs=(grad_output,),
 | |
|                     retain_graph=False,
 | |
|                 )
 | |
|                 self.assertEqual(hook_buffer, list(reversed(range(len(inputs)))))
 | |
| 
 | |
| 
 | |
| # TODO(crcrpar): Hide this inside torch/testing/_internal.
 | |
| # would end up adding another layer to `foreach_inputs_sample_func.__call__`
 | |
| # so that we can use this function as something like the first argument of `filter` function.
 | |
| # Even after moving this function to testing, I personally think it'd be better to check the error message.
 | |
| def check_autodiff_sample(op, sample, dtype, is_inplace):
 | |
|     if op.name == "_foreach_abs" and is_inplace and dtype == torch.complex128:
 | |
|         return False, "In-place abs is not supported for complex tensors."
 | |
|     if op.name == "_foreach_sub" and (
 | |
|         (
 | |
|             isinstance(sample.args[-1], list)
 | |
|             and any(isinstance(a, bool) for a in sample.args[-1])
 | |
|         )
 | |
|         or isinstance(sample.args[-1], bool)
 | |
|     ):
 | |
|         return False, _BOOL_SUB_ERR_MSG
 | |
|     if op.name == "_foreach_norm" and (not is_inplace):
 | |
|         return (
 | |
|             False,
 | |
|             "Trying to set a forward gradient that has a different size than that of the original Tensor, "
 | |
|             "this is not supported. Tensor is of size [] while the given forward gradient is of size [1",
 | |
|         )
 | |
|     rhs_arg_has_complex_number = sample.args and (
 | |
|         (
 | |
|             isinstance(sample.args[-1], list)
 | |
|             and any(isinstance(a, complex) for a in sample.args[-1])
 | |
|         )
 | |
|         or (isinstance(sample.args[-1], complex))
 | |
|     )
 | |
|     if rhs_arg_has_complex_number and dtype == torch.float64:
 | |
|         if op.name == "_foreach_lerp":
 | |
|             return False, "value cannot be converted to type double without overflow"
 | |
|         if op.name in (
 | |
|             "_foreach_clamp_max",
 | |
|             "_foreach_clamp_min",
 | |
|             "_foreach_maximum",
 | |
|             "_foreach_minimum",
 | |
|         ):
 | |
|             return False, "clamp is not supported for complex types"
 | |
|         if op.name == "_foreach_lerp" and is_inplace:
 | |
|             return False, "value cannot be converted to type double without overflow"
 | |
|         if not is_inplace:
 | |
|             return False, ""
 | |
|         elif op.name in (
 | |
|             "_foreach_add",
 | |
|             "_foreach_sub",
 | |
|             "_foreach_mul",
 | |
|             "_foreach_div",
 | |
|             "_foreach_pow",
 | |
|         ):
 | |
|             return (
 | |
|                 False,
 | |
|                 "result type ComplexDouble can't be cast to the desired output type Double",
 | |
|             )
 | |
|     return True, ""
 | |
| 
 | |
| 
 | |
| instantiate_device_type_tests(TestForeach, globals())
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     run_tests()
 |