Stop warning on .names() access in max_pool2d and max_pool2d_backward (#60059)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60059

Fixes #60053.

The problem is that `.names()` always triggers the named tensor warning.
To not trigger it, one has to guard it with has_names:
`x.has_names() ? x.names() : DimnameList{}`

This is not the first time this has happened; we should probably
make it so that .names() doesn't raise a warning unless it is actually
populated with names. That's a little tricky to implement so I'm leaving
it for the future.

Test Plan:
- New test, also run `python test/test_nn.py -v -k "max_pool"` and
confirm there are no warnings.

Reviewed By: gchanan

Differential Revision: D29152737

Pulled By: zou3519

fbshipit-source-id: 89a2fdbe6a6064a7044b5b75f7d0c58e51e57509
This commit is contained in:
Richard Zou
2021-06-17 10:33:08 -07:00
committed by Facebook GitHub Bot
parent ef09428804
commit ebafd2aadf
2 changed files with 16 additions and 5 deletions

View File

@ -296,6 +296,15 @@ class TestNamedTensor(TestCase):
check_tuple_return(F.max_pool2d_with_indices, [named_tensor_2d, [2, 2]], named_tensor_2d.names)
check_tuple_return(F.max_pool3d_with_indices, [named_tensor_3d, [2, 2, 2]], named_tensor_3d.names)
def test_max_pooling_without_names_does_not_warn(self):
for device in torch.testing.get_all_device_types():
tensor_2d = torch.zeros(2, 3, 5, 7, device=device, requires_grad=True)
with warnings.catch_warnings(record=True) as warns:
warnings.simplefilter("always")
result = F.max_pool2d(tensor_2d, [2, 2])
result.sum().backward()
self.assertEqual(len(warns), 0)
def test_no_save_support(self):
named_tensor = torch.zeros(2, 3, names=('N', 'C'))
buf = io.BytesIO()