Mark bucketize as not subject to autograd (#44102)

Summary:
Bucketize returns integers, currently this triggers an internal assert, so we apply the mechanism for this case (also used for argmax etc.).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44102

Reviewed By: zou3519

Differential Revision: D23500048

Pulled By: albanD

fbshipit-source-id: fdd869cd1feead6616b532b3e188bd5512adedea
This commit is contained in:
Thomas Viehmann
2020-09-03 12:01:57 -07:00
committed by Facebook GitHub Bot
parent 91b0d1866a
commit 42f9897983
2 changed files with 7 additions and 1 deletions

View File

@ -4560,6 +4560,11 @@ for shape in [(1,), ()]:
self.assertFalse(out.dtype.is_floating_point)
self.assertFalse(out.requires_grad)
bins = torch.linspace(0, 1.0, requires_grad=True)
vals = torch.rand(5, 5, requires_grad=True)
out = torch.bucketize(vals, bins)
self.assertFalse(out.dtype.is_floating_point)
self.assertFalse(out.requires_grad)
def index_variable(shape, max_indices):
if not isinstance(shape, tuple):

View File

@ -138,7 +138,8 @@ DONT_REQUIRE_DERIVATIVE = {
# Quantize functions should not record gradients
'quantize_per_tensor', 'quantize_per_channel',
# Functions that return integers should not have output that require gradients
'argmax', 'argmin', 'argsort', 'searchsorted'
'argmax', 'argmin', 'argsort', 'searchsorted',
'bucketize'
}
# Some operators invalidate the grad_accumulator. Let's reset it.