mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 00:24:53 +08:00
109 lines
3.1 KiB
Python
109 lines
3.1 KiB
Python
import torch
|
|
from .Module import Module
|
|
from .utils import clear
|
|
|
|
|
|
class CosineDistance(Module):
|
|
|
|
def __init__(self, ):
|
|
super(CosineDistance, self).__init__()
|
|
self.gradInput = [torch.Tensor(), torch.Tensor()]
|
|
|
|
self._input1 = None
|
|
self._input2 = None
|
|
self.buffer = None
|
|
self.w1 = None
|
|
self.w22 = None
|
|
self.w = None
|
|
self.w32 = None
|
|
self.ones = None
|
|
|
|
def _makeContiguous(self, input1, input2):
|
|
if not input1.is_contiguous():
|
|
if self._input1 is None:
|
|
self._input1 = input1.new()
|
|
self._input1.resize_as_(input1).copy_(input1)
|
|
input1 = self._input1
|
|
|
|
if not input2.is_contiguous():
|
|
if self._input2 is None:
|
|
self._input2 = input2.new()
|
|
self._input2.resize_as_(input2).copy_(input2)
|
|
input2 = self._input2
|
|
|
|
return input1, input2
|
|
|
|
def updateOutput(self, input):
|
|
input1, input2 = input[0], input[1]
|
|
input1, input2 = self._makeContiguous(input1, input2)
|
|
|
|
if self.buffer is None:
|
|
self.buffer = input1.new()
|
|
self.w1 = input1.new()
|
|
self.w22 = input1.new()
|
|
self.w = input1.new()
|
|
self.w32 = input1.new()
|
|
self.ones = input1.new()
|
|
|
|
torch.mul(input1, input2, out=self.buffer)
|
|
torch.sum(self.buffer, 1, out=self.w1)
|
|
|
|
epsilon = 1e-12
|
|
torch.mul(input1, input1, out=self.buffer)
|
|
torch.sum(self.buffer, 1, out=self.w22).add_(epsilon)
|
|
self.w22.reciprocal_()
|
|
self.w.resize_as_(self.w22).copy_(self.w22)
|
|
|
|
torch.mul(input2, input2, out=self.buffer)
|
|
torch.sum(self.buffer, 1, out=self.w32).add_(epsilon)
|
|
self.w32.reciprocal_()
|
|
self.w.mul_(self.w32)
|
|
self.w.sqrt_()
|
|
|
|
torch.mul(self.w1, self.w, out=self.output)
|
|
self.output.resize_(input1.size(0))
|
|
|
|
return self.output
|
|
|
|
def updateGradInput(self, input, gradOutput):
|
|
v1 = input[0]
|
|
v2 = input[1]
|
|
v1, v2 = self._makeContiguous(v1, v2)
|
|
|
|
if len(self.gradInput) != 2:
|
|
if self.gradInput[0] is None:
|
|
self.gradInput[0] = v1.new()
|
|
if self.gradInput[1] is None:
|
|
self.gradInput[1] = v1.new()
|
|
self.gradInput = self.gradInput[:2]
|
|
|
|
gw1 = self.gradInput[0]
|
|
gw2 = self.gradInput[1]
|
|
gw1.resize_as_(v1).copy_(v2)
|
|
gw2.resize_as_(v1).copy_(v1)
|
|
|
|
torch.mul(self.w1, self.w22, out=self.buffer)
|
|
gw1.addcmul_(-1, self.buffer.expand_as(v1), v1)
|
|
gw1.mul_(self.w.expand_as(v1))
|
|
|
|
torch.mul(self.w1, self.w32, out=self.buffer)
|
|
gw2.addcmul_(-1, self.buffer.expand_as(v1), v2)
|
|
gw2.mul_(self.w.expand_as(v1))
|
|
|
|
go = gradOutput.contiguous().view(-1, 1).expand_as(v1)
|
|
gw1.mul_(go)
|
|
gw2.mul_(go)
|
|
|
|
return self.gradInput
|
|
|
|
def clearState(self):
|
|
clear(self, [
|
|
'buffer',
|
|
'w1',
|
|
'w22',
|
|
'w',
|
|
'w32',
|
|
'ones',
|
|
])
|
|
return super(CosineDistance, self).clearState()
|