mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42537 [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](https://github.com/pytorch/pytorch/pull/41554). **Motivation** [GitHub issue](https://github.com/pytorch/pytorch/issues/38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **Current API restrictions** - List can't be empty (will fixed in upcoming PRs). - All tensors in the list must have the same dtype, device and size. **Broadcasting** At this point we don't support broadcasting. **What is 'Fast' and 'Slow' route** In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. To go the fast route, - All tensors must have strided layout - All tensors must be dense and not have overlapping memory - The resulting tensor type must be the same. ---------------- **In this PR** Adding APIs: ``` torch._foreach_exp(TensorList tl1) torch._foreach_exp_(TensorList tl1) torch._foreach_sqrt(TensorList tl1) torch._foreach_sqrt_(TensorList tl1) ``` **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists 2. Properly handle bool tensors **Plan for the next PRs** 1. APIs - Pointwise Ops 2. Complete tasks from TODO 3. Rewrite PyTorch optimizers to use for-each operators for performance gains. Test Plan: Imported from OSS Reviewed By: cpuhrsch Differential Revision: D23331889 Pulled By: izdeby fbshipit-source-id: 8b04673b8412957472ed56361954ca3884eb9376
360 lines
16 KiB
Python
360 lines
16 KiB
Python
import torch
|
|
import unittest
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
|
|
|
|
class TestForeach(TestCase):
|
|
bin_ops = [
|
|
torch._foreach_add,
|
|
torch._foreach_add_,
|
|
torch._foreach_sub,
|
|
torch._foreach_sub_,
|
|
torch._foreach_mul,
|
|
torch._foreach_mul_,
|
|
torch._foreach_div,
|
|
torch._foreach_div_,
|
|
]
|
|
|
|
def _get_test_data(self, device, dtype, N):
|
|
if dtype in [torch.bfloat16, torch.bool, torch.float16]:
|
|
tensors = [torch.randn(N, N, device=device).to(dtype) for _ in range(N)]
|
|
|
|
elif dtype in torch.testing.get_all_int_dtypes():
|
|
tensors = [torch.randint(1, 100, (N, N), device=device, dtype=dtype) for _ in range(N)]
|
|
else:
|
|
tensors = [torch.randn(N, N, device=device, dtype=dtype) for _ in range(N)]
|
|
|
|
return tensors
|
|
|
|
def _test_bin_op_list(self, device, dtype, foreach_op, foreach_op_, torch_op, N=20):
|
|
tensors1 = self._get_test_data(device, dtype, N)
|
|
tensors2 = self._get_test_data(device, dtype, N)
|
|
|
|
expected = [torch_op(tensors1[i], tensors2[i]) for i in range(N)]
|
|
res = foreach_op(tensors1, tensors2)
|
|
foreach_op_(tensors1, tensors2)
|
|
self.assertEqual(res, tensors1)
|
|
self.assertEqual(tensors1, expected)
|
|
|
|
def _test_unary_op(self, device, dtype, foreach_op, foreach_op_, torch_op, N=20):
|
|
tensors1 = self._get_test_data(device, dtype, N)
|
|
expected = [torch_op(tensors1[i]) for i in range(N)]
|
|
res = foreach_op(tensors1)
|
|
foreach_op_(tensors1)
|
|
self.assertEqual(res, tensors1)
|
|
self.assertEqual(tensors1, expected)
|
|
|
|
def _test_pointwise_op(self, device, dtype, foreach_op, foreach_op_, torch_op, N=20):
|
|
tensors = self._get_test_data(device, dtype, N)
|
|
tensors1 = self._get_test_data(device, dtype, N)
|
|
tensors2 = self._get_test_data(device, dtype, N)
|
|
value = 2
|
|
|
|
expected = [torch_op(tensors[i], tensors1[i], tensors2[i], value=value) for i in range(N)]
|
|
|
|
res = foreach_op(tensors, tensors1, tensors2, value)
|
|
foreach_op_(tensors, tensors1, tensors2, value)
|
|
self.assertEqual(res, tensors)
|
|
self.assertEqual(tensors, expected)
|
|
|
|
#
|
|
# Unary ops
|
|
#
|
|
@dtypes(*[torch.float, torch.double, torch.complex64, torch.complex128])
|
|
def test_sqrt(self, device, dtype):
|
|
self._test_unary_op(device, dtype, torch._foreach_sqrt, torch._foreach_sqrt_, torch.sqrt)
|
|
|
|
@dtypes(*[torch.float, torch.double, torch.complex64, torch.complex128])
|
|
def test_exp(self, device, dtype):
|
|
self._test_unary_op(device, dtype, torch._foreach_exp, torch._foreach_exp_, torch.exp)
|
|
|
|
#
|
|
# Pointwise ops
|
|
#
|
|
@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False))
|
|
def test_addcmul(self, device, dtype):
|
|
if device == 'cpu':
|
|
if dtype == torch.half:
|
|
with self.assertRaisesRegex(RuntimeError, r"\"addcmul_cpu_out\" not implemented for \'Half\'"):
|
|
self._test_pointwise_op(device, dtype, torch._foreach_addcmul,
|
|
torch._foreach_addcmul_, torch.addcmul)
|
|
return
|
|
|
|
self._test_pointwise_op(device, dtype, torch._foreach_addcmul, torch._foreach_addcmul_, torch.addcmul)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False))
|
|
def test_addcdiv(self, device, dtype):
|
|
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"Integer division with addcdiv is no longer supported, and in a future"):
|
|
self._test_pointwise_op(device, dtype, torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv)
|
|
return
|
|
|
|
if device == 'cpu':
|
|
if dtype == torch.half:
|
|
with self.assertRaisesRegex(RuntimeError, r"\"addcdiv_cpu_out\" not implemented for \'Half\'"):
|
|
self._test_pointwise_op(device, dtype, torch._foreach_addcdiv,
|
|
torch._foreach_addcdiv_, torch.addcdiv)
|
|
return
|
|
self._test_pointwise_op(device, dtype, torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv)
|
|
|
|
#
|
|
# Ops with scalar
|
|
#
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_int_scalar(self, device, dtype):
|
|
tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
|
int_scalar = 1
|
|
|
|
# bool tensor + 1 will result in int64 tensor
|
|
if dtype == torch.bool:
|
|
expected = [torch.ones(10, 10, device=device, dtype=torch.int64) for _ in range(10)]
|
|
else:
|
|
expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
|
|
|
res = torch._foreach_add(tensors, int_scalar)
|
|
self.assertEqual(res, expected)
|
|
|
|
if dtype in [torch.bool]:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"result type Long can't be cast to the desired output type Bool"):
|
|
torch._foreach_add_(tensors, int_scalar)
|
|
else:
|
|
torch._foreach_add_(tensors, int_scalar)
|
|
self.assertEqual(res, tensors)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_float_scalar(self, device, dtype):
|
|
tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
|
float_scalar = 1.
|
|
|
|
# float scalar + integral tensor will result in float tensor
|
|
if dtype in [torch.uint8, torch.int8, torch.int16,
|
|
torch.int32, torch.int64, torch.bool]:
|
|
expected = [torch.ones(10, 10, device=device, dtype=torch.float32) for _ in range(10)]
|
|
else:
|
|
expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
|
|
|
res = torch._foreach_add(tensors, float_scalar)
|
|
self.assertEqual(res, expected)
|
|
|
|
if dtype in [torch.uint8, torch.int8, torch.int16,
|
|
torch.int32, torch.int64, torch.bool]:
|
|
self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, float_scalar))
|
|
else:
|
|
torch._foreach_add_(tensors, float_scalar)
|
|
self.assertEqual(res, tensors)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_complex_scalar(self, device, dtype):
|
|
tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
|
complex_scalar = 3 + 5j
|
|
|
|
# bool tensor + 1 will result in int64 tensor
|
|
expected = [torch.add(complex_scalar, torch.zeros(10, 10, device=device, dtype=dtype)) for _ in range(10)]
|
|
|
|
if dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16] and device == 'cuda:0':
|
|
# value cannot be converted to dtype without overflow:
|
|
self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, complex_scalar))
|
|
self.assertRaises(RuntimeError, lambda: torch._foreach_add(tensors, complex_scalar))
|
|
return
|
|
|
|
res = torch._foreach_add(tensors, complex_scalar)
|
|
self.assertEqual(res, expected)
|
|
|
|
if dtype not in [torch.complex64, torch.complex128]:
|
|
self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, complex_scalar))
|
|
else:
|
|
torch._foreach_add_(tensors, complex_scalar)
|
|
self.assertEqual(res, tensors)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_bool_scalar(self, device, dtype):
|
|
tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
|
bool_scalar = True
|
|
|
|
expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
|
|
|
res = torch._foreach_add(tensors, bool_scalar)
|
|
self.assertEqual(res, expected)
|
|
|
|
torch._foreach_add_(tensors, bool_scalar)
|
|
self.assertEqual(res, tensors)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_add_with_different_size_tensors(self, device, dtype):
|
|
if dtype == torch.bool:
|
|
return
|
|
tensors = [torch.zeros(10 + n, 10 + n, device=device, dtype=dtype) for n in range(10)]
|
|
expected = [torch.ones(10 + n, 10 + n, device=device, dtype=dtype) for n in range(10)]
|
|
|
|
torch._foreach_add_(tensors, 1)
|
|
self.assertEqual(expected, tensors)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
|
|
# TODO: enable empty list case
|
|
for tensors in [[torch.randn([0])]]:
|
|
res = torch._foreach_add(tensors, 1)
|
|
self.assertEqual(res, tensors)
|
|
|
|
torch._foreach_add_(tensors, 1)
|
|
self.assertEqual(res, tensors)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_add_scalar_with_overlapping_tensors(self, device, dtype):
|
|
tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
|
|
expected = [torch.tensor([[[2, 2, 2]], [[2, 2, 2]]], dtype=dtype, device=device)]
|
|
|
|
# bool tensor + 1 will result in int64 tensor
|
|
if dtype == torch.bool:
|
|
expected[0] = expected[0].to(torch.int64).add(1)
|
|
|
|
res = torch._foreach_add(tensors, 1)
|
|
self.assertEqual(res, expected)
|
|
|
|
def test_bin_op_scalar_with_different_tensor_dtypes(self, device):
|
|
tensors = [torch.tensor([1.1], dtype=torch.float, device=device),
|
|
torch.tensor([1], dtype=torch.long, device=device)]
|
|
self.assertRaises(RuntimeError, lambda: torch._foreach_add(tensors, 1))
|
|
|
|
#
|
|
# Ops with list
|
|
#
|
|
def test_add_list_error_cases(self, device):
|
|
tensors1 = []
|
|
tensors2 = []
|
|
|
|
# Empty lists
|
|
with self.assertRaises(RuntimeError):
|
|
torch._foreach_add(tensors1, tensors2)
|
|
with self.assertRaises(RuntimeError):
|
|
torch._foreach_add_(tensors1, tensors2)
|
|
|
|
# One empty list
|
|
tensors1.append(torch.tensor([1], device=device))
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor list must have at least one tensor."):
|
|
torch._foreach_add(tensors1, tensors2)
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor list must have at least one tensor."):
|
|
torch._foreach_add_(tensors1, tensors2)
|
|
|
|
# Lists have different amount of tensors
|
|
tensors2.append(torch.tensor([1], device=device))
|
|
tensors2.append(torch.tensor([1], device=device))
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
|
|
torch._foreach_add(tensors1, tensors2)
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
|
|
torch._foreach_add_(tensors1, tensors2)
|
|
|
|
# Different dtypes
|
|
tensors1 = [torch.zeros(10, 10, device=device, dtype=torch.float) for _ in range(10)]
|
|
tensors2 = [torch.ones(10, 10, device=device, dtype=torch.int) for _ in range(10)]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."):
|
|
torch._foreach_add(tensors1, tensors2)
|
|
with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."):
|
|
torch._foreach_add_(tensors1, tensors2)
|
|
|
|
# different devices
|
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
|
tensor1 = torch.zeros(10, 10, device="cuda:0")
|
|
tensor2 = torch.ones(10, 10, device="cuda:1")
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
|
|
torch._foreach_add([tensor1], [tensor2])
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
|
|
torch._foreach_add_([tensor1], [tensor2])
|
|
|
|
# Coresponding tensors with different sizes
|
|
tensors1 = [torch.zeros(10, 10, device=device) for _ in range(10)]
|
|
tensors2 = [torch.ones(11, 11, device=device) for _ in range(10)]
|
|
with self.assertRaisesRegex(RuntimeError, "Corresponding tensors in lists must have the same size"):
|
|
torch._foreach_add(tensors1, tensors2)
|
|
with self.assertRaisesRegex(RuntimeError, r", got \[10, 10\] and \[11, 11\]"):
|
|
torch._foreach_add_(tensors1, tensors2)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_add_list(self, device, dtype):
|
|
self._test_bin_op_list(device, dtype, torch._foreach_add, torch._foreach_add_, torch.add)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_sub_list(self, device, dtype):
|
|
if dtype == torch.bool:
|
|
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool tensors is not supported."):
|
|
self._test_bin_op_list(device, dtype, torch._foreach_sub, torch._foreach_sub_, torch.sub)
|
|
else:
|
|
self._test_bin_op_list(device, dtype, torch._foreach_sub, torch._foreach_sub_, torch.sub)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_mul_list(self, device, dtype):
|
|
self._test_bin_op_list(device, dtype, torch._foreach_mul, torch._foreach_mul_, torch.mul)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_div_list(self, device, dtype):
|
|
if dtype in torch.testing.integral_types_and(torch.bool):
|
|
with self.assertRaisesRegex(RuntimeError, "Integer division of tensors using div or / is no longer"):
|
|
self._test_bin_op_list(device, dtype, torch._foreach_div, torch._foreach_div_, torch.div)
|
|
return
|
|
|
|
self._test_bin_op_list(device, dtype, torch._foreach_div, torch._foreach_div_, torch.div)
|
|
|
|
def test_bin_op_list_error_cases(self, device):
|
|
tensors1 = []
|
|
tensors2 = []
|
|
|
|
for bin_op in self.bin_ops:
|
|
# Empty lists
|
|
with self.assertRaises(RuntimeError):
|
|
bin_op(tensors1, tensors2)
|
|
|
|
# One empty list
|
|
tensors1.append(torch.tensor([1], device=device))
|
|
with self.assertRaises(RuntimeError):
|
|
bin_op(tensors1, tensors2)
|
|
|
|
# Lists have different amount of tensors
|
|
tensors2.append(torch.tensor([1], device=device))
|
|
tensors2.append(torch.tensor([1], device=device))
|
|
with self.assertRaises(RuntimeError):
|
|
bin_op(tensors1, tensors2)
|
|
|
|
# Different dtypes
|
|
tensors1 = [torch.zeros(2, 2, device=device, dtype=torch.float) for _ in range(2)]
|
|
tensors2 = [torch.ones(2, 2, device=device, dtype=torch.int) for _ in range(2)]
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
bin_op(tensors1, tensors2)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_add_list_different_sizes(self, device, dtype):
|
|
tensors1 = [torch.zeros(10 + n, 10 + n, device=device, dtype=dtype) for n in range(10)]
|
|
tensors2 = [torch.ones(10 + n, 10 + n, device=device, dtype=dtype) for n in range(10)]
|
|
|
|
res = torch._foreach_add(tensors1, tensors2)
|
|
torch._foreach_add_(tensors1, tensors2)
|
|
self.assertEqual(res, tensors1)
|
|
self.assertEqual(res, [torch.ones(10 + n, 10 + n, device=device, dtype=dtype) for n in range(10)])
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_add_list_slow_path(self, device, dtype):
|
|
# different strides
|
|
tensor1 = torch.zeros(10, 10, device=device, dtype=dtype)
|
|
tensor2 = torch.ones(10, 10, device=device, dtype=dtype)
|
|
res = torch._foreach_add([tensor1], [tensor2.t()])
|
|
torch._foreach_add_([tensor1], [tensor2])
|
|
self.assertEqual(res, [tensor1])
|
|
|
|
# non contiguous
|
|
tensor1 = torch.randn(5, 2, 1, 3, device=device)[:, 0]
|
|
tensor2 = torch.randn(5, 2, 1, 3, device=device)[:, 0]
|
|
self.assertFalse(tensor1.is_contiguous())
|
|
self.assertFalse(tensor2.is_contiguous())
|
|
res = torch._foreach_add([tensor1], [tensor2])
|
|
torch._foreach_add_([tensor1], [tensor2])
|
|
self.assertEqual(res, [tensor1])
|
|
|
|
instantiate_device_type_tests(TestForeach, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|