remove torch.equal usages (#89527)

Preparation for the next PR in this stack: #89559.

I replaced

- `self.assertTrue(torch.equal(...))` with `self.assertEqual(..., rtol=0, atol=0, exact_device=True)`,
- the same for `self.assertFalse(...)` with `self.assertNotEqual(...)`, and
- `assert torch.equal(...)` with `torch.testing.assert_close(..., rtol=0, atol=0)` (note that we don't need to set `check_device=True` here since that is the default).

There were a few instances where the result of `torch.equal` is used directly. In that cases I've replaced with `(... == ...).all().item()` while sometimes also dropping the `.item()` depending on the context.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89527
Approved by: https://github.com/mruberry
This commit is contained in:
Philip Meier
2022-12-01 09:28:03 +01:00
committed by PyTorch MergeBot
parent 0acbcef4ab
commit 4095ef8b80
38 changed files with 169 additions and 154 deletions

View File

@ -59,7 +59,7 @@ class TestAutocastCPU(TestCase):
# For example, lstm_cell returns a tuple and equal returns bool.
def compare(first, second):
if isinstance(first, torch.Tensor):
return torch.equal(first, second)
return (first == second).all().item()
elif isinstance(first, collections.abc.Iterable):
return all(compare(f, s) for f, s in zip(first, second))
else: