mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: The tests were using the old args, which caused them to emit a lot of deprecation warnings. closes #9103. Reviewed By: ezyang Differential Revision: D8720581 Pulled By: li-roy fbshipit-source-id: 3b79527f6fe862fb48b99a6394e8d7b89fc7a8c8
37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
import torch
|
|
from torch.nn.functional import _Reduction
|
|
from .Criterion import Criterion
|
|
|
|
|
|
class AbsCriterion(Criterion):
|
|
|
|
def __init__(self, sizeAverage=True):
|
|
super(AbsCriterion, self).__init__()
|
|
self.sizeAverage = sizeAverage
|
|
self.output_tensor = torch.Tensor(1)
|
|
|
|
def updateOutput(self, input, target):
|
|
if self.output_tensor is None:
|
|
self.output_tensor = input.new(1)
|
|
self._backend.AbsCriterion_updateOutput(
|
|
self._backend.library_state,
|
|
input,
|
|
target,
|
|
self.output_tensor,
|
|
_Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
|
|
)
|
|
self.output = self.output_tensor[0].item()
|
|
return self.output
|
|
|
|
def updateGradInput(self, input, target):
|
|
implicit_gradOutput = torch.ones(1).type_as(input)
|
|
self._backend.AbsCriterion_updateGradInput(
|
|
self._backend.library_state,
|
|
input,
|
|
target,
|
|
implicit_gradOutput,
|
|
self.gradInput,
|
|
_Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
|
|
)
|
|
return self.gradInput
|