mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
correctly keep track of processed tensors for foreach reductions (#140103)
Fixes #140066 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140103 Approved by: https://github.com/janeyx99 Co-authored-by: Jane Xu <janeyx@meta.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
f3cbf67686
commit
1cdaf1d85f
@ -1014,20 +1014,34 @@ class TestForeach(TestCase):
|
||||
@onlyCUDA
|
||||
@ops(foreach_reduce_op_db, allowed_dtypes=floating_types())
|
||||
@parametrize("use_cuda_graph", (False, True))
|
||||
def test_big_num_tensors(self, device, dtype, op, use_cuda_graph):
|
||||
@parametrize("w_empty", (False, True))
|
||||
def test_big_num_tensors(self, device, dtype, op, use_cuda_graph, w_empty):
|
||||
# foreach_max cannot handle empty tensors as max requires an identity
|
||||
intersperse_empty_tensors = w_empty and op.name != "_foreach_max"
|
||||
|
||||
N = 600
|
||||
indices_with_empty_tensors = (
|
||||
set()
|
||||
if not intersperse_empty_tensors
|
||||
else {200, 300, 301, 400, 401, 402, 404, 598}
|
||||
)
|
||||
tensorlist = [
|
||||
make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False)
|
||||
for _ in range(N)
|
||||
if i not in indices_with_empty_tensors
|
||||
else torch.empty(0, dtype=dtype, device=device)
|
||||
for i in range(N)
|
||||
]
|
||||
fn, ref_fn, *_ = self._get_funcs(op)
|
||||
|
||||
import math
|
||||
|
||||
if op.name == "_foreach_norm":
|
||||
ords = (1, 2, math.inf)
|
||||
ords = [1, 2]
|
||||
if not intersperse_empty_tensors:
|
||||
# inf norm over an empty tensor is not defined by vector norm as it expects an identity
|
||||
ords.append(math.inf)
|
||||
else:
|
||||
ords = (None,)
|
||||
ords = [None]
|
||||
|
||||
for ord in ords:
|
||||
kwargs = {"ord": ord} if ord else {}
|
||||
@ -1055,20 +1069,28 @@ class TestForeach(TestCase):
|
||||
|
||||
@onlyCUDA
|
||||
@ops(foreach_reduce_op_db)
|
||||
def test_foreach_reduce_large_input(self, device, dtype, op):
|
||||
# test inputs larger than kChunkSize = 65536
|
||||
N = 65536 * 2
|
||||
@parametrize("w_empty", (False, True))
|
||||
def test_foreach_reduce_large_input(self, device, dtype, op, w_empty):
|
||||
# test inputs larger than kChunkSize (65536) * max_num_blocks (320)
|
||||
N = 65536 * 320 * 2
|
||||
disable_fastpath = False
|
||||
kwargs = {}
|
||||
if op.name == "_foreach_norm":
|
||||
ord = 2
|
||||
disable_fastpath = not (
|
||||
ord in (1, 2)
|
||||
and dtype in floating_types_and(torch.half, torch.bfloat16)
|
||||
kwargs["ord"] = 2
|
||||
disable_fastpath = dtype not in floating_types_and(
|
||||
torch.half, torch.bfloat16
|
||||
)
|
||||
kwargs["ord"] = ord
|
||||
|
||||
inputs = ([make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)],)
|
||||
tensorlist = [
|
||||
make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)
|
||||
]
|
||||
# foreach_max cannot handle empty tensors as max over empty is undefined
|
||||
if w_empty and op.name != "_foreach_max":
|
||||
tensorlist += [
|
||||
torch.empty(0, dtype=dtype, device=device),
|
||||
make_tensor((N,), dtype=dtype, device=device, noncontiguous=False),
|
||||
]
|
||||
inputs = (tensorlist,)
|
||||
wrapped_op, ref, _, _ = self._get_funcs(op)
|
||||
self.assertEqual(
|
||||
ref(inputs, **kwargs),
|
||||
|
Reference in New Issue
Block a user