mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
* Use _six for inf and nan (#9500) Summary: Things like `float('inf')` are actually quite expensive. ```py In [1]: import math In [2]: %timeit -n 200 math.inf 49.3 ns ± 1.42 ns per loop (mean ± std. dev. of 7 runs, 200 loops each) In [3]: %timeit -n 200 float('inf') 194 ns ± 39.1 ns per loop (mean ± std. dev. of 7 runs, 200 loops each) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/9500 Reviewed By: soumith Differential Revision: D8876229 Pulled By: SsnL fbshipit-source-id: 78602b76bb53d5588910b58270930c0bd413d2d7 * use int64_t for im2col
156 lines
5.4 KiB
Python
156 lines
5.4 KiB
Python
import torch
|
|
from torch._six import inf
|
|
from .Module import Module
|
|
from .utils import clear
|
|
|
|
|
|
class Normalize(Module):
|
|
|
|
def __init__(self, p, eps=1e-10):
|
|
super(Normalize, self).__init__()
|
|
assert p > 0
|
|
self.p = p
|
|
self.eps = eps
|
|
|
|
self._output = None
|
|
self.norm = None
|
|
self.buffer = None
|
|
self._indices = None
|
|
self.normp = None
|
|
self._gradInput = None
|
|
self.cross = None
|
|
self.buffer2 = None
|
|
|
|
def updateOutput(self, input):
|
|
assert input.dim() == 2
|
|
input_size = input.size()
|
|
|
|
if self._output is None:
|
|
self._output = input.new()
|
|
if self.norm is None:
|
|
self.norm = input.new()
|
|
if self.buffer is None:
|
|
self.buffer = input.new()
|
|
|
|
self._output.resize_as_(input)
|
|
|
|
# specialization for the infinity norm
|
|
if self.p == inf:
|
|
if not self._indices:
|
|
self._indices = torch.cuda.FloatTensor() if torch.typename(self.output) == 'torch.cuda.FloatTensor' \
|
|
else torch.LongTensor()
|
|
|
|
torch.abs(input, out=self.buffer)
|
|
torch.max(self._indices, self.buffer, 1, out=self.norm, keepdim=True)
|
|
self.norm.add_(self.eps)
|
|
else:
|
|
if self.normp is None:
|
|
self.normp = input.new()
|
|
if self.p % 2 != 0:
|
|
torch.abs(input, out=self.buffer).pow_(self.p)
|
|
else:
|
|
torch.pow(input, self.p, out=self.buffer)
|
|
|
|
torch.sum(self.buffer, 1, out=self.normp, keepdim=True).add_(self.eps)
|
|
torch.pow(self.normp, 1. / self.p, out=self.norm)
|
|
|
|
torch.div(input, self.norm.view(-1, 1).expand_as(input), out=self._output)
|
|
|
|
self.output = self._output.view(input_size)
|
|
return self.output
|
|
|
|
def updateGradInput(self, input, gradOutput):
|
|
assert input.dim() == 2
|
|
assert gradOutput.dim() == 2
|
|
|
|
input_size = input.size()
|
|
n = input.size(0) # batch size
|
|
d = input.size(1) # dimensionality of vectors
|
|
|
|
if self._gradInput is None:
|
|
self._gradInput = input.new()
|
|
if self.cross is None:
|
|
self.cross = input.new()
|
|
# compute diagonal term with gradOutput
|
|
self._gradInput.resize_(n, d)
|
|
if self.p == inf:
|
|
# specialization for the inf case
|
|
torch.mul(self.norm.view(n, 1, 1).expand(n, d, 1), gradOutput, out=self._gradInput)
|
|
self.buffer.resize_as_(input).zero_()
|
|
self.cross.resize_(n, 1)
|
|
torch.gather(input, 1, self._indices, out=self.cross)
|
|
self.cross.div_(self.norm)
|
|
self.buffer.scatter_(1, self._indices, self.cross)
|
|
else:
|
|
torch.mul(self.normp.view(n, 1).expand(n, d), gradOutput, out=self._gradInput)
|
|
# small optimizations for different p
|
|
# buffer = input*|input|^(p-2)
|
|
# for non-even p, need to add absolute value
|
|
if self.p % 2 != 0:
|
|
if self.p < 2:
|
|
# add eps to avoid possible division by 0
|
|
torch.abs(input, out=self.buffer).add_(self.eps).pow_(self.p - 2).mul_(input)
|
|
else:
|
|
torch.abs(input, out=self.buffer).pow_(self.p - 2).mul_(input)
|
|
# special case for p == 2, pow(x, 0) = 1
|
|
elif self.p == 2:
|
|
self.buffer.copy_(input)
|
|
else:
|
|
# p is even and > 2, pow(x, p) is always positive
|
|
torch.pow(input, self.p - 2, out=self.buffer).mul_(input)
|
|
|
|
# compute cross term in two steps
|
|
self.cross.resize_(n, 1)
|
|
|
|
# instead of having a huge temporary matrix (b1*b2),
|
|
#: the computations as b1*(b2*gradOutput). This avoids redundant
|
|
# computation and also a huge buffer of size n*d^2
|
|
if self.buffer2 is None:
|
|
self.buffer2 = input.new() # nxd
|
|
torch.mul(input, gradOutput, out=self.buffer2)
|
|
torch.sum(self.buffer2, 1, out=self.cross, keepdim=True)
|
|
|
|
self.buffer.mul_(self.cross.expand_as(self.buffer))
|
|
self._gradInput.add_(-1, self.buffer)
|
|
|
|
# reuse cross buffer for normalization
|
|
if self.p == inf:
|
|
torch.mul(self.norm, self.norm, out=self.cross)
|
|
else:
|
|
torch.mul(self.normp, self.norm, out=self.cross)
|
|
|
|
self._gradInput.div_(self.cross.expand(n, d))
|
|
|
|
self.gradInput = self._gradInput.view(input_size)
|
|
return self.gradInput
|
|
|
|
def __repr__(self):
|
|
return super(Normalize, self).__repr__() + '({})'.format(self.p)
|
|
|
|
def type(self, type, tensorCache=None):
|
|
if not type:
|
|
return self._type
|
|
# torch.max expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor.
|
|
if type == 'torch.cuda.FloatTensor':
|
|
super(Normalize, self).type(type, tensorCache)
|
|
else:
|
|
# self._indices must be a LongTensor. Setting it to nil temporarily avoids
|
|
# unnecessary memory allocations.
|
|
indices, self._indices = self._indices, None
|
|
super(Normalize, self).type(type, tensorCache)
|
|
self._indices = indices.long() if indices else None
|
|
|
|
return self
|
|
|
|
def clearState(self):
|
|
clear(self, [
|
|
'_output',
|
|
'_indices',
|
|
'_gradInput',
|
|
'buffer',
|
|
'norm',
|
|
'normp',
|
|
'cross',
|
|
])
|
|
return super(Normalize, self).clearState()
|