mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-31 04:04:57 +08:00
chunky sync again
This commit is contained in:
@ -106,7 +106,11 @@ class GradientChecker:
|
||||
outputs_with_grads
|
||||
)
|
||||
grad_estimate = np.zeros_like(inputs[input_to_check])
|
||||
assert grad_estimate.shape == grad.shape, input_to_check
|
||||
if grad_estimate.shape != grad.shape:
|
||||
raise Exception(
|
||||
"Mismatched gradient shapes: estimated ({}), grad ({})".format(
|
||||
grad_estimate.shape, grad.shape))
|
||||
|
||||
for current_dim in range(dims_to_check):
|
||||
# Positive gradient
|
||||
inputs[input_to_check].flat[current_dim] += self._stepsize
|
||||
|
||||
Reference in New Issue
Block a user