mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "_foreach_copy
with different src/dst dtypes (#121717)"
This reverts commit da2a9a05127c2b44e447e734d99e727d856cb36f. Reverted https://github.com/pytorch/pytorch/pull/121717 on behalf of https://github.com/janeyx99 due to Causing IMAs on V100s internally :C ([comment](https://github.com/pytorch/pytorch/pull/121717#issuecomment-2025553295))
This commit is contained in:
@ -838,20 +838,6 @@ class TestForeach(TestCase):
|
||||
copy_(t, s, non_blocking)
|
||||
self.assertEqual(ref_input, sample.input)
|
||||
|
||||
@onlyCUDA
|
||||
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
|
||||
def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
|
||||
# check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_
|
||||
foreach_copy_ = ForeachFuncWrapper(op.inplace_variant)
|
||||
for sample in op.sample_inputs(device, dtype, noncontiguous=False):
|
||||
for src_dtype in floating_types_and(torch.half, torch.bfloat16):
|
||||
if src_dtype == dtype:
|
||||
continue
|
||||
self_tensors = [t.clone() for t in sample.input]
|
||||
src_tensors = [t.to(src_dtype) for t in self_tensors]
|
||||
out = foreach_copy_((self_tensors, src_tensors), is_cuda=True, expect_fastpath=True)
|
||||
self.assertEqual(out, [torch.empty_like(t).copy_(s) for t, s in zip(self_tensors, src_tensors)])
|
||||
|
||||
# Test reverse-mode & forward-mode AD if supported.
|
||||
@onlyCUDA
|
||||
@ops(
|
||||
|
Reference in New Issue
Block a user