mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Convert ForeachFuncInfo to dataclass (#125001)"
This reverts commit 9466335ae4cb049efd3f4c2b32b2115ba00694f3.
Reverted https://github.com/pytorch/pytorch/pull/125001 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it is breaking on ROCm 9466335ae4 ([comment](https://github.com/pytorch/pytorch/pull/125001#issuecomment-2086640674))
This commit is contained in:
@ -164,22 +164,20 @@ class TestForeach(TestCase):
|
||||
wrapped_op, _, inplace_op, _ = self._get_funcs(op)
|
||||
|
||||
for sample in op.sample_zero_size_inputs(device, dtype):
|
||||
if op.method_variant is not None:
|
||||
if op.supports_out:
|
||||
wrapped_op(
|
||||
(sample.input, *sample.args),
|
||||
is_cuda=self.is_cuda,
|
||||
expect_fastpath=True,
|
||||
zero_size=True,
|
||||
)
|
||||
|
||||
if op.inplace_variant is not None:
|
||||
with InplaceForeachVersionBumpCheck(self, sample.input):
|
||||
inplace_op(
|
||||
(sample.input, *sample.args),
|
||||
is_cuda=self.is_cuda,
|
||||
expect_fastpath=True,
|
||||
zero_size=True,
|
||||
)
|
||||
with InplaceForeachVersionBumpCheck(self, sample.input):
|
||||
inplace_op(
|
||||
(sample.input, *sample.args),
|
||||
is_cuda=self.is_cuda,
|
||||
expect_fastpath=True,
|
||||
zero_size=True,
|
||||
)
|
||||
|
||||
@skipIfRocmVersionLessThan((6, 0))
|
||||
@ops(
|
||||
@ -1227,16 +1225,12 @@ class TestForeach(TestCase):
|
||||
"inplace", (False, True), name_fn=lambda x: "inplace" if x else "outplace"
|
||||
)
|
||||
def test_autodiff(self, device, dtype, op, inplace):
|
||||
if not (op.supports_autograd or op.supports_forward_ad):
|
||||
self.skipTest("neither reverse mode nor forward mode supported")
|
||||
if (not inplace) and not op.supports_out:
|
||||
self.skipTest("out-of-place not implemented")
|
||||
if inplace and op.has_no_in_place:
|
||||
self.skipTest("in-place not implemented")
|
||||
if not (
|
||||
op.supports_autograd
|
||||
or op.supports_inplace_autograd
|
||||
or op.supports_forward_ad
|
||||
):
|
||||
self.skipTest("neither reverse mode nor forward mode supported")
|
||||
|
||||
# note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user