mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Mark bucketize as not subject to autograd (#44102)
Summary: Bucketize returns integers, currently this triggers an internal assert, so we apply the mechanism for this case (also used for argmax etc.). Pull Request resolved: https://github.com/pytorch/pytorch/pull/44102 Reviewed By: zou3519 Differential Revision: D23500048 Pulled By: albanD fbshipit-source-id: fdd869cd1feead6616b532b3e188bd5512adedea
This commit is contained in:
committed by
Facebook GitHub Bot
parent
91b0d1866a
commit
42f9897983
@ -4560,6 +4560,11 @@ for shape in [(1,), ()]:
|
||||
self.assertFalse(out.dtype.is_floating_point)
|
||||
self.assertFalse(out.requires_grad)
|
||||
|
||||
bins = torch.linspace(0, 1.0, requires_grad=True)
|
||||
vals = torch.rand(5, 5, requires_grad=True)
|
||||
out = torch.bucketize(vals, bins)
|
||||
self.assertFalse(out.dtype.is_floating_point)
|
||||
self.assertFalse(out.requires_grad)
|
||||
|
||||
def index_variable(shape, max_indices):
|
||||
if not isinstance(shape, tuple):
|
||||
|
@ -138,7 +138,8 @@ DONT_REQUIRE_DERIVATIVE = {
|
||||
# Quantize functions should not record gradients
|
||||
'quantize_per_tensor', 'quantize_per_channel',
|
||||
# Functions that return integers should not have output that require gradients
|
||||
'argmax', 'argmin', 'argsort', 'searchsorted'
|
||||
'argmax', 'argmin', 'argsort', 'searchsorted',
|
||||
'bucketize'
|
||||
}
|
||||
|
||||
# Some operators invalidate the grad_accumulator. Let's reset it.
|
||||
|
Reference in New Issue
Block a user