mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
import torch
|
|
from .Module import Module
|
|
|
|
|
|
class MM(Module):
|
|
|
|
def __init__(self, transA=False, transB=False):
|
|
super(MM, self).__init__()
|
|
self.transA = transA
|
|
self.transB = transB
|
|
self.gradInput = [torch.Tensor(), torch.Tensor()]
|
|
|
|
def updateOutput(self, input):
|
|
assert len(input) == 2
|
|
a, b = input
|
|
assert a.ndimension() == 2 or a.ndimension() == 3
|
|
assert a.dim() == b.dim()
|
|
|
|
if a.ndimension() == 2:
|
|
if self.transA:
|
|
a = a.t()
|
|
if self.transB:
|
|
b = b.t()
|
|
self.output.resize_(a.size(0), b.size(1))
|
|
torch.mm(a, b, out=self.output)
|
|
else:
|
|
if self.transA:
|
|
a = a.transpose(1, 2)
|
|
if self.transB:
|
|
b = b.transpose(1, 2)
|
|
|
|
self.output.resize_(a.size(0), a.size(1), b.size(2))
|
|
torch.bmm(a, b, out=self.output)
|
|
|
|
return self.output
|
|
|
|
def updateGradInput(self, input, gradOutput):
|
|
if self.gradInput[0] is None:
|
|
self.gradInput[0] = input[0].new()
|
|
if self.gradInput[1] is None:
|
|
self.gradInput[1] = input[1].new()
|
|
|
|
assert len(input) == 2
|
|
a, b = input
|
|
self.gradInput[0].resize_as_(a)
|
|
self.gradInput[1].resize_as_(b)
|
|
|
|
assert gradOutput.ndimension() == 2 or gradOutput.ndimension() == 3
|
|
assert a.dim() == b.dim() == gradOutput.dim()
|
|
|
|
if gradOutput.ndimension() == 2:
|
|
h_dim, w_dim = 0, 1
|
|
f = "mm"
|
|
else:
|
|
h_dim, w_dim = 1, 2
|
|
f = "bmm"
|
|
|
|
if self.transA == self.transB:
|
|
a = a.transpose(h_dim, w_dim)
|
|
b = b.transpose(h_dim, w_dim)
|
|
|
|
if self.transA:
|
|
getattr(torch, f)(b, gradOutput.transpose(h_dim, w_dim), out=self.gradInput[0])
|
|
else:
|
|
getattr(torch, f)(gradOutput, b, out=self.gradInput[0])
|
|
|
|
if self.transB:
|
|
getattr(torch, f)(gradOutput.transpose(h_dim, w_dim), a, out=self.gradInput[1])
|
|
else:
|
|
getattr(torch, f)(a, gradOutput, out=self.gradInput[1])
|
|
|
|
return self.gradInput
|