correctly keep track of processed tensors for foreach reductions (#140103)

Fixes #140066

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140103
Approved by: https://github.com/janeyx99

Co-authored-by: Jane Xu <janeyx@meta.com>
This commit is contained in:
Natalia Gimelshein
2024-11-08 23:04:51 +00:00
committed by PyTorch MergeBot
parent f3cbf67686
commit 1cdaf1d85f
3 changed files with 41 additions and 15 deletions

View File

@ -1014,20 +1014,34 @@ class TestForeach(TestCase):
@onlyCUDA
@ops(foreach_reduce_op_db, allowed_dtypes=floating_types())
@parametrize("use_cuda_graph", (False, True))
def test_big_num_tensors(self, device, dtype, op, use_cuda_graph):
@parametrize("w_empty", (False, True))
def test_big_num_tensors(self, device, dtype, op, use_cuda_graph, w_empty):
# foreach_max cannot handle empty tensors as max requires an identity
intersperse_empty_tensors = w_empty and op.name != "_foreach_max"
N = 600
indices_with_empty_tensors = (
set()
if not intersperse_empty_tensors
else {200, 300, 301, 400, 401, 402, 404, 598}
)
tensorlist = [
make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False)
for _ in range(N)
if i not in indices_with_empty_tensors
else torch.empty(0, dtype=dtype, device=device)
for i in range(N)
]
fn, ref_fn, *_ = self._get_funcs(op)
import math
if op.name == "_foreach_norm":
ords = (1, 2, math.inf)
ords = [1, 2]
if not intersperse_empty_tensors:
# inf norm over an empty tensor is not defined by vector norm as it expects an identity
ords.append(math.inf)
else:
ords = (None,)
ords = [None]
for ord in ords:
kwargs = {"ord": ord} if ord else {}
@ -1055,20 +1069,28 @@ class TestForeach(TestCase):
@onlyCUDA
@ops(foreach_reduce_op_db)
def test_foreach_reduce_large_input(self, device, dtype, op):
# test inputs larger than kChunkSize = 65536
N = 65536 * 2
@parametrize("w_empty", (False, True))
def test_foreach_reduce_large_input(self, device, dtype, op, w_empty):
# test inputs larger than kChunkSize (65536) * max_num_blocks (320)
N = 65536 * 320 * 2
disable_fastpath = False
kwargs = {}
if op.name == "_foreach_norm":
ord = 2
disable_fastpath = not (
ord in (1, 2)
and dtype in floating_types_and(torch.half, torch.bfloat16)
kwargs["ord"] = 2
disable_fastpath = dtype not in floating_types_and(
torch.half, torch.bfloat16
)
kwargs["ord"] = ord
inputs = ([make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)],)
tensorlist = [
make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)
]
# foreach_max cannot handle empty tensors as max over empty is undefined
if w_empty and op.name != "_foreach_max":
tensorlist += [
torch.empty(0, dtype=dtype, device=device),
make_tensor((N,), dtype=dtype, device=device, noncontiguous=False),
]
inputs = (tensorlist,)
wrapped_op, ref, _, _ = self._get_funcs(op)
self.assertEqual(
ref(inputs, **kwargs),