Files
pytorch/torch/legacy/nn/Normalize.py
Tongzhou Wang 3684cc4e52 cherry pick #9500 and #9590 into 0.4.1 (#9599)
* Use _six for inf and nan (#9500)

Summary:
Things like `float('inf')` are actually quite expensive.
```py
In [1]: import math

In [2]: %timeit -n 200 math.inf
49.3 ns ± 1.42 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)

In [3]: %timeit -n 200 float('inf')
194 ns ± 39.1 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9500

Reviewed By: soumith

Differential Revision: D8876229

Pulled By: SsnL

fbshipit-source-id: 78602b76bb53d5588910b58270930c0bd413d2d7

* use int64_t for im2col
2018-07-19 17:33:33 -04:00

156 lines
5.4 KiB
Python

import torch
from torch._six import inf
from .Module import Module
from .utils import clear
class Normalize(Module):
def __init__(self, p, eps=1e-10):
super(Normalize, self).__init__()
assert p > 0
self.p = p
self.eps = eps
self._output = None
self.norm = None
self.buffer = None
self._indices = None
self.normp = None
self._gradInput = None
self.cross = None
self.buffer2 = None
def updateOutput(self, input):
assert input.dim() == 2
input_size = input.size()
if self._output is None:
self._output = input.new()
if self.norm is None:
self.norm = input.new()
if self.buffer is None:
self.buffer = input.new()
self._output.resize_as_(input)
# specialization for the infinity norm
if self.p == inf:
if not self._indices:
self._indices = torch.cuda.FloatTensor() if torch.typename(self.output) == 'torch.cuda.FloatTensor' \
else torch.LongTensor()
torch.abs(input, out=self.buffer)
torch.max(self._indices, self.buffer, 1, out=self.norm, keepdim=True)
self.norm.add_(self.eps)
else:
if self.normp is None:
self.normp = input.new()
if self.p % 2 != 0:
torch.abs(input, out=self.buffer).pow_(self.p)
else:
torch.pow(input, self.p, out=self.buffer)
torch.sum(self.buffer, 1, out=self.normp, keepdim=True).add_(self.eps)
torch.pow(self.normp, 1. / self.p, out=self.norm)
torch.div(input, self.norm.view(-1, 1).expand_as(input), out=self._output)
self.output = self._output.view(input_size)
return self.output
def updateGradInput(self, input, gradOutput):
assert input.dim() == 2
assert gradOutput.dim() == 2
input_size = input.size()
n = input.size(0) # batch size
d = input.size(1) # dimensionality of vectors
if self._gradInput is None:
self._gradInput = input.new()
if self.cross is None:
self.cross = input.new()
# compute diagonal term with gradOutput
self._gradInput.resize_(n, d)
if self.p == inf:
# specialization for the inf case
torch.mul(self.norm.view(n, 1, 1).expand(n, d, 1), gradOutput, out=self._gradInput)
self.buffer.resize_as_(input).zero_()
self.cross.resize_(n, 1)
torch.gather(input, 1, self._indices, out=self.cross)
self.cross.div_(self.norm)
self.buffer.scatter_(1, self._indices, self.cross)
else:
torch.mul(self.normp.view(n, 1).expand(n, d), gradOutput, out=self._gradInput)
# small optimizations for different p
# buffer = input*|input|^(p-2)
# for non-even p, need to add absolute value
if self.p % 2 != 0:
if self.p < 2:
# add eps to avoid possible division by 0
torch.abs(input, out=self.buffer).add_(self.eps).pow_(self.p - 2).mul_(input)
else:
torch.abs(input, out=self.buffer).pow_(self.p - 2).mul_(input)
# special case for p == 2, pow(x, 0) = 1
elif self.p == 2:
self.buffer.copy_(input)
else:
# p is even and > 2, pow(x, p) is always positive
torch.pow(input, self.p - 2, out=self.buffer).mul_(input)
# compute cross term in two steps
self.cross.resize_(n, 1)
# instead of having a huge temporary matrix (b1*b2),
#: the computations as b1*(b2*gradOutput). This avoids redundant
# computation and also a huge buffer of size n*d^2
if self.buffer2 is None:
self.buffer2 = input.new() # nxd
torch.mul(input, gradOutput, out=self.buffer2)
torch.sum(self.buffer2, 1, out=self.cross, keepdim=True)
self.buffer.mul_(self.cross.expand_as(self.buffer))
self._gradInput.add_(-1, self.buffer)
# reuse cross buffer for normalization
if self.p == inf:
torch.mul(self.norm, self.norm, out=self.cross)
else:
torch.mul(self.normp, self.norm, out=self.cross)
self._gradInput.div_(self.cross.expand(n, d))
self.gradInput = self._gradInput.view(input_size)
return self.gradInput
def __repr__(self):
return super(Normalize, self).__repr__() + '({})'.format(self.p)
def type(self, type, tensorCache=None):
if not type:
return self._type
# torch.max expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor.
if type == 'torch.cuda.FloatTensor':
super(Normalize, self).type(type, tensorCache)
else:
# self._indices must be a LongTensor. Setting it to nil temporarily avoids
# unnecessary memory allocations.
indices, self._indices = self._indices, None
super(Normalize, self).type(type, tensorCache)
self._indices = indices.long() if indices else None
return self
def clearState(self):
clear(self, [
'_output',
'_indices',
'_gradInput',
'buffer',
'norm',
'normp',
'cross',
])
return super(Normalize, self).clearState()