mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Stop warning on .names() access in max_pool2d and max_pool2d_backward (#60059)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60059 Fixes #60053. The problem is that `.names()` always triggers the named tensor warning. To not trigger it, one has to guard it with has_names: `x.has_names() ? x.names() : DimnameList{}` This is not the first time this has happened; we should probably make it so that .names() doesn't raise a warning unless it is actually populated with names. That's a little tricky to implement so I'm leaving it for the future. Test Plan: - New test, also run `python test/test_nn.py -v -k "max_pool"` and confirm there are no warnings. Reviewed By: gchanan Differential Revision: D29152737 Pulled By: zou3519 fbshipit-source-id: 89a2fdbe6a6064a7044b5b75f7d0c58e51e57509
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ef09428804
commit
ebafd2aadf
@ -296,6 +296,15 @@ class TestNamedTensor(TestCase):
|
||||
check_tuple_return(F.max_pool2d_with_indices, [named_tensor_2d, [2, 2]], named_tensor_2d.names)
|
||||
check_tuple_return(F.max_pool3d_with_indices, [named_tensor_3d, [2, 2, 2]], named_tensor_3d.names)
|
||||
|
||||
def test_max_pooling_without_names_does_not_warn(self):
|
||||
for device in torch.testing.get_all_device_types():
|
||||
tensor_2d = torch.zeros(2, 3, 5, 7, device=device, requires_grad=True)
|
||||
with warnings.catch_warnings(record=True) as warns:
|
||||
warnings.simplefilter("always")
|
||||
result = F.max_pool2d(tensor_2d, [2, 2])
|
||||
result.sum().backward()
|
||||
self.assertEqual(len(warns), 0)
|
||||
|
||||
def test_no_save_support(self):
|
||||
named_tensor = torch.zeros(2, 3, names=('N', 'C'))
|
||||
buf = io.BytesIO()
|
||||
|
||||
Reference in New Issue
Block a user