mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: The tests were using the old args, which caused them to emit a lot of deprecation warnings. closes #9103. Reviewed By: ezyang Differential Revision: D8720581 Pulled By: li-roy fbshipit-source-id: 3b79527f6fe862fb48b99a6394e8d7b89fc7a8c8
109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
import math
|
|
import torch
|
|
from torch.nn.functional import _Reduction
|
|
from .MSECriterion import MSECriterion
|
|
|
|
"""
|
|
This file implements a criterion for multi-class classification.
|
|
It learns an embedding per class, where each class' embedding
|
|
is a point on an (N-1)-dimensional simplex, where N is
|
|
the number of classes.
|
|
For example usage of this class, look at.c/criterion.md
|
|
|
|
Reference: http.//arxiv.org/abs/1506.08230
|
|
"""
|
|
|
|
|
|
class ClassSimplexCriterion(MSECriterion):
|
|
|
|
def __init__(self, nClasses):
|
|
super(ClassSimplexCriterion, self).__init__()
|
|
self.nClasses = nClasses
|
|
|
|
# embedding the simplex in a space of dimension strictly greater than
|
|
# the minimum possible (nClasses-1) is critical for effective training.
|
|
simp = self._regsplex(nClasses - 1)
|
|
self.simplex = torch.cat((simp, torch.zeros(simp.size(0), nClasses - simp.size(1))), 1)
|
|
self._target = torch.Tensor(nClasses)
|
|
|
|
self.output_tensor = None
|
|
|
|
def _regsplex(self, n):
|
|
"""
|
|
regsplex returns the coordinates of the vertices of a
|
|
regular simplex centered at the origin.
|
|
The Euclidean norms of the vectors specifying the vertices are
|
|
all equal to 1. The input n is the dimension of the vectors;
|
|
the simplex has n+1 vertices.
|
|
|
|
input:
|
|
n # dimension of the vectors specifying the vertices of the simplex
|
|
|
|
output:
|
|
a # tensor dimensioned (n+1, n) whose rows are
|
|
vectors specifying the vertices
|
|
|
|
reference:
|
|
http.//en.wikipedia.org/wiki/Simplex#Cartesian_coordinates_for_regular_n-dimensional_simplex_in_Rn
|
|
"""
|
|
a = torch.zeros(n + 1, n)
|
|
|
|
for k in range(n):
|
|
# determine the last nonzero entry in the vector for the k-th vertex
|
|
if k == 0:
|
|
a[k][k] = 1
|
|
else:
|
|
a[k][k] = math.sqrt(1 - a[k:k + 1, 0:k + 1].norm() ** 2)
|
|
|
|
# fill_ the k-th coordinates for the vectors of the remaining vertices
|
|
c = (a[k][k] ** 2 - 1 - 1 / n) / a[k][k]
|
|
a[k + 1:n + 2, k:k + 1].fill_(c)
|
|
|
|
return a
|
|
|
|
# handle target being both 1D tensor, and
|
|
# target being 2D tensor (2D tensor means.nt: anything)
|
|
def _transformTarget(self, target):
|
|
assert target.dim() == 1
|
|
nSamples = target.size(0)
|
|
self._target.resize_(nSamples, self.nClasses)
|
|
for i in range(nSamples):
|
|
self._target[i].copy_(self.simplex[int(target[i])])
|
|
|
|
def updateOutput(self, input, target):
|
|
self._transformTarget(target)
|
|
|
|
assert input.nelement() == self._target.nelement()
|
|
if self.output_tensor is None:
|
|
self.output_tensor = input.new(1)
|
|
self._backend.MSECriterion_updateOutput(
|
|
self._backend.library_state,
|
|
input,
|
|
self._target,
|
|
self.output_tensor,
|
|
_Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
|
|
)
|
|
self.output = self.output_tensor[0].item()
|
|
return self.output
|
|
|
|
def updateGradInput(self, input, target):
|
|
assert input.nelement() == self._target.nelement()
|
|
implicit_gradOutput = torch.Tensor([1]).type(input.type())
|
|
self._backend.MSECriterion_updateGradInput(
|
|
self._backend.library_state,
|
|
input,
|
|
self._target,
|
|
implicit_gradOutput,
|
|
self.gradInput,
|
|
_Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
|
|
)
|
|
return self.gradInput
|
|
|
|
def getPredictions(self, input):
|
|
return torch.mm(input, self.simplex.t())
|
|
|
|
def getTopPrediction(self, input):
|
|
prod = self.getPredictions(input)
|
|
_, maxs = prod.max(prod.ndimension() - 1)
|
|
return maxs.view(-1)
|