mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pep8] Fix most lint automatically with autopep8
Here's the command I used to invoke autopep8 (in parallel!): git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i Several rules are ignored in setup.cfg. The goal is to let autopep8 handle everything which it can handle safely, and to disable any rules which are tricky or controversial to address. We may want to come back and re-enable some of these rules later, but I'm trying to make this patch as safe as possible. Also configures flake8 to match pep8's behavior. Also configures TravisCI to check the whole project for lint.
This commit is contained in:
@ -44,4 +44,4 @@ matrix:
|
||||
python: "2.7"
|
||||
addons: true
|
||||
install: pip install pep8
|
||||
script: pep8 setup.py
|
||||
script: pep8
|
||||
|
@ -201,12 +201,13 @@ from docutils import nodes
|
||||
from sphinx.util.docfields import TypedField
|
||||
from sphinx import addnodes
|
||||
|
||||
|
||||
def patched_make_field(self, types, domain, items):
|
||||
# type: (List, unicode, Tuple) -> nodes.field
|
||||
def handle_item(fieldarg, content):
|
||||
par = nodes.paragraph()
|
||||
par += addnodes.literal_strong('', fieldarg) # Patch: this line added
|
||||
#par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
|
||||
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
|
||||
# addnodes.literal_strong))
|
||||
if fieldarg in types:
|
||||
par += nodes.Text(' (')
|
||||
|
@ -1,2 +1,7 @@
|
||||
[pep8]
|
||||
max-line-length = 120
|
||||
ignore = E402,E721,E731
|
||||
|
||||
[flake8]
|
||||
max-line-length = 120
|
||||
ignore = E305,E402,E721,E731,F401,F403,F405,F811,F812,F821,F841
|
||||
|
@ -12,6 +12,7 @@ from torch.autograd import Variable, Function
|
||||
|
||||
torch.set_default_tensor_type('torch.DoubleTensor')
|
||||
|
||||
|
||||
def run_tests():
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument('--seed', type=int, default=123)
|
||||
@ -29,6 +30,7 @@ try:
|
||||
except ImportError:
|
||||
TEST_NUMPY = False
|
||||
|
||||
|
||||
def get_cpu_type(t):
|
||||
assert t.__module__ == 'torch.cuda'
|
||||
return getattr(torch, t.__class__.__name__)
|
||||
@ -155,7 +157,7 @@ def make_jacobian(input, num_out):
|
||||
return torch.zeros(input.nelement(), num_out)
|
||||
else:
|
||||
return type(input)(filter(lambda x: x is not None,
|
||||
(make_jacobian(elem, num_out) for elem in input)))
|
||||
(make_jacobian(elem, num_out) for elem in input)))
|
||||
|
||||
|
||||
def iter_tensors(x, only_requiring_grad=False):
|
||||
@ -206,7 +208,7 @@ def get_numerical_jacobian(fn, input, target):
|
||||
outb.copy_(fn(input))
|
||||
flat_tensor[i] = orig
|
||||
|
||||
outb.add_(-1,outa).div_(2*perturbation)
|
||||
outb.add_(-1, outa).div_(2 * perturbation)
|
||||
d_tensor[i] = outb
|
||||
|
||||
return jacobian
|
||||
|
@ -25,14 +25,14 @@ module_tests = [
|
||||
module_name='Linear',
|
||||
constructor_args=(10, 8),
|
||||
input_size=(4, 10),
|
||||
reference_fn=lambda i,p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)
|
||||
reference_fn=lambda i, p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)
|
||||
),
|
||||
dict(
|
||||
module_name='Linear',
|
||||
constructor_args=(10, 8, False),
|
||||
input_size=(4, 10),
|
||||
desc='no_bias',
|
||||
reference_fn=lambda i,p: torch.mm(i, p[0].t())
|
||||
reference_fn=lambda i, p: torch.mm(i, p[0].t())
|
||||
),
|
||||
dict(
|
||||
module_name='Threshold',
|
||||
@ -72,7 +72,7 @@ module_tests = [
|
||||
dict(
|
||||
module_name='Hardtanh',
|
||||
input_size=(3, 2, 5),
|
||||
reference_fn=lambda i,_: i.clamp(-1, 1)
|
||||
reference_fn=lambda i, _: i.clamp(-1, 1)
|
||||
),
|
||||
dict(
|
||||
module_name='Sigmoid',
|
||||
@ -85,22 +85,22 @@ module_tests = [
|
||||
dict(
|
||||
module_name='Softmax',
|
||||
input_size=(10, 20),
|
||||
reference_fn=lambda i,_: torch.exp(i).div(torch.exp(i).sum(1).expand(10, 20))
|
||||
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1).expand(10, 20))
|
||||
),
|
||||
dict(
|
||||
module_name='Softmax2d',
|
||||
input_size=(1, 3, 10, 20),
|
||||
reference_fn=lambda i,_: torch.exp(i).div(torch.exp(i).sum(1).expand_as(i))
|
||||
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1).expand_as(i))
|
||||
),
|
||||
dict(
|
||||
module_name='LogSoftmax',
|
||||
input_size=(10, 20),
|
||||
reference_fn=lambda i,_: torch.exp(i).div_(torch.exp(i).sum(1).expand(10, 20)).log_()
|
||||
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1).expand(10, 20)).log_()
|
||||
),
|
||||
dict(
|
||||
module_name='LogSoftmax',
|
||||
input_size=(1, 3, 10, 20),
|
||||
reference_fn=lambda i,_: torch.exp(i).div_(torch.exp(i).sum(1).expand_as(i)).log_(),
|
||||
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1).expand_as(i)).log_(),
|
||||
desc='multiparam'
|
||||
),
|
||||
dict(
|
||||
@ -130,18 +130,18 @@ module_tests = [
|
||||
dict(
|
||||
module_name='LogSigmoid',
|
||||
input_size=(2, 3, 4),
|
||||
reference_fn=lambda i,_: i.sigmoid().log()
|
||||
reference_fn=lambda i, _: i.sigmoid().log()
|
||||
),
|
||||
dict(
|
||||
module_name='Softplus',
|
||||
input_size=(10, 20),
|
||||
reference_fn=lambda i,_: torch.log(1 + torch.exp(i))
|
||||
reference_fn=lambda i, _: torch.log(1 + torch.exp(i))
|
||||
),
|
||||
dict(
|
||||
module_name='Softplus',
|
||||
constructor_args=(2,),
|
||||
input_size=(10, 20),
|
||||
reference_fn=lambda i,_: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
|
||||
reference_fn=lambda i, _: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
|
||||
desc='beta'
|
||||
),
|
||||
dict(
|
||||
@ -172,7 +172,7 @@ module_tests = [
|
||||
dict(
|
||||
module_name='Softsign',
|
||||
input_size=(3, 2, 5),
|
||||
reference_fn=lambda i,_: i.div(1 + torch.abs(i))
|
||||
reference_fn=lambda i, _: i.div(1 + torch.abs(i))
|
||||
),
|
||||
dict(
|
||||
module_name='Softmin',
|
||||
@ -187,11 +187,11 @@ module_tests = [
|
||||
|
||||
criterion_tests = [
|
||||
dict(module_name='L1Loss',
|
||||
input_size=(2, 3, 4),
|
||||
target=torch.randn(2, 3, 4),
|
||||
reference_fn=lambda i,t,_: 1./i.numel() * \
|
||||
sum((a-b).abs().sum() for a,b in zip(i, t))
|
||||
),
|
||||
input_size=(2, 3, 4),
|
||||
target=torch.randn(2, 3, 4),
|
||||
reference_fn=lambda i, t, _: 1. / i.numel() *
|
||||
sum((a - b).abs().sum() for a, b in zip(i, t))
|
||||
),
|
||||
dict(
|
||||
module_name='NLLLoss',
|
||||
input=torch.rand(15, 10).log(),
|
||||
@ -213,7 +213,7 @@ criterion_tests = [
|
||||
module_name='MSELoss',
|
||||
input=torch.randn(2, 3, 4, 5),
|
||||
target=torch.randn(2, 3, 4, 5),
|
||||
reference_fn=lambda i,t,_: (i-t).abs().pow(2).sum() / i.numel()
|
||||
reference_fn=lambda i, t, _: (i - t).abs().pow(2).sum() / i.numel()
|
||||
),
|
||||
dict(
|
||||
module_name='BCELoss',
|
||||
@ -370,9 +370,9 @@ class NNTestCase(TestCase):
|
||||
|
||||
if jacobian_input:
|
||||
for jacobian_x, d_x in zip(flat_jacobian_input, iter_tensors(d_input)):
|
||||
jacobian_x[:,i] = d_x
|
||||
jacobian_x[:, i] = d_x
|
||||
if jacobian_parameters:
|
||||
jacobian_param[:,i] = torch.cat(self._flatten_tensors(d_param), 0)
|
||||
jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
|
||||
|
||||
res = tuple()
|
||||
if jacobian_input:
|
||||
@ -433,7 +433,7 @@ class NNTestCase(TestCase):
|
||||
fx1 = self._forward_criterion(criterion, input, target)
|
||||
x[i] = original - eps
|
||||
fx2 = self._forward_criterion(criterion, input, target)
|
||||
deriv = (fx1 - fx2) / (2.*eps)
|
||||
deriv = (fx1 - fx2) / (2. * eps)
|
||||
d_x[i] = deriv
|
||||
x[i] = original
|
||||
|
||||
@ -447,8 +447,9 @@ class NNTestCase(TestCase):
|
||||
|
||||
|
||||
class TestBase(object):
|
||||
|
||||
def __init__(self, constructor, constructor_args=tuple(), input_size=None,
|
||||
input=None, desc='', reference_fn=None, fullname=None, **kwargs):
|
||||
input=None, desc='', reference_fn=None, fullname=None, **kwargs):
|
||||
if input_size is None and input is None:
|
||||
raise RuntimeError("Specify either an input tensor, or it's size!")
|
||||
self.constructor = constructor
|
||||
@ -496,6 +497,7 @@ class TestBase(object):
|
||||
|
||||
|
||||
class ModuleTest(TestBase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModuleTest, self).__init__(*args, **kwargs)
|
||||
self.jacobian_input = kwargs.get('jacobian_input', True)
|
||||
@ -568,6 +570,7 @@ class ModuleTest(TestBase):
|
||||
|
||||
|
||||
class CriterionTest(TestBase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CriterionTest, self).__init__(*args, **kwargs)
|
||||
self.target = self._get_target(kwargs['target'])
|
||||
@ -590,7 +593,7 @@ class CriterionTest(TestBase):
|
||||
if isinstance(target, Variable):
|
||||
target = target.data
|
||||
expected_out = self.reference_fn(deepcopy(self._unpack_input(input)),
|
||||
deepcopy(target), module)
|
||||
deepcopy(target), module)
|
||||
test_case.assertEqual(out, expected_out)
|
||||
|
||||
test_case.check_criterion_jacobian(module, input, self.target)
|
||||
|
@ -2,6 +2,7 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.linear = nn.Linear(10, 20)
|
||||
|
@ -2,6 +2,7 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.linear = nn.Linear(10, 20)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
def check_error(desc, fn, *required_substrings):
|
||||
try:
|
||||
fn()
|
||||
@ -16,54 +17,55 @@ def check_error(desc, fn, *required_substrings):
|
||||
assert False, "given function ({}) didn't raise an error".format(desc)
|
||||
|
||||
check_error(
|
||||
'Wrong argument types',
|
||||
lambda: torch.FloatStorage(object()),
|
||||
'object')
|
||||
'Wrong argument types',
|
||||
lambda: torch.FloatStorage(object()),
|
||||
'object')
|
||||
|
||||
check_error('Unknown keyword argument',
|
||||
lambda: torch.FloatStorage(content=1234.),
|
||||
'keyword')
|
||||
lambda: torch.FloatStorage(content=1234.),
|
||||
'keyword')
|
||||
|
||||
check_error('Invalid types inside a sequence',
|
||||
lambda: torch.FloatStorage(['a', 'b']),
|
||||
'list', 'str')
|
||||
lambda: torch.FloatStorage(['a', 'b']),
|
||||
'list', 'str')
|
||||
|
||||
check_error('Invalid size type',
|
||||
lambda: torch.FloatStorage(1.5),
|
||||
'float')
|
||||
lambda: torch.FloatStorage(1.5),
|
||||
'float')
|
||||
|
||||
check_error('Invalid offset',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), 4),
|
||||
'2', '4')
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), 4),
|
||||
'2', '4')
|
||||
|
||||
check_error('Negative offset',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), -1),
|
||||
'2', '-1')
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(2), -1),
|
||||
'2', '-1')
|
||||
|
||||
check_error('Invalid size',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
|
||||
'2', '1', '5')
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
|
||||
'2', '1', '5')
|
||||
|
||||
check_error('Negative size',
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
|
||||
'2', '1', '-5')
|
||||
lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
|
||||
'2', '1', '-5')
|
||||
|
||||
check_error('Invalid index type',
|
||||
lambda: torch.FloatStorage(10)['first item'],
|
||||
'str')
|
||||
lambda: torch.FloatStorage(10)['first item'],
|
||||
'str')
|
||||
|
||||
|
||||
def assign():
|
||||
torch.FloatStorage(10)[1:-1] = '1'
|
||||
check_error('Invalid value type',
|
||||
assign,
|
||||
'str')
|
||||
assign,
|
||||
'str')
|
||||
|
||||
check_error('resize_ with invalid type',
|
||||
lambda: torch.FloatStorage(10).resize_(1.5),
|
||||
'float')
|
||||
lambda: torch.FloatStorage(10).resize_(1.5),
|
||||
'float')
|
||||
|
||||
check_error('fill_ with invalid type',
|
||||
lambda: torch.IntStorage(10).fill_('asdf'),
|
||||
'str')
|
||||
lambda: torch.IntStorage(10).fill_('asdf'),
|
||||
'str')
|
||||
|
||||
# TODO: frombuffer
|
||||
|
@ -3,10 +3,12 @@ import torch
|
||||
import torch.legacy.optim as optim
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
def rosenbrock(tensor):
|
||||
x, y = tensor
|
||||
return (1 - x)**2 + 100 * (y - x**2)**2
|
||||
|
||||
|
||||
def drosenbrock(tensor):
|
||||
x, y = tensor
|
||||
return torch.DoubleTensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * x * (y - x**2)))
|
||||
|
@ -8,7 +8,7 @@ from copy import deepcopy
|
||||
from collections import OrderedDict
|
||||
|
||||
from common import make_jacobian, TestCase, iter_tensors, \
|
||||
get_numerical_jacobian, run_tests
|
||||
get_numerical_jacobian, run_tests
|
||||
from torch.autograd._functions import *
|
||||
from torch.autograd import Variable, Function
|
||||
|
||||
@ -46,7 +46,7 @@ def get_analytical_jacobian(input, output):
|
||||
zero_gradients(input)
|
||||
output.backward(grad_output, retain_variables=True)
|
||||
for jacobian_x, d_x in zip(jacobian, iter_gradients(input)):
|
||||
jacobian_x[:,i] = d_x
|
||||
jacobian_x[:, i] = d_x
|
||||
|
||||
return jacobian
|
||||
|
||||
@ -68,6 +68,7 @@ class TestAutograd(TestCase):
|
||||
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
|
||||
|
||||
counter = [0]
|
||||
|
||||
def bw_hook(inc, grad):
|
||||
self.assertIsInstance(grad, Variable)
|
||||
counter[0] += inc
|
||||
@ -103,6 +104,7 @@ class TestAutograd(TestCase):
|
||||
# WARNING: this is a test for autograd internals.
|
||||
# You should never have to use such things in your code.
|
||||
class NoneGradientFunction(Function):
|
||||
|
||||
def forward(self, x, y):
|
||||
assert self.needs_input_grad[0]
|
||||
assert not self.needs_input_grad[1]
|
||||
@ -114,6 +116,7 @@ class TestAutograd(TestCase):
|
||||
fn = NoneGradientFunction()
|
||||
fn._backward_hooks = OrderedDict()
|
||||
was_called = [False]
|
||||
|
||||
def hook(grad_input, grad_output):
|
||||
self.assertIsInstance(grad_input, tuple)
|
||||
self.assertIsInstance(grad_output, tuple)
|
||||
@ -242,6 +245,7 @@ class TestAutograd(TestCase):
|
||||
self.assertFalse(a.requires_grad)
|
||||
b = a + z
|
||||
self.assertTrue(b.requires_grad)
|
||||
|
||||
def error():
|
||||
raise RuntimeError
|
||||
# Make sure backward isn't called on these
|
||||
@ -379,6 +383,7 @@ class TestAutograd(TestCase):
|
||||
segfault.
|
||||
"""
|
||||
class CollectOnDelete(Function):
|
||||
|
||||
def __del__(self):
|
||||
gc.collect()
|
||||
|
||||
@ -386,7 +391,7 @@ class TestAutograd(TestCase):
|
||||
Variable(torch.randn(10, 10), creator=CollectOnDelete())
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available() or torch.cuda.device_count() < 2,
|
||||
"CUDA not available or <2 GPUs detected")
|
||||
"CUDA not available or <2 GPUs detected")
|
||||
def test_unused_output_gpu(self):
|
||||
from torch.nn.parallel._functions import Broadcast
|
||||
x = Variable(torch.randn(5, 5).float().cuda(), requires_grad=True)
|
||||
@ -436,6 +441,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
def test_return_leaf(self):
|
||||
class Identity(Function):
|
||||
|
||||
def forward(self, a, b):
|
||||
return a, a + b
|
||||
|
||||
@ -443,6 +449,7 @@ class TestAutograd(TestCase):
|
||||
return grad_a + grad_b, grad_b
|
||||
|
||||
class Inplace(InplaceFunction):
|
||||
|
||||
def forward(self, a, b):
|
||||
self.mark_dirty(a)
|
||||
return a.add_(b), b + 2
|
||||
@ -464,6 +471,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
def test_return_leaf_inplace(self):
|
||||
class Inplace(InplaceFunction):
|
||||
|
||||
def forward(self, a, b):
|
||||
self.mark_dirty(a)
|
||||
return a.add_(b), b + 2
|
||||
@ -496,51 +504,51 @@ class TestAutograd(TestCase):
|
||||
self.assertEqual(z.grad.data, torch.ones(5) * 2)
|
||||
|
||||
def test_backward_copy(self):
|
||||
# This tests checks backward engine for a very subtle bug that appreared
|
||||
# in one of the initial versions of autograd. Gradients tensors were
|
||||
# simply stored in lists while the function waited for all its gradients
|
||||
# to be computed. However, sometimes an output was used multiple times,
|
||||
# so the gradients needed to be summed. Engine used to keep a need_copy
|
||||
# set of tensors that will need a clone upon next addition and removed
|
||||
# them from the set as soon as the clone was performed. However, this
|
||||
# could lead to incorrect results if the same gradient tensor was
|
||||
# buffered in three places in the graph:
|
||||
# 1. When accumulating gradients in one of these places it was cloned
|
||||
# and removed from need_copy set.
|
||||
# 2. When accumulating in second place, it wasn't in the need_copy set,
|
||||
# so the gradients were simply accumulated in-place (which already
|
||||
# modified the grad in 3rd place)
|
||||
# 3. When accumulating in the third place, it wasn't in the need_copy set
|
||||
# as well, so the incoming gradient was summed in-place, yielding
|
||||
# incorrect results in all functions, except the first one.
|
||||
x = Variable(torch.ones(5, 5), requires_grad=True)
|
||||
y = Variable(torch.ones(5, 5), requires_grad=True)
|
||||
# Simulate that we're in the middle of the graph
|
||||
a = x + 2
|
||||
b = y + 2
|
||||
c = x + 2
|
||||
# This op will just return grad_output two times in backward
|
||||
add1 = a + b
|
||||
add2 = add1 + c
|
||||
# Simulate a long branch, so grad_output will get buffered.
|
||||
for i in range(4):
|
||||
a = a * 2
|
||||
b = b * 2
|
||||
c = c * 2
|
||||
branch = a + b + c
|
||||
out = add2 + branch
|
||||
# expected gradients are:
|
||||
# for x: 34 (16 from final a, 16 from final c, 2 from add2)
|
||||
# for y: 17 (16 from final b, 1 from add2)
|
||||
grad_output = torch.ones(5, 5)
|
||||
out.backward(grad_output)
|
||||
self.assertEqual(x.grad.data, torch.ones(5, 5) * 34)
|
||||
self.assertEqual(y.grad.data, torch.ones(5, 5) * 17)
|
||||
# This tests checks backward engine for a very subtle bug that appreared
|
||||
# in one of the initial versions of autograd. Gradients tensors were
|
||||
# simply stored in lists while the function waited for all its gradients
|
||||
# to be computed. However, sometimes an output was used multiple times,
|
||||
# so the gradients needed to be summed. Engine used to keep a need_copy
|
||||
# set of tensors that will need a clone upon next addition and removed
|
||||
# them from the set as soon as the clone was performed. However, this
|
||||
# could lead to incorrect results if the same gradient tensor was
|
||||
# buffered in three places in the graph:
|
||||
# 1. When accumulating gradients in one of these places it was cloned
|
||||
# and removed from need_copy set.
|
||||
# 2. When accumulating in second place, it wasn't in the need_copy set,
|
||||
# so the gradients were simply accumulated in-place (which already
|
||||
# modified the grad in 3rd place)
|
||||
# 3. When accumulating in the third place, it wasn't in the need_copy set
|
||||
# as well, so the incoming gradient was summed in-place, yielding
|
||||
# incorrect results in all functions, except the first one.
|
||||
x = Variable(torch.ones(5, 5), requires_grad=True)
|
||||
y = Variable(torch.ones(5, 5), requires_grad=True)
|
||||
# Simulate that we're in the middle of the graph
|
||||
a = x + 2
|
||||
b = y + 2
|
||||
c = x + 2
|
||||
# This op will just return grad_output two times in backward
|
||||
add1 = a + b
|
||||
add2 = add1 + c
|
||||
# Simulate a long branch, so grad_output will get buffered.
|
||||
for i in range(4):
|
||||
a = a * 2
|
||||
b = b * 2
|
||||
c = c * 2
|
||||
branch = a + b + c
|
||||
out = add2 + branch
|
||||
# expected gradients are:
|
||||
# for x: 34 (16 from final a, 16 from final c, 2 from add2)
|
||||
# for y: 17 (16 from final b, 1 from add2)
|
||||
grad_output = torch.ones(5, 5)
|
||||
out.backward(grad_output)
|
||||
self.assertEqual(x.grad.data, torch.ones(5, 5) * 34)
|
||||
self.assertEqual(y.grad.data, torch.ones(5, 5) * 17)
|
||||
|
||||
def test_functional_blas(self):
|
||||
def compare(fn, *args):
|
||||
unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg
|
||||
for arg in args)
|
||||
for arg in args)
|
||||
self.assertEqual(fn(*args).data, fn(*unpacked_args))
|
||||
|
||||
def test_blas_add(fn, x, y, z):
|
||||
@ -553,27 +561,29 @@ class TestAutograd(TestCase):
|
||||
compare(fn, x, y)
|
||||
|
||||
test_blas(torch.mm, Variable(torch.randn(2, 10)),
|
||||
Variable(torch.randn(10, 4)))
|
||||
Variable(torch.randn(10, 4)))
|
||||
test_blas_add(torch.addmm, Variable(torch.randn(2, 4)),
|
||||
Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
|
||||
Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
|
||||
test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)),
|
||||
Variable(torch.randn(4, 10, 4)))
|
||||
Variable(torch.randn(4, 10, 4)))
|
||||
test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)),
|
||||
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
|
||||
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
|
||||
test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)),
|
||||
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
|
||||
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
|
||||
test_blas(torch.mv, Variable(torch.randn(2, 10)),
|
||||
Variable(torch.randn(10)))
|
||||
Variable(torch.randn(10)))
|
||||
test_blas_add(torch.addmv, Variable(torch.randn(2)),
|
||||
Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
|
||||
Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
|
||||
test_blas(torch.ger, Variable(torch.randn(5)),
|
||||
Variable(torch.randn(6)))
|
||||
Variable(torch.randn(6)))
|
||||
test_blas_add(torch.addr, Variable(torch.randn(5, 6)),
|
||||
Variable(torch.randn(5)), Variable(torch.randn(6)))
|
||||
Variable(torch.randn(5)), Variable(torch.randn(6)))
|
||||
|
||||
def test_save_none_for_backward(self):
|
||||
test_case = self
|
||||
|
||||
class MyFn(Function):
|
||||
|
||||
def forward(self, input):
|
||||
self.save_for_backward(None, input, None)
|
||||
return input * input
|
||||
@ -591,6 +601,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
def test_too_many_grads(self):
|
||||
class MyFn(Function):
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
@ -679,6 +690,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
def test_dep_nograd(self):
|
||||
class F1(Function):
|
||||
|
||||
def forward(self, input):
|
||||
out = torch.randn(input.size())
|
||||
self.mark_non_differentiable(out)
|
||||
@ -688,6 +700,7 @@ class TestAutograd(TestCase):
|
||||
return grad_output
|
||||
|
||||
class F2(Function):
|
||||
|
||||
def forward(self, input, ignored):
|
||||
return input
|
||||
|
||||
@ -710,6 +723,7 @@ def index_variable(shape, max_indices):
|
||||
index = torch.rand(*shape).mul_(max_indices).floor_().long()
|
||||
return Variable(index, requires_grad=False)
|
||||
|
||||
|
||||
def gather_variable(shape, index_dim, max_indices):
|
||||
assert len(shape) == 2
|
||||
assert index_dim < 2
|
||||
@ -717,7 +731,7 @@ def gather_variable(shape, index_dim, max_indices):
|
||||
index = torch.LongTensor(*shape)
|
||||
for i in range(shape[index_dim]):
|
||||
index.select(index_dim, i).copy_(
|
||||
torch.randperm(max_indices)[:shape[batch_dim]])
|
||||
torch.randperm(max_indices)[:shape[batch_dim]])
|
||||
return Variable(index, requires_grad=False)
|
||||
|
||||
|
||||
@ -725,215 +739,215 @@ L = 20
|
||||
M = 10
|
||||
S = 5
|
||||
function_tests = [
|
||||
(Add, (), ((M, M), (M, M)) ),
|
||||
(Sub, (), ((M, M), (M, M)) ),
|
||||
(Mul, (), ((M, M), (M, M)) ),
|
||||
(Div, (), ((M, M), torch.rand(M, M) + 5e-2) ),
|
||||
(Pow, (), (torch.rand(M, M) + 1e-3, torch.rand(M, M) + 0.1)),
|
||||
(AddConstant, (3.14,), ((L, L),) ),
|
||||
(SubConstant, (3.14,), ((L, L),) ),
|
||||
(SubConstant, (3.14, True), ((L, L),), 'from_tensor' ),
|
||||
(MulConstant, (3.14,), ((L, L),) ),
|
||||
(DivConstant, (3.14, True), (torch.rand(L, L) + 1e-1,), 'by_tensor' ),
|
||||
(PowConstant, (3.14,), (torch.rand(L, L),) ),
|
||||
(PowConstant, (3.14, True), (torch.rand(L, L),), 'tensor_power' ),
|
||||
(Transpose, (0, 1), (torch.rand(L, L),) ),
|
||||
(Transpose, (2, 0), (torch.rand(S, S, S),), '3d' ),
|
||||
(Permute, ((0, 4, 3, 5, 1, 2),), ((1, 2, 3, 4, 5, 6),) ),
|
||||
(Index, ((1, 2),), (torch.rand(S, S, S),) ),
|
||||
(Index, (slice(0, 3),), (torch.rand(S, S, S),), 'slice' ),
|
||||
(Index, ((slice(0, 3), 1),),(torch.rand(S, S, S),), 'slice_index' ),
|
||||
(View, (S*S, S), (torch.rand(S, S, S),) ),
|
||||
(Expand, ((S, 5, S, 5),), ((S, 1, S, 1),) ),
|
||||
(Exp, (), (torch.rand(S, S, S),) ),
|
||||
(Log, (), (torch.rand(S, S, S) + 1e-2,) ),
|
||||
(Log1p, (), (torch.rand(S, S, S),) ),
|
||||
(Tanh, (), ((S, S, S),) ),
|
||||
(Sigmoid, (), ((S, S, S),) ),
|
||||
(Sinh, (), ((S, S, S),) ),
|
||||
(Cosh, (), ((S, S, S),) ),
|
||||
(Abs, (), ((S, S, S),) ),
|
||||
(Clamp, (0, 1), ((S, S, S),) ),
|
||||
(Sqrt, (), (torch.rand(S, S, S) + 5e-4,) ),
|
||||
(Sin, (), ((S, S, S),) ),
|
||||
(Cos, (), ((S, S, S),) ),
|
||||
(Tan, (), (torch.randn(S, S, S).clamp(-1, 1),) ),
|
||||
(Asin, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),) ),
|
||||
(Acos, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),) ),
|
||||
(Atan, (), ((S, S, S),) ),
|
||||
(Reciprocal, (), (torch.rand(S, S, S) + 0.1,) ),
|
||||
(Cmax, (), ((S, S, S), (S, S, S)) ),
|
||||
(Cmin, (), ((S, S, S), (S, S, S)) ),
|
||||
(Round, (), ((S, S, S),) ),
|
||||
(Sign, (), ((S, S, S),) ),
|
||||
(Trunc, (), ((S, S, S),) ),
|
||||
(Floor, (), ((S, S, S),) ),
|
||||
(Ceil, (), ((S, S, S),) ),
|
||||
(Frac, (), ((S, S, S),) ),
|
||||
(Fmod, (1.5,), ((S, S, S),) ),
|
||||
(Lerp, (0.2,), ((S, S, S), (S, S, S)) ),
|
||||
(Rsqrt, (), (torch.rand(S, S, S) + 1e-2,) ),
|
||||
(Remainder, (1.5,), ((S, S, S),) ),
|
||||
(CmaxConstant, (0.5,), ((S, S, S),) ),
|
||||
(CminConstant, (0.5,), ((S, S, S),) ),
|
||||
(Mean, (), ((S, S, S),) ),
|
||||
(Mean, (1,), ((S, S, S),), 'dim' ),
|
||||
(Sum, (), ((S, S, S),) ),
|
||||
(Sum, (1,), ((S, S, S),), 'dim' ),
|
||||
(Prod, (), ((S, S, S),) ),
|
||||
(Prod, (1,), ((S, S, S),), 'dim' ),
|
||||
(Addmm, (), ((S, M), (S, S), (S, M)), ),
|
||||
(Addmm, (0.1, 1), ((S, M), (S, S), (S, M)), 'coef' ),
|
||||
(Addbmm, (), ((S, M), (S, S, S), (S, S, M)), ),
|
||||
(Addbmm, (0.1, 0.4), ((S, M), (S, S, S), (S, S, M)), 'coef' ),
|
||||
(Baddbmm, (), ((S, S, M), (S, S, S), (S, S, M)), ),
|
||||
(Baddbmm, (0.1, 0.4), ((S, S, M), (S, S, S), (S, S, M)), 'coef' ),
|
||||
(Addmv, (), ((S,), (S, M), (M,)), ),
|
||||
(Addmv, (0.1, 0.4), ((S,), (S, M), (M,)), 'coef' ),
|
||||
(Addr, (), ((S, M), (S,), (M,)), ),
|
||||
(Addr, (0.1, 0.4), ((S, M), (S,), (M,)), 'coef' ),
|
||||
(Dot, (), ((L,), (L,)), ),
|
||||
(Max, (), ((S, S, S),), ),
|
||||
(Min, (), ((S, S, S),), ),
|
||||
(Max, (0,), ((S, S, S),), 'dim' ),
|
||||
(Min, (0,), ((S, S, S),), 'dim' ),
|
||||
(Mode, (0,), ((S, S, S),), ),
|
||||
(Kthvalue, (2, 0), ((S, S, S),), ),
|
||||
(Median, (0,), ((S, S, S),), ),
|
||||
(Norm, (1.5,), (torch.rand(S, S, S),), '1_5' ),
|
||||
(Norm, (), ((S, S, S),), '2' ),
|
||||
(Norm, (3,), ((S, S, S),), '3' ),
|
||||
(Norm, (1.5, 0), (torch.rand(S, S, S),), '1_5_dim' ),
|
||||
(Norm, (2, 0), ((S, S, S),), '2_dim' ),
|
||||
(Norm, (3, 0), ((S, S, S),), '3_dim' ),
|
||||
(Addcmul, (), ((S, S), (S, S), (S, S)) ),
|
||||
(Addcmul, (0.6,), ((S, S), (S, S), (S, S)), 'scale' ),
|
||||
(Addcdiv, (), ((S, S), (S, S), torch.rand(S, S) + 1e-2) ),
|
||||
(Addcdiv, (0.6,), ((S, S), (S, S), torch.rand(S, S) + 1e-2), 'scale'),
|
||||
(IndexAdd, (0,), ((S, S), index_variable(2, S), (2, S)) ),
|
||||
(Add, (), ((M, M), (M, M))),
|
||||
(Sub, (), ((M, M), (M, M))),
|
||||
(Mul, (), ((M, M), (M, M))),
|
||||
(Div, (), ((M, M), torch.rand(M, M) + 5e-2)),
|
||||
(Pow, (), (torch.rand(M, M) + 1e-3, torch.rand(M, M) + 0.1)),
|
||||
(AddConstant, (3.14,), ((L, L),)),
|
||||
(SubConstant, (3.14,), ((L, L),)),
|
||||
(SubConstant, (3.14, True), ((L, L),), 'from_tensor'),
|
||||
(MulConstant, (3.14,), ((L, L),)),
|
||||
(DivConstant, (3.14, True), (torch.rand(L, L) + 1e-1,), 'by_tensor'),
|
||||
(PowConstant, (3.14,), (torch.rand(L, L),)),
|
||||
(PowConstant, (3.14, True), (torch.rand(L, L),), 'tensor_power'),
|
||||
(Transpose, (0, 1), (torch.rand(L, L),)),
|
||||
(Transpose, (2, 0), (torch.rand(S, S, S),), '3d'),
|
||||
(Permute, ((0, 4, 3, 5, 1, 2),), ((1, 2, 3, 4, 5, 6),)),
|
||||
(Index, ((1, 2),), (torch.rand(S, S, S),)),
|
||||
(Index, (slice(0, 3),), (torch.rand(S, S, S),), 'slice'),
|
||||
(Index, ((slice(0, 3), 1),), (torch.rand(S, S, S),), 'slice_index'),
|
||||
(View, (S * S, S), (torch.rand(S, S, S),)),
|
||||
(Expand, ((S, 5, S, 5),), ((S, 1, S, 1),)),
|
||||
(Exp, (), (torch.rand(S, S, S),)),
|
||||
(Log, (), (torch.rand(S, S, S) + 1e-2,)),
|
||||
(Log1p, (), (torch.rand(S, S, S),)),
|
||||
(Tanh, (), ((S, S, S),)),
|
||||
(Sigmoid, (), ((S, S, S),)),
|
||||
(Sinh, (), ((S, S, S),)),
|
||||
(Cosh, (), ((S, S, S),)),
|
||||
(Abs, (), ((S, S, S),)),
|
||||
(Clamp, (0, 1), ((S, S, S),)),
|
||||
(Sqrt, (), (torch.rand(S, S, S) + 5e-4,)),
|
||||
(Sin, (), ((S, S, S),)),
|
||||
(Cos, (), ((S, S, S),)),
|
||||
(Tan, (), (torch.randn(S, S, S).clamp(-1, 1),)),
|
||||
(Asin, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),)),
|
||||
(Acos, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),)),
|
||||
(Atan, (), ((S, S, S),)),
|
||||
(Reciprocal, (), (torch.rand(S, S, S) + 0.1,)),
|
||||
(Cmax, (), ((S, S, S), (S, S, S))),
|
||||
(Cmin, (), ((S, S, S), (S, S, S))),
|
||||
(Round, (), ((S, S, S),)),
|
||||
(Sign, (), ((S, S, S),)),
|
||||
(Trunc, (), ((S, S, S),)),
|
||||
(Floor, (), ((S, S, S),)),
|
||||
(Ceil, (), ((S, S, S),)),
|
||||
(Frac, (), ((S, S, S),)),
|
||||
(Fmod, (1.5,), ((S, S, S),)),
|
||||
(Lerp, (0.2,), ((S, S, S), (S, S, S))),
|
||||
(Rsqrt, (), (torch.rand(S, S, S) + 1e-2,)),
|
||||
(Remainder, (1.5,), ((S, S, S),)),
|
||||
(CmaxConstant, (0.5,), ((S, S, S),)),
|
||||
(CminConstant, (0.5,), ((S, S, S),)),
|
||||
(Mean, (), ((S, S, S),)),
|
||||
(Mean, (1,), ((S, S, S),), 'dim'),
|
||||
(Sum, (), ((S, S, S),)),
|
||||
(Sum, (1,), ((S, S, S),), 'dim'),
|
||||
(Prod, (), ((S, S, S),)),
|
||||
(Prod, (1,), ((S, S, S),), 'dim'),
|
||||
(Addmm, (), ((S, M), (S, S), (S, M)),),
|
||||
(Addmm, (0.1, 1), ((S, M), (S, S), (S, M)), 'coef'),
|
||||
(Addbmm, (), ((S, M), (S, S, S), (S, S, M)),),
|
||||
(Addbmm, (0.1, 0.4), ((S, M), (S, S, S), (S, S, M)), 'coef'),
|
||||
(Baddbmm, (), ((S, S, M), (S, S, S), (S, S, M)),),
|
||||
(Baddbmm, (0.1, 0.4), ((S, S, M), (S, S, S), (S, S, M)), 'coef'),
|
||||
(Addmv, (), ((S,), (S, M), (M,)),),
|
||||
(Addmv, (0.1, 0.4), ((S,), (S, M), (M,)), 'coef'),
|
||||
(Addr, (), ((S, M), (S,), (M,)),),
|
||||
(Addr, (0.1, 0.4), ((S, M), (S,), (M,)), 'coef'),
|
||||
(Dot, (), ((L,), (L,)),),
|
||||
(Max, (), ((S, S, S),),),
|
||||
(Min, (), ((S, S, S),),),
|
||||
(Max, (0,), ((S, S, S),), 'dim'),
|
||||
(Min, (0,), ((S, S, S),), 'dim'),
|
||||
(Mode, (0,), ((S, S, S),),),
|
||||
(Kthvalue, (2, 0), ((S, S, S),),),
|
||||
(Median, (0,), ((S, S, S),),),
|
||||
(Norm, (1.5,), (torch.rand(S, S, S),), '1_5'),
|
||||
(Norm, (), ((S, S, S),), '2'),
|
||||
(Norm, (3,), ((S, S, S),), '3'),
|
||||
(Norm, (1.5, 0), (torch.rand(S, S, S),), '1_5_dim'),
|
||||
(Norm, (2, 0), ((S, S, S),), '2_dim'),
|
||||
(Norm, (3, 0), ((S, S, S),), '3_dim'),
|
||||
(Addcmul, (), ((S, S), (S, S), (S, S))),
|
||||
(Addcmul, (0.6,), ((S, S), (S, S), (S, S)), 'scale'),
|
||||
(Addcdiv, (), ((S, S), (S, S), torch.rand(S, S) + 1e-2)),
|
||||
(Addcdiv, (0.6,), ((S, S), (S, S), torch.rand(S, S) + 1e-2), 'scale'),
|
||||
(IndexAdd, (0,), ((S, S), index_variable(2, S), (2, S))),
|
||||
# (IndexCopy, (0,), ((S, S), index_variable(2, S), (2, S)) ),
|
||||
(IndexFill, (0, 2), ((S, S), index_variable(2, S)) ),
|
||||
(IndexSelect, (0,), ((S, S), index_variable(2, S)) ),
|
||||
(Gather, (0,), ((M, S), gather_variable((S, S), 1, M)) ),
|
||||
(Gather, (1,), ((M, S), gather_variable((M, S//2), 0, S)), 'dim1'),
|
||||
(Scatter, (0,), ((M, S), gather_variable((S, S), 1, M), (S, S))),
|
||||
(Scatter, (1,), ((M, S), gather_variable((M, S//2), 0, S), (M, S//2)), 'dim1'),
|
||||
(Concat, (0,), ((1, S, S), (2, S, S), (3, S, S)) ),
|
||||
(Resize, (S*S, S), ((S, S, S),) ),
|
||||
(Diag, (), ((S, S),), '2d' ),
|
||||
(Diag, (), ((S,),), '1d' ),
|
||||
(Tril, (), ((S, S),) ),
|
||||
(Tril, (2,), ((S, S),), 'idx' ),
|
||||
(Triu, (), ((S, S),) ),
|
||||
(Triu, (2,), ((S, S),), 'idx' ),
|
||||
(Clone, (), ((S, M, S),) ),
|
||||
(Squeeze, (), ((S, 1, M, 1),) ),
|
||||
(Squeeze, (1,), ((S, 1, M, 1),), 'dim' ),
|
||||
(Unsqueeze, (0,), ((S, M, S),), '0' ),
|
||||
(Unsqueeze, (1,), ((S, M, S),), '1' ),
|
||||
(IndexFill, (0, 2), ((S, S), index_variable(2, S))),
|
||||
(IndexSelect, (0,), ((S, S), index_variable(2, S))),
|
||||
(Gather, (0,), ((M, S), gather_variable((S, S), 1, M))),
|
||||
(Gather, (1,), ((M, S), gather_variable((M, S // 2), 0, S)), 'dim1'),
|
||||
(Scatter, (0,), ((M, S), gather_variable((S, S), 1, M), (S, S))),
|
||||
(Scatter, (1,), ((M, S), gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1'),
|
||||
(Concat, (0,), ((1, S, S), (2, S, S), (3, S, S))),
|
||||
(Resize, (S * S, S), ((S, S, S),)),
|
||||
(Diag, (), ((S, S),), '2d'),
|
||||
(Diag, (), ((S,),), '1d'),
|
||||
(Tril, (), ((S, S),)),
|
||||
(Tril, (2,), ((S, S),), 'idx'),
|
||||
(Triu, (), ((S, S),)),
|
||||
(Triu, (2,), ((S, S),), 'idx'),
|
||||
(Clone, (), ((S, M, S),)),
|
||||
(Squeeze, (), ((S, 1, M, 1),)),
|
||||
(Squeeze, (1,), ((S, 1, M, 1),), 'dim'),
|
||||
(Unsqueeze, (0,), ((S, M, S),), '0'),
|
||||
(Unsqueeze, (1,), ((S, M, S),), '1'),
|
||||
# (MaskedCopy, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False), (S, S),)),
|
||||
(MaskedFill, (10,), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
|
||||
(MaskedSelect, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
|
||||
(Sort, (), ((S, M, S),) ),
|
||||
(Sort, (1,), ((S, M, S),), 'dim' ),
|
||||
(Sort, (1, True), ((S, M, S),), 'dim_desc' ),
|
||||
(Topk, (3,), ((S, M, S),) ),
|
||||
(Topk, (3, 1), ((S, M, S),), 'dim' ),
|
||||
(Topk, (3, 1, True), ((S, M, S),), 'dim_desc' ),
|
||||
(Topk, (3, 1, True, True), ((S, M, S),), 'dim_desc_sort' ),
|
||||
(MaskedFill, (10,), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
|
||||
(MaskedSelect, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
|
||||
(Sort, (), ((S, M, S),)),
|
||||
(Sort, (1,), ((S, M, S),), 'dim'),
|
||||
(Sort, (1, True), ((S, M, S),), 'dim_desc'),
|
||||
(Topk, (3,), ((S, M, S),)),
|
||||
(Topk, (3, 1), ((S, M, S),), 'dim'),
|
||||
(Topk, (3, 1, True), ((S, M, S),), 'dim_desc'),
|
||||
(Topk, (3, 1, True, True), ((S, M, S),), 'dim_desc_sort'),
|
||||
]
|
||||
|
||||
|
||||
method_tests = [
|
||||
('add', (S, S, S), ((S, S, S),) ),
|
||||
('add', (S, S, S), (3.14,), 'constant' ),
|
||||
('sub', (S, S, S), ((S, S, S),) ),
|
||||
('sub', (S, S, S), (3.14,), 'constant' ),
|
||||
('mul', (S, S, S), ((S, S, S),) ),
|
||||
('mul', (S, S, S), (3.14,), 'constant' ),
|
||||
('div', (S, S, S), ((S, S, S),) ),
|
||||
('div', (S, S, S), (3.14,), 'constant' ),
|
||||
('pow', (S, S, S), ((S, S, S),) ),
|
||||
('pow', (S, S, S), (3.14,), 'constant' ),
|
||||
('transpose', (1, 2, 3), (1, 2) ),
|
||||
('t', (1, 2), () ),
|
||||
('view', (S, S, S), (S*S, S), ),
|
||||
('view_as', (S, S, S), ((S*S, S),) ),
|
||||
('expand', (S, 1, S), (S, S, S) ),
|
||||
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size' ),
|
||||
('exp', (S, S, S), () ),
|
||||
('log', (S, S, S), () ),
|
||||
('log1p', (S, S, S), () ),
|
||||
('tanh', (S, S, S), () ),
|
||||
('sigmoid', (S, S, S), () ),
|
||||
('sinh', (S, S, S), () ),
|
||||
('cosh', (S, S, S), () ),
|
||||
('abs', (S, S, S), () ),
|
||||
('clamp', (S, S, S), (0, 1) ),
|
||||
('sqrt', (S, S, S), () ),
|
||||
('sin', (S, S, S), () ),
|
||||
('cos', (S, S, S), () ),
|
||||
('tan', (S, S, S), () ),
|
||||
('asin', (S, S, S), () ),
|
||||
('acos', (S, S, S), () ),
|
||||
('atan', (S, S, S), () ),
|
||||
('reciprocal', (S, S, S), () ),
|
||||
('round', (S, S, S), () ),
|
||||
('sign', (S, S, S), () ),
|
||||
('trunc', (S, S, S), () ),
|
||||
('floor', (S, S, S), () ),
|
||||
('ceil', (S, S, S), () ),
|
||||
('rsqrt', (S, S, S), () ),
|
||||
('fmod', (S, S, S), (1.5,) ),
|
||||
('remainder', (S, S, S), (1.5,) ),
|
||||
('lerp', (S, S, S), ((S, S, S), 0.4) ),
|
||||
('max', (S, S, S), () ),
|
||||
('max', (S, S, S), ((S, S, S),), 'elementwise' ),
|
||||
('min', (S, S, S), () ),
|
||||
('min', (S, S, S), ((S, S, S),), 'elementwise' ),
|
||||
('mean', (S, S, S), () ),
|
||||
('mean', (S, S, S), (1,), 'dim' ),
|
||||
('sum', (S, S, S), () ),
|
||||
('sum', (S, S, S), (1,), 'dim' ),
|
||||
('prod', (S, S, S), () ),
|
||||
('prod', (S, S, S), (1,), 'dim' ),
|
||||
('addmm', (S, M), ((S, S), (S, M)), ),
|
||||
('addmm', (S, M), (0.2, 0.6, (S, S), (S, M)), 'coef' ),
|
||||
('addbmm', (S, M), ((S, S, S), (S, S, M)), ),
|
||||
('addbmm', (S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef' ),
|
||||
('baddbmm', (S, S, M), ((S, S, S), (S, S, M)), ),
|
||||
('baddbmm', (S, S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef' ),
|
||||
('addmv', (S,), ((S, M), (M,)), ),
|
||||
('addmv', (S,), (0.2, 0.6, (S, M), (M,)), 'coef' ),
|
||||
('addr', (S, M), ((S,), (M,)), ),
|
||||
('addr', (S, M), (0.2, 0.6, (S,), (M,)), 'coef' ),
|
||||
('dot', (L,), ((L,),), ),
|
||||
('addcmul', (S, S), ((S, S), (S, S)) ),
|
||||
('addcmul', (S, S), (0.5, (S, S), (S, S)), 'scale' ),
|
||||
('addcdiv', (S, S), ((S, S), (S, S)) ),
|
||||
('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale' ),
|
||||
('norm', (S, S, S), (2,) ),
|
||||
('norm', (S, S, S), (2, 1), 'dim' ),
|
||||
('dist', (S, S, S), ((S, S, S),) ),
|
||||
('dist', (S, S, S), ((S, S, S), 4), '4' ),
|
||||
('index_select', (S, S, S), (0, index_variable(2, S)) ),
|
||||
('diag', (M, M), (), '2d' ),
|
||||
('diag', (M,), (), '1d' ),
|
||||
('tril', (M, M), () ),
|
||||
('triu', (M, M), () ),
|
||||
('clone', (S, M, S), () ),
|
||||
('permute', (1, 2, 3, 4), (0, 2, 3, 1) ),
|
||||
('select', (S, S, S), (1, 2) ),
|
||||
('narrow', (S, S, S), (1, 2, 2) ),
|
||||
('squeeze', (S, 1, S, 1), () ),
|
||||
('squeeze', (S, 1, S, 1), (1,), '1_dim' ),
|
||||
('squeeze', (S, 1, S, 1), (2,), 'not_1_dim' ),
|
||||
('unsqueeze', (S, S, S), (0,), 'first' ),
|
||||
('unsqueeze', (S, S, S), (1,), 'middle' ),
|
||||
('unsqueeze', (S, S, S), (3,), 'last' ),
|
||||
('masked_select', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False),) ),
|
||||
('masked_fill_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), 10) ),
|
||||
('masked_copy_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M)) ),
|
||||
('add', (S, S, S), ((S, S, S),)),
|
||||
('add', (S, S, S), (3.14,), 'constant'),
|
||||
('sub', (S, S, S), ((S, S, S),)),
|
||||
('sub', (S, S, S), (3.14,), 'constant'),
|
||||
('mul', (S, S, S), ((S, S, S),)),
|
||||
('mul', (S, S, S), (3.14,), 'constant'),
|
||||
('div', (S, S, S), ((S, S, S),)),
|
||||
('div', (S, S, S), (3.14,), 'constant'),
|
||||
('pow', (S, S, S), ((S, S, S),)),
|
||||
('pow', (S, S, S), (3.14,), 'constant'),
|
||||
('transpose', (1, 2, 3), (1, 2)),
|
||||
('t', (1, 2), ()),
|
||||
('view', (S, S, S), (S * S, S),),
|
||||
('view_as', (S, S, S), ((S * S, S),)),
|
||||
('expand', (S, 1, S), (S, S, S)),
|
||||
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size'),
|
||||
('exp', (S, S, S), ()),
|
||||
('log', (S, S, S), ()),
|
||||
('log1p', (S, S, S), ()),
|
||||
('tanh', (S, S, S), ()),
|
||||
('sigmoid', (S, S, S), ()),
|
||||
('sinh', (S, S, S), ()),
|
||||
('cosh', (S, S, S), ()),
|
||||
('abs', (S, S, S), ()),
|
||||
('clamp', (S, S, S), (0, 1)),
|
||||
('sqrt', (S, S, S), ()),
|
||||
('sin', (S, S, S), ()),
|
||||
('cos', (S, S, S), ()),
|
||||
('tan', (S, S, S), ()),
|
||||
('asin', (S, S, S), ()),
|
||||
('acos', (S, S, S), ()),
|
||||
('atan', (S, S, S), ()),
|
||||
('reciprocal', (S, S, S), ()),
|
||||
('round', (S, S, S), ()),
|
||||
('sign', (S, S, S), ()),
|
||||
('trunc', (S, S, S), ()),
|
||||
('floor', (S, S, S), ()),
|
||||
('ceil', (S, S, S), ()),
|
||||
('rsqrt', (S, S, S), ()),
|
||||
('fmod', (S, S, S), (1.5,)),
|
||||
('remainder', (S, S, S), (1.5,)),
|
||||
('lerp', (S, S, S), ((S, S, S), 0.4)),
|
||||
('max', (S, S, S), ()),
|
||||
('max', (S, S, S), ((S, S, S),), 'elementwise'),
|
||||
('min', (S, S, S), ()),
|
||||
('min', (S, S, S), ((S, S, S),), 'elementwise'),
|
||||
('mean', (S, S, S), ()),
|
||||
('mean', (S, S, S), (1,), 'dim'),
|
||||
('sum', (S, S, S), ()),
|
||||
('sum', (S, S, S), (1,), 'dim'),
|
||||
('prod', (S, S, S), ()),
|
||||
('prod', (S, S, S), (1,), 'dim'),
|
||||
('addmm', (S, M), ((S, S), (S, M)),),
|
||||
('addmm', (S, M), (0.2, 0.6, (S, S), (S, M)), 'coef'),
|
||||
('addbmm', (S, M), ((S, S, S), (S, S, M)),),
|
||||
('addbmm', (S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef'),
|
||||
('baddbmm', (S, S, M), ((S, S, S), (S, S, M)),),
|
||||
('baddbmm', (S, S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef'),
|
||||
('addmv', (S,), ((S, M), (M,)),),
|
||||
('addmv', (S,), (0.2, 0.6, (S, M), (M,)), 'coef'),
|
||||
('addr', (S, M), ((S,), (M,)),),
|
||||
('addr', (S, M), (0.2, 0.6, (S,), (M,)), 'coef'),
|
||||
('dot', (L,), ((L,),),),
|
||||
('addcmul', (S, S), ((S, S), (S, S))),
|
||||
('addcmul', (S, S), (0.5, (S, S), (S, S)), 'scale'),
|
||||
('addcdiv', (S, S), ((S, S), (S, S))),
|
||||
('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale'),
|
||||
('norm', (S, S, S), (2,)),
|
||||
('norm', (S, S, S), (2, 1), 'dim'),
|
||||
('dist', (S, S, S), ((S, S, S),)),
|
||||
('dist', (S, S, S), ((S, S, S), 4), '4'),
|
||||
('index_select', (S, S, S), (0, index_variable(2, S))),
|
||||
('diag', (M, M), (), '2d'),
|
||||
('diag', (M,), (), '1d'),
|
||||
('tril', (M, M), ()),
|
||||
('triu', (M, M), ()),
|
||||
('clone', (S, M, S), ()),
|
||||
('permute', (1, 2, 3, 4), (0, 2, 3, 1)),
|
||||
('select', (S, S, S), (1, 2)),
|
||||
('narrow', (S, S, S), (1, 2, 2)),
|
||||
('squeeze', (S, 1, S, 1), ()),
|
||||
('squeeze', (S, 1, S, 1), (1,), '1_dim'),
|
||||
('squeeze', (S, 1, S, 1), (2,), 'not_1_dim'),
|
||||
('unsqueeze', (S, S, S), (0,), 'first'),
|
||||
('unsqueeze', (S, S, S), (1,), 'middle'),
|
||||
('unsqueeze', (S, S, S), (3,), 'last'),
|
||||
('masked_select', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False),)),
|
||||
('masked_fill_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), 10)),
|
||||
('masked_copy_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M))),
|
||||
]
|
||||
# TODO: mm, bmm, mv, ger
|
||||
# TODO: max, min with dim (problem with indices)
|
||||
@ -946,6 +960,7 @@ method_tests = [
|
||||
def create_input(call_args):
|
||||
if not isinstance(call_args, tuple):
|
||||
call_args = (call_args,)
|
||||
|
||||
def map_arg(arg):
|
||||
if isinstance(arg, tuple) and not isinstance(arg[0], Variable):
|
||||
return Variable(torch.randn(*arg).double(), requires_grad=True)
|
||||
@ -976,8 +991,9 @@ ignore_inplace = set((
|
||||
for test in function_tests:
|
||||
cls, constructor_args, call_args = test[:3]
|
||||
test_name = 'test_' + cls.__name__ + ('_' + test[3] if len(test) == 4 else '')
|
||||
|
||||
def do_test(self, cls=cls, constructor_args=constructor_args,
|
||||
call_args=call_args, test_name=test_name):
|
||||
call_args=call_args, test_name=test_name):
|
||||
input = create_input(call_args)
|
||||
output = cls(*constructor_args)(*input)
|
||||
if not isinstance(output, tuple):
|
||||
@ -986,6 +1002,7 @@ for test in function_tests:
|
||||
if not o.requires_grad:
|
||||
continue
|
||||
analytical = get_analytical_jacobian(input, o)
|
||||
|
||||
def fn(input):
|
||||
tmp = cls(*constructor_args)(*input)
|
||||
if not isinstance(tmp, tuple):
|
||||
@ -1032,6 +1049,7 @@ EXCLUDE_FUNCTIONAL = {
|
||||
for test in method_tests:
|
||||
name, self_size, args = test[:3]
|
||||
test_name = 'test_' + name + ('_' + test[3] if len(test) == 4 else '')
|
||||
|
||||
def do_test(self, name=name, self_size=self_size, args=args, test_name=test_name):
|
||||
def check(name):
|
||||
self_variable = create_input((self_size,))[0]
|
||||
@ -1064,7 +1082,6 @@ for test in method_tests:
|
||||
if not 'only supports scalar' in e.args[0]:
|
||||
raise
|
||||
|
||||
|
||||
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
|
||||
setattr(TestAutograd, test_name, do_test)
|
||||
|
||||
|
@ -14,6 +14,7 @@ if not torch.cuda.is_available():
|
||||
import sys
|
||||
sys.exit()
|
||||
|
||||
|
||||
def is_floating(t):
|
||||
return type(t) in [torch.FloatTensor, torch.DoubleTensor,
|
||||
torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
|
||||
@ -31,7 +32,8 @@ types = [
|
||||
float_types = [
|
||||
torch.FloatTensor,
|
||||
torch.DoubleTensor
|
||||
] # TODO: add half...
|
||||
] # TODO: add half...
|
||||
|
||||
|
||||
def number(floating, integer, t):
|
||||
name = type(t).__name__
|
||||
@ -44,188 +46,204 @@ def number(floating, integer, t):
|
||||
S = 10
|
||||
M = 50
|
||||
|
||||
|
||||
def make_tensor(t, *sizes):
|
||||
return t(*sizes).copy_(torch.randn(*sizes))
|
||||
|
||||
|
||||
def small_2d(t):
|
||||
return make_tensor(t, S, S)
|
||||
|
||||
|
||||
def small_2d_scaled(t, scale=10):
|
||||
return make_tensor(t, S, S).mul(scale)
|
||||
|
||||
|
||||
def small_3d(t):
|
||||
return make_tensor(t, S, S, S)
|
||||
|
||||
|
||||
def medium_1d(t):
|
||||
return make_tensor(t, M)
|
||||
|
||||
|
||||
def medium_2d(t):
|
||||
return make_tensor(t, M, M)
|
||||
|
||||
|
||||
def medium_2d_scaled(t, scale=10):
|
||||
return make_tensor(t, M, M).mul(scale)
|
||||
|
||||
|
||||
def small_3d_ones(t):
|
||||
return t(S, S, S).copy_(torch.ones(S, S, S))
|
||||
|
||||
|
||||
def small_3d_positive(t):
|
||||
min_val = 1e-3 if is_floating(t) else 2
|
||||
return make_tensor(t, S, S, S).clamp_(min_val, 120)
|
||||
|
||||
|
||||
def small_3d_unique(t):
|
||||
return t(S, S, S).copy_(torch.range(1, S*S*S))
|
||||
return t(S, S, S).copy_(torch.range(1, S * S * S))
|
||||
|
||||
|
||||
def small_1d_lapack(t):
|
||||
return t(1, 3).copy_(torch.range(1, 3).view(3))
|
||||
|
||||
|
||||
def small_2d_lapack(t):
|
||||
return t(3, 3).copy_(torch.range(1, 9).view(3, 3))
|
||||
|
||||
|
||||
def small_2d_lapack_skinny(t):
|
||||
return t(3, 4).copy_(torch.range(1, 12).view(3, 4))
|
||||
|
||||
|
||||
def small_2d_lapack_fat(t):
|
||||
return t(4, 3).copy_(torch.range(1, 12).view(4, 3))
|
||||
|
||||
|
||||
def new_t(*sizes):
|
||||
def tmp(t):
|
||||
return t(*sizes).copy_(torch.randn(*sizes))
|
||||
return tmp
|
||||
|
||||
tests = [
|
||||
('add', small_3d, lambda t: [number(3.14, 3, t)] ),
|
||||
('add', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('add', small_3d, lambda t: [number(0.2, 2, t), small_3d_positive(t)], 'scalar_tensor' ),
|
||||
('sub', small_3d, lambda t: [number(3.14, 3, t)], ),
|
||||
('sub', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('mul', small_3d, lambda t: [number(3.14, 3, t)], ),
|
||||
('mul', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('div', small_3d, lambda t: [number(3.14, 3, t)], ),
|
||||
('div', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('pow', small_3d, lambda t: [number(3.14, 3, t)], None, float_types),
|
||||
('pow', small_3d, lambda t: [small_3d(t).abs_()], 'tensor', float_types),
|
||||
('addbmm', small_2d, lambda t: [small_3d(t), small_3d(t)], None, float_types),
|
||||
('addbmm', small_2d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar' ),
|
||||
('addbmm', small_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars' ),
|
||||
('baddbmm', small_3d, lambda t: [small_3d(t), small_3d(t)], ),
|
||||
('baddbmm', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar' ),
|
||||
('baddbmm', small_3d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars' ),
|
||||
('addcdiv', small_2d_lapack, lambda t: [small_2d_lapack(t).mul(2), small_2d_lapack(t)], ),
|
||||
('addcdiv', small_2d_lapack, lambda t: [number(2.8, 1, t), small_2d_lapack(t).mul(2), small_2d_lapack(t)], 'scalar' ),
|
||||
('addcmul', small_3d, lambda t: [small_3d(t), small_3d(t)], ),
|
||||
('addcmul', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar' ),
|
||||
('addmm', medium_2d, lambda t: [medium_2d(t), medium_2d(t)], ),
|
||||
('addmm', medium_2d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'scalar' ),
|
||||
('addmm', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'two_scalars' ),
|
||||
('addmv', medium_1d, lambda t: [medium_2d(t), medium_1d(t)], ),
|
||||
('addmv', medium_1d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'scalar' ),
|
||||
('addmv', medium_1d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'two_scalars' ),
|
||||
('addr', medium_2d, lambda t: [medium_1d(t), medium_1d(t)], ),
|
||||
('addr', medium_2d, lambda t: [number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'scalar' ),
|
||||
('addr', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'two_scalars' ),
|
||||
('atan2', medium_2d, lambda t: [medium_2d(t)], None, float_types),
|
||||
('fmod', small_3d, lambda t: [3], 'value' ),
|
||||
('fmod', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('chunk', medium_2d, lambda t: [4], ),
|
||||
('chunk', medium_2d, lambda t: [4, 1], 'dim' ),
|
||||
('clamp', medium_2d_scaled, lambda t: [-1, 5], ),
|
||||
('clone', medium_2d, lambda t: [], ),
|
||||
('contiguous', medium_2d, lambda t: [], ),
|
||||
('cross', new_t(M, 3, M), lambda t: [new_t(M, 3, M)(t)], ),
|
||||
('cumprod', small_3d, lambda t: [1], ),
|
||||
('cumsum', small_3d, lambda t: [1], ),
|
||||
('dim', small_3d, lambda t: [], ),
|
||||
('dist', small_2d, lambda t: [small_2d(t)], ),
|
||||
('dist', small_2d, lambda t: [small_2d(t), 3], '3_norm' ),
|
||||
('dist', small_2d, lambda t: [small_2d(t), 2.5], '2_5_norm' ),
|
||||
('dot', medium_1d, lambda t: [medium_1d(t)], ),
|
||||
('element_size', medium_1d, lambda t: [], ),
|
||||
('eq', small_3d_ones, lambda t: [small_3d(t)], ),
|
||||
('eq', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
|
||||
('ne', small_3d_ones, lambda t: [small_3d(t)], ),
|
||||
('ne', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
|
||||
('equal', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
|
||||
('equal', small_3d_ones, lambda t: [small_3d(t)], ),
|
||||
('expand', new_t(M, 1, M), lambda t: [M, 4, M], ),
|
||||
('expand_as', new_t(M, 1, M), lambda t: [new_t(M, 4, M)(t)], ),
|
||||
('fill', medium_2d, lambda t: [number(3.14, 3, t)], ),
|
||||
('ge', medium_2d, lambda t: [medium_2d(t)], ),
|
||||
('le', medium_2d, lambda t: [medium_2d(t)], ),
|
||||
('gt', medium_2d, lambda t: [medium_2d(t)], ),
|
||||
('lt', medium_2d, lambda t: [medium_2d(t)], ),
|
||||
('is_contiguous', medium_2d, lambda t: [], ),
|
||||
('add', small_3d, lambda t: [number(3.14, 3, t)]),
|
||||
('add', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('add', small_3d, lambda t: [number(0.2, 2, t), small_3d_positive(t)], 'scalar_tensor'),
|
||||
('sub', small_3d, lambda t: [number(3.14, 3, t)],),
|
||||
('sub', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('mul', small_3d, lambda t: [number(3.14, 3, t)],),
|
||||
('mul', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('div', small_3d, lambda t: [number(3.14, 3, t)],),
|
||||
('div', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('pow', small_3d, lambda t: [number(3.14, 3, t)], None, float_types),
|
||||
('pow', small_3d, lambda t: [small_3d(t).abs_()], 'tensor', float_types),
|
||||
('addbmm', small_2d, lambda t: [small_3d(t), small_3d(t)], None, float_types),
|
||||
('addbmm', small_2d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'),
|
||||
('addbmm', small_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars'),
|
||||
('baddbmm', small_3d, lambda t: [small_3d(t), small_3d(t)],),
|
||||
('baddbmm', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'),
|
||||
('baddbmm', small_3d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars'),
|
||||
('addcdiv', small_2d_lapack, lambda t: [small_2d_lapack(t).mul(2), small_2d_lapack(t)],),
|
||||
('addcdiv', small_2d_lapack, lambda t: [number(2.8, 1, t),
|
||||
small_2d_lapack(t).mul(2), small_2d_lapack(t)], 'scalar'),
|
||||
('addcmul', small_3d, lambda t: [small_3d(t), small_3d(t)],),
|
||||
('addcmul', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'),
|
||||
('addmm', medium_2d, lambda t: [medium_2d(t), medium_2d(t)],),
|
||||
('addmm', medium_2d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'scalar'),
|
||||
('addmm', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'two_scalars'),
|
||||
('addmv', medium_1d, lambda t: [medium_2d(t), medium_1d(t)],),
|
||||
('addmv', medium_1d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'scalar'),
|
||||
('addmv', medium_1d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'two_scalars'),
|
||||
('addr', medium_2d, lambda t: [medium_1d(t), medium_1d(t)],),
|
||||
('addr', medium_2d, lambda t: [number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'scalar'),
|
||||
('addr', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'two_scalars'),
|
||||
('atan2', medium_2d, lambda t: [medium_2d(t)], None, float_types),
|
||||
('fmod', small_3d, lambda t: [3], 'value'),
|
||||
('fmod', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('chunk', medium_2d, lambda t: [4],),
|
||||
('chunk', medium_2d, lambda t: [4, 1], 'dim'),
|
||||
('clamp', medium_2d_scaled, lambda t: [-1, 5],),
|
||||
('clone', medium_2d, lambda t: [],),
|
||||
('contiguous', medium_2d, lambda t: [],),
|
||||
('cross', new_t(M, 3, M), lambda t: [new_t(M, 3, M)(t)],),
|
||||
('cumprod', small_3d, lambda t: [1],),
|
||||
('cumsum', small_3d, lambda t: [1],),
|
||||
('dim', small_3d, lambda t: [],),
|
||||
('dist', small_2d, lambda t: [small_2d(t)],),
|
||||
('dist', small_2d, lambda t: [small_2d(t), 3], '3_norm'),
|
||||
('dist', small_2d, lambda t: [small_2d(t), 2.5], '2_5_norm'),
|
||||
('dot', medium_1d, lambda t: [medium_1d(t)],),
|
||||
('element_size', medium_1d, lambda t: [],),
|
||||
('eq', small_3d_ones, lambda t: [small_3d(t)],),
|
||||
('eq', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'),
|
||||
('ne', small_3d_ones, lambda t: [small_3d(t)],),
|
||||
('ne', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'),
|
||||
('equal', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'),
|
||||
('equal', small_3d_ones, lambda t: [small_3d(t)],),
|
||||
('expand', new_t(M, 1, M), lambda t: [M, 4, M],),
|
||||
('expand_as', new_t(M, 1, M), lambda t: [new_t(M, 4, M)(t)],),
|
||||
('fill', medium_2d, lambda t: [number(3.14, 3, t)],),
|
||||
('ge', medium_2d, lambda t: [medium_2d(t)],),
|
||||
('le', medium_2d, lambda t: [medium_2d(t)],),
|
||||
('gt', medium_2d, lambda t: [medium_2d(t)],),
|
||||
('lt', medium_2d, lambda t: [medium_2d(t)],),
|
||||
('is_contiguous', medium_2d, lambda t: [],),
|
||||
# TODO: can't check negative case - GPU copy will be contiguous
|
||||
('is_same_size', medium_2d, lambda t: [small_3d(t)], 'negative' ),
|
||||
('is_same_size', medium_2d, lambda t: [medium_2d(t)], 'positive' ),
|
||||
('is_set_to', medium_2d, lambda t: [medium_2d(t)], ),
|
||||
('is_same_size', medium_2d, lambda t: [small_3d(t)], 'negative'),
|
||||
('is_same_size', medium_2d, lambda t: [medium_2d(t)], 'positive'),
|
||||
('is_set_to', medium_2d, lambda t: [medium_2d(t)],),
|
||||
# TODO: positive case
|
||||
('kthvalue', small_3d_unique, lambda t: [3], ),
|
||||
('kthvalue', small_3d_unique, lambda t: [3, 1], 'dim' ),
|
||||
('lerp', small_3d, lambda t: [small_3d(t), 0.3], ),
|
||||
('max', small_3d_unique, lambda t: [], ),
|
||||
('max', small_3d_unique, lambda t: [1], 'dim' ),
|
||||
('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise' ),
|
||||
('min', small_3d_unique, lambda t: [], ),
|
||||
('min', small_3d_unique, lambda t: [1], 'dim' ),
|
||||
('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise' ),
|
||||
('mean', small_3d, lambda t: [], ),
|
||||
('mean', small_3d, lambda t: [1], 'dim' ),
|
||||
('mode', small_3d, lambda t: [], ),
|
||||
('mode', small_3d, lambda t: [1], 'dim' ),
|
||||
('remainder', small_3d, lambda t: [3], 'value' ),
|
||||
('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor' ),
|
||||
('std', small_3d, lambda t: [], ),
|
||||
('std', small_3d, lambda t: [1], 'dim' ),
|
||||
('var', small_3d, lambda t: [], ),
|
||||
('var', small_3d, lambda t: [1], 'dim' ),
|
||||
('ndimension', small_3d, lambda t: [], ),
|
||||
('nelement', small_3d, lambda t: [], ),
|
||||
('numel', small_3d, lambda t: [], ),
|
||||
('narrow', small_3d, lambda t: [1, 3, 2], ),
|
||||
('nonzero', small_3d, lambda t: [], ),
|
||||
('norm', small_3d, lambda t: [], ),
|
||||
('norm', small_3d, lambda t: [3], '3_norm' ),
|
||||
('norm', small_3d, lambda t: [3, 0], '3_norm_dim' ),
|
||||
('ones', small_3d, lambda t: [1, 2, 3, 4, 5], ),
|
||||
('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0], ),
|
||||
('prod', small_3d, lambda t: [], ),
|
||||
('prod', small_3d, lambda t: [1], 'dim' ),
|
||||
('sum', small_2d, lambda t: [], ),
|
||||
('sum', small_3d, lambda t: [1], 'dim' ),
|
||||
('renorm', small_3d, lambda t: [2, 1, 1], '2_norm' ),
|
||||
('renorm', small_3d, lambda t: [1.5, 1, 1], '1_5_norm' ),
|
||||
('repeat', small_2d, lambda t: [2, 2, 2], ),
|
||||
('size', new_t(1, 2, 3, 4), lambda t: [], ),
|
||||
('sort', small_3d_unique, lambda t: [], ),
|
||||
('sort', small_3d_unique, lambda t: [1], 'dim' ),
|
||||
('sort', small_3d_unique, lambda t: [1, True], 'dim_descending'),
|
||||
('split', small_3d, lambda t: [2], ),
|
||||
('split', small_3d, lambda t: [2, 1], 'dim' ),
|
||||
('squeeze', new_t(1, 2, 1, 4), lambda t: [], ),
|
||||
('squeeze', new_t(1, 2, 1, 4), lambda t: [2], 'dim' ),
|
||||
('t', new_t(1, 2), lambda t: [], ),
|
||||
('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2], ),
|
||||
('to_list', small_3d, lambda t: [], ),
|
||||
('topk', small_3d, lambda t: [2, 1, False, True], 'dim_sort' ),
|
||||
('topk', small_3d, lambda t: [2, 1, True, True], 'dim_desc_sort' ),
|
||||
('trace', medium_2d, lambda t: [], ),
|
||||
('tril', medium_2d, lambda t: [], ),
|
||||
('tril', medium_2d, lambda t: [2], 'positive' ),
|
||||
('tril', medium_2d, lambda t: [-2], 'negative' ),
|
||||
('triu', medium_2d, lambda t: [], ),
|
||||
('triu', medium_2d, lambda t: [2], 'positive' ),
|
||||
('triu', medium_2d, lambda t: [-2], 'negative' ),
|
||||
('view', small_3d, lambda t: [100, 10], ),
|
||||
('view_as', small_3d, lambda t: [t(100, 10)], ),
|
||||
('zero', small_3d, lambda t: [], ),
|
||||
('zeros', small_3d, lambda t: [1, 2, 3, 4], ),
|
||||
('rsqrt', lambda t: small_3d(t) + 1, lambda t: [], None, float_types),
|
||||
('sinh', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types),
|
||||
('tan', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types),
|
||||
('kthvalue', small_3d_unique, lambda t: [3],),
|
||||
('kthvalue', small_3d_unique, lambda t: [3, 1], 'dim'),
|
||||
('lerp', small_3d, lambda t: [small_3d(t), 0.3],),
|
||||
('max', small_3d_unique, lambda t: [],),
|
||||
('max', small_3d_unique, lambda t: [1], 'dim'),
|
||||
('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
|
||||
('min', small_3d_unique, lambda t: [],),
|
||||
('min', small_3d_unique, lambda t: [1], 'dim'),
|
||||
('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
|
||||
('mean', small_3d, lambda t: [],),
|
||||
('mean', small_3d, lambda t: [1], 'dim'),
|
||||
('mode', small_3d, lambda t: [],),
|
||||
('mode', small_3d, lambda t: [1], 'dim'),
|
||||
('remainder', small_3d, lambda t: [3], 'value'),
|
||||
('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
|
||||
('std', small_3d, lambda t: [],),
|
||||
('std', small_3d, lambda t: [1], 'dim'),
|
||||
('var', small_3d, lambda t: [],),
|
||||
('var', small_3d, lambda t: [1], 'dim'),
|
||||
('ndimension', small_3d, lambda t: [],),
|
||||
('nelement', small_3d, lambda t: [],),
|
||||
('numel', small_3d, lambda t: [],),
|
||||
('narrow', small_3d, lambda t: [1, 3, 2],),
|
||||
('nonzero', small_3d, lambda t: [],),
|
||||
('norm', small_3d, lambda t: [],),
|
||||
('norm', small_3d, lambda t: [3], '3_norm'),
|
||||
('norm', small_3d, lambda t: [3, 0], '3_norm_dim'),
|
||||
('ones', small_3d, lambda t: [1, 2, 3, 4, 5],),
|
||||
('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0],),
|
||||
('prod', small_3d, lambda t: [],),
|
||||
('prod', small_3d, lambda t: [1], 'dim'),
|
||||
('sum', small_2d, lambda t: [],),
|
||||
('sum', small_3d, lambda t: [1], 'dim'),
|
||||
('renorm', small_3d, lambda t: [2, 1, 1], '2_norm'),
|
||||
('renorm', small_3d, lambda t: [1.5, 1, 1], '1_5_norm'),
|
||||
('repeat', small_2d, lambda t: [2, 2, 2],),
|
||||
('size', new_t(1, 2, 3, 4), lambda t: [],),
|
||||
('sort', small_3d_unique, lambda t: [],),
|
||||
('sort', small_3d_unique, lambda t: [1], 'dim'),
|
||||
('sort', small_3d_unique, lambda t: [1, True], 'dim_descending'),
|
||||
('split', small_3d, lambda t: [2],),
|
||||
('split', small_3d, lambda t: [2, 1], 'dim'),
|
||||
('squeeze', new_t(1, 2, 1, 4), lambda t: [],),
|
||||
('squeeze', new_t(1, 2, 1, 4), lambda t: [2], 'dim'),
|
||||
('t', new_t(1, 2), lambda t: [],),
|
||||
('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2],),
|
||||
('to_list', small_3d, lambda t: [],),
|
||||
('topk', small_3d, lambda t: [2, 1, False, True], 'dim_sort'),
|
||||
('topk', small_3d, lambda t: [2, 1, True, True], 'dim_desc_sort'),
|
||||
('trace', medium_2d, lambda t: [],),
|
||||
('tril', medium_2d, lambda t: [],),
|
||||
('tril', medium_2d, lambda t: [2], 'positive'),
|
||||
('tril', medium_2d, lambda t: [-2], 'negative'),
|
||||
('triu', medium_2d, lambda t: [],),
|
||||
('triu', medium_2d, lambda t: [2], 'positive'),
|
||||
('triu', medium_2d, lambda t: [-2], 'negative'),
|
||||
('view', small_3d, lambda t: [100, 10],),
|
||||
('view_as', small_3d, lambda t: [t(100, 10)],),
|
||||
('zero', small_3d, lambda t: [],),
|
||||
('zeros', small_3d, lambda t: [1, 2, 3, 4],),
|
||||
('rsqrt', lambda t: small_3d(t) + 1, lambda t: [], None, float_types),
|
||||
('sinh', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types),
|
||||
('tan', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], None, float_types),
|
||||
# lapack tests
|
||||
('qr', small_2d_lapack, lambda t: [], 'square', float_types),
|
||||
('qr', small_2d_lapack_skinny, lambda t: [], 'skinny', float_types),
|
||||
('qr', small_2d_lapack_fat, lambda t: [], 'fat', float_types),
|
||||
('qr', small_2d_lapack, lambda t: [], 'square', float_types),
|
||||
('qr', small_2d_lapack_skinny, lambda t: [], 'skinny', float_types),
|
||||
('qr', small_2d_lapack_fat, lambda t: [], 'fat', float_types),
|
||||
|
||||
]
|
||||
|
||||
@ -275,6 +293,8 @@ for fn in simple_pointwise_float:
|
||||
tests.append((fn, small_3d, lambda t: [], None, float_types))
|
||||
|
||||
_cycles_per_ms = None
|
||||
|
||||
|
||||
def get_cycles_per_ms():
|
||||
"""Approximate number of cycles per millisecond for torch.cuda._sleep"""
|
||||
global _cycles_per_ms
|
||||
@ -288,6 +308,7 @@ def get_cycles_per_ms():
|
||||
_cycles_per_ms = 1000000 / start.elapsed_time(end)
|
||||
return _cycles_per_ms
|
||||
|
||||
|
||||
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5):
|
||||
def tmp(self):
|
||||
cpu_tensor = tensor_constructor(t)
|
||||
@ -314,6 +335,7 @@ def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5):
|
||||
self.assertEqual(cpu_result, gpu_result, precision)
|
||||
return tmp
|
||||
|
||||
|
||||
class TestCuda(TestCase):
|
||||
|
||||
def test_autogpu(self):
|
||||
@ -412,7 +434,7 @@ class TestCuda(TestCase):
|
||||
y_cuda = y.cuda(1)
|
||||
result = comm.reduce_add((x_cuda, y_cuda))
|
||||
self.assertEqual(result.get_device(), 0)
|
||||
self.assertEqual(result.cpu(), x+y)
|
||||
self.assertEqual(result.cpu(), x + y)
|
||||
|
||||
def _test_scatter(self, input, chunk_sizes=None, dim=0):
|
||||
if torch.cuda.device_count() < 2:
|
||||
@ -473,7 +495,7 @@ class TestCuda(TestCase):
|
||||
self._test_gather(1)
|
||||
|
||||
def test_from_sequence(self):
|
||||
seq = [list(range(i*4,i*4+4)) for i in range(5)]
|
||||
seq = [list(range(i * 4, i * 4 + 4)) for i in range(5)]
|
||||
reference = torch.range(0, 19).resize_(5, 4)
|
||||
for t in types:
|
||||
cuda_type = get_gpu_type(t)
|
||||
@ -526,6 +548,7 @@ class TestCuda(TestCase):
|
||||
@unittest.skipIf(torch.cuda.device_count() < 2, "detected only one GPU")
|
||||
def test_multigpu_serialization_remap(self):
|
||||
x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
|
||||
|
||||
def gpu_remap(storage, location):
|
||||
if location == 'cuda:1':
|
||||
return storage.cuda(0)
|
||||
@ -666,7 +689,8 @@ for decl in tests:
|
||||
if not hasattr(tensor, name_inner):
|
||||
continue
|
||||
if not hasattr(gpu_tensor, name_inner):
|
||||
print("Ignoring {}, because it's not implemented by torch.cuda.{}".format(name_inner, gpu_tensor.__class__.__name__))
|
||||
print("Ignoring {}, because it's not implemented by torch.cuda.{}".format(
|
||||
name_inner, gpu_tensor.__class__.__name__))
|
||||
continue
|
||||
|
||||
test_name = 'test_' + t.__name__ + '_' + name_inner
|
||||
|
@ -27,11 +27,12 @@ class TestTensorDataset(TestCase):
|
||||
l = torch.randn(15)
|
||||
source = TensorDataset(t, l)
|
||||
for i in range(15):
|
||||
self.assertEqual(t[i:i+1], source[i][0])
|
||||
self.assertEqual(l[i:i+1], source[i][1])
|
||||
self.assertEqual(t[i:i + 1], source[i][0])
|
||||
self.assertEqual(l[i:i + 1], source[i][1])
|
||||
|
||||
|
||||
class ErrorDataset(Dataset):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
@ -50,9 +51,9 @@ class TestDataLoader(TestCase):
|
||||
batch_size = loader.batch_size
|
||||
for i, (sample, target) in enumerate(loader):
|
||||
idx = i * batch_size
|
||||
self.assertEqual(sample, self.data[idx:idx+batch_size])
|
||||
self.assertEqual(target, self.labels[idx:idx+batch_size].view(-1, 1))
|
||||
self.assertEqual(i, math.floor((len(self.dataset)-1) / batch_size))
|
||||
self.assertEqual(sample, self.data[idx:idx + batch_size])
|
||||
self.assertEqual(target, self.labels[idx:idx + batch_size].view(-1, 1))
|
||||
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
|
||||
|
||||
def _test_shuffle(self, loader):
|
||||
found_data = {i: 0 for i in range(self.data.size(0))}
|
||||
@ -67,9 +68,9 @@ class TestDataLoader(TestCase):
|
||||
break
|
||||
self.assertEqual(target, self.labels.narrow(0, data_point_idx, 1))
|
||||
found_labels[data_point_idx] += 1
|
||||
self.assertEqual(sum(found_data.values()), (i+1) * batch_size)
|
||||
self.assertEqual(sum(found_labels.values()), (i+1) * batch_size)
|
||||
self.assertEqual(i, math.floor((len(self.dataset)-1) / batch_size))
|
||||
self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
|
||||
self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
|
||||
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
|
||||
|
||||
def _test_error(self, loader):
|
||||
it = iter(loader)
|
||||
@ -81,10 +82,9 @@ class TestDataLoader(TestCase):
|
||||
errors += 1
|
||||
except StopIteration:
|
||||
self.assertEqual(errors,
|
||||
math.ceil(float(len(loader.dataset))/loader.batch_size))
|
||||
math.ceil(float(len(loader.dataset)) / loader.batch_size))
|
||||
return
|
||||
|
||||
|
||||
def test_sequential(self):
|
||||
self._test_sequential(DataLoader(self.dataset))
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -16,8 +16,8 @@ from common import TestCase, run_tests
|
||||
|
||||
HAS_SHM_FILES = os.path.isdir('/dev/shm')
|
||||
TEST_CUDA_IPC = torch.cuda.is_available() and \
|
||||
sys.version_info[0] == 3 and \
|
||||
sys.platform != 'darwin'
|
||||
sys.version_info[0] == 3 and \
|
||||
sys.platform != 'darwin'
|
||||
|
||||
|
||||
def simple_fill(queue, event):
|
||||
@ -74,7 +74,7 @@ def autograd_sharing(queue, ready, master_modified):
|
||||
master_modified.wait()
|
||||
|
||||
expected_var = torch.range(1, 25).view(5, 5)
|
||||
expected_var[0,0] = 1000
|
||||
expected_var[0, 0] = 1000
|
||||
is_ok = var.data.equal(expected_var)
|
||||
var.data[:] = torch.ones(5, 5)
|
||||
|
||||
@ -189,7 +189,7 @@ class TestMultiprocessing(TestCase):
|
||||
def _test_preserve_sharing(self, ctx=mp, repeat=1):
|
||||
def do_test():
|
||||
x = torch.randn(5, 5)
|
||||
data = [x.storage(), x.storage()[1:4], x, x[2], x[:,1]]
|
||||
data = [x.storage(), x.storage()[1:4], x, x[2], x[:, 1]]
|
||||
q = ctx.Queue()
|
||||
q.put(data)
|
||||
new_data = q.get()
|
||||
@ -268,6 +268,7 @@ class TestMultiprocessing(TestCase):
|
||||
|
||||
def test_inherit_tensor(self):
|
||||
class SubProcess(mp.Process):
|
||||
|
||||
def __init__(self, tensor):
|
||||
super(SubProcess, self).__init__()
|
||||
self.tensor = tensor
|
||||
@ -286,7 +287,6 @@ class TestMultiprocessing(TestCase):
|
||||
torch.cuda.FloatTensor([1]) # initialize CUDA outside of leak checker
|
||||
self._test_sharing(mp.get_context('spawn'), torch.cuda.FloatTensor)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
|
||||
def test_cuda_small_tensors(self):
|
||||
# Check multiple small tensors which will likely use the same
|
||||
@ -359,7 +359,7 @@ class TestMultiprocessing(TestCase):
|
||||
queue.put(var)
|
||||
|
||||
ready.wait()
|
||||
var.data[0,0] = 1000
|
||||
var.data[0, 0] = 1000
|
||||
if var.grad is not None:
|
||||
var.grad.data[:] = torch.ones(5, 5) * 4
|
||||
master_modified.set()
|
||||
@ -380,8 +380,8 @@ class TestMultiprocessing(TestCase):
|
||||
]
|
||||
for requires_grad, volatile in configs:
|
||||
var = Variable(torch.range(1, 25).view(5, 5),
|
||||
requires_grad=requires_grad,
|
||||
volatile=volatile)
|
||||
requires_grad=requires_grad,
|
||||
volatile=volatile)
|
||||
self._test_autograd_sharing(var)
|
||||
|
||||
def test_parameter_sharing(self):
|
||||
|
@ -16,8 +16,10 @@ from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
|
||||
module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PRECISION
|
||||
from common import freeze_rng_state, run_tests
|
||||
|
||||
|
||||
def default_tensor_type(type):
|
||||
type_str = torch.typename(type)
|
||||
|
||||
def decorator(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
@ -30,9 +32,12 @@ def default_tensor_type(type):
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
class InputVariableMixin(object):
|
||||
|
||||
def _get_input(self):
|
||||
input = TestBase._get_input(self)
|
||||
|
||||
def map_variables(i):
|
||||
if isinstance(i, Variable):
|
||||
return i
|
||||
@ -44,6 +49,7 @@ class InputVariableMixin(object):
|
||||
|
||||
|
||||
class NewModuleTest(InputVariableMixin, ModuleTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(NewModuleTest, self).__init__(*args, **kwargs)
|
||||
self.cudnn = kwargs.get('cudnn', False)
|
||||
@ -356,21 +362,21 @@ class TestNN(NNTestCase):
|
||||
|
||||
def _test_dropout(self, cls, input):
|
||||
p = 0.2
|
||||
input.fill_(1-p)
|
||||
input.fill_(1 - p)
|
||||
|
||||
module = cls(p)
|
||||
input_var = Variable(input, requires_grad=True)
|
||||
output = module(input_var)
|
||||
self.assertLess(abs(output.data.mean() - (1-p)), 0.05)
|
||||
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
|
||||
output.backward(input)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1-p)), 0.05)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
|
||||
|
||||
module = cls(p, True)
|
||||
input_var = Variable(input.clone(), requires_grad=True)
|
||||
output = module(input_var + 0)
|
||||
self.assertLess(abs(output.data.mean() - (1-p)), 0.05)
|
||||
self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
|
||||
output.backward(input)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1-p)), 0.05)
|
||||
self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)
|
||||
|
||||
# Check that these don't raise errors
|
||||
module.__repr__()
|
||||
@ -379,7 +385,9 @@ class TestNN(NNTestCase):
|
||||
def test_parameters(self):
|
||||
def num_params(module):
|
||||
return len(list(module.parameters()))
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.l1 = l
|
||||
@ -394,6 +402,7 @@ class TestNN(NNTestCase):
|
||||
|
||||
def test_modules(self):
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.l1 = l
|
||||
@ -455,6 +464,7 @@ class TestNN(NNTestCase):
|
||||
def test_non_leaf_parameters(self):
|
||||
l1 = nn.Linear(10, 10)
|
||||
l2 = nn.Linear(10, 10)
|
||||
|
||||
def assign_weight():
|
||||
l2.weight = l1.weight + 2
|
||||
self.assertRaises(TypeError, assign_weight)
|
||||
@ -462,8 +472,8 @@ class TestNN(NNTestCase):
|
||||
l2.weight = Parameter(torch.randn(10, 10))
|
||||
|
||||
def test_embedding_padding_idx(self):
|
||||
embedding = nn.Embedding(10, 20, padding_idx = 0)
|
||||
input = Variable(torch.LongTensor([[0,2,4,5],[4,3,0,9]]))
|
||||
embedding = nn.Embedding(10, 20, padding_idx=0)
|
||||
input = Variable(torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]]))
|
||||
output = embedding(input)
|
||||
self.assertEqual(output[0][0].sum().data[0], 0)
|
||||
self.assertEqual(output[1][2].sum().data[0], 0)
|
||||
@ -493,14 +503,14 @@ class TestNN(NNTestCase):
|
||||
def expected_indices(dim):
|
||||
if dim == 1:
|
||||
return torch.DoubleTensor([1, 3])
|
||||
lower_dim = expected_indices(dim-1)
|
||||
lower_dim = expected_indices(dim - 1)
|
||||
lower_dim = lower_dim.view(1, *lower_dim.size())
|
||||
return torch.cat((lower_dim+4, lower_dim+12), 0)
|
||||
return torch.cat((lower_dim + 4, lower_dim + 12), 0)
|
||||
|
||||
def expected_grad(dim):
|
||||
if dim == 1:
|
||||
return torch.DoubleTensor([0, 1, 0, 1])
|
||||
lower_dim_grad = expected_grad(dim-1)
|
||||
lower_dim_grad = expected_grad(dim - 1)
|
||||
grad = lower_dim_grad.view(1, *lower_dim_grad.size())
|
||||
zero = torch.zeros(grad.size())
|
||||
return torch.cat((zero, grad, zero, grad), 0)
|
||||
@ -671,7 +681,9 @@ class TestNN(NNTestCase):
|
||||
def test_data_parallel_nested_output(self):
|
||||
def fn(input):
|
||||
return [input, (input.sin(), input.cos(), [input.add(1)]), input]
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def forward(self, input):
|
||||
return fn(input)
|
||||
i = Variable(torch.randn(2, 2).float().cuda(1))
|
||||
@ -690,7 +702,9 @@ class TestNN(NNTestCase):
|
||||
def test_data_parallel_nested_input(self):
|
||||
def fn(input):
|
||||
return input[1][0]
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def forward(self, input):
|
||||
return fn(input)
|
||||
i = Variable(torch.randn(20, 3).float().cuda(1))
|
||||
@ -712,7 +726,7 @@ class TestNN(NNTestCase):
|
||||
def test_state_dict(self):
|
||||
l = nn.Linear(5, 5)
|
||||
block = nn.Module()
|
||||
block.conv=nn.Conv2d(3, 3, 3, bias=False)
|
||||
block.conv = nn.Conv2d(3, 3, 3, bias=False)
|
||||
net = nn.Module()
|
||||
net.linear1 = l
|
||||
net.linear2 = l
|
||||
@ -781,6 +795,7 @@ class TestNN(NNTestCase):
|
||||
|
||||
def test_parameter_assignment(self):
|
||||
l = nn.Linear(5, 5)
|
||||
|
||||
def num_params():
|
||||
return len(list(l.parameters()))
|
||||
self.assertEqual(num_params(), 2)
|
||||
@ -814,9 +829,9 @@ class TestNN(NNTestCase):
|
||||
# These sizes require huge cuDNN workspaces. Make sure we choose a
|
||||
# reasonable algorithm that does not run out of memory
|
||||
sizes = [
|
||||
(1, 256, 109, 175),
|
||||
(1, 256, 80, 128),
|
||||
(1, 256, 120, 192),
|
||||
(1, 256, 109, 175),
|
||||
(1, 256, 80, 128),
|
||||
(1, 256, 120, 192),
|
||||
]
|
||||
dtype = torch.cuda.FloatTensor
|
||||
|
||||
@ -887,7 +902,7 @@ class TestNN(NNTestCase):
|
||||
small_t = torch.rand(1, 1, 5, 5)
|
||||
for i in range(0, 4, 2):
|
||||
for j in range(0, 4, 2):
|
||||
small_t[:,:,i,j] = 100
|
||||
small_t[:, :, i, j] = 100
|
||||
output_small, indices_small = m(Variable(small_t))
|
||||
for h in range(3, 10):
|
||||
for w in range(3, 10):
|
||||
@ -900,10 +915,11 @@ class TestNN(NNTestCase):
|
||||
mu(output_small, indices_small, output_size=size)
|
||||
else:
|
||||
self.assertRaises(ValueError, lambda:
|
||||
mu(output_small, indices_small, (h, w)))
|
||||
mu(output_small, indices_small, (h, w)))
|
||||
|
||||
def test_container_copy(self):
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.linear = nn.Linear(4, 5)
|
||||
@ -955,7 +971,7 @@ class TestNN(NNTestCase):
|
||||
for i in range(6):
|
||||
hx, cx = lstm(input, (hx, cx))
|
||||
|
||||
(hx+cx).sum().backward()
|
||||
(hx + cx).sum().backward()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
||||
@default_tensor_type(torch.FloatTensor) # FIXME: just until torch.cuda.DoubleTensor.sum() implemented
|
||||
@ -987,9 +1003,9 @@ class TestNN(NNTestCase):
|
||||
output, hy = rnn(input, hx)
|
||||
# FIXME this is because of a pytorch bug
|
||||
if is_lstm:
|
||||
fake_loss = 0*(hy[0] + hy[1]).sum()
|
||||
fake_loss = 0 * (hy[0] + hy[1]).sum()
|
||||
else:
|
||||
fake_loss = 0*hy.sum()
|
||||
fake_loss = 0 * hy.sum()
|
||||
|
||||
loss = output.sum() + fake_loss
|
||||
loss.backward()
|
||||
@ -1019,11 +1035,10 @@ class TestNN(NNTestCase):
|
||||
for (cpu_weight, gpu_weight) in zip(cpu_layer_weight, gpu_layer_weight):
|
||||
self.assertEqual(cpu_weight.grad.data, gpu_weight.grad.data, prec=5e-5)
|
||||
|
||||
|
||||
for module in (nn.RNN, nn.LSTM, nn.GRU):
|
||||
for bias in (True, False):
|
||||
for bidirectional in (False, True):
|
||||
for dropout in (0, 1): # Because of dropout randomness, can only compare 0 and 1
|
||||
for dropout in (0, 1): # Because of dropout randomness, can only compare 0 and 1
|
||||
for batch_first in (False, True):
|
||||
num_directions = 2 if bidirectional else 1
|
||||
if batch_first:
|
||||
@ -1038,7 +1053,7 @@ class TestNN(NNTestCase):
|
||||
bias=bias,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
batch_first = batch_first)
|
||||
batch_first=batch_first)
|
||||
|
||||
outputs_cpu = forward_backward(
|
||||
False, rnn, input_val, hx_val, rnn.all_weights)
|
||||
@ -1049,7 +1064,7 @@ class TestNN(NNTestCase):
|
||||
bias=bias,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
batch_first = batch_first)
|
||||
batch_first=batch_first)
|
||||
|
||||
outputs_gpu = forward_backward(
|
||||
True, rnn_gpu, input_val, hx_val, rnn.all_weights)
|
||||
@ -1087,8 +1102,8 @@ class TestNN(NNTestCase):
|
||||
rnn.weight_hh_l0.data.fill_(1)
|
||||
rnn.weight_ih_l1.data.fill_(1)
|
||||
rnn.weight_hh_l1.data.fill_(1)
|
||||
input = Variable(torch.Tensor(1,1,10).fill_(1))
|
||||
hx = Variable(torch.Tensor(2,1,1000).fill_(0))
|
||||
input = Variable(torch.Tensor(1, 1, 10).fill_(1))
|
||||
hx = Variable(torch.Tensor(2, 1, 1000).fill_(0))
|
||||
if cuda:
|
||||
input = input.cuda()
|
||||
hx = hx.cuda()
|
||||
@ -1129,8 +1144,8 @@ class TestNN(NNTestCase):
|
||||
rnn.train()
|
||||
else:
|
||||
rnn.eval()
|
||||
input = Variable(torch.Tensor(1,1,100).uniform_())
|
||||
hx = Variable(torch.Tensor(2,1,100).uniform_())
|
||||
input = Variable(torch.Tensor(1, 1, 100).uniform_())
|
||||
hx = Variable(torch.Tensor(2, 1, 100).uniform_())
|
||||
if cuda:
|
||||
input = input.cuda()
|
||||
hx = hx.cuda()
|
||||
@ -1185,8 +1200,8 @@ class TestNN(NNTestCase):
|
||||
module = nn.BatchNorm1d(3).type(tp)
|
||||
module.eval()
|
||||
|
||||
data = Variable(torch.rand(4,3).type(tp), requires_grad=True)
|
||||
grad = torch.rand(4,3).type(tp)
|
||||
data = Variable(torch.rand(4, 3).type(tp), requires_grad=True)
|
||||
grad = torch.rand(4, 3).type(tp)
|
||||
|
||||
# 1st pass
|
||||
res1 = module(data)
|
||||
@ -1210,8 +1225,8 @@ def add_test(test):
|
||||
raise RuntimeError('Found two tests with the same name: ' + test_name)
|
||||
if hasattr(TestNN, cuda_test_name):
|
||||
raise RuntimeError('Found two tests with the same name: ' + cuda_test_name)
|
||||
setattr(TestNN, test_name, lambda self,test=test: test(self))
|
||||
setattr(TestNN, cuda_test_name, lambda self,test=test: test.test_cuda(self))
|
||||
setattr(TestNN, test_name, lambda self, test=test: test(self))
|
||||
setattr(TestNN, cuda_test_name, lambda self, test=test: test.test_cuda(self))
|
||||
|
||||
|
||||
new_module_tests = [
|
||||
@ -1528,13 +1543,15 @@ new_module_tests = [
|
||||
jacobian_input=False
|
||||
),
|
||||
dict(
|
||||
constructor=lambda: nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
|
||||
constructor=lambda: nn.FractionalMaxPool2d(
|
||||
2, output_ratio=0.5, _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
|
||||
input_size=(1, 3, 5, 5),
|
||||
fullname='FractionalMaxPool2d_ratio',
|
||||
test_cuda=False
|
||||
),
|
||||
dict(
|
||||
constructor=lambda: nn.FractionalMaxPool2d((2, 2), output_size=(4, 4), _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
|
||||
constructor=lambda: nn.FractionalMaxPool2d((2, 2), output_size=(
|
||||
4, 4), _random_samples=torch.DoubleTensor(1, 3, 2).uniform_()),
|
||||
input_size=(1, 3, 7, 7),
|
||||
fullname='FractionalMaxPool2d_size',
|
||||
test_cuda=False
|
||||
@ -1596,6 +1613,7 @@ for test_params in criterion_tests:
|
||||
|
||||
|
||||
class UnpoolingNet(nn.Module):
|
||||
|
||||
def __init__(self, pool, unpool):
|
||||
super(UnpoolingNet, self).__init__()
|
||||
self.pool = pool
|
||||
|
@ -53,7 +53,7 @@ class TestOptim(TestCase):
|
||||
for i in range(2000):
|
||||
optimizer.step(eval)
|
||||
old_fn(lambda _: (rosenbrock(params_t), drosenbrock(params_t)),
|
||||
params_t, state)
|
||||
params_t, state)
|
||||
self.assertEqual(params.data, params_t)
|
||||
|
||||
self.assertLessEqual(params.data.dist(solution), initial_dist)
|
||||
@ -128,8 +128,8 @@ class TestOptim(TestCase):
|
||||
)
|
||||
# non-contiguous parameters
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5, 2)[...,0],
|
||||
torch.randn(10, 2)[...,0],
|
||||
torch.randn(10, 5, 2)[..., 0],
|
||||
torch.randn(10, 2)[..., 0],
|
||||
torch.randn(5),
|
||||
constructor
|
||||
)
|
||||
|
@ -11,6 +11,7 @@ SparseTensor = sparse.DoubleTensor
|
||||
|
||||
|
||||
class TestSparse(TestCase):
|
||||
|
||||
@staticmethod
|
||||
def _gen_sparse(d, nnz, with_size):
|
||||
v = torch.randn(nnz)
|
||||
@ -19,7 +20,7 @@ class TestSparse(TestCase):
|
||||
x = SparseTensor(i, v)
|
||||
else:
|
||||
i = torch.rand(d, nnz) * \
|
||||
torch.Tensor(with_size).repeat(nnz, 1).transpose(0, 1)
|
||||
torch.Tensor(with_size).repeat(nnz, 1).transpose(0, 1)
|
||||
i = i.type(torch.LongTensor)
|
||||
x = SparseTensor(i, v, torch.Size(with_size))
|
||||
|
||||
@ -74,13 +75,13 @@ class TestSparse(TestCase):
|
||||
|
||||
def test_contig(self):
|
||||
i = torch.LongTensor([
|
||||
[1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
|
||||
[1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
|
||||
[92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
|
||||
])
|
||||
v = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
x = SparseTensor(i, v, torch.Size([100, 100]))
|
||||
exp_i = torch.LongTensor([
|
||||
[0, 1, 6, 14, 27, 35, 39, 40, 66, 71],
|
||||
[0, 1, 6, 14, 27, 35, 39, 40, 66, 71],
|
||||
[31, 92, 65, 50, 34, 62, 22, 56, 74, 89],
|
||||
])
|
||||
exp_v = torch.Tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7])
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -28,7 +28,9 @@ try:
|
||||
except ImportError:
|
||||
HAS_CFFI = False
|
||||
|
||||
|
||||
class SimplePlugin(Plugin):
|
||||
|
||||
def __init__(self, interval):
|
||||
super(SimplePlugin, self).__init__(interval)
|
||||
self.trainer = None
|
||||
@ -58,6 +60,7 @@ class SimplePlugin(Plugin):
|
||||
|
||||
|
||||
class ModelMock(object):
|
||||
|
||||
def __init__(self):
|
||||
self.num_calls = 0
|
||||
self.output = Variable(torch.ones(1, 1), requires_grad=True)
|
||||
@ -68,6 +71,7 @@ class ModelMock(object):
|
||||
|
||||
|
||||
class CriterionMock(object):
|
||||
|
||||
def __init__(self):
|
||||
self.num_calls = 0
|
||||
|
||||
@ -95,6 +99,7 @@ class OptimizerMock(object):
|
||||
|
||||
|
||||
class DatasetMock(object):
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(10):
|
||||
yield torch.randn(2, 10), torch.randperm(10)[:2]
|
||||
@ -183,6 +188,7 @@ class TestTrainer(TestCase):
|
||||
|
||||
test_dir = os.path.abspath(os.path.dirname(str(__file__)))
|
||||
|
||||
|
||||
class TestFFI(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -196,13 +202,13 @@ class TestFFI(TestCase):
|
||||
@unittest.skipIf(not HAS_CFFI, "ffi tests require cffi package")
|
||||
def test_cpu(self):
|
||||
compile_extension(
|
||||
name='test_extensions.cpulib',
|
||||
header=test_dir + '/ffi/src/cpu/lib.h',
|
||||
sources=[
|
||||
test_dir + '/ffi/src/cpu/lib1.c',
|
||||
test_dir + '/ffi/src/cpu/lib2.c',
|
||||
],
|
||||
verbose=False,
|
||||
name='test_extensions.cpulib',
|
||||
header=test_dir + '/ffi/src/cpu/lib.h',
|
||||
sources=[
|
||||
test_dir + '/ffi/src/cpu/lib1.c',
|
||||
test_dir + '/ffi/src/cpu/lib2.c',
|
||||
],
|
||||
verbose=False,
|
||||
)
|
||||
from test_extensions import cpulib
|
||||
tensor = torch.ones(2, 2).float()
|
||||
@ -217,20 +223,20 @@ class TestFFI(TestCase):
|
||||
self.assertIs(type(f), float)
|
||||
|
||||
self.assertRaises(TypeError,
|
||||
lambda: cpulib.good_func(tensor.double(), 2, 1.5))
|
||||
lambda: cpulib.good_func(tensor.double(), 2, 1.5))
|
||||
self.assertRaises(torch.FatalError,
|
||||
lambda: cpulib.bad_func(tensor, 2, 1.5))
|
||||
lambda: cpulib.bad_func(tensor, 2, 1.5))
|
||||
|
||||
@unittest.skipIf(not HAS_CFFI or not HAS_CUDA, "ffi tests require cffi package")
|
||||
def test_gpu(self):
|
||||
compile_extension(
|
||||
name='gpulib',
|
||||
header=test_dir + '/ffi/src/cuda/cudalib.h',
|
||||
sources=[
|
||||
test_dir + '/ffi/src/cuda/cudalib.c',
|
||||
],
|
||||
with_cuda=True,
|
||||
verbose=False,
|
||||
name='gpulib',
|
||||
header=test_dir + '/ffi/src/cuda/cudalib.h',
|
||||
sources=[
|
||||
test_dir + '/ffi/src/cuda/cudalib.c',
|
||||
],
|
||||
with_cuda=True,
|
||||
verbose=False,
|
||||
)
|
||||
import gpulib
|
||||
tensor = torch.ones(2, 2).float()
|
||||
@ -243,9 +249,9 @@ class TestFFI(TestCase):
|
||||
self.assertEqual(ctensor, torch.ones(2, 2) * 2 + 1.5)
|
||||
|
||||
self.assertRaises(TypeError,
|
||||
lambda: gpulib.cuda_func(tensor, 2, 1.5))
|
||||
lambda: gpulib.cuda_func(tensor, 2, 1.5))
|
||||
self.assertRaises(TypeError,
|
||||
lambda: gpulib.cuda_func(ctensor.storage(), 2, 1.5))
|
||||
lambda: gpulib.cuda_func(ctensor.storage(), 2, 1.5))
|
||||
|
||||
|
||||
class TestLuaReader(TestCase):
|
||||
@ -320,7 +326,7 @@ class TestLuaReader(TestCase):
|
||||
cls._download_data(test_file_path)
|
||||
except urllib.URLError as e:
|
||||
warnings.warn(("Couldn't download the test file for TestLuaReader! "
|
||||
"Tests will be incomplete!"), RuntimeWarning)
|
||||
"Tests will be incomplete!"), RuntimeWarning)
|
||||
return
|
||||
|
||||
tests = load_lua(test_file_path)
|
||||
|
@ -20,13 +20,14 @@ class cwrap(object):
|
||||
""")
|
||||
|
||||
OPTION_CODE_TEMPLATE = [
|
||||
'$call',
|
||||
'$return_result',
|
||||
'$call',
|
||||
'$return_result',
|
||||
]
|
||||
|
||||
FUNCTION_CALL_TEMPLATE = Template("$capture_result$cname($arg_unpack);")
|
||||
|
||||
DEFAULT_PLUGIN_CLASSES = [ArgcountChecker, ConstantArguments, OptionalArguments, ArgumentReferences, BeforeAfterCall, ReturnArguments, GILRelease]
|
||||
DEFAULT_PLUGIN_CLASSES = [ArgcountChecker, ConstantArguments, OptionalArguments,
|
||||
ArgumentReferences, BeforeAfterCall, ReturnArguments, GILRelease]
|
||||
|
||||
def __init__(self, source, destination=None, plugins=[], default_plugins=True):
|
||||
if destination is None:
|
||||
@ -87,7 +88,7 @@ class cwrap(object):
|
||||
with open(fname, 'r') as f:
|
||||
included = f.read().split('\n')
|
||||
# insert it into lines at position i+1
|
||||
lines[i+1:i+1] = included
|
||||
lines[i + 1:i + 1] = included
|
||||
else:
|
||||
output.append(line)
|
||||
i += 1
|
||||
@ -136,10 +137,10 @@ class cwrap(object):
|
||||
return fallback(*args)
|
||||
|
||||
def get_type_check(self, arg, option):
|
||||
return self.search_plugins('get_type_check', (arg, option), lambda arg,_: None)
|
||||
return self.search_plugins('get_type_check', (arg, option), lambda arg, _: None)
|
||||
|
||||
def get_type_unpack(self, arg, option):
|
||||
return self.search_plugins('get_type_unpack', (arg, option), lambda arg,_: None)
|
||||
return self.search_plugins('get_type_unpack', (arg, option), lambda arg, _: None)
|
||||
|
||||
def get_return_wrapper(self, option):
|
||||
return self.search_plugins('get_return_wrapper', (option,), lambda _: self.RETURN_WRAPPERS[option['return']])
|
||||
@ -193,14 +194,14 @@ class cwrap(object):
|
||||
|
||||
# Generate checks
|
||||
arg_checks = self.map_selected_arguments('get_type_check',
|
||||
'process_single_check', option, checked_args)
|
||||
'process_single_check', option, checked_args)
|
||||
arg_checks = ' &&\n '.join(arg_checks)
|
||||
for plugin in self.plugins:
|
||||
arg_checks = plugin.process_all_checks(arg_checks, option)
|
||||
|
||||
# Generate unpacks
|
||||
arg_unpack = self.map_selected_arguments('get_type_unpack',
|
||||
'process_single_unpack', option, option['arguments'])
|
||||
'process_single_unpack', option, option['arguments'])
|
||||
arg_unpack = ', '.join(arg_unpack)
|
||||
for plugin in self.plugins:
|
||||
arg_unpack = plugin.process_all_unpacks(arg_unpack, option)
|
||||
@ -209,16 +210,16 @@ class cwrap(object):
|
||||
try:
|
||||
return_result = self.get_return_wrapper(option).substitute()
|
||||
call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result='',
|
||||
cname=option['cname'], arg_unpack=arg_unpack)
|
||||
cname=option['cname'], arg_unpack=arg_unpack)
|
||||
except KeyError:
|
||||
return_result = self.get_return_wrapper(option).substitute(result='__result')
|
||||
call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result=(option['return'] + ' __result = '),
|
||||
cname=option['cname'], arg_unpack=arg_unpack)
|
||||
cname=option['cname'], arg_unpack=arg_unpack)
|
||||
|
||||
code_template = deepcopy(self.OPTION_CODE_TEMPLATE)
|
||||
for plugin in self.plugins:
|
||||
code_template = plugin.process_option_code_template(code_template,
|
||||
option)
|
||||
option)
|
||||
code_template = Template('\n'.join(code_template))
|
||||
code = code_template.substitute(call=call, return_result=return_result)
|
||||
code_lines = map(lambda s: s.strip(), code.split('\n'))
|
||||
|
@ -1,5 +1,6 @@
|
||||
from . import CWrapPlugin
|
||||
|
||||
|
||||
class ArgcountChecker(CWrapPlugin):
|
||||
|
||||
def process_all_checks(self, checks, option):
|
||||
|
@ -1,5 +1,6 @@
|
||||
from . import CWrapPlugin
|
||||
|
||||
|
||||
class ArgcountSortPlugin(CWrapPlugin):
|
||||
|
||||
def __init__(self, descending=True):
|
||||
@ -11,4 +12,3 @@ class ArgcountSortPlugin(CWrapPlugin):
|
||||
for declaration in declarations:
|
||||
declaration['options'].sort(key=num_checked_args, reverse=self.descending)
|
||||
return declarations
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class ArgumentReferences(CWrapPlugin):
|
||||
|
||||
def initialize(self, cwrap):
|
||||
|
@ -1,5 +1,6 @@
|
||||
from . import CWrapPlugin
|
||||
|
||||
|
||||
class AutoGPU(CWrapPlugin):
|
||||
|
||||
def __init__(self, has_self=True, condition=None):
|
||||
|
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class BeforeAfterCall(CWrapPlugin):
|
||||
|
||||
def initialize(self, cwrap):
|
||||
@ -13,7 +14,7 @@ class BeforeAfterCall(CWrapPlugin):
|
||||
if '$' in prepend_str:
|
||||
before_call_template = Template(option[name])
|
||||
args = {'arg' + str(i): self.cwrap.get_arg_accessor(arg, option) for i, arg
|
||||
in enumerate(option['arguments'])}
|
||||
in enumerate(option['arguments'])}
|
||||
prepend_str = before_call_template.substitute(args)
|
||||
template.insert(offset, prepend_str)
|
||||
|
||||
@ -23,5 +24,5 @@ class BeforeAfterCall(CWrapPlugin):
|
||||
self.insert_snippet(template, option, call_idx, 'before_call')
|
||||
# call position might have changed
|
||||
call_idx = template.index('$call')
|
||||
self.insert_snippet(template, option, call_idx+1, 'after_call')
|
||||
self.insert_snippet(template, option, call_idx + 1, 'after_call')
|
||||
return template
|
||||
|
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class BoolOption(CWrapPlugin):
|
||||
|
||||
UNPACK_TEMPLATE = Template('$arg == Py_True ? $if_true : $if_false')
|
||||
@ -16,4 +17,3 @@ class BoolOption(CWrapPlugin):
|
||||
if self.is_bool_option(arg):
|
||||
return Template(self.UNPACK_TEMPLATE.safe_substitute(
|
||||
if_true=arg['if_true'], if_false=arg['if_false']))
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class ConstantArguments(CWrapPlugin):
|
||||
|
||||
def process_declarations(self, declarations):
|
||||
@ -18,5 +19,3 @@ class ConstantArguments(CWrapPlugin):
|
||||
def get_arg_accessor(self, arg, option):
|
||||
if arg['type'] == 'CONSTANT':
|
||||
return arg['name']
|
||||
|
||||
|
||||
|
@ -3,30 +3,31 @@ from copy import deepcopy
|
||||
from . import CWrapPlugin
|
||||
from itertools import product
|
||||
|
||||
|
||||
class CuDNNPlugin(CWrapPlugin):
|
||||
|
||||
TYPE_UNPACK = {
|
||||
'THTensor*': Template('((THPVoidTensor*)$arg)->cdata'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'THTensor*': Template('((THPVoidTensor*)$arg)->cdata'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'std::vector<int>': Template('THPUtils_unpackIntTuple($arg)'),
|
||||
'cudnnDataType_t': Template('$arg'),
|
||||
'cudnnHandle_t': Template('$arg'),
|
||||
'Convolution*': Template('(Convolution*)THPWrapper_get($arg)'),
|
||||
'bool': Template('$arg == Py_True'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
'cudnnDataType_t': Template('$arg'),
|
||||
'cudnnHandle_t': Template('$arg'),
|
||||
'Convolution*': Template('(Convolution*)THPWrapper_get($arg)'),
|
||||
'bool': Template('$arg == Py_True'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
}
|
||||
|
||||
TYPE_CHECK = {
|
||||
'Convolution*': Template('THPWrapper_check($arg)'),
|
||||
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == tensorClass'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'Convolution*': Template('THPWrapper_check($arg)'),
|
||||
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == tensorClass'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'std::vector<int>': Template('THPUtils_checkIntTuple($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
}
|
||||
|
||||
RETURN_WRAPPER = {
|
||||
'Convolution*': Template('return THPWrapper_New($result, [](void* arg) { delete (Convolution*)arg; });'),
|
||||
'Convolution*': Template('return THPWrapper_New($result, [](void* arg) { delete (Convolution*)arg; });'),
|
||||
}
|
||||
|
||||
METHODS_DECLARATION = Template("""
|
||||
@ -151,8 +152,8 @@ static PyObject * $name(PyObject *self, PyObject *args, PyObject *kwargs)
|
||||
if not declaration.get('only_register'):
|
||||
extra_flags += ' | METH_KEYWORDS'
|
||||
entry = Template(' {"$python_name", (PyCFunction)$name, METH_VARARGS$extra_flags, NULL},\n').substitute(
|
||||
python_name=declaration['python_name'], name=declaration['name'], extra_flags=extra_flags
|
||||
)
|
||||
python_name=declaration['python_name'], name=declaration['name'], extra_flags=extra_flags
|
||||
)
|
||||
if 'defined_if' in declaration:
|
||||
entry = self.preprocessor_guard(entry, declaration['defined_if'])
|
||||
methods += entry
|
||||
|
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class GILRelease(CWrapPlugin):
|
||||
|
||||
OPTION_START = [
|
||||
@ -24,6 +25,5 @@ class GILRelease(CWrapPlugin):
|
||||
def process_option_code_template(self, template, option):
|
||||
call_idx = template.index('$call')
|
||||
template.insert(call_idx, self.BEFORE_CALL)
|
||||
template.insert(call_idx+2, self.AFTER_CALL)
|
||||
template.insert(call_idx + 2, self.AFTER_CALL)
|
||||
return self.OPTION_START + template + self.OPTION_END
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class KwargsPlugin(CWrapPlugin):
|
||||
|
||||
ACCESSOR_TEMPLATE = Template('(__tuplecount > $idx ? PyTuple_GET_ITEM(args, $idx) : __kw_$name)')
|
||||
@ -53,7 +54,8 @@ class KwargsPlugin(CWrapPlugin):
|
||||
seen_args.add(name)
|
||||
args.append(name)
|
||||
declarations = '\n '.join(['PyObject *__kw_{} = NULL;'.format(name) for name in args])
|
||||
lookups = '\n '.join(['__kw_{name} = PyDict_GetItemString(kwargs, "{name}");'.format(name=name) for name in args])
|
||||
lookups = '\n '.join(
|
||||
['__kw_{name} = PyDict_GetItemString(kwargs, "{name}");'.format(name=name) for name in args])
|
||||
start_idx = code.find('{') + 1
|
||||
new_code = self.WRAPPER_TEMPLATE.substitute(declarations=declarations, lookups=lookups)
|
||||
return code[:start_idx] + new_code + code[start_idx:]
|
||||
|
@ -1,6 +1,8 @@
|
||||
from . import CWrapPlugin
|
||||
|
||||
|
||||
class NullableArguments(CWrapPlugin):
|
||||
|
||||
def process_single_check(self, code, arg, arg_accessor):
|
||||
if 'nullable' in arg and arg['nullable']:
|
||||
return '({} || {} == Py_None)'.format(code, arg_accessor)
|
||||
@ -10,5 +12,3 @@ class NullableArguments(CWrapPlugin):
|
||||
if 'nullable' in arg and arg['nullable']:
|
||||
return '({} == Py_None ? NULL : {})'.format(arg_accessor, code)
|
||||
return code
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@ from copy import deepcopy
|
||||
from . import CWrapPlugin
|
||||
from itertools import product
|
||||
|
||||
|
||||
class OptionalArguments(CWrapPlugin):
|
||||
|
||||
def process_declarations(self, declarations):
|
||||
@ -32,20 +33,20 @@ class OptionalArguments(CWrapPlugin):
|
||||
else:
|
||||
kwarg_only_count = -kwarg_only_count
|
||||
arg_signature = '#'.join(
|
||||
arg['type']
|
||||
for arg in option['arguments'][:kwarg_only_count]
|
||||
if not arg.get('ignore_check'))
|
||||
arg['type']
|
||||
for arg in option['arguments'][:kwarg_only_count]
|
||||
if not arg.get('ignore_check'))
|
||||
if kwarg_only_count is None:
|
||||
return arg_signature
|
||||
kwarg_only_signature = '#'.join(
|
||||
arg['name'] + '#' + arg['type']
|
||||
for arg in option['arguments'][kwarg_only_count:]
|
||||
if not arg.get('ignore_check'))
|
||||
arg['name'] + '#' + arg['type']
|
||||
for arg in option['arguments'][kwarg_only_count:]
|
||||
if not arg.get('ignore_check'))
|
||||
return arg_signature + "#-#" + kwarg_only_signature
|
||||
seen_signatures = set()
|
||||
unique = []
|
||||
for option in options:
|
||||
for num_kwarg_only in range(0, len(option['arguments'])+1):
|
||||
for num_kwarg_only in range(0, len(option['arguments']) + 1):
|
||||
sig = signature(option, num_kwarg_only)
|
||||
if sig not in seen_signatures:
|
||||
if num_kwarg_only > 0:
|
||||
@ -55,4 +56,3 @@ class OptionalArguments(CWrapPlugin):
|
||||
seen_signatures.add(sig)
|
||||
break
|
||||
return unique
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
from . import CWrapPlugin
|
||||
from string import Template
|
||||
|
||||
|
||||
class ReturnArguments(CWrapPlugin):
|
||||
ARGUMENT_RETURN_TEMPLATE = Template("Py_INCREF($arg);\nreturn (PyObject*)($arg);")
|
||||
TUPLE_RETURN_TEMPLATE = Template("return PyTuple_Pack($num_args, $args);")
|
||||
ARGUMENT_RETURN_TEMPLATE = Template("Py_INCREF($arg);\nreturn (PyObject*)($arg);")
|
||||
TUPLE_RETURN_TEMPLATE = Template("return PyTuple_Pack($num_args, $args);")
|
||||
|
||||
def initialize(self, cwrap):
|
||||
self.cwrap = cwrap
|
||||
|
@ -26,41 +26,41 @@ $METHODS
|
||||
class StandaloneExtension(CWrapPlugin):
|
||||
|
||||
TYPE_UNPACK = {
|
||||
'THFloatTensor*': Template('THPFloatTensor_CData((THPFloatTensor*)$arg)'),
|
||||
'THDoubleTensor*': Template('THPDoubleTensor_CData((THPDoubleTensor*)$arg)'),
|
||||
'THLongTensor*': Template('THPLongTensor_CData((THPLongTensor*)$arg)'),
|
||||
'THIntTensor*': Template('THPIntTensor_CData((THPIntTensor*)$arg)'),
|
||||
'THFloatTensor*': Template('THPFloatTensor_CData((THPFloatTensor*)$arg)'),
|
||||
'THDoubleTensor*': Template('THPDoubleTensor_CData((THPDoubleTensor*)$arg)'),
|
||||
'THLongTensor*': Template('THPLongTensor_CData((THPLongTensor*)$arg)'),
|
||||
'THIntTensor*': Template('THPIntTensor_CData((THPIntTensor*)$arg)'),
|
||||
'THCudaHalfTensor*': Template('THCPHalfTensor_CData((THCPHalfTensor*)$arg)'),
|
||||
'THCudaTensor*': Template('THCPFloatTensor_CData((THCPFloatTensor*)$arg)'),
|
||||
'THCudaTensor*': Template('THCPFloatTensor_CData((THCPFloatTensor*)$arg)'),
|
||||
'THCudaDoubleTensor*': Template('THCPDoubleTensor_CData((THCPDoubleTensor*)$arg)'),
|
||||
'THCudaLongTensor*': Template('THCPLongTensor_CData((THCPLongTensor*)$arg)'),
|
||||
'half': Template('THPHalfUtils_unpackReal($arg)'),
|
||||
'float': Template('THPFloatUtils_unpackReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
'bool': Template('($arg == Py_True ? true : false)'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'long': Template('THPUtils_unpackLong($arg)'),
|
||||
'void*': Template('(void*)THPUtils_unpackLong($arg)'),
|
||||
'THGenerator*': Template('THPGenerator_CData((THPGenerator*)$arg)'),
|
||||
'half': Template('THPHalfUtils_unpackReal($arg)'),
|
||||
'float': Template('THPFloatUtils_unpackReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
'bool': Template('($arg == Py_True ? true : false)'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'long': Template('THPUtils_unpackLong($arg)'),
|
||||
'void*': Template('(void*)THPUtils_unpackLong($arg)'),
|
||||
'THGenerator*': Template('THPGenerator_CData((THPGenerator*)$arg)'),
|
||||
}
|
||||
|
||||
TYPE_CHECK = {
|
||||
'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'),
|
||||
'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'),
|
||||
'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'),
|
||||
'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'),
|
||||
'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'),
|
||||
'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'),
|
||||
'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'),
|
||||
'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'),
|
||||
'THCudaHalfTensor*': Template('THCPHalfTensor_Check($arg)'),
|
||||
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
|
||||
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
|
||||
'THCudaDoubleTensor*': Template('THCPDoubleTensor_Check($arg)'),
|
||||
'THCudaLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPLongTensorClass'),
|
||||
'half': Template('THPHalfUtils_checkReal($arg)'),
|
||||
'float': Template('THPFloatUtils_checkReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'long': Template('THPUtils_checkLong($arg)'),
|
||||
'void*': Template('THPUtils_checkLong($arg)'),
|
||||
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
|
||||
'half': Template('THPHalfUtils_checkReal($arg)'),
|
||||
'float': Template('THPFloatUtils_checkReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'long': Template('THPUtils_checkLong($arg)'),
|
||||
'void*': Template('THPUtils_checkLong($arg)'),
|
||||
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
|
||||
}
|
||||
|
||||
WRAPPER_TEMPLATE = Template("""
|
||||
@ -131,6 +131,7 @@ PyObject * $name(PyObject *_unused, PyObject *args)
|
||||
|
||||
def get_wrapper_template(self, declaration):
|
||||
arg_desc = []
|
||||
|
||||
def describe_arg(arg):
|
||||
desc = self.TYPE_NAMES[arg['type']] + ' ' + arg['name']
|
||||
if arg.get('nullable'):
|
||||
@ -138,8 +139,8 @@ PyObject * $name(PyObject *_unused, PyObject *args)
|
||||
return desc
|
||||
for option in declaration['options']:
|
||||
option_desc = [describe_arg(arg)
|
||||
for arg in option['arguments']
|
||||
if not arg.get('ignore_check', False)]
|
||||
for arg in option['arguments']
|
||||
if not arg.get('ignore_check', False)]
|
||||
if option_desc:
|
||||
arg_desc.append('({})'.format(', '.join(option_desc)))
|
||||
else:
|
||||
|
@ -4,85 +4,86 @@ from . import CWrapPlugin
|
||||
from itertools import product, chain
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class THPPlugin(CWrapPlugin):
|
||||
|
||||
TYPE_UNPACK = {
|
||||
'THFloatTensor*': Template('((THPFloatTensor*)$arg)->cdata'),
|
||||
'THDoubleTensor*': Template('((THPDoubleTensor*)$arg)->cdata'),
|
||||
'THLongTensor*': Template('((THPLongTensor*)$arg)->cdata'),
|
||||
'THIntTensor*': Template('((THPIntTensor*)$arg)->cdata'),
|
||||
'THTensor*': Template('((THPTensor*)$arg)->cdata'),
|
||||
'THBoolTensor*': Template('((THPBoolTensor*)$arg)->cdata'),
|
||||
'THIndexTensor*': Template('((THPIndexTensor*)$arg)->cdata'),
|
||||
'THFloatTensor*': Template('((THPFloatTensor*)$arg)->cdata'),
|
||||
'THDoubleTensor*': Template('((THPDoubleTensor*)$arg)->cdata'),
|
||||
'THLongTensor*': Template('((THPLongTensor*)$arg)->cdata'),
|
||||
'THIntTensor*': Template('((THPIntTensor*)$arg)->cdata'),
|
||||
'THTensor*': Template('((THPTensor*)$arg)->cdata'),
|
||||
'THBoolTensor*': Template('((THPBoolTensor*)$arg)->cdata'),
|
||||
'THIndexTensor*': Template('((THPIndexTensor*)$arg)->cdata'),
|
||||
|
||||
'THSFloatTensor*': Template('((THSPFloatTensor*)$arg)->cdata'),
|
||||
'THSFloatTensor*': Template('((THSPFloatTensor*)$arg)->cdata'),
|
||||
'THSDoubleTensor*': Template('((THSPDoubleTensor*)$arg)->cdata'),
|
||||
'THSLongTensor*': Template('((THSPLongTensor*)$arg)->cdata'),
|
||||
'THSIntTensor*': Template('((THSPIntTensor*)$arg)->cdata'),
|
||||
'THSTensor*': Template('((THSPTensor*)$arg)->cdata'),
|
||||
'THSBoolTensor*': Template('((THSPBoolTensor*)$arg)->cdata'),
|
||||
'THSIndexTensor*': Template('((THSPIndexTensor*)$arg)->cdata'),
|
||||
'THSLongTensor*': Template('((THSPLongTensor*)$arg)->cdata'),
|
||||
'THSIntTensor*': Template('((THSPIntTensor*)$arg)->cdata'),
|
||||
'THSTensor*': Template('((THSPTensor*)$arg)->cdata'),
|
||||
'THSBoolTensor*': Template('((THSPBoolTensor*)$arg)->cdata'),
|
||||
'THSIndexTensor*': Template('((THSPIndexTensor*)$arg)->cdata'),
|
||||
|
||||
'THLongStorage*': Template('((THPLongStorage*)$arg)->cdata'),
|
||||
'THStorage*': Template('((THPStorage*)$arg)->cdata'),
|
||||
'THGenerator*': Template('((THPGenerator*)$arg)->cdata'),
|
||||
'THSize*': Template('__size.get()'),
|
||||
'THStride*': Template('__stride.get()'),
|
||||
'void*': Template('THPUtils_unpackLong($arg)'),
|
||||
'long': Template('THPUtils_unpackLong($arg)'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'bool': Template('($arg == Py_True ? true : false)'),
|
||||
'float': Template('THPFloatUtils_unpackReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
'real': Template('THPUtils_(unpackReal)($arg)'),
|
||||
'accreal': Template('THPUtils_(unpackAccreal)($arg)'),
|
||||
'THLongStorage*': Template('((THPLongStorage*)$arg)->cdata'),
|
||||
'THStorage*': Template('((THPStorage*)$arg)->cdata'),
|
||||
'THGenerator*': Template('((THPGenerator*)$arg)->cdata'),
|
||||
'THSize*': Template('__size.get()'),
|
||||
'THStride*': Template('__stride.get()'),
|
||||
'void*': Template('THPUtils_unpackLong($arg)'),
|
||||
'long': Template('THPUtils_unpackLong($arg)'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'bool': Template('($arg == Py_True ? true : false)'),
|
||||
'float': Template('THPFloatUtils_unpackReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
'real': Template('THPUtils_(unpackReal)($arg)'),
|
||||
'accreal': Template('THPUtils_(unpackAccreal)($arg)'),
|
||||
}
|
||||
|
||||
TYPE_CHECK = {
|
||||
'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'),
|
||||
'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'),
|
||||
'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'),
|
||||
'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'),
|
||||
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
|
||||
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == THPTensorClass'),
|
||||
'THBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THPBoolTensorClass'),
|
||||
'THIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIndexTensorClass'),
|
||||
'THDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THPDoubleTensorClass'),
|
||||
'THFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THPFloatTensorClass'),
|
||||
'THLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THPLongTensorClass'),
|
||||
'THIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIntTensorClass'),
|
||||
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
|
||||
'THTensor*': Template('(PyObject*)Py_TYPE($arg) == THPTensorClass'),
|
||||
'THBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THPBoolTensorClass'),
|
||||
'THIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THPIndexTensorClass'),
|
||||
|
||||
'THSDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPDoubleTensorClass'),
|
||||
'THSFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPFloatTensorClass'),
|
||||
'THSLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPLongTensorClass'),
|
||||
'THSIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIntTensorClass'),
|
||||
'THSTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPTensorClass'),
|
||||
'THSBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPBoolTensorClass'),
|
||||
'THSIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIndexTensorClass'),
|
||||
'THSFloatTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPFloatTensorClass'),
|
||||
'THSLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPLongTensorClass'),
|
||||
'THSIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIntTensorClass'),
|
||||
'THSTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPTensorClass'),
|
||||
'THSBoolTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPBoolTensorClass'),
|
||||
'THSIndexTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPIndexTensorClass'),
|
||||
|
||||
'THLongStorage*': Template('(PyObject*)Py_TYPE($arg) == THPLongStorageClass'),
|
||||
'THStorage*': Template('(PyObject*)Py_TYPE($arg) == THPStorageClass'),
|
||||
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
|
||||
'THSize*': Template('THPUtils_tryUnpackLongs($arg, __size)'),
|
||||
'THStride*': Template('THPUtils_tryUnpackLongs($arg, __stride)'),
|
||||
'void*': Template('THPUtils_checkLong($arg)'),
|
||||
'long': Template('THPUtils_checkLong($arg)'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'float': Template('THPFloatUtils_checkReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
'real': Template('THPUtils_(checkReal)($arg)'),
|
||||
'accreal': Template('THPUtils_(checkAccreal)($arg)'),
|
||||
'THLongStorage*': Template('(PyObject*)Py_TYPE($arg) == THPLongStorageClass'),
|
||||
'THStorage*': Template('(PyObject*)Py_TYPE($arg) == THPStorageClass'),
|
||||
'THGenerator*': Template('(PyObject*)Py_TYPE($arg) == THPGeneratorClass'),
|
||||
'THSize*': Template('THPUtils_tryUnpackLongs($arg, __size)'),
|
||||
'THStride*': Template('THPUtils_tryUnpackLongs($arg, __stride)'),
|
||||
'void*': Template('THPUtils_checkLong($arg)'),
|
||||
'long': Template('THPUtils_checkLong($arg)'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'bool': Template('PyBool_Check($arg)'),
|
||||
'float': Template('THPFloatUtils_checkReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
'real': Template('THPUtils_(checkReal)($arg)'),
|
||||
'accreal': Template('THPUtils_(checkAccreal)($arg)'),
|
||||
}
|
||||
|
||||
SIZE_VARARG_CHECK = Template('THPUtils_tryUnpackLongVarArgs(args, $idx, __size)')
|
||||
|
||||
RETURN_WRAPPER = {
|
||||
'THTensor*': Template('return THPTensor_(New)($result);'),
|
||||
'THSTensor*': Template('return THSPTensor_(New)($result);'),
|
||||
'THLongTensor*': Template('return THPLongTensor_New($result);'),
|
||||
'THLongStorage*': Template('return THPLongStorage_New($result);'),
|
||||
'THTensor*': Template('return THPTensor_(New)($result);'),
|
||||
'THSTensor*': Template('return THSPTensor_(New)($result);'),
|
||||
'THLongTensor*': Template('return THPLongTensor_New($result);'),
|
||||
'THLongStorage*': Template('return THPLongStorage_New($result);'),
|
||||
# TODO: make it smarter - it should return python long if result doesn't fit into an int
|
||||
'long': Template('return PyInt_FromLong($result);'),
|
||||
'accreal': Template('return THPUtils_(newAccreal)($result);'),
|
||||
'self': Template('Py_INCREF(self);\nreturn (PyObject*)self;'),
|
||||
'real': Template('return THPUtils_(newReal)($result);'),
|
||||
'long': Template('return PyInt_FromLong($result);'),
|
||||
'accreal': Template('return THPUtils_(newAccreal)($result);'),
|
||||
'self': Template('Py_INCREF(self);\nreturn (PyObject*)self;'),
|
||||
'real': Template('return THPUtils_(newReal)($result);'),
|
||||
}
|
||||
|
||||
TENSOR_METHODS_DECLARATION = Template("""
|
||||
@ -138,13 +139,13 @@ ${cpu}
|
||||
return Template(code)
|
||||
|
||||
ALLOCATE_TYPE = {
|
||||
'THTensor*': _allocate('', ALLOCATE_TMPL),
|
||||
'THLongTensor*': _allocate('Long', ALLOCATE_TMPL),
|
||||
'THIntTensor*': _allocate('Int', ALLOCATE_TMPL),
|
||||
'THBoolTensor*': _allocate('Byte', ALLOCATE_TMPL, ALLOCATE_CUDA),
|
||||
'THIndexTensor*': _allocate('Long', ALLOCATE_TMPL, ALLOCATE_CUDA),
|
||||
'THTensor*': _allocate('', ALLOCATE_TMPL),
|
||||
'THLongTensor*': _allocate('Long', ALLOCATE_TMPL),
|
||||
'THIntTensor*': _allocate('Int', ALLOCATE_TMPL),
|
||||
'THBoolTensor*': _allocate('Byte', ALLOCATE_TMPL, ALLOCATE_CUDA),
|
||||
'THIndexTensor*': _allocate('Long', ALLOCATE_TMPL, ALLOCATE_CUDA),
|
||||
|
||||
'THSTensor*': _allocate('', ALLOCATE_TMPL, sparse=True),
|
||||
'THSTensor*': _allocate('', ALLOCATE_TMPL, sparse=True),
|
||||
}
|
||||
|
||||
TYPE_NAMES = {
|
||||
@ -205,7 +206,7 @@ ${cpu}
|
||||
if len(output_args) > 1:
|
||||
out_type = 'tuple['
|
||||
out_type += ', '.join(
|
||||
self.TYPE_NAMES[arg['type']] for arg in output_args)
|
||||
self.TYPE_NAMES[arg['type']] for arg in output_args)
|
||||
out_type += ']'
|
||||
option_desc += ['#' + out_type + ' out']
|
||||
else:
|
||||
@ -287,7 +288,7 @@ ${cpu}
|
||||
if not output_provided:
|
||||
arg['ignore_check'] = True
|
||||
else:
|
||||
option_copy['argcount_offset'] = -len(out_idx) + 1
|
||||
option_copy['argcount_offset'] = -len(out_idx) + 1
|
||||
arg['no_kwargs'] = True
|
||||
arg['no_idx'] = True
|
||||
new_options.append(option_copy)
|
||||
@ -345,7 +346,6 @@ ${cpu}
|
||||
if arg['name'] == 'self':
|
||||
arg['ignore_check'] = True
|
||||
|
||||
|
||||
declarations = [d for d in declarations if not d.get('only_stateless', False)]
|
||||
self.declarations.extend(filter(lambda x: not x.get('only_stateless', False), register_only))
|
||||
self.stateless_declarations.extend(filter(lambda x: x.get('only_stateless', False), register_only))
|
||||
@ -377,9 +377,9 @@ ${cpu}
|
||||
if declaration.get('override_method_flags'):
|
||||
flags = declaration['override_method_flags']
|
||||
entry = Template(' {"$python_name", (PyCFunction)$name, $flags, $docstring},\n').substitute(
|
||||
python_name=declaration['python_name'], name=declaration['name'], flags=flags,
|
||||
docstring=declaration.get('docstring_var', 'NULL')
|
||||
)
|
||||
python_name=declaration['python_name'], name=declaration['name'], flags=flags,
|
||||
docstring=declaration.get('docstring_var', 'NULL')
|
||||
)
|
||||
if 'defined_if' in declaration:
|
||||
entry = self.preprocessor_guard(entry, declaration['defined_if'])
|
||||
tensor_methods += entry
|
||||
@ -401,7 +401,7 @@ ${cpu}
|
||||
)
|
||||
|
||||
def preprocessor_guard(self, code, condition):
|
||||
return '#if ' + condition + '\n' + code + '#endif\n'
|
||||
return '#if ' + condition + '\n' + code + '#endif\n'
|
||||
|
||||
def process_wrapper(self, code, declaration):
|
||||
if 'defined_if' in declaration:
|
||||
@ -419,7 +419,7 @@ ${cpu}
|
||||
if option['output_count'] > 1:
|
||||
checks += "PyTuple_Check(__out) &&\n" + indent
|
||||
length_check = "PyTuple_GET_SIZE(__out) == {} &&\n".format(
|
||||
option['output_count'])
|
||||
option['output_count'])
|
||||
checks += length_check + indent
|
||||
code = checks + code
|
||||
else:
|
||||
@ -443,13 +443,13 @@ ${cpu}
|
||||
def generate_docstrings_cpp(self):
|
||||
template = Template('char* $name = "$content";')
|
||||
return '\n\n'.join(
|
||||
template.substitute(name=decl['docstring_var'], content=decl['docstring_content'])
|
||||
for decl in chain(self.declarations, self.stateless_declarations)
|
||||
if 'docstring_var' in decl)
|
||||
template.substitute(name=decl['docstring_var'], content=decl['docstring_content'])
|
||||
for decl in chain(self.declarations, self.stateless_declarations)
|
||||
if 'docstring_var' in decl)
|
||||
|
||||
def generate_docstrings_h(self):
|
||||
template = Template('extern char* $name;')
|
||||
return '\n\n'.join(
|
||||
template.substitute(name=decl['docstring_var'])
|
||||
for decl in chain(self.declarations, self.stateless_declarations)
|
||||
if 'docstring_var' in decl)
|
||||
template.substitute(name=decl['docstring_var'])
|
||||
for decl in chain(self.declarations, self.stateless_declarations)
|
||||
if 'docstring_var' in decl)
|
||||
|
@ -8,6 +8,7 @@ BASE_PATH = os.path.realpath(os.path.join(__file__, '..', '..', '..'))
|
||||
WRAPPER_PATH = os.path.join(BASE_PATH, 'torch', 'csrc', 'nn')
|
||||
THNN_UTILS_PATH = os.path.join(BASE_PATH, 'torch', '_thnn', 'utils.py')
|
||||
|
||||
|
||||
def import_module(name, path):
|
||||
if sys.version_info >= (3, 5):
|
||||
import importlib.util
|
||||
@ -81,7 +82,8 @@ for t in ['CudaHalf', 'Cuda', 'CudaDouble']:
|
||||
def wrap_function(name, type, arguments):
|
||||
cname = 'THNN_' + type + name
|
||||
declaration = ''
|
||||
declaration += 'extern "C" void ' + cname + '(' + ', '.join(TYPE_TRANSFORMS[type].get(arg.type, arg.type) for arg in arguments) + ');\n'
|
||||
declaration += 'extern "C" void ' + cname + \
|
||||
'(' + ', '.join(TYPE_TRANSFORMS[type].get(arg.type, arg.type) for arg in arguments) + ');\n'
|
||||
declaration += FUNCTION_TEMPLATE.substitute(name=type + name, cname=cname)
|
||||
indent = ' ' * 4
|
||||
dict_indent = ' ' * 6
|
||||
@ -92,15 +94,17 @@ def wrap_function(name, type, arguments):
|
||||
else:
|
||||
t = TYPE_TRANSFORMS[type].get(arg.type, arg.type)
|
||||
declaration += prefix + 'type: ' + t + '\n' + \
|
||||
dict_indent + 'name: ' + arg.name + '\n' + \
|
||||
dict_indent + 'nullable: True' + '\n'
|
||||
dict_indent + 'name: ' + arg.name + '\n' + \
|
||||
dict_indent + 'nullable: True' + '\n'
|
||||
declaration += ']]\n\n\n'
|
||||
return declaration
|
||||
|
||||
|
||||
def generate_wrappers():
|
||||
wrap_nn()
|
||||
wrap_cunn()
|
||||
|
||||
|
||||
def wrap_nn():
|
||||
wrapper = '#include <TH/TH.h>\n\n\n'
|
||||
nn_functions = thnn_utils.parse_header(thnn_utils.THNN_H_PATH)
|
||||
@ -114,6 +118,7 @@ def wrap_nn():
|
||||
NullableArguments(),
|
||||
])
|
||||
|
||||
|
||||
def wrap_cunn():
|
||||
wrapper = '#include <TH/TH.h>\n'
|
||||
wrapper += '#include <THC/THC.h>\n\n\n'
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
|
||||
|
||||
def check_env_flag(name):
|
||||
return os.getenv(name) in ['ON', '1', 'YES', 'TRUE', 'Y']
|
||||
|
@ -56,6 +56,7 @@ del old_flags
|
||||
# Define basic utilities
|
||||
################################################################################
|
||||
|
||||
|
||||
def typename(o):
|
||||
module = ''
|
||||
class_name = ''
|
||||
@ -91,7 +92,7 @@ def set_default_tensor_type(t):
|
||||
|
||||
def set_rng_state(new_state):
|
||||
r"""Sets the random number generator state.
|
||||
|
||||
|
||||
Args:
|
||||
new_state (torch.ByteTensor): The desired state
|
||||
"""
|
||||
@ -106,7 +107,7 @@ def get_rng_state():
|
||||
def manual_seed(seed):
|
||||
r"""Sets the seed for generating random numbers. And returns a
|
||||
`torch._C.Generator` object.
|
||||
|
||||
|
||||
Args:
|
||||
seed (int or long): The desired seed.
|
||||
"""
|
||||
@ -130,61 +131,101 @@ from ._tensor_str import set_printoptions
|
||||
from .storage import _StorageBase
|
||||
from .tensor import _TensorBase
|
||||
|
||||
|
||||
class DoubleStorage(_C.DoubleStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class FloatStorage(_C.FloatStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class LongStorage(_C.LongStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class IntStorage(_C.IntStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class ShortStorage(_C.ShortStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class CharStorage(_C.CharStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class ByteStorage(_C.ByteStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class DoubleTensor(_C.DoubleTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return DoubleStorage
|
||||
|
||||
|
||||
class FloatTensor(_C.FloatTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return FloatStorage
|
||||
|
||||
|
||||
class LongTensor(_C.LongTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return LongStorage
|
||||
|
||||
|
||||
class IntTensor(_C.IntTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return IntStorage
|
||||
|
||||
|
||||
class ShortTensor(_C.ShortTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return ShortStorage
|
||||
|
||||
|
||||
class CharTensor(_C.CharTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
# TODO
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return CharStorage
|
||||
|
||||
|
||||
class ByteTensor(_C.ByteTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return ByteStorage
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -22,7 +22,7 @@ def set_printoptions(
|
||||
edgeitems=None,
|
||||
linewidth=None,
|
||||
profile=None,
|
||||
):
|
||||
):
|
||||
"""Set options for printing. Items shamelessly taken from Numpy
|
||||
|
||||
Args:
|
||||
@ -119,7 +119,7 @@ def _number_format(tensor, min_sz=-1):
|
||||
else:
|
||||
if exp_max > prec + 1 or exp_max < 0:
|
||||
sz = max(min_sz, 7)
|
||||
scale = math.pow(10, exp_max-1)
|
||||
scale = math.pow(10, exp_max - 1)
|
||||
else:
|
||||
if exp_max == 0:
|
||||
sz = 7
|
||||
@ -132,19 +132,19 @@ def _number_format(tensor, min_sz=-1):
|
||||
|
||||
def _tensor_str(self):
|
||||
n = PRINT_OPTS.edgeitems
|
||||
has_hdots = self.size()[-1] > 2*n
|
||||
has_vdots = self.size()[-2] > 2*n
|
||||
has_hdots = self.size()[-1] > 2 * n
|
||||
has_vdots = self.size()[-2] > 2 * n
|
||||
print_full_mat = not has_hdots and not has_vdots
|
||||
formatter = _number_format(self, min_sz=3 if not print_full_mat else 0)
|
||||
print_dots = self.numel() >= PRINT_OPTS.threshold
|
||||
|
||||
dim_sz = max(2, max(len(str(x)) for x in self.size()))
|
||||
dim_fmt = "{:^" + str(dim_sz) + "}"
|
||||
dot_fmt = u"{:^" + str(dim_sz+1) + "}"
|
||||
dot_fmt = u"{:^" + str(dim_sz + 1) + "}"
|
||||
|
||||
counter_dim = self.ndimension() - 2
|
||||
counter = torch.LongStorage(counter_dim).fill_(0)
|
||||
counter[counter.size()-1] = -1
|
||||
counter[counter.size() - 1] = -1
|
||||
finished = False
|
||||
strt = ''
|
||||
while True:
|
||||
@ -152,7 +152,7 @@ def _tensor_str(self):
|
||||
nskipped = [False for i in counter]
|
||||
for i in _range(counter_dim - 1, -1, -1):
|
||||
counter[i] += 1
|
||||
if print_dots and counter[i] == n and self.size(i) > 2*n:
|
||||
if print_dots and counter[i] == n and self.size(i) > 2 * n:
|
||||
counter[i] = self.size(i) - n
|
||||
nskipped[i] = True
|
||||
if counter[i] == self.size(i):
|
||||
@ -188,18 +188,18 @@ def __repr_row(row, indent, fmt, scale, sz, truncate=None):
|
||||
if truncate is not None:
|
||||
dotfmt = " {:^5} "
|
||||
return (indent +
|
||||
' '.join(fmt.format(val/scale) for val in row[:truncate]) +
|
||||
' '.join(fmt.format(val / scale) for val in row[:truncate]) +
|
||||
dotfmt.format('...') +
|
||||
' '.join(fmt.format(val/scale) for val in row[-truncate:]) +
|
||||
' '.join(fmt.format(val / scale) for val in row[-truncate:]) +
|
||||
'\n')
|
||||
else:
|
||||
return indent + ' '.join(fmt.format(val/scale) for val in row) + '\n'
|
||||
return indent + ' '.join(fmt.format(val / scale) for val in row) + '\n'
|
||||
|
||||
|
||||
def _matrix_str(self, indent='', formatter=None, force_truncate=False):
|
||||
n = PRINT_OPTS.edgeitems
|
||||
has_hdots = self.size(1) > 2*n
|
||||
has_vdots = self.size(0) > 2*n
|
||||
has_hdots = self.size(1) > 2 * n
|
||||
has_vdots = self.size(0) > 2 * n
|
||||
print_full_mat = not has_hdots and not has_vdots
|
||||
|
||||
if formatter is None:
|
||||
@ -207,14 +207,14 @@ def _matrix_str(self, indent='', formatter=None, force_truncate=False):
|
||||
min_sz=5 if not print_full_mat else 0)
|
||||
else:
|
||||
fmt, scale, sz = formatter
|
||||
nColumnPerLine = int(math.floor((PRINT_OPTS.linewidth-len(indent))/(sz+1)))
|
||||
nColumnPerLine = int(math.floor((PRINT_OPTS.linewidth - len(indent)) / (sz + 1)))
|
||||
strt = ''
|
||||
firstColumn = 0
|
||||
|
||||
if not force_truncate and \
|
||||
(self.numel() < PRINT_OPTS.threshold or print_full_mat):
|
||||
while firstColumn < self.size(1):
|
||||
lastColumn = min(firstColumn + nColumnPerLine - 1, self.size(1)-1)
|
||||
lastColumn = min(firstColumn + nColumnPerLine - 1, self.size(1) - 1)
|
||||
if nColumnPerLine < self.size(1):
|
||||
strt += '\n' if firstColumn != 1 else ''
|
||||
strt += 'Columns {} to {} \n{}'.format(
|
||||
@ -223,15 +223,15 @@ def _matrix_str(self, indent='', formatter=None, force_truncate=False):
|
||||
strt += SCALE_FORMAT.format(scale)
|
||||
for l in _range(self.size(0)):
|
||||
strt += indent + (' ' if scale != 1 else '')
|
||||
row_slice = self[l, firstColumn:lastColumn+1]
|
||||
strt += ' '.join(fmt.format(val/scale) for val in row_slice)
|
||||
row_slice = self[l, firstColumn:lastColumn + 1]
|
||||
strt += ' '.join(fmt.format(val / scale) for val in row_slice)
|
||||
strt += '\n'
|
||||
firstColumn = lastColumn + 1
|
||||
else:
|
||||
if scale != 1:
|
||||
strt += SCALE_FORMAT.format(scale)
|
||||
if has_vdots and has_hdots:
|
||||
vdotfmt = "{:^" + str((sz+1)*n-1) + "}"
|
||||
vdotfmt = "{:^" + str((sz + 1) * n - 1) + "}"
|
||||
ddotfmt = u"{:^5}"
|
||||
for row in self[:n]:
|
||||
strt += __repr_row(row, indent, fmt, scale, sz, n)
|
||||
@ -245,8 +245,8 @@ def _matrix_str(self, indent='', formatter=None, force_truncate=False):
|
||||
strt += __repr_row(row, indent, fmt, scale, sz, n)
|
||||
elif has_vdots and not has_hdots:
|
||||
vdotfmt = u"{:^" + \
|
||||
str(len(__repr_row(self[0], '', fmt, scale, sz))) + \
|
||||
"}\n"
|
||||
str(len(__repr_row(self[0], '', fmt, scale, sz))) + \
|
||||
"}\n"
|
||||
for row in self[:n]:
|
||||
strt += __repr_row(row, indent, fmt, scale, sz)
|
||||
strt += vdotfmt.format(u'\u22EE')
|
||||
@ -269,13 +269,13 @@ def _vector_str(self):
|
||||
ident = ' '
|
||||
if self.numel() < PRINT_OPTS.threshold:
|
||||
return (strt +
|
||||
'\n'.join(ident + fmt.format(val/scale) for val in self) +
|
||||
'\n'.join(ident + fmt.format(val / scale) for val in self) +
|
||||
'\n')
|
||||
else:
|
||||
return (strt +
|
||||
'\n'.join(ident + fmt.format(val/scale) for val in self[:n]) +
|
||||
'\n'.join(ident + fmt.format(val / scale) for val in self[:n]) +
|
||||
'\n' + (ident + dotfmt.format(u"\u22EE")) +
|
||||
'\n'.join(ident + fmt.format(val/scale) for val in self[-n:]) +
|
||||
'\n'.join(ident + fmt.format(val / scale) for val in self[-n:]) +
|
||||
'\n')
|
||||
|
||||
|
||||
@ -295,4 +295,3 @@ def _str(self):
|
||||
strt += '[{} of size {}{}]\n'.format(torch.typename(self),
|
||||
size_str, device_str)
|
||||
return '\n' + strt
|
||||
|
||||
|
@ -2,7 +2,9 @@ import threading
|
||||
import torch.cuda
|
||||
from .utils import THNN_H_PATH, THCUNN_H_PATH, parse_header, load_backend
|
||||
|
||||
|
||||
class Backends(object):
|
||||
|
||||
def __init__(self):
|
||||
self.backends = {}
|
||||
|
||||
@ -14,6 +16,7 @@ class Backends(object):
|
||||
|
||||
|
||||
class Backend(object):
|
||||
|
||||
def __init__(self, lib_prefix, lib_name, functions, mixins=tuple()):
|
||||
self.lib_prefix = lib_prefix
|
||||
self.lib_name = lib_name
|
||||
@ -32,11 +35,12 @@ class Backend(object):
|
||||
with self.loading_lock:
|
||||
if self.backend is None:
|
||||
self.backend = load_backend(self.lib_prefix, self.lib_name,
|
||||
self.functions, self.mixins)
|
||||
self.functions, self.mixins)
|
||||
return self.backend
|
||||
|
||||
|
||||
class THNNCudaBackendStateMixin(object):
|
||||
|
||||
@property
|
||||
def library_state(self):
|
||||
return torch.cuda._state_cdata
|
||||
|
@ -12,6 +12,7 @@ def _unpickle_backend(backend_name):
|
||||
|
||||
|
||||
class THNNBackendBase(object):
|
||||
|
||||
def __init__(self):
|
||||
self.methods = {}
|
||||
|
||||
@ -33,6 +34,7 @@ class THNNBackendBase(object):
|
||||
|
||||
|
||||
class Function(object):
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.arguments = []
|
||||
@ -46,6 +48,7 @@ class Function(object):
|
||||
|
||||
|
||||
class Argument(object):
|
||||
|
||||
def __init__(self, _type, name, is_optional):
|
||||
self.type = _type
|
||||
self.name = name
|
||||
|
@ -4,7 +4,7 @@ import torch._C
|
||||
from torch._C import _add_docstr as add_docstr
|
||||
|
||||
add_docstr(torch._C.abs,
|
||||
"""abs(input, out=None) -> Tensor
|
||||
"""abs(input, out=None) -> Tensor
|
||||
|
||||
Computes the element-wise absolute value of the given :attr:`input` a tensor.
|
||||
|
||||
@ -15,7 +15,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.acos,
|
||||
"""
|
||||
"""
|
||||
acos(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the arccosine of the elements of :attr:`input`.
|
||||
@ -44,7 +44,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.add,
|
||||
"""
|
||||
"""
|
||||
.. function:: add(input, value, out=None)
|
||||
|
||||
Adds the scalar :attr:`value` to each element of the input :attr:`input`
|
||||
@ -127,7 +127,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addbmm,
|
||||
"""
|
||||
"""
|
||||
addbmm(beta=1, mat, alpha=1, batch1, batch2, out=None) -> Tensor
|
||||
|
||||
Performs a batch matrix-matrix product of matrices stored
|
||||
@ -167,7 +167,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addcdiv,
|
||||
"""
|
||||
"""
|
||||
addcdiv(tensor, value=1, tensor1, tensor2, out=None) -> Tensor
|
||||
|
||||
Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`,
|
||||
@ -195,7 +195,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addcmul,
|
||||
"""
|
||||
"""
|
||||
addcmul(tensor, value=1, tensor1, tensor2, out=None) -> Tensor
|
||||
|
||||
Performs the element-wise multiplication of :attr:`tensor1`
|
||||
@ -224,7 +224,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addmm,
|
||||
"""
|
||||
"""
|
||||
addmm(beta=1, mat, alpha=1, mat1, mat2, out=None) -> Tensor
|
||||
|
||||
Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`.
|
||||
@ -259,7 +259,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addmv,
|
||||
"""
|
||||
"""
|
||||
addmv(beta=1, tensor, alpha=1, mat, vec, out=None) -> Tensor
|
||||
|
||||
Performs a matrix-vector product of the matrix :attr:`mat` and
|
||||
@ -296,7 +296,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.addr,
|
||||
r"""
|
||||
r"""
|
||||
addr(beta=1, mat, alpha=1, vec1, vec2, out=None) -> Tensor
|
||||
|
||||
Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2`
|
||||
@ -332,7 +332,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.asin,
|
||||
"""
|
||||
"""
|
||||
asin(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the arcsine of the elements of :attr:`input`.
|
||||
@ -360,7 +360,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.atan,
|
||||
"""
|
||||
"""
|
||||
atan(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the arctangent of the elements of :attr:`input`.
|
||||
@ -388,7 +388,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.atan2,
|
||||
"""
|
||||
"""
|
||||
atan2(input1, input2, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the arctangent of the elements of :attr:`input1`
|
||||
@ -418,7 +418,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.baddbmm,
|
||||
r"""
|
||||
r"""
|
||||
baddbmm(beta=1, mat, alpha=1, batch1, batch2, out=None) -> Tensor
|
||||
|
||||
Performs a batch matrix-matrix product of matrices in :attr:`batch1`
|
||||
@ -452,7 +452,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.bernoulli,
|
||||
"""
|
||||
"""
|
||||
bernoulli(input, out=None) -> Tensor
|
||||
|
||||
Draws binary random numbers (0 or 1) from a bernoulli distribution.
|
||||
@ -508,7 +508,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.bmm,
|
||||
"""
|
||||
"""
|
||||
bmm(batch1, batch2, out=None) -> Tensor
|
||||
|
||||
Performs a batch matrix-matrix product of matrices stored in :attr:`batch1` and :attr:`batch2`.
|
||||
@ -533,7 +533,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cat,
|
||||
"""
|
||||
"""
|
||||
cat(inputs, dimension=0) -> Tensor
|
||||
|
||||
Concatenates the given sequence of :attr:`inputs` Tensors in the given dimension.
|
||||
@ -574,7 +574,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.ceil,
|
||||
"""
|
||||
"""
|
||||
ceil(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the ceil of the elements of :attr:`input`, the smallest integer greater than or equal to each element.
|
||||
@ -605,7 +605,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.reciprocal,
|
||||
"""
|
||||
"""
|
||||
reciprocal(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the reciprocal of the elements of :attr:`input`, i.e. :math:`1.0 / x`
|
||||
@ -636,7 +636,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.clamp,
|
||||
"""
|
||||
"""
|
||||
clamp(input, min, max, out=None) -> Tensor
|
||||
|
||||
Clamp all elements in :attr:`input` into the range `[min, max]` and return a resulting Tensor.
|
||||
@ -731,7 +731,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cos,
|
||||
"""
|
||||
"""
|
||||
cos(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the cosine of the elements of :attr:`input`.
|
||||
@ -759,7 +759,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cosh,
|
||||
"""
|
||||
"""
|
||||
cosh(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the hyperbolic cosine of the elements of :attr:`input`.
|
||||
@ -787,7 +787,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cross,
|
||||
"""
|
||||
"""
|
||||
cross(input, other, dim=-1, out=None) -> Tensor
|
||||
|
||||
|
||||
@ -841,7 +841,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cumprod,
|
||||
"""
|
||||
"""
|
||||
cumprod(input, dim, out=None) -> Tensor
|
||||
|
||||
Returns the cumulative product of elements of :attr:`input` in the dimension :attr:`dim`.
|
||||
@ -903,7 +903,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.cumsum,
|
||||
"""
|
||||
"""
|
||||
cumsum(input, dim, out=None) -> Tensor
|
||||
|
||||
Returns the cumulative sum of elements of :attr:`input` in the dimension :attr:`dim`.
|
||||
@ -951,7 +951,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.diag,
|
||||
"""
|
||||
"""
|
||||
diag(input, diagonal=0, out=None) -> Tensor
|
||||
|
||||
- If :attr:`input` is a vector (1D Tensor), then returns a 2D square Tensor with the elements of :attr:`input` as the diagonal.
|
||||
@ -1022,7 +1022,7 @@ Get the k-th diagonal of a given matrix::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.dist,
|
||||
"""
|
||||
"""
|
||||
dist(input, other, p=2, out=None) -> Tensor
|
||||
|
||||
Returns the p-norm of (:attr:`input` - :attr:`other`)
|
||||
@ -1066,7 +1066,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.div,
|
||||
"""
|
||||
"""
|
||||
.. function:: div(input, value, out=None)
|
||||
|
||||
Divides each element of the input :attr:`input` with the scalar :attr:`value` and returns a new resulting tensor.
|
||||
@ -1150,7 +1150,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.dot,
|
||||
"""
|
||||
"""
|
||||
dot(tensor1, tensor2) -> float
|
||||
|
||||
Computes the dot product (inner product) of two tensors. Both tensors are
|
||||
@ -1163,7 +1163,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.eig,
|
||||
"""
|
||||
"""
|
||||
eig(a, eigenvectors=False, out=None) -> (Tensor, Tensor)
|
||||
|
||||
Computes the eigenvalues and eigenvectors of a real square matrix.
|
||||
@ -1183,7 +1183,7 @@ Returns:
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.eq,
|
||||
"""
|
||||
"""
|
||||
eq(input, other, out=None) -> Tensor
|
||||
|
||||
Computes element-wise equality
|
||||
@ -1208,7 +1208,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.equal,
|
||||
"""
|
||||
"""
|
||||
equal(tensor1, tensor2) -> bool
|
||||
|
||||
True if two tensors have the same size and elements, False otherwise.
|
||||
@ -1220,7 +1220,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.exp,
|
||||
"""
|
||||
"""
|
||||
exp(tensor, out=None) -> Tensor
|
||||
|
||||
Computes the exponential of each element.
|
||||
@ -1232,7 +1232,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.eye,
|
||||
"""
|
||||
"""
|
||||
eye(n, m=None, out=None)
|
||||
|
||||
Returns a 2-D tensor with ones on the diagonal and zeros elsewhere.
|
||||
@ -1255,7 +1255,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.floor,
|
||||
"""
|
||||
"""
|
||||
floor(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the floor of the elements of :attr:`input`, the largest integer less than or equal to each element.
|
||||
@ -1287,7 +1287,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.fmod,
|
||||
"""
|
||||
"""
|
||||
fmod(input, divisor, out=None) -> Tensor
|
||||
|
||||
Computes the element-wise remainder of division.
|
||||
@ -1315,7 +1315,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.frac,
|
||||
"""
|
||||
"""
|
||||
frac(tensor, out=None) -> Tensor
|
||||
|
||||
Computes the fractional portion of each element in `tensor`.
|
||||
@ -1327,7 +1327,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.from_numpy,
|
||||
"""
|
||||
"""
|
||||
from_numpy(ndarray) -> Tensor
|
||||
|
||||
Creates a :class:`Tensor` from a :class:`numpy.ndarray`.
|
||||
@ -1348,7 +1348,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.gather,
|
||||
"""
|
||||
"""
|
||||
gather(input, dim, index, out=None) -> Tensor
|
||||
|
||||
Gathers values along an axis specified by `dim`.
|
||||
@ -1375,7 +1375,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.ge,
|
||||
"""
|
||||
"""
|
||||
ge(input, other, out=None) -> Tensor
|
||||
|
||||
Computes `tensor >= other` element-wise.
|
||||
@ -1400,7 +1400,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.gels,
|
||||
r"""
|
||||
r"""
|
||||
gels(B, A, out=None) -> Tensor
|
||||
|
||||
Computes the solution to the least squares and least norm problems for a full
|
||||
@ -1466,7 +1466,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.geqrf,
|
||||
r"""
|
||||
r"""
|
||||
geqrf(input, out=None) -> (Tensor, Tensor)
|
||||
|
||||
This is a low-level function for calling LAPACK directly.
|
||||
@ -1489,7 +1489,7 @@ Args:
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.ger,
|
||||
"""
|
||||
"""
|
||||
ger(vec1, vec2, out=None) -> Tensor
|
||||
Outer product of :attr:`vec1` and :attr:`vec2`. If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector of size `m`, then :attr:`out` must be a matrix of size `n x m`.
|
||||
|
||||
@ -1513,7 +1513,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.gesv,
|
||||
"""
|
||||
"""
|
||||
gesv(B, A, out=None) -> (Tensor, Tensor)
|
||||
|
||||
`X, LU = torch.gesv(B, A)` returns the solution to the system of linear
|
||||
@ -1552,14 +1552,14 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.get_num_threads,
|
||||
"""
|
||||
"""
|
||||
get_num_threads() -> int
|
||||
|
||||
Gets the number of OpenMP threads used for parallelizing CPU operations
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.gt,
|
||||
"""
|
||||
"""
|
||||
gt(input, other, out=None) -> Tensor
|
||||
|
||||
Computes `tensor > other` element-wise.
|
||||
@ -1584,7 +1584,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.histc,
|
||||
"""
|
||||
"""
|
||||
histc(input, bins=100, min=0, max=0, out=None) -> Tensor
|
||||
|
||||
Computes the histogram of a tensor.
|
||||
@ -1610,7 +1610,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.index_select,
|
||||
"""
|
||||
"""
|
||||
index_select(input, dim, index, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` which indexes the :attr:`input` `Tensor` along dimension :attr:`dim`
|
||||
@ -1653,7 +1653,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.inverse,
|
||||
"""
|
||||
"""
|
||||
inverse(input, out=None) -> Tensor
|
||||
|
||||
Takes the inverse of the square matrix :attr:`input`.
|
||||
@ -1704,7 +1704,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.kthvalue,
|
||||
"""
|
||||
"""
|
||||
kthvalue(input, k, dim=None, out=None) -> (Tensor, LongTensor)
|
||||
|
||||
Returns the :attr:`k`th smallest element of the given :attr:`input` Tensor along a given dimension.
|
||||
@ -1745,7 +1745,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.le,
|
||||
"""
|
||||
"""
|
||||
le(input, other, out=None) -> Tensor
|
||||
|
||||
Computes `tensor <= other` element-wise.
|
||||
@ -1770,7 +1770,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.lerp,
|
||||
"""
|
||||
"""
|
||||
lerp(start, end, weight, out=None)
|
||||
|
||||
Does a linear interpolation of two tensors :attr:`start` and :attr:`end` based on a scalar :attr:`weight`: and returns the resulting :attr:`out` Tensor.
|
||||
@ -1814,7 +1814,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.linspace,
|
||||
"""
|
||||
"""
|
||||
linspace(start, end, steps=100, out=None) -> Tensor
|
||||
|
||||
Returns a one-dimensional Tensor of :attr:`steps`
|
||||
@ -1860,7 +1860,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.log,
|
||||
"""
|
||||
"""
|
||||
log(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the natural logarithm of the elements of :attr:`input`.
|
||||
@ -1893,7 +1893,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.log1p,
|
||||
"""
|
||||
"""
|
||||
log1p(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the natural logarithm of (1 + :attr:`input`).
|
||||
@ -1930,7 +1930,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.logspace,
|
||||
"""
|
||||
"""
|
||||
logspace(start, end, steps=100, out=None) -> Tensor
|
||||
|
||||
Returns a one-dimensional Tensor of :attr:`steps` points
|
||||
@ -1967,7 +1967,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.lt,
|
||||
"""
|
||||
"""
|
||||
lt(input, other, out=None) -> Tensor
|
||||
|
||||
Computes `tensor < other` element-wise.
|
||||
@ -1992,7 +1992,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.masked_select,
|
||||
"""
|
||||
"""
|
||||
masked_select(input, mask, out=None) -> Tensor
|
||||
|
||||
Returns a new 1D `Tensor` which indexes the :attr:`input` `Tensor` according to the binary mask :attr:`mask` which is a `ByteTensor`.
|
||||
@ -2038,7 +2038,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.max,
|
||||
"""
|
||||
"""
|
||||
.. function:: max(input) -> float
|
||||
|
||||
Returns the maximum value of all elements in the :attr:`input` Tensor.
|
||||
@ -2144,7 +2144,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.mean,
|
||||
"""
|
||||
"""
|
||||
.. function:: mean(input) -> float
|
||||
|
||||
Returns the mean value of all elements in the :attr:`input` Tensor.
|
||||
@ -2197,7 +2197,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.median,
|
||||
"""
|
||||
"""
|
||||
median(input, dim=-1, values=None, indices=None) -> (Tensor, LongTensor)
|
||||
|
||||
Returns the median value of each row of the :attr:`input` Tensor in the given dimension :attr:`dim`.
|
||||
@ -2252,7 +2252,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.min,
|
||||
"""
|
||||
"""
|
||||
.. function:: min(input) -> float
|
||||
|
||||
Returns the minimum value of all elements in the :attr:`input` Tensor.
|
||||
@ -2357,7 +2357,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.mm,
|
||||
"""
|
||||
"""
|
||||
mm(mat1, mat2, out=None) -> Tensor
|
||||
|
||||
Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`.
|
||||
@ -2380,7 +2380,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.mode,
|
||||
"""
|
||||
"""
|
||||
mode(input, dim=-1, values=None, indices=None) -> (Tensor, LongTensor)
|
||||
|
||||
Returns the mode value of each row of the :attr:`input` Tensor in the given dimension :attr:`dim`.
|
||||
@ -2435,7 +2435,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.mul,
|
||||
"""
|
||||
"""
|
||||
.. function:: mul(input, value, out=None)
|
||||
|
||||
Multiplies each element of the input :attr:`input` with the scalar :attr:`value` and returns a new resulting tensor.
|
||||
@ -2508,7 +2508,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.multinomial,
|
||||
u"""
|
||||
u"""
|
||||
multinomial(input, num_samples, replacement=False, out=None) -> LongTensor
|
||||
|
||||
Returns a Tensor where each row
|
||||
@ -2562,7 +2562,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.mv,
|
||||
"""
|
||||
"""
|
||||
mv(mat, vec, out=None) -> Tensor
|
||||
|
||||
Performs a matrix-vector product of the matrix :attr:`mat` and the vector :attr:`vec`.
|
||||
@ -2585,7 +2585,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.ne,
|
||||
"""
|
||||
"""
|
||||
ne(input, other, out=None) -> Tensor
|
||||
|
||||
Computes `tensor != other` element-wise.
|
||||
@ -2610,7 +2610,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.neg,
|
||||
"""
|
||||
"""
|
||||
neg(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the negative of the elements of :attr:`input`.
|
||||
@ -2645,7 +2645,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.nonzero,
|
||||
"""
|
||||
"""
|
||||
nonzero(input, out=None) -> LongTensor
|
||||
|
||||
Returns a tensor containing the indices of all non-zero elements of :attr:`input`.
|
||||
@ -2681,7 +2681,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.norm,
|
||||
"""
|
||||
"""
|
||||
.. function:: norm(input, p=2) -> float
|
||||
|
||||
Returns the p-norm of the :attr:`input` Tensor.
|
||||
@ -2743,7 +2743,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.normal,
|
||||
"""
|
||||
"""
|
||||
.. function:: normal(means, stddevs, out=None)
|
||||
|
||||
Returns a Tensor of random numbers drawn from separate normal distributions
|
||||
@ -2825,7 +2825,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.numel,
|
||||
"""
|
||||
"""
|
||||
numel(input) -> int
|
||||
|
||||
Returns the total number of elements in the :attr:`input` Tensor.
|
||||
@ -2845,7 +2845,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.ones,
|
||||
"""
|
||||
"""
|
||||
ones(*sizes, out=None) -> Tensor
|
||||
|
||||
Returns a Tensor filled with the scalar value `1`, with the shape defined
|
||||
@ -2896,7 +2896,7 @@ Example::
|
||||
# """)
|
||||
|
||||
add_docstr(torch._C.pow,
|
||||
"""
|
||||
"""
|
||||
.. function:: pow(input, exponent, out=None)
|
||||
|
||||
Takes the power of each element in :attr:`input` with :attr:`exponent` and returns a Tensor with the result.
|
||||
@ -2991,7 +2991,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.prod,
|
||||
"""
|
||||
"""
|
||||
.. function:: prod(input) -> float
|
||||
|
||||
Returns the product of all elements in the :attr:`input` Tensor.
|
||||
@ -3049,7 +3049,7 @@ Example::
|
||||
# """)
|
||||
|
||||
add_docstr(torch._C.qr,
|
||||
"""
|
||||
"""
|
||||
qr(input, out=None) -> (Tensor, Tensor)
|
||||
|
||||
Computes the QR decomposition of a matrix :attr:`input`: returns matrices
|
||||
@ -3106,7 +3106,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.rand,
|
||||
"""
|
||||
"""
|
||||
rand(*sizes, out=None) -> Tensor
|
||||
|
||||
Returns a Tensor filled with random numbers from a uniform distribution
|
||||
@ -3137,7 +3137,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.randn,
|
||||
"""
|
||||
"""
|
||||
randn(*sizes, out=None) -> Tensor
|
||||
|
||||
Returns a Tensor filled with random numbers from a normal distribution
|
||||
@ -3168,7 +3168,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.randperm,
|
||||
"""
|
||||
"""
|
||||
randperm(n, out=None) -> LongTensor
|
||||
|
||||
Returns a random permutation of integers from ``0`` to ``n - 1``.
|
||||
@ -3188,7 +3188,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.range,
|
||||
"""
|
||||
"""
|
||||
range(start, end, step=1, out=None) -> Tensor
|
||||
|
||||
returns a 1D Tensor of size :math:`floor((end - start) / step) + 1` with values
|
||||
@ -3225,7 +3225,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.remainder,
|
||||
"""
|
||||
"""
|
||||
remainder(input, divisor, out=None) -> Tensor
|
||||
|
||||
Computes the element-wise remainder of division.
|
||||
@ -3253,7 +3253,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.renorm,
|
||||
"""
|
||||
"""
|
||||
renorm(input, p, dim, maxnorm, out=None) -> Tensor
|
||||
|
||||
Returns a Tensor where each sub-tensor of :attr:`input` along dimension :attr:`dim`
|
||||
@ -3290,7 +3290,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.round,
|
||||
"""
|
||||
"""
|
||||
round(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with each of the elements of :attr:`input` rounded to the closest integer.
|
||||
@ -3321,7 +3321,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.rsqrt,
|
||||
"""
|
||||
"""
|
||||
rsqrt(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the reciprocal of the square-root of each of the elements of :attr:`input`.
|
||||
@ -3352,14 +3352,14 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.set_num_threads,
|
||||
"""
|
||||
"""
|
||||
set_num_threads(int)
|
||||
|
||||
Sets the number of OpenMP threads used for parallelizing CPU operations
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sigmoid,
|
||||
"""
|
||||
"""
|
||||
sigmoid(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the sigmoid of the elements of :attr:`input`.
|
||||
@ -3390,7 +3390,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sign,
|
||||
"""
|
||||
"""
|
||||
sign(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the sign of the elements of :attr:`input`.
|
||||
@ -3420,7 +3420,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sin,
|
||||
"""
|
||||
"""
|
||||
sin(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the sine of the elements of :attr:`input`.
|
||||
@ -3448,7 +3448,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sinh,
|
||||
"""
|
||||
"""
|
||||
sinh(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the hyperbolic sine of the elements of :attr:`input`.
|
||||
@ -3476,7 +3476,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sort,
|
||||
"""
|
||||
"""
|
||||
sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)
|
||||
|
||||
Sorts the elements of the :attr:`input` Tensor along a given dimension in ascending order by value.
|
||||
@ -3530,7 +3530,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sqrt,
|
||||
"""
|
||||
"""
|
||||
sqrt(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the square-root of the elements of :attr:`input`.
|
||||
@ -3561,7 +3561,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.squeeze,
|
||||
"""
|
||||
"""
|
||||
squeeze(input, dim=None, out=None)
|
||||
|
||||
Returns a `Tensor` with all the dimensions of :attr:`input` of size `1` removed.
|
||||
@ -3599,7 +3599,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.std,
|
||||
"""
|
||||
"""
|
||||
.. function:: std(input) -> float
|
||||
|
||||
Returns the standard-deviation of all elements in the :attr:`input` Tensor.
|
||||
@ -3652,7 +3652,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.sum,
|
||||
"""
|
||||
"""
|
||||
.. function:: sum(input) -> float
|
||||
|
||||
Returns the sum of all elements in the :attr:`input` Tensor.
|
||||
@ -3705,7 +3705,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.svd,
|
||||
"""
|
||||
"""
|
||||
svd(input, some=True, out=None) -> (Tensor, Tensor, Tensor)
|
||||
|
||||
`U, S, V = torch.svd(A)` returns the singular value decomposition of a
|
||||
@ -3780,7 +3780,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.symeig,
|
||||
"""
|
||||
"""
|
||||
symeig(input, eigenvectors=False, upper=True, out=None) -> (Tensor, Tensor)
|
||||
|
||||
`e, V = torch.symeig(input)` returns eigenvalues and eigenvectors
|
||||
@ -3842,7 +3842,7 @@ Examples::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.t,
|
||||
"""
|
||||
"""
|
||||
t(input, out=None) -> Tensor
|
||||
|
||||
Expects :attr:`input` to be a matrix (2D Tensor) and transposes dimensions 0 and 1.
|
||||
@ -3872,7 +3872,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.tan,
|
||||
"""
|
||||
"""
|
||||
tan(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the tangent of the elements of :attr:`input`.
|
||||
@ -3900,7 +3900,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.tanh,
|
||||
"""
|
||||
"""
|
||||
tanh(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the hyperbolic tangent of the elements of :attr:`input`.
|
||||
@ -3928,7 +3928,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.topk,
|
||||
"""
|
||||
"""
|
||||
topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
|
||||
|
||||
Returns the :attr:`k` largest elements of the given :attr:`input` Tensor along a given dimension.
|
||||
@ -3992,7 +3992,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.trace,
|
||||
"""
|
||||
"""
|
||||
trace(input) -> float
|
||||
|
||||
Returns the sum of the elements of the diagonal of the input 2D matrix.
|
||||
@ -4013,7 +4013,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.transpose,
|
||||
"""
|
||||
"""
|
||||
transpose(input, dim0, dim1, out=None) -> Tensor
|
||||
|
||||
Returns a `Tensor` that is a transposed version of :attr:`input`. The given dimensions :attr:`dim0` and :attr:`dim1` are swapped.
|
||||
@ -4044,7 +4044,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.tril,
|
||||
"""
|
||||
"""
|
||||
tril(input, k=0, out=None) -> Tensor
|
||||
|
||||
Returns the lower triangular part of the matrix (2D Tensor) :attr:`input`,
|
||||
@ -4097,7 +4097,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.triu,
|
||||
"""
|
||||
"""
|
||||
triu(input, k=0, out=None) -> Tensor
|
||||
|
||||
Returns the upper triangular part of the matrix (2D Tensor) :attr:`input`,
|
||||
@ -4155,7 +4155,7 @@ Example::
|
||||
# """)
|
||||
|
||||
add_docstr(torch._C.trunc,
|
||||
"""
|
||||
"""
|
||||
trunc(input, out=None) -> Tensor
|
||||
|
||||
Returns a new `Tensor` with the truncated integer values of the elements of :attr:`input`.
|
||||
@ -4186,7 +4186,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.var,
|
||||
"""
|
||||
"""
|
||||
.. function:: var(input) -> float
|
||||
|
||||
Returns the variance of all elements in the :attr:`input` Tensor.
|
||||
@ -4239,7 +4239,7 @@ Example::
|
||||
""")
|
||||
|
||||
add_docstr(torch._C.zeros,
|
||||
"""
|
||||
"""
|
||||
zeros(*sizes, out=None) -> Tensor
|
||||
|
||||
Returns a Tensor filled with the scalar value `0`, with the shape defined
|
||||
|
@ -12,6 +12,7 @@ from .stochastic_function import StochasticFunction
|
||||
|
||||
__all__ = ['Variable', 'Function', 'StochasticFunction', 'backward']
|
||||
|
||||
|
||||
def backward(variables, grad_variables, retain_variables=False):
|
||||
"""Computes the sum of gradients of given variables w.r.t. graph leaves.
|
||||
|
||||
@ -37,6 +38,6 @@ def backward(variables, grad_variables, retain_variables=False):
|
||||
times.
|
||||
"""
|
||||
Variable._execution_engine.run_backward(
|
||||
tuple(variables), tuple(grad_variables), retain_variables)
|
||||
tuple(variables), tuple(grad_variables), retain_variables)
|
||||
|
||||
assert torch._C._autograd_init()
|
||||
|
@ -5,4 +5,3 @@ from .reduce import *
|
||||
from .linalg import *
|
||||
from .blas import *
|
||||
from .stochastic import *
|
||||
|
||||
|
@ -59,7 +59,7 @@ class Pow(Function):
|
||||
|
||||
def backward(self, grad_output):
|
||||
a, b = self.saved_tensors
|
||||
return grad_output.mul(b).mul_(a.pow(b-1)), grad_output.mul(a.pow(b)).mul_(a.log())
|
||||
return grad_output.mul(b).mul_(a.pow(b - 1)), grad_output.mul(a.pow(b)).mul_(a.log())
|
||||
|
||||
|
||||
class AddConstant(InplaceFunction):
|
||||
@ -174,7 +174,7 @@ class PowConstant(Function):
|
||||
return grad_output.mul(self.fw_result).mul_(math.log(self.constant))
|
||||
else:
|
||||
a = self.saved_tensors[0]
|
||||
return grad_output.mul(self.constant).mul_(a.pow(self.constant-1))
|
||||
return grad_output.mul(self.constant).mul_(a.pow(self.constant - 1))
|
||||
|
||||
|
||||
class Negate(InplaceFunction):
|
||||
|
@ -25,7 +25,7 @@ class Addmm(_BlasBase):
|
||||
self.save_for_backward(matrix1, matrix2)
|
||||
output = self._get_output(add_matrix)
|
||||
return torch.addmm(self.alpha, add_matrix, self.beta,
|
||||
matrix1, matrix2, out=output)
|
||||
matrix1, matrix2, out=output)
|
||||
|
||||
def backward(self, grad_output):
|
||||
matrix1, matrix2 = self.saved_tensors
|
||||
@ -55,7 +55,7 @@ class Addbmm(_BlasBase):
|
||||
self.save_for_backward(batch1, batch2)
|
||||
output = self._get_output(add_matrix)
|
||||
return torch.addbmm(self.alpha, add_matrix, self.beta,
|
||||
batch1, batch2, out=output)
|
||||
batch1, batch2, out=output)
|
||||
|
||||
def backward(self, grad_output):
|
||||
batch1, batch2 = self.saved_tensors
|
||||
@ -68,8 +68,8 @@ class Addbmm(_BlasBase):
|
||||
|
||||
if any(self.needs_input_grad[1:]):
|
||||
batch_grad_output = (grad_output
|
||||
.unsqueeze(0)
|
||||
.expand(batch1.size(0), batch1.size(1), batch2.size(2)))
|
||||
.unsqueeze(0)
|
||||
.expand(batch1.size(0), batch1.size(1), batch2.size(2)))
|
||||
|
||||
if self.needs_input_grad[1]:
|
||||
grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
|
||||
@ -90,7 +90,7 @@ class Baddbmm(_BlasBase):
|
||||
self.save_for_backward(batch1, batch2)
|
||||
output = self._get_output(add_batch)
|
||||
return torch.baddbmm(self.alpha, add_batch, self.beta,
|
||||
batch1, batch2, out=output)
|
||||
batch1, batch2, out=output)
|
||||
|
||||
def backward(self, grad_output):
|
||||
batch1, batch2 = self.saved_tensors
|
||||
@ -120,7 +120,7 @@ class Addmv(_BlasBase):
|
||||
self.save_for_backward(matrix, vector)
|
||||
output = self._get_output(add_vector)
|
||||
return torch.addmv(self.alpha, add_vector, self.beta,
|
||||
matrix, vector, out=output)
|
||||
matrix, vector, out=output)
|
||||
|
||||
def backward(self, grad_output):
|
||||
matrix, vector = self.saved_tensors
|
||||
@ -150,7 +150,7 @@ class Addr(_BlasBase):
|
||||
self.save_for_backward(vector1, vector2)
|
||||
output = self._get_output(add_matrix)
|
||||
return torch.addr(self.alpha, add_matrix, self.beta,
|
||||
vector1, vector2, out=output)
|
||||
vector1, vector2, out=output)
|
||||
|
||||
def backward(self, grad_output):
|
||||
vector1, vector2 = self.saved_tensors
|
||||
@ -199,4 +199,3 @@ class Dot(Function):
|
||||
# TODO: trace
|
||||
# TODO: tril
|
||||
# TODO: triu
|
||||
|
||||
|
@ -42,4 +42,3 @@ class Triu(Function):
|
||||
return grad_output.triu(self.diagonal_idx)
|
||||
|
||||
# TODO: trace
|
||||
|
||||
|
@ -165,6 +165,7 @@ class Tan(Function):
|
||||
|
||||
|
||||
class Asin(Function):
|
||||
|
||||
def forward(self, i):
|
||||
self.save_for_backward(i)
|
||||
return i.asin()
|
||||
@ -175,6 +176,7 @@ class Asin(Function):
|
||||
|
||||
|
||||
class Acos(Function):
|
||||
|
||||
def forward(self, i):
|
||||
self.save_for_backward(i)
|
||||
return i.acos()
|
||||
@ -185,6 +187,7 @@ class Acos(Function):
|
||||
|
||||
|
||||
class Atan(Function):
|
||||
|
||||
def forward(self, i):
|
||||
self.save_for_backward(i)
|
||||
return i.atan()
|
||||
|
@ -4,6 +4,7 @@ from ..function import Function
|
||||
|
||||
|
||||
class _DimReduceFunction(Function):
|
||||
|
||||
def __init__(self, dim=None):
|
||||
super(_DimReduceFunction, self).__init__()
|
||||
self.dim = dim
|
||||
@ -139,6 +140,7 @@ class Kthvalue(_SelectionFunction):
|
||||
|
||||
|
||||
class Norm(Function):
|
||||
|
||||
def __init__(self, norm_type=2, dim=None):
|
||||
super(Norm, self).__init__()
|
||||
self.norm_type = norm_type
|
||||
|
@ -65,7 +65,7 @@ class Normal(StochasticFunction):
|
||||
output.mul_(stddevs)
|
||||
else:
|
||||
raise RuntimeError("Normal function requires specifying a common "
|
||||
"stddev, or per-sample stddev")
|
||||
"stddev, or per-sample stddev")
|
||||
output.add_(means)
|
||||
self.save_for_backward(output, means, stddevs)
|
||||
self.mark_non_differentiable(output)
|
||||
@ -74,7 +74,7 @@ class Normal(StochasticFunction):
|
||||
def backward(self, reward):
|
||||
output, means, stddevs = self.saved_tensors
|
||||
grad_stddevs = None
|
||||
grad_means = means - output # == -(output - means)
|
||||
grad_means = means - output # == -(output - means)
|
||||
assert self.stddev is not None or stddevs is not None
|
||||
if self.stddev is not None:
|
||||
grad_means /= 1e-6 + self.stddev ** 2
|
||||
@ -88,4 +88,3 @@ class Normal(StochasticFunction):
|
||||
grad_means /= stddevs_sq
|
||||
grad_means *= reward
|
||||
return grad_means, grad_stddevs
|
||||
|
||||
|
@ -103,6 +103,7 @@ class View(Function):
|
||||
|
||||
|
||||
class Expand(Function):
|
||||
|
||||
def __init__(self, sizes):
|
||||
super(Expand, self).__init__()
|
||||
self.sizes = sizes
|
||||
@ -110,8 +111,8 @@ class Expand(Function):
|
||||
|
||||
def forward(self, i):
|
||||
self.expanded_dims = [dim for dim, (expanded, original)
|
||||
in enumerate(zip(self.sizes, i.size()))
|
||||
if expanded != original]
|
||||
in enumerate(zip(self.sizes, i.size()))
|
||||
if expanded != original]
|
||||
result = i.expand(*self.sizes)
|
||||
self.mark_shared_storage((i, result))
|
||||
return result
|
||||
@ -304,8 +305,8 @@ class Concat(Function):
|
||||
return torch.cat(inputs, self.dim)
|
||||
|
||||
def backward(self, grad_output):
|
||||
return tuple(grad_output.narrow(self.dim, end-size, size) for size, end
|
||||
in zip(self.input_sizes, _accumulate(self.input_sizes)))
|
||||
return tuple(grad_output.narrow(self.dim, end - size, size) for size, end
|
||||
in zip(self.input_sizes, _accumulate(self.input_sizes)))
|
||||
|
||||
|
||||
class Resize(Function):
|
||||
@ -318,11 +319,11 @@ class Resize(Function):
|
||||
def forward(self, tensor):
|
||||
if tensor.numel() != self.numel:
|
||||
raise RuntimeError(("requested resize to {} ({} elements in total), "
|
||||
"but the given tensor has a size of {} ({} elements). "
|
||||
"autograd's resize can only change the shape of a given "
|
||||
"tensor, while preserving the number of elements. ").format(
|
||||
'x'.join(map(str, self.sizes)), self.numel,
|
||||
'x'.join(map(str, tensor.size())), tensor.numel()))
|
||||
"but the given tensor has a size of {} ({} elements). "
|
||||
"autograd's resize can only change the shape of a given "
|
||||
"tensor, while preserving the number of elements. ").format(
|
||||
'x'.join(map(str, self.sizes)), self.numel,
|
||||
'x'.join(map(str, tensor.size())), tensor.numel()))
|
||||
self.input_sizes = tensor.size()
|
||||
result = tensor.new(tensor).resize_(*self.sizes)
|
||||
self.mark_shared_storage((tensor, result))
|
||||
@ -493,7 +494,7 @@ class Topk(_MultiSelectionFunction):
|
||||
self.sort = sort
|
||||
|
||||
def forward(self, input):
|
||||
dim = self.dim if self.dim is not None else input.dim()-1
|
||||
dim = self.dim if self.dim is not None else input.dim() - 1
|
||||
self.args = (self.k, dim, self.largest, self.sort)
|
||||
return super(Topk, self).forward(input)
|
||||
|
||||
|
@ -71,8 +71,8 @@ class BasicEngine(object):
|
||||
else:
|
||||
if prev_fn.num_outputs != 1:
|
||||
raise RuntimeError("one of the function outputs "
|
||||
"wasn't used - this is an error not, but "
|
||||
"it's going to be fixed soon")
|
||||
"wasn't used - this is an error not, but "
|
||||
"it's going to be fixed soon")
|
||||
prev_grad = (d_prev_fn,)
|
||||
ready.appendleft((prev_fn, prev_grad))
|
||||
else:
|
||||
|
@ -154,9 +154,10 @@ def _nested_map(condition, fn):
|
||||
return type(obj)(_map(x) for x in obj)
|
||||
else:
|
||||
raise ValueError("NestedIOFunction doesn't know how to process "
|
||||
"an input object of type " + torch.typename(obj))
|
||||
"an input object of type " + torch.typename(obj))
|
||||
return _map
|
||||
|
||||
|
||||
def _iter_filter(condition):
|
||||
def _iter(obj):
|
||||
if condition(obj):
|
||||
@ -169,7 +170,7 @@ def _iter_filter(condition):
|
||||
yield var
|
||||
else:
|
||||
raise ValueError("NestedIOFunction doesn't know how to process "
|
||||
"an input object of type " + torch.typename(obj))
|
||||
"an input object of type " + torch.typename(obj))
|
||||
return _iter
|
||||
|
||||
|
||||
@ -178,8 +179,10 @@ _iter_tensors = _iter_filter(torch.is_tensor)
|
||||
_iter_None_tensors = _iter_filter(lambda o: o is None or torch.is_tensor(o))
|
||||
_map_variable_tensor = _nested_map(lambda o: isinstance(o, torch.autograd.Variable), lambda o: o.data)
|
||||
|
||||
|
||||
def _map_tensor_fromiter(itr):
|
||||
return _nested_map(lambda o: torch.is_tensor(o), lambda o: next(itr))
|
||||
return _nested_map(lambda o: torch.is_tensor(o), lambda o: next(itr))
|
||||
|
||||
|
||||
class NestedIOFunction(Function):
|
||||
|
||||
|
@ -2,6 +2,7 @@ from .function import Function
|
||||
|
||||
_NOT_PROVIDED = object()
|
||||
|
||||
|
||||
class StochasticFunction(Function):
|
||||
|
||||
def __init__(self):
|
||||
@ -10,7 +11,7 @@ class StochasticFunction(Function):
|
||||
def _do_backward(self, grad_output, retain_variables):
|
||||
if self.reward is _NOT_PROVIDED:
|
||||
raise RuntimeError("differentiating stochastic functions requires "
|
||||
"providing a reward")
|
||||
"providing a reward")
|
||||
result = super(StochasticFunction, self)._do_backward((self.reward,), retain_variables)
|
||||
if not retain_variables:
|
||||
self.reward = None
|
||||
@ -18,4 +19,3 @@ class StochasticFunction(Function):
|
||||
|
||||
def _reinforce(self, reward):
|
||||
self.reward = reward
|
||||
|
||||
|
@ -72,12 +72,12 @@ class Variable(_C._VariableBase):
|
||||
if self.creator is not None:
|
||||
if value is False:
|
||||
hint = (" If you want to use a computed variable in a subgraph "
|
||||
"that doesn't require differentiation use "
|
||||
"var_no_grad = var.detach().")
|
||||
"that doesn't require differentiation use "
|
||||
"var_no_grad = var.detach().")
|
||||
else:
|
||||
hint = ''
|
||||
raise RuntimeError("you can only change requires_grad flags of "
|
||||
"leaf variables." + hint)
|
||||
"leaf variables." + hint)
|
||||
self._requires_grad = value
|
||||
|
||||
def __getattr__(self, name):
|
||||
@ -87,13 +87,13 @@ class Variable(_C._VariableBase):
|
||||
|
||||
def __getitem__(self, key):
|
||||
if (isinstance(key, Variable) and
|
||||
type(key.data).__name__ == 'ByteTensor'):
|
||||
type(key.data).__name__ == 'ByteTensor'):
|
||||
return MaskedSelect()(self, key)
|
||||
return Index(key)(self)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if (isinstance(key, Variable) and
|
||||
type(key.data).__name__ == 'ByteTensor'):
|
||||
type(key.data).__name__ == 'ByteTensor'):
|
||||
if isinstance(value, Variable):
|
||||
return MaskedCopy(inplace=True)(self, key, value)
|
||||
else:
|
||||
@ -107,9 +107,9 @@ class Variable(_C._VariableBase):
|
||||
def __deepcopy__(self, memo):
|
||||
if self.creator is not None:
|
||||
raise RuntimeError("Only Variables created explicitly by the user "
|
||||
"(graph leaves) support the deepcopy protocol at the moment")
|
||||
"(graph leaves) support the deepcopy protocol at the moment")
|
||||
result = type(self)(self.data.clone(), requires_grad=self.requires_grad,
|
||||
volatile=self.volatile)
|
||||
volatile=self.volatile)
|
||||
memo[id(self)] = result
|
||||
return result
|
||||
|
||||
@ -151,7 +151,8 @@ class Variable(_C._VariableBase):
|
||||
raise RuntimeError('calling backward on a volatile variable')
|
||||
if gradient is None and self.requires_grad:
|
||||
if self.data.numel() != 1:
|
||||
raise RuntimeError('backward should be called only on a scalar (i.e. 1-element tensor) or with gradient w.r.t. the variable')
|
||||
raise RuntimeError(
|
||||
'backward should be called only on a scalar (i.e. 1-element tensor) or with gradient w.r.t. the variable')
|
||||
gradient = self.data.new().resize_as_(self.data).fill_(1)
|
||||
self._execution_engine.run_backward((self,), (gradient,), retain_variables)
|
||||
|
||||
@ -219,7 +220,7 @@ class Variable(_C._VariableBase):
|
||||
"""
|
||||
if not isinstance(self.creator, StochasticFunction):
|
||||
raise RuntimeError("reinforce() can be only called on outputs "
|
||||
"of stochastic functions")
|
||||
"of stochastic functions")
|
||||
self.creator._reinforce(reward)
|
||||
|
||||
def detach(self):
|
||||
@ -392,7 +393,7 @@ class Variable(_C._VariableBase):
|
||||
def clamp(self, min=None, max=None):
|
||||
if min is None and max is None:
|
||||
raise ValueError("clamp requires specifying at least one of "
|
||||
"min and max arguments")
|
||||
"min and max arguments")
|
||||
elif min is None and max is not None:
|
||||
return CminConstant(max)(self)
|
||||
elif min is not None and max is None:
|
||||
@ -503,7 +504,7 @@ class Variable(_C._VariableBase):
|
||||
|
||||
def bmm(self, batch):
|
||||
output = Variable(self.data.new(self.data.size(0), self.data.size(1),
|
||||
batch.data.size(2)))
|
||||
batch.data.size(2)))
|
||||
return self._static_blas(Baddbmm, (output, 0, 1, self, batch), False)
|
||||
|
||||
def mv(self, vector):
|
||||
@ -622,7 +623,7 @@ class Variable(_C._VariableBase):
|
||||
if isinstance(sizes[0], torch.Size):
|
||||
if len(sizes) > 1:
|
||||
raise ValueError("expand expects a several ints or a single "
|
||||
"torch.Size argument")
|
||||
"torch.Size argument")
|
||||
sizes = sizes[0]
|
||||
return Expand(sizes)(self)
|
||||
|
||||
@ -641,7 +642,7 @@ class Variable(_C._VariableBase):
|
||||
|
||||
def narrow(self, dim, start_index, length):
|
||||
index = tuple(slice(None, None) for _ in range(dim)) + \
|
||||
(slice(start_index, start_index+length),)
|
||||
(slice(start_index, start_index + length),)
|
||||
|
||||
return Index(index)(self)
|
||||
|
||||
@ -710,7 +711,7 @@ class Variable(_C._VariableBase):
|
||||
elif dim_self == 2 and dim_other == 2:
|
||||
return self.mm(other)
|
||||
raise ValueError("both arguments to __matmul__ need to be 1D or 2D, "
|
||||
"but they are {}D and {}D".format(dim_self, dim_other))
|
||||
"but they are {}D and {}D".format(dim_self, dim_other))
|
||||
|
||||
def __div__(self, other):
|
||||
return self.div(other)
|
||||
|
@ -20,6 +20,7 @@ elif sys.platform == 'darwin':
|
||||
else:
|
||||
libnames = []
|
||||
|
||||
|
||||
def _loadlib():
|
||||
global lib
|
||||
loaded = False
|
||||
@ -39,6 +40,7 @@ def _loadlib():
|
||||
lib = None
|
||||
raise OSError("Could not load cuDNN")
|
||||
|
||||
|
||||
def is_acceptable(tensor):
|
||||
if not enabled:
|
||||
return False
|
||||
@ -58,13 +60,15 @@ def is_acceptable(tensor):
|
||||
return False
|
||||
if not _C.has_cudnn:
|
||||
warnings.warn("cuDNN library has been detected, but your pytorch "
|
||||
"installation was compiled without support for it. You "
|
||||
"might want to rebuild pytorch, making sure the library "
|
||||
"is visible to the build system.")
|
||||
"installation was compiled without support for it. You "
|
||||
"might want to rebuild pytorch, making sure the library "
|
||||
"is visible to the build system.")
|
||||
return False
|
||||
return True
|
||||
|
||||
__cudnn_version = []
|
||||
|
||||
|
||||
def version():
|
||||
if not lib:
|
||||
raise RuntimeError("cuDNN not initialized")
|
||||
@ -108,7 +112,9 @@ CUDNN_GRU = 3
|
||||
CUDNN_LINEAR_INPUT = 0
|
||||
CUDNN_SKIP_INPUT = 1
|
||||
|
||||
|
||||
class CuDNNHandle:
|
||||
|
||||
def __init__(self):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreate(ctypes.byref(ptr)))
|
||||
@ -117,7 +123,9 @@ class CuDNNHandle:
|
||||
def __del__(self):
|
||||
check_error(lib.cudnnDestroy(self))
|
||||
|
||||
|
||||
class CuDNNError(RuntimeError):
|
||||
|
||||
def __init__(self, status):
|
||||
self.status = status
|
||||
msg = '{}: {}'.format(status, get_error_string(status))
|
||||
@ -125,6 +133,7 @@ class CuDNNError(RuntimeError):
|
||||
|
||||
|
||||
class TensorDescriptor(object):
|
||||
|
||||
def __init__(self):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreateTensorDescriptor(ctypes.byref(ptr)))
|
||||
@ -147,6 +156,7 @@ class TensorDescriptor(object):
|
||||
|
||||
|
||||
class TensorDescriptorArray(object):
|
||||
|
||||
def __init__(self, N):
|
||||
self.ptrs = (ctypes.c_void_p * N)()
|
||||
for i in range(N):
|
||||
@ -175,6 +185,7 @@ class TensorDescriptorArray(object):
|
||||
|
||||
|
||||
class ConvolutionDescriptor(object):
|
||||
|
||||
def __init__(self):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreateConvolutionDescriptor(ctypes.byref(ptr)))
|
||||
@ -195,7 +206,9 @@ class ConvolutionDescriptor(object):
|
||||
def as_tuple(self):
|
||||
return (self._pad, self._stride)
|
||||
|
||||
|
||||
class FilterDescriptor(object):
|
||||
|
||||
def __init__(self):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreateFilterDescriptor(ctypes.byref(ptr)))
|
||||
@ -216,6 +229,7 @@ class FilterDescriptor(object):
|
||||
|
||||
|
||||
class DropoutDescriptor(object):
|
||||
|
||||
def __init__(self, handle, dropout, seed):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreateDropoutDescriptor(ctypes.byref(ptr)))
|
||||
@ -241,10 +255,10 @@ class DropoutDescriptor(object):
|
||||
check_error(lib.cudnnDestroyDropoutDescriptor(self))
|
||||
|
||||
|
||||
|
||||
class RNNDescriptor(object):
|
||||
|
||||
def __init__(self, hidden_size, num_layers, dropout_desc, input_mode,
|
||||
bidirectional, mode, datatype):
|
||||
bidirectional, mode, datatype):
|
||||
ptr = ctypes.c_void_p()
|
||||
check_error(lib.cudnnCreateRNNDescriptor(ctypes.byref(ptr)))
|
||||
self._as_parameter_ = ptr
|
||||
@ -272,13 +286,16 @@ class ConvolutionAlgoPerf(ctypes.Structure):
|
||||
("memory", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
|
||||
def check_error(status):
|
||||
if status is not 0:
|
||||
raise CuDNNError(status)
|
||||
|
||||
|
||||
def get_error_string(status):
|
||||
return lib.cudnnGetErrorString(status)
|
||||
|
||||
|
||||
def get_handle():
|
||||
if lib is None:
|
||||
_loadlib()
|
||||
@ -296,11 +313,12 @@ _typemap = {
|
||||
}
|
||||
|
||||
_sizeofmap = {
|
||||
CUDNN_DATA_HALF : 2,
|
||||
CUDNN_DATA_FLOAT : 4,
|
||||
CUDNN_DATA_DOUBLE : 8,
|
||||
CUDNN_DATA_HALF: 2,
|
||||
CUDNN_DATA_FLOAT: 4,
|
||||
CUDNN_DATA_DOUBLE: 8,
|
||||
}
|
||||
|
||||
|
||||
def c_type(tensor):
|
||||
if isinstance(tensor, torch.cuda.HalfTensor):
|
||||
return ctypes.c_float
|
||||
@ -311,10 +329,12 @@ def c_type(tensor):
|
||||
else:
|
||||
raise ValueError("unknown type '{}'".format(type(tensor)))
|
||||
|
||||
|
||||
def int_array(itr):
|
||||
array_type = ctypes.c_int * len(itr)
|
||||
return array_type(*itr)
|
||||
|
||||
|
||||
def descriptor(tensor, N=None):
|
||||
if N is not None:
|
||||
descriptor = TensorDescriptorArray(N)
|
||||
@ -331,9 +351,11 @@ _autotuner_forward = {}
|
||||
_autotuner_backward_data = {}
|
||||
_autotuner_backward_filter = {}
|
||||
|
||||
|
||||
def convolution_autotuner_key(idesc, weight_desc, conv_desc):
|
||||
return (idesc.as_tuple(), weight_desc.as_tuple(), conv_desc.as_tuple())
|
||||
|
||||
|
||||
def convolution_forward_algorithm(idesc, weight_desc, conv_desc, odesc):
|
||||
k = convolution_autotuner_key(idesc, weight_desc, conv_desc)
|
||||
if k in _autotuner_forward:
|
||||
@ -360,15 +382,19 @@ def convolution_forward_algorithm(idesc, weight_desc, conv_desc, odesc):
|
||||
wlimit, ctypes.byref(fwd_alg)))
|
||||
return fwd_alg
|
||||
|
||||
|
||||
def convolution_forward_workspace_size(*args):
|
||||
check_error(lib.cudnnGetConvolutionForwardWorkspaceSize(*args))
|
||||
|
||||
|
||||
def convolution_forward(*args):
|
||||
check_error(lib.cudnnConvolutionForward(*args))
|
||||
|
||||
|
||||
def convolution_backward_data(*args):
|
||||
return check_error(lib.cudnnConvolutionBackwardData(*args))
|
||||
|
||||
|
||||
def convolution_backward_data_algorithm(weight_desc, odesc, conv_desc, idesc):
|
||||
k = convolution_autotuner_key(idesc, weight_desc, conv_desc)
|
||||
if k in _autotuner_backward_data:
|
||||
@ -395,12 +421,15 @@ def convolution_backward_data_algorithm(weight_desc, odesc, conv_desc, idesc):
|
||||
wlimit, ctypes.byref(bwd_data_alg)))
|
||||
return bwd_data_alg
|
||||
|
||||
|
||||
def convolution_backward_data_workspace_size(*args):
|
||||
return check_error(lib.cudnnGetConvolutionBackwardDataWorkspaceSize(*args))
|
||||
|
||||
|
||||
def convolution_backward_filter(*args):
|
||||
return check_error(lib.cudnnConvolutionBackwardFilter(*args))
|
||||
|
||||
|
||||
def convolution_backward_filter_algorithm(idesc, odesc, conv_desc, weight_desc):
|
||||
k = convolution_autotuner_key(idesc, weight_desc, conv_desc)
|
||||
if k in _autotuner_backward_filter:
|
||||
@ -427,11 +456,14 @@ def convolution_backward_filter_algorithm(idesc, odesc, conv_desc, weight_desc):
|
||||
wlimit, ctypes.byref(bwd_filter_alg)))
|
||||
return bwd_filter_alg
|
||||
|
||||
|
||||
def convolution_backward_filter_workspace_size(*args):
|
||||
return check_error(lib.cudnnGetConvolutionBackwardFilterWorkspaceSize(*args))
|
||||
|
||||
|
||||
def convolution_backward_bias(*args):
|
||||
check_error(lib.cudnnConvolutionBackwardBias(*args))
|
||||
|
||||
|
||||
def add_tensor(*args):
|
||||
check_error(lib.cudnnAddTensor(*args))
|
||||
|
@ -3,6 +3,7 @@ import torch.backends.cudnn as cudnn
|
||||
from torch.backends.cudnn import check_error
|
||||
import ctypes
|
||||
|
||||
|
||||
def get_cudnn_mode(mode):
|
||||
if mode == 'RNN_RELU':
|
||||
return cudnn.CUDNN_RNN_RELU
|
||||
@ -17,9 +18,10 @@ def get_cudnn_mode(mode):
|
||||
|
||||
|
||||
class Unserializable(object):
|
||||
|
||||
def __init__(self, inner):
|
||||
self.inner = inner
|
||||
|
||||
|
||||
def get(self):
|
||||
return self.inner
|
||||
|
||||
@ -39,6 +41,7 @@ def init_dropout_descriptor(fn, handle):
|
||||
fn.dropout_seed
|
||||
)
|
||||
|
||||
|
||||
def init_rnn_descriptor(fn):
|
||||
return cudnn.RNNDescriptor(
|
||||
fn.hidden_size,
|
||||
@ -161,7 +164,6 @@ def get_parameters(fn, handle, weight_buf):
|
||||
|
||||
cur_offset = offset + filter_dim_a[0]
|
||||
|
||||
|
||||
params.append(layer_params)
|
||||
|
||||
return params
|
||||
@ -237,7 +239,7 @@ def forward(fn, input, hx, weight, output, hy):
|
||||
|
||||
if tuple(hx.size()) != hidden_size:
|
||||
raise RuntimeError('Expected hidden size {}, got {}'.format(
|
||||
hidden_size, tuple(hx.size())))
|
||||
hidden_size, tuple(hx.size())))
|
||||
if cx is not None and tuple(cx.size()) != hidden_size:
|
||||
raise RuntimeError('Expected cell size {}, got {}'.format(
|
||||
hidden_size, tuple(cx.size())))
|
||||
@ -295,7 +297,6 @@ def forward(fn, input, hx, weight, output, hy):
|
||||
output = output.transpose_(0, 1)
|
||||
|
||||
|
||||
|
||||
def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_input, grad_hx):
|
||||
with torch.cuda.device_of(input):
|
||||
handle = cudnn.get_handle()
|
||||
|
@ -51,9 +51,9 @@ def _load_cudart():
|
||||
except OSError:
|
||||
pass
|
||||
raise RuntimeError("couldn't find libcudart. Make sure CUDA libraries "
|
||||
"are installed in a default location, or that they're in " +
|
||||
("DYLD_LIBRARY_PATH" if system == 'Darwin' else "LD_LIBRARY_PATH") +
|
||||
".")
|
||||
"are installed in a default location, or that they're in " +
|
||||
("DYLD_LIBRARY_PATH" if system == 'Darwin' else "LD_LIBRARY_PATH") +
|
||||
".")
|
||||
|
||||
|
||||
def _check_driver():
|
||||
@ -259,67 +259,112 @@ class _CudaBase(object):
|
||||
|
||||
class DoubleStorage(_CudaBase, torch._C.CudaDoubleStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class FloatStorage(_CudaBase, torch._C.CudaFloatStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class LongStorage(_CudaBase, torch._C.CudaLongStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class IntStorage(_CudaBase, torch._C.CudaIntStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class ShortStorage(_CudaBase, torch._C.CudaShortStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class CharStorage(_CudaBase, torch._C.CudaCharStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class ByteStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class DoubleTensor(_CudaBase, torch._C.CudaDoubleTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return DoubleStorage
|
||||
|
||||
|
||||
class FloatTensor(_CudaBase, torch._C.CudaFloatTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return FloatStorage
|
||||
|
||||
|
||||
class LongTensor(_CudaBase, torch._C.CudaLongTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return LongStorage
|
||||
|
||||
|
||||
class IntTensor(_CudaBase, torch._C.CudaIntTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return IntStorage
|
||||
|
||||
|
||||
class ShortTensor(_CudaBase, torch._C.CudaShortTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return ShortStorage
|
||||
|
||||
|
||||
class CharTensor(_CudaBase, torch._C.CudaCharTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
# TODO
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return CharStorage
|
||||
|
||||
|
||||
class ByteTensor(_CudaBase, torch._C.CudaByteTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def storage_type(cls):
|
||||
return ByteStorage
|
||||
|
||||
|
||||
class HalfTensor(_CudaBase, torch._C.CudaHalfTensorBase, _TensorBase):
|
||||
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def storage_type():
|
||||
return HalfStorage
|
||||
|
@ -4,6 +4,7 @@ from torch._utils import _accumulate
|
||||
|
||||
# TODO: sync streams when implemented
|
||||
|
||||
|
||||
def broadcast(tensor, devices):
|
||||
"""Broadcasts a tensor to a number of GPUs.
|
||||
|
||||
|
@ -92,6 +92,7 @@ nccl_types = {
|
||||
|
||||
|
||||
class NcclError(RuntimeError):
|
||||
|
||||
def __init__(self, status):
|
||||
self.status = status
|
||||
msg = '{0} ({1})'.format(status_codes.get(status), status)
|
||||
@ -103,6 +104,7 @@ class NcclComm(ctypes.c_void_p):
|
||||
|
||||
|
||||
class NcclCommList(object):
|
||||
|
||||
def __init__(self, devices):
|
||||
self.devices = devices
|
||||
ptrs = (NcclComm * len(devices))()
|
||||
@ -141,7 +143,7 @@ def communicator(inputs, outputs=None):
|
||||
|
||||
def cudaStream():
|
||||
# TODO: return the current stream
|
||||
#ffi.C.THCState_getCurrentStream(cutorch.getState())
|
||||
# ffi.C.THCState_getCurrentStream(cutorch.getState())
|
||||
return None
|
||||
|
||||
|
||||
@ -202,7 +204,7 @@ def all_gather(inputs, outputs):
|
||||
|
||||
|
||||
def reduce_scatter(inputs, outputs, op=SUM):
|
||||
_check_inputs(inputs, outputs, 1.0/len(inputs))
|
||||
_check_inputs(inputs, outputs, 1.0 / len(inputs))
|
||||
comm = communicator(inputs, outputs)
|
||||
count = inputs[0].numel() // len(inputs)
|
||||
data_type = nccl_types[inputs[0].type()]
|
||||
|
@ -35,4 +35,3 @@ def seed_all():
|
||||
def initial_seed():
|
||||
_lazy_init()
|
||||
return _C._cuda_initialSeed()
|
||||
|
||||
|
@ -8,6 +8,7 @@ ERROR_NOT_READY = 34
|
||||
|
||||
|
||||
class CudaError(RuntimeError):
|
||||
|
||||
def __init__(self, code):
|
||||
msg = cudart().cudaGetErrorString(code).decode('utf-8')
|
||||
super(CudaError, self).__init__('{0} ({1})'.format(msg, code))
|
||||
|
@ -1,16 +1,18 @@
|
||||
import torch
|
||||
from ._utils import _range
|
||||
|
||||
|
||||
def split(tensor, split_size, dim=0):
|
||||
if dim < 0:
|
||||
dim += tensor.dim()
|
||||
dim_size = tensor.size(dim)
|
||||
num_splits = (dim_size + split_size - 1) // split_size
|
||||
last_split_size = split_size - (split_size * num_splits - dim_size)
|
||||
|
||||
def get_split_size(i):
|
||||
return split_size if i < num_splits-1 else last_split_size
|
||||
return tuple(tensor.narrow(int(dim), int(i*split_size), int(get_split_size(i))) for i
|
||||
in _range(0, num_splits))
|
||||
return split_size if i < num_splits - 1 else last_split_size
|
||||
return tuple(tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) for i
|
||||
in _range(0, num_splits))
|
||||
|
||||
|
||||
def chunk(tensor, n_chunks, dim=0):
|
||||
|
@ -1,24 +1,25 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class Abs(Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Abs, self).__init__()
|
||||
|
||||
def updateOutput(self, input):
|
||||
self._backend.Abs_updateOutput(
|
||||
self._backend.library_state,
|
||||
input,
|
||||
self.output
|
||||
self._backend.library_state,
|
||||
input,
|
||||
self.output
|
||||
)
|
||||
return self.output
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
self._backend.Abs_updateGradInput(
|
||||
self._backend.library_state,
|
||||
input,
|
||||
gradOutput,
|
||||
self.gradInput
|
||||
self._backend.library_state,
|
||||
input,
|
||||
gradOutput,
|
||||
self.gradInput
|
||||
)
|
||||
return self.gradInput
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Criterion import Criterion
|
||||
|
||||
|
||||
class AbsCriterion(Criterion):
|
||||
|
||||
def __init__(self, sizeAverage=True):
|
||||
@ -10,7 +11,7 @@ class AbsCriterion(Criterion):
|
||||
|
||||
def updateOutput(self, input, target):
|
||||
if self.output_tensor is None:
|
||||
self.output_tensor = input.new(1)
|
||||
self.output_tensor = input.new(1)
|
||||
self._backend.AbsCriterion_updateOutput(
|
||||
self._backend.library_state,
|
||||
input,
|
||||
@ -21,7 +22,6 @@ class AbsCriterion(Criterion):
|
||||
self.output = self.output_tensor[0]
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, target):
|
||||
self._backend.AbsCriterion_updateGradInput(
|
||||
self._backend.library_state,
|
||||
@ -31,4 +31,3 @@ class AbsCriterion(Criterion):
|
||||
self.sizeAverage
|
||||
)
|
||||
return self.gradInput
|
||||
|
||||
|
@ -2,6 +2,7 @@ import math
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class Add(Module):
|
||||
|
||||
def __init__(self, inputSize, scalar=False):
|
||||
@ -19,16 +20,16 @@ class Add(Module):
|
||||
|
||||
def reset(self, stdv=None):
|
||||
if stdv is not None:
|
||||
stdv = stdv * math.sqrt(3)
|
||||
stdv = stdv * math.sqrt(3)
|
||||
else:
|
||||
stdv = 1./math.sqrt(self.bias.size(0))
|
||||
stdv = 1. / math.sqrt(self.bias.size(0))
|
||||
|
||||
self.bias.uniform_(-stdv, stdv)
|
||||
|
||||
def updateOutput(self, input):
|
||||
self.output.resize_as_(input).copy_(input)
|
||||
if self.scalar:
|
||||
self.output.add_(self.bias[0]);
|
||||
self.output.add_(self.bias[0])
|
||||
else:
|
||||
batchSize = input.size(0)
|
||||
if self._ones.size(0) != batchSize:
|
||||
@ -42,16 +43,15 @@ class Add(Module):
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
if self.gradInput is not None:
|
||||
self.gradInput.resize_as_(gradOutput).copy_(gradOutput)
|
||||
return self.gradInput
|
||||
self.gradInput.resize_as_(gradOutput).copy_(gradOutput)
|
||||
return self.gradInput
|
||||
|
||||
def accGradParameters(self, input, gradOutput, scale=1):
|
||||
if self.gradBias.size(0) == 1:
|
||||
self.gradBias[0] = self.gradBias[0] + scale*gradOutput.sum();
|
||||
self.gradBias[0] = self.gradBias[0] + scale * gradOutput.sum()
|
||||
else:
|
||||
if input.is_same_size(self.bias):
|
||||
self.gradBias.add_(scale, gradOutput)
|
||||
else:
|
||||
gradOutput = gradOutput.view(input.size(0), -1)
|
||||
self.gradBias.view(-1).addmv_(scale, gradOutput.t(), self._ones)
|
||||
|
||||
if input.is_same_size(self.bias):
|
||||
self.gradBias.add_(scale, gradOutput)
|
||||
else:
|
||||
gradOutput = gradOutput.view(input.size(0), -1)
|
||||
self.gradBias.view(-1).addmv_(scale, gradOutput.t(), self._ones)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class AddConstant(Module):
|
||||
|
||||
def __init__(self, constant_scalar, inplace=False):
|
||||
@ -29,4 +30,3 @@ class AddConstant(Module):
|
||||
self.gradInput.copy_(gradOutput)
|
||||
|
||||
return self.gradInput
|
||||
|
||||
|
@ -2,6 +2,8 @@ import torch
|
||||
from .Criterion import Criterion
|
||||
|
||||
# TODO: use THNN
|
||||
|
||||
|
||||
class BCECriterion(Criterion):
|
||||
eps = 1e-12
|
||||
|
||||
@ -20,7 +22,7 @@ class BCECriterion(Criterion):
|
||||
raise RuntimeError("input and target size mismatch")
|
||||
|
||||
if self.buffer is None:
|
||||
self.buffer = input.new()
|
||||
self.buffer = input.new()
|
||||
|
||||
buffer = self.buffer
|
||||
weights = self.weights
|
||||
@ -38,7 +40,7 @@ class BCECriterion(Criterion):
|
||||
output = torch.dot(target, buffer)
|
||||
|
||||
# log(1 - input) * (1 - target)
|
||||
torch.mul(input, -1, out=buffer).add_(1+self.eps).log_()
|
||||
torch.mul(input, -1, out=buffer).add_(1 + self.eps).log_()
|
||||
if weights is not None:
|
||||
buffer.mul_(weights)
|
||||
|
||||
@ -52,42 +54,39 @@ class BCECriterion(Criterion):
|
||||
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, target):
|
||||
# - (target - input) / ( input (1 - input) )
|
||||
# The gradient is slightly incorrect:
|
||||
# It should have be divided by (input + self.eps) (1 - input + self.eps)
|
||||
# but it is divided by input (1 - input + self.eps) + self.eps
|
||||
# This modification requires less memory to be computed.
|
||||
if input.nelement() != target.nelement():
|
||||
if input.nelement() != target.nelement():
|
||||
raise RuntimeError("input and target size mismatch")
|
||||
|
||||
if self.buffer is None:
|
||||
self.buffer = input.new()
|
||||
if self.buffer is None:
|
||||
self.buffer = input.new()
|
||||
|
||||
buffer = self.buffer
|
||||
weights = self.weights
|
||||
gradInput = self.gradInput
|
||||
buffer = self.buffer
|
||||
weights = self.weights
|
||||
gradInput = self.gradInput
|
||||
|
||||
if weights is not None and target.dim() != 1:
|
||||
weights = self.weights.view(1, target.size(1)).expand_as(target)
|
||||
if weights is not None and target.dim() != 1:
|
||||
weights = self.weights.view(1, target.size(1)).expand_as(target)
|
||||
|
||||
buffer.resize_as_(input)
|
||||
# - x ( 1 + self.eps -x ) + self.eps
|
||||
torch.add(input, -1, out=buffer).add_(-self.eps).mul_(input).add_(-self.eps)
|
||||
|
||||
buffer.resize_as_(input)
|
||||
# - x ( 1 + self.eps -x ) + self.eps
|
||||
torch.add(input, -1, out=buffer).add_(-self.eps).mul_(input).add_(-self.eps)
|
||||
gradInput.resize_as_(input)
|
||||
# y - x
|
||||
torch.add(target, -1, input, out=gradInput)
|
||||
# - (y - x) / ( x ( 1 + self.eps -x ) + self.eps )
|
||||
gradInput.div_(buffer)
|
||||
|
||||
gradInput.resize_as_(input)
|
||||
# y - x
|
||||
torch.add(target, -1, input, out=gradInput)
|
||||
# - (y - x) / ( x ( 1 + self.eps -x ) + self.eps )
|
||||
gradInput.div_(buffer)
|
||||
if weights is not None:
|
||||
gradInput.mul_(weights)
|
||||
|
||||
if weights is not None:
|
||||
gradInput.mul_(weights)
|
||||
|
||||
if self.sizeAverage:
|
||||
gradInput.div_(target.nelement())
|
||||
|
||||
return gradInput
|
||||
if self.sizeAverage:
|
||||
gradInput.div_(target.nelement())
|
||||
|
||||
return gradInput
|
||||
|
@ -32,6 +32,7 @@ import torch
|
||||
from .Module import Module
|
||||
from .utils import clear
|
||||
|
||||
|
||||
class BatchNormalization(Module):
|
||||
# expected dimension of input
|
||||
nDim = 2
|
||||
@ -51,44 +52,45 @@ class BatchNormalization(Module):
|
||||
self.save_std = None
|
||||
|
||||
if self.affine:
|
||||
self.weight = torch.Tensor(nOutput)
|
||||
self.bias = torch.Tensor(nOutput)
|
||||
self.gradWeight = torch.Tensor(nOutput)
|
||||
self.gradBias = torch.Tensor(nOutput)
|
||||
self.reset()
|
||||
self.weight = torch.Tensor(nOutput)
|
||||
self.bias = torch.Tensor(nOutput)
|
||||
self.gradWeight = torch.Tensor(nOutput)
|
||||
self.gradBias = torch.Tensor(nOutput)
|
||||
self.reset()
|
||||
else:
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.gradWeight = None
|
||||
self.gradBias = None
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.gradWeight = None
|
||||
self.gradBias = None
|
||||
|
||||
def reset(self):
|
||||
if self.weight is not None:
|
||||
self.weight.uniform_()
|
||||
self.weight.uniform_()
|
||||
|
||||
if self.bias is not None:
|
||||
self.bias.zero_()
|
||||
self.bias.zero_()
|
||||
|
||||
self.running_mean.zero_()
|
||||
self.running_var.fill_(1)
|
||||
|
||||
def _checkInputDim(self, input):
|
||||
if input.dim() != self.nDim:
|
||||
raise RuntimeError('only mini-batch supported ({}D tensor), got {}D tensor instead'.format(self.nDim, input.dim()))
|
||||
raise RuntimeError(
|
||||
'only mini-batch supported ({}D tensor), got {}D tensor instead'.format(self.nDim, input.dim()))
|
||||
if input.size(1) != self.running_mean.nelement():
|
||||
raise RuntimeError('got {}-feature tensor, expected {}'.format(input.size(1), self.running_mean.nelement()))
|
||||
|
||||
def _makeContiguous(self, input, gradOutput=None):
|
||||
if not input.is_contiguous():
|
||||
if self._input is None:
|
||||
self._input = input.new()
|
||||
self._input = input.new()
|
||||
self._input.resize_as_(input).copy_(input)
|
||||
input = self._input
|
||||
|
||||
if gradOutput is not None:
|
||||
if not gradOutput.is_contiguous():
|
||||
if self._gradOutput is None:
|
||||
self._gradOutput = gradOutput.new()
|
||||
self._gradOutput = gradOutput.new()
|
||||
self._gradOutput.resize_as_(gradOutput).copy_(gradOutput)
|
||||
gradOutput = self._gradOutput
|
||||
|
||||
@ -101,10 +103,10 @@ class BatchNormalization(Module):
|
||||
|
||||
self.output.resize_as_(input)
|
||||
if self.save_mean is None:
|
||||
self.save_mean = input.new()
|
||||
self.save_mean = input.new()
|
||||
self.save_mean.resize_as_(self.running_mean)
|
||||
if self.save_std is None:
|
||||
self.save_std = input.new()
|
||||
self.save_std = input.new()
|
||||
self.save_std.resize_as_(self.running_var)
|
||||
|
||||
self._backend.BatchNormalization_updateOutput(
|
||||
@ -124,7 +126,6 @@ class BatchNormalization(Module):
|
||||
|
||||
return self.output
|
||||
|
||||
|
||||
def _backward(self, input, gradOutput, scale, gradInput=None, gradWeight=None, gradBias=None):
|
||||
self._checkInputDim(input)
|
||||
self._checkInputDim(gradOutput)
|
||||
@ -135,8 +136,7 @@ class BatchNormalization(Module):
|
||||
|
||||
scale = scale or 1.
|
||||
if gradInput is not None:
|
||||
gradInput.resize_as_(gradOutput)
|
||||
|
||||
gradInput.resize_as_(gradOutput)
|
||||
|
||||
self._backend.BatchNormalization_backward(
|
||||
self._backend.library_state,
|
||||
@ -177,15 +177,14 @@ class BatchNormalization(Module):
|
||||
# first 5 buffers are not present in the current implementation,
|
||||
# but we keep them for cleaning old saved models
|
||||
clear(self, [
|
||||
'buffer',
|
||||
'buffer2',
|
||||
'centered',
|
||||
'std',
|
||||
'normalized',
|
||||
'_input',
|
||||
'_gradOutput',
|
||||
'save_mean',
|
||||
'save_std',
|
||||
'buffer',
|
||||
'buffer2',
|
||||
'centered',
|
||||
'std',
|
||||
'normalized',
|
||||
'_input',
|
||||
'_gradOutput',
|
||||
'save_mean',
|
||||
'save_std',
|
||||
])
|
||||
return super(BatchNormalization, self).clearState()
|
||||
|
||||
|
@ -3,6 +3,7 @@ import torch
|
||||
from .Module import Module
|
||||
from .utils import clear
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
|
||||
def _assertInput(self, input):
|
||||
@ -23,14 +24,13 @@ class Bilinear(Module):
|
||||
if gradOutput.size(1) != self.weight.size(0):
|
||||
raise RuntimeError('number of columns in gradOutput does not match layer\'s output size')
|
||||
|
||||
|
||||
def __init__(self, inputSize1, inputSize2, outputSize, bias=True):
|
||||
# set up model:
|
||||
super(Bilinear, self).__init__()
|
||||
self.weight = torch.Tensor(outputSize, inputSize1, inputSize2)
|
||||
self.weight = torch.Tensor(outputSize, inputSize1, inputSize2)
|
||||
self.gradWeight = torch.Tensor(outputSize, inputSize1, inputSize2)
|
||||
if bias:
|
||||
self.bias = torch.Tensor(outputSize)
|
||||
self.bias = torch.Tensor(outputSize)
|
||||
self.gradBias = torch.Tensor(outputSize)
|
||||
else:
|
||||
self.bias = None
|
||||
@ -53,13 +53,12 @@ class Bilinear(Module):
|
||||
self.bias.uniform_(-stdv, stdv)
|
||||
return self
|
||||
|
||||
|
||||
def updateOutput(self, input):
|
||||
self._assertInput(input)
|
||||
|
||||
# set up buffer:
|
||||
if self.buff2 is None:
|
||||
self.buff2 = input[0].new()
|
||||
self.buff2 = input[0].new()
|
||||
self.buff2.resize_as_(input[1])
|
||||
|
||||
# compute output scores:
|
||||
@ -74,7 +73,6 @@ class Bilinear(Module):
|
||||
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
if self.gradInput is None:
|
||||
return
|
||||
@ -87,38 +85,36 @@ class Bilinear(Module):
|
||||
#: first slice of weight tensor (k = 1)
|
||||
self.gradInput[0].addmm_(input[1], self.weight[0].t())
|
||||
self.gradInput[0].mul_(gradOutput.narrow(1, 0, 1).expand(self.gradInput[0].size(0),
|
||||
self.gradInput[0].size(1)))
|
||||
self.gradInput[0].size(1)))
|
||||
self.gradInput[1].addmm_(input[0], self.weight[0])
|
||||
self.gradInput[1].mul_(gradOutput.narrow(1, 0, 1).expand(self.gradInput[1].size(0),
|
||||
self.gradInput[1].size(1)))
|
||||
self.gradInput[1].size(1)))
|
||||
|
||||
#: remaining slices of weight tensor
|
||||
if self.weight.size(0) > 1:
|
||||
if self.buff1 is None:
|
||||
self.buff1 = input[0].new()
|
||||
self.buff1 = input[0].new()
|
||||
self.buff1.resize_as_(input[0])
|
||||
|
||||
for k in range(1, self.weight.size(0)):
|
||||
torch.mm(input[1], self.weight[k].t(), out=self.buff1)
|
||||
self.buff1.mul_(gradOutput.narrow(1, k, 1).expand(self.gradInput[0].size(0),
|
||||
self.gradInput[0].size(1)))
|
||||
self.gradInput[0].size(1)))
|
||||
self.gradInput[0].add_(self.buff1)
|
||||
|
||||
torch.mm(input[0], self.weight[k], out=self.buff2)
|
||||
self.buff2.mul_(gradOutput.narrow(1, k, 1).expand(self.gradInput[1].size(0),
|
||||
self.gradInput[1].size(1)))
|
||||
self.gradInput[1].size(1)))
|
||||
self.gradInput[1].add_(self.buff2)
|
||||
|
||||
return self.gradInput
|
||||
|
||||
|
||||
|
||||
def accGradParameters(self, input, gradOutput, scale=1):
|
||||
self._assertInputGradOutput(input, gradOutput)
|
||||
|
||||
# make sure we have buffer:
|
||||
if self.buff1 is None:
|
||||
self.buff1 = input[0].new()
|
||||
self.buff1 = input[0].new()
|
||||
self.buff1.resize_as_(input[0])
|
||||
|
||||
# accumulate parameter gradients:
|
||||
@ -129,15 +125,13 @@ class Bilinear(Module):
|
||||
if self.bias is not None:
|
||||
self.gradBias.add_(scale, gradOutput.sum(0))
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return str(type(self)) + \
|
||||
'({}x{} -> {}) {}'.format(
|
||||
self.weight.size(1), self.weight.size(2), self.weight.size(0),
|
||||
(' without bias' if self.bias is None else '')
|
||||
)
|
||||
'({}x{} -> {}) {}'.format(
|
||||
self.weight.size(1), self.weight.size(2), self.weight.size(0),
|
||||
(' without bias' if self.bias is None else '')
|
||||
)
|
||||
|
||||
def clearState(self):
|
||||
clear(self, 'buff1', 'buff2')
|
||||
return super(Bilinear, self).clearState()
|
||||
|
||||
|
@ -1,25 +1,25 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class CAddTable(Module):
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super(CAddTable, self).__init__()
|
||||
self.inplace = inplace
|
||||
self.gradInput = []
|
||||
|
||||
|
||||
def updateOutput(self, input):
|
||||
if self.inplace:
|
||||
self.output.set_(input[0])
|
||||
self.output.set_(input[0])
|
||||
else:
|
||||
self.output.resize_as_(input[0]).copy_(input[0])
|
||||
self.output.resize_as_(input[0]).copy_(input[0])
|
||||
|
||||
for i in range(1, len(input)):
|
||||
self.output.add_(input[i])
|
||||
self.output.add_(input[i])
|
||||
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
for i in range(len(input)):
|
||||
if i >= len(self.gradInput):
|
||||
@ -34,4 +34,3 @@ class CAddTable(Module):
|
||||
del self.gradInput[len(input):]
|
||||
|
||||
return self.gradInput
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class CDivTable(Module):
|
||||
|
||||
def __init__(self, ):
|
||||
super(CDivTable, self).__init__()
|
||||
self.gradInput = []
|
||||
@ -20,4 +22,3 @@ class CDivTable(Module):
|
||||
del self.gradInput[len(input):]
|
||||
|
||||
return self.gradInput
|
||||
|
||||
|
@ -4,6 +4,7 @@ import torch
|
||||
from .Module import Module
|
||||
from .utils import clear, contiguousView
|
||||
|
||||
|
||||
class CMul(Module):
|
||||
|
||||
def __init__(self, *args):
|
||||
@ -33,11 +34,10 @@ class CMul(Module):
|
||||
if stdv is not None:
|
||||
stdv = stdv * math.sqrt(3)
|
||||
else:
|
||||
stdv = 1./math.sqrt(self.weight.nelement())
|
||||
stdv = 1. / math.sqrt(self.weight.nelement())
|
||||
|
||||
self.weight.uniform_(-stdv, stdv)
|
||||
|
||||
|
||||
def updateOutput(self, input):
|
||||
# lazy-initialize
|
||||
if self._output is None:
|
||||
@ -61,10 +61,9 @@ class CMul(Module):
|
||||
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
if self.gradInput is None:
|
||||
return
|
||||
return
|
||||
|
||||
if self._gradOutput is None:
|
||||
self._gradOutput = input.new()
|
||||
@ -85,7 +84,6 @@ class CMul(Module):
|
||||
|
||||
return self.gradInput
|
||||
|
||||
|
||||
def accGradParameters(self, input, gradOutput, scale=1):
|
||||
if self._input is None:
|
||||
self._input = input.new()
|
||||
@ -103,17 +101,17 @@ class CMul(Module):
|
||||
|
||||
def type(self, type=None, tensorCache=None):
|
||||
if type:
|
||||
self.clearState()
|
||||
self.clearState()
|
||||
return super(CMul, self).type(type, tensorCache)
|
||||
|
||||
def clearState(self):
|
||||
clear(self, [
|
||||
'_input',
|
||||
'_output',
|
||||
'_weight',
|
||||
'_gradWeight',
|
||||
'_expand',
|
||||
'_repeat',
|
||||
'_sum',
|
||||
'_input',
|
||||
'_output',
|
||||
'_weight',
|
||||
'_gradWeight',
|
||||
'_expand',
|
||||
'_repeat',
|
||||
'_sum',
|
||||
])
|
||||
return super(CMul, self).clearState()
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
from .Module import Module
|
||||
from .utils import clear
|
||||
|
||||
|
||||
class CMulTable(Module):
|
||||
|
||||
def __init__(self, ):
|
||||
@ -17,7 +18,7 @@ class CMulTable(Module):
|
||||
|
||||
def updateGradInput_efficient(self, input, gradOutput):
|
||||
if self.tout is None:
|
||||
self.tout = input[0].new()
|
||||
self.tout = input[0].new()
|
||||
self.tout.resize_as_(self.output)
|
||||
for i in range(len(input)):
|
||||
if len(self.gradInput) <= i:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class CSubTable(Module):
|
||||
|
||||
def __init__(self, ):
|
||||
@ -14,12 +15,11 @@ class CSubTable(Module):
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
if self.gradInput[0] is None:
|
||||
self.gradInput[0] = input[0].new()
|
||||
self.gradInput[0] = input[0].new()
|
||||
if self.gradInput[1] is None:
|
||||
self.gradInput[1] = input[1].new()
|
||||
self.gradInput[1] = input[1].new()
|
||||
self.gradInput[0].resize_as_(input[0]).copy_(gradOutput)
|
||||
self.gradInput[1].resize_as_(input[1]).copy_(gradOutput).mul_(-1)
|
||||
|
||||
self.gradInput = self.gradInput[:2]
|
||||
return self.gradInput
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
from .HardTanh import HardTanh
|
||||
|
||||
|
||||
class Clamp(HardTanh):
|
||||
|
||||
def __init__(self, min_value, max_value):
|
||||
super(Clamp, self,).__init__(min_value, max_value)
|
||||
|
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from .Criterion import Criterion
|
||||
|
||||
|
||||
class ClassNLLCriterion(Criterion):
|
||||
|
||||
def __init__(self, weights=None, sizeAverage=True):
|
||||
super(ClassNLLCriterion, self).__init__()
|
||||
self.sizeAverage = sizeAverage
|
||||
@ -27,7 +29,6 @@ class ClassNLLCriterion(Criterion):
|
||||
self.output = self.output_tensor[0]
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, target):
|
||||
self.gradInput.resize_as_(input).zero_()
|
||||
target = target.long()
|
||||
|
@ -12,19 +12,20 @@ from .MSECriterion import MSECriterion
|
||||
Reference: http.//arxiv.org/abs/1506.08230
|
||||
"""
|
||||
|
||||
|
||||
class ClassSimplexCriterion(MSECriterion):
|
||||
|
||||
def __init__(self, nClasses):
|
||||
super(ClassSimplexCriterion, self).__init__()
|
||||
self.nClasses = nClasses
|
||||
super(ClassSimplexCriterion, self).__init__()
|
||||
self.nClasses = nClasses
|
||||
|
||||
# embedding the simplex in a space of dimension strictly greater than
|
||||
# the minimum possible (nClasses-1) is critical for effective training.
|
||||
simp = self._regsplex(nClasses - 1)
|
||||
self.simplex = torch.cat((simp, torch.zeros(simp.size(0), nClasses - simp.size(1))), 1)
|
||||
self._target = torch.Tensor(nClasses)
|
||||
# embedding the simplex in a space of dimension strictly greater than
|
||||
# the minimum possible (nClasses-1) is critical for effective training.
|
||||
simp = self._regsplex(nClasses - 1)
|
||||
self.simplex = torch.cat((simp, torch.zeros(simp.size(0), nClasses - simp.size(1))), 1)
|
||||
self._target = torch.Tensor(nClasses)
|
||||
|
||||
self.output_tensor = None
|
||||
self.output_tensor = None
|
||||
|
||||
def _regsplex(self, n):
|
||||
"""
|
||||
@ -51,11 +52,11 @@ class ClassSimplexCriterion(MSECriterion):
|
||||
if k == 0:
|
||||
a[k][k] = 1
|
||||
else:
|
||||
a[k][k] = math.sqrt(1 - a[k:k+1, 0:k+1].norm()**2)
|
||||
a[k][k] = math.sqrt(1 - a[k:k + 1, 0:k + 1].norm()**2)
|
||||
|
||||
# fill_ the k-th coordinates for the vectors of the remaining vertices
|
||||
c = (a[k][k]**2 - 1 - 1/n) / a[k][k]
|
||||
a[k+1:n+2, k:k+1].fill_(c)
|
||||
c = (a[k][k]**2 - 1 - 1 / n) / a[k][k]
|
||||
a[k + 1:n + 2, k:k + 1].fill_(c)
|
||||
|
||||
return a
|
||||
|
||||
@ -69,20 +70,20 @@ class ClassSimplexCriterion(MSECriterion):
|
||||
self._target[i].copy_(self.simplex[int(target[i])])
|
||||
|
||||
def updateOutput(self, input, target):
|
||||
self._transformTarget(target)
|
||||
self._transformTarget(target)
|
||||
|
||||
assert input.nelement() == self._target.nelement()
|
||||
if self.output_tensor is None:
|
||||
self.output_tensor = input.new(1)
|
||||
self._backend.MSECriterion_updateOutput(
|
||||
assert input.nelement() == self._target.nelement()
|
||||
if self.output_tensor is None:
|
||||
self.output_tensor = input.new(1)
|
||||
self._backend.MSECriterion_updateOutput(
|
||||
self._backend.library_state,
|
||||
input,
|
||||
self._target,
|
||||
self.output_tensor,
|
||||
self.sizeAverage
|
||||
)
|
||||
self.output = self.output_tensor[0]
|
||||
return self.output
|
||||
)
|
||||
self.output = self.output_tensor[0]
|
||||
return self.output
|
||||
|
||||
def updateGradInput(self, input, target):
|
||||
assert input.nelement() == self._target.nelement()
|
||||
@ -100,6 +101,5 @@ class ClassSimplexCriterion(MSECriterion):
|
||||
|
||||
def getTopPrediction(self, input):
|
||||
prod = self.getPredictions(input)
|
||||
_, maxs = prod.max(prod.ndimension()-1)
|
||||
_, maxs = prod.max(prod.ndimension() - 1)
|
||||
return maxs.view(-1)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Container import Container
|
||||
|
||||
|
||||
class Concat(Container):
|
||||
|
||||
def __init__(self, dimension):
|
||||
@ -22,9 +23,9 @@ class Concat(Container):
|
||||
|
||||
offset = 0
|
||||
for i, module in enumerate(self.modules):
|
||||
currentOutput = outs[i]
|
||||
self.output.narrow(self.dimension, offset, currentOutput.size(self.dimension)).copy_(currentOutput)
|
||||
offset = offset + currentOutput.size(self.dimension)
|
||||
currentOutput = outs[i]
|
||||
self.output.narrow(self.dimension, offset, currentOutput.size(self.dimension)).copy_(currentOutput)
|
||||
offset = offset + currentOutput.size(self.dimension)
|
||||
|
||||
return self.output
|
||||
|
||||
@ -34,9 +35,11 @@ class Concat(Container):
|
||||
offset = 0
|
||||
for i, module in enumerate(self.modules):
|
||||
currentOutput = module.output
|
||||
currentGradInput = module.updateGradInput(input, gradOutput.narrow(self.dimension, offset, currentOutput.size(self.dimension)))
|
||||
currentGradInput = module.updateGradInput(input, gradOutput.narrow(
|
||||
self.dimension, offset, currentOutput.size(self.dimension)))
|
||||
|
||||
if currentGradInput: # if the module does not produce a gradInput (for example first layer),: ignore it and move on.
|
||||
# if the module does not produce a gradInput (for example first layer),: ignore it and move on.
|
||||
if currentGradInput:
|
||||
if i == 0:
|
||||
self.gradInput.copy_(currentGradInput)
|
||||
else:
|
||||
@ -46,24 +49,25 @@ class Concat(Container):
|
||||
|
||||
return self.gradInput
|
||||
|
||||
|
||||
def accGradParameters(self, input, gradOutput, scale=1):
|
||||
offset = 0
|
||||
for i, module in enumerate(self.modules):
|
||||
currentOutput = module.output
|
||||
module.accGradParameters(
|
||||
input,
|
||||
gradOutput.narrow(self.dimension, offset, currentOutput.size(self.dimension)),
|
||||
scale)
|
||||
offset = offset + currentOutput.size(self.dimension)
|
||||
currentOutput = module.output
|
||||
module.accGradParameters(
|
||||
input,
|
||||
gradOutput.narrow(self.dimension, offset, currentOutput.size(self.dimension)),
|
||||
scale)
|
||||
offset = offset + currentOutput.size(self.dimension)
|
||||
|
||||
def backward(self, input, gradOutput, scale=1):
|
||||
self.gradInput.resize_as_(input)
|
||||
offset = 0
|
||||
for i, module in enumerate(self.modules):
|
||||
currentOutput = module.output
|
||||
currentGradInput = module.backward(input, gradOutput.narrow(self.dimension, offset, currentOutput.size(self.dimension)), scale)
|
||||
if currentGradInput is not None: # if the module.es not produce a gradInput (for example first layer),: ignore it and move on.
|
||||
currentGradInput = module.backward(input, gradOutput.narrow(
|
||||
self.dimension, offset, currentOutput.size(self.dimension)), scale)
|
||||
# if the module.es not produce a gradInput (for example first layer),: ignore it and move on.
|
||||
if currentGradInput is not None:
|
||||
if i == 0:
|
||||
self.gradInput.copy_(currentGradInput)
|
||||
else:
|
||||
@ -75,12 +79,12 @@ class Concat(Container):
|
||||
def accUpdateGradParameters(self, input, gradOutput, lr):
|
||||
offset = 0
|
||||
for i, module in enumerate(self.modules):
|
||||
currentOutput = module.output
|
||||
module.accUpdateGradParameters(
|
||||
input,
|
||||
gradOutput.narrow(self.dimension, offset, currentOutput.size(self.dimension)),
|
||||
lr)
|
||||
offset = offset + currentOutput.size(self.dimension)
|
||||
currentOutput = module.output
|
||||
module.accUpdateGradParameters(
|
||||
input,
|
||||
gradOutput.narrow(self.dimension, offset, currentOutput.size(self.dimension)),
|
||||
lr)
|
||||
offset = offset + currentOutput.size(self.dimension)
|
||||
|
||||
def __tostring__(self):
|
||||
tab = ' '
|
||||
@ -92,7 +96,7 @@ class Concat(Container):
|
||||
res = torch.type(self)
|
||||
res += ' {' + line + tab + 'input'
|
||||
for i in range(len(self.modules)):
|
||||
if i == len(self.modules)-1:
|
||||
if i == len(self.modules) - 1:
|
||||
res += line + tab + next + '(' + i + '): ' + str(self.modules[i]).replace(line, line + tab + extlast)
|
||||
else:
|
||||
res += line + tab + next + '(' + i + '): ' + str(self.modules[i]).replace(line, line + tab + ext)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Container import Container
|
||||
|
||||
|
||||
class ConcatTable(Container):
|
||||
|
||||
def __init__(self, ):
|
||||
@ -23,7 +24,7 @@ class ConcatTable(Container):
|
||||
l1[i] = res
|
||||
else:
|
||||
f(l1, i, v)
|
||||
for i in range(len(l1)-1, len(l2)-1, -1):
|
||||
for i in range(len(l1) - 1, len(l2) - 1, -1):
|
||||
del l1[i]
|
||||
return l1
|
||||
|
||||
@ -44,6 +45,7 @@ class ConcatTable(Container):
|
||||
|
||||
if i == 0:
|
||||
self.gradInput = self.gradInput if wasTable else []
|
||||
|
||||
def fn(l, i, v):
|
||||
if i >= len(l):
|
||||
assert len(l) == i
|
||||
@ -82,11 +84,11 @@ class ConcatTable(Container):
|
||||
|
||||
def accGradParameters(self, input, gradOutput, scale=1):
|
||||
for i, module in ipairs(self.modules):
|
||||
self.rethrowErrors(module, i, 'accGradParameters', input, gradOutput[i], scale)
|
||||
self.rethrowErrors(module, i, 'accGradParameters', input, gradOutput[i], scale)
|
||||
|
||||
def accUpdateGradParameters(self, input, gradOutput, lr):
|
||||
for i, module in ipairs(self.modules):
|
||||
self.rethrowErrors(module, i, 'accUpdateGradParameters', input, gradOutput[i], lr)
|
||||
self.rethrowErrors(module, i, 'accUpdateGradParameters', input, gradOutput[i], lr)
|
||||
|
||||
def __repr__(self):
|
||||
tab = ' '
|
||||
@ -98,14 +100,13 @@ class ConcatTable(Container):
|
||||
res = torch.typename(self)
|
||||
res = res + ' {' + line + tab + 'input'
|
||||
for i in range(len(self.modules)):
|
||||
if i == len(self.modules)-1:
|
||||
res = res + line + tab + next + '(' + str(i) + '): ' + str(self.modules[i]).replace(line, line + tab + extlast)
|
||||
else:
|
||||
res = res + line + tab + next + '(' + str(i) + '): ' + str(self.modules[i]).replace(line, line + tab + ext)
|
||||
|
||||
if i == len(self.modules) - 1:
|
||||
res = res + line + tab + next + '(' + str(i) + '): ' + \
|
||||
str(self.modules[i]).replace(line, line + tab + extlast)
|
||||
else:
|
||||
res = res + line + tab + next + '(' + str(i) + '): ' + \
|
||||
str(self.modules[i]).replace(line, line + tab + ext)
|
||||
|
||||
res = res + line + tab + last + 'output'
|
||||
res = res + line + '}'
|
||||
return res
|
||||
|
||||
|
||||
|
@ -4,11 +4,12 @@ from .utils import clear
|
||||
from functools import wraps
|
||||
import sys
|
||||
|
||||
|
||||
class Container(Module):
|
||||
|
||||
def __init__(self, *args):
|
||||
super(Container, self).__init__(*args)
|
||||
self.modules = []
|
||||
super(Container, self).__init__(*args)
|
||||
self.modules = []
|
||||
|
||||
def add(self, module):
|
||||
self.modules.append(module)
|
||||
@ -18,11 +19,11 @@ class Container(Module):
|
||||
return self.modules[index]
|
||||
|
||||
def size(self):
|
||||
return len(self.modules)
|
||||
return len(self.modules)
|
||||
|
||||
def applyToModules(self, func):
|
||||
for module in self.modules:
|
||||
func(module)
|
||||
for module in self.modules:
|
||||
func(module)
|
||||
|
||||
def zeroGradParameters(self):
|
||||
self.applyToModules(lambda m: m.zeroGradParameters())
|
||||
@ -46,16 +47,16 @@ class Container(Module):
|
||||
self.applyToModules(lambda m: m.reset(stdv))
|
||||
|
||||
def parameters(self):
|
||||
w = []
|
||||
gw = []
|
||||
for module in self.modules:
|
||||
mparam = module.parameters()
|
||||
if mparam is not None:
|
||||
w.extend(mparam[0])
|
||||
gw.extend(mparam[1])
|
||||
if not w:
|
||||
return
|
||||
return w, gw
|
||||
w = []
|
||||
gw = []
|
||||
for module in self.modules:
|
||||
mparam = module.parameters()
|
||||
if mparam is not None:
|
||||
w.extend(mparam[0])
|
||||
gw.extend(mparam[1])
|
||||
if not w:
|
||||
return
|
||||
return w, gw
|
||||
|
||||
def clearState(self):
|
||||
clear('output')
|
||||
@ -63,4 +64,3 @@ class Container(Module):
|
||||
for module in self.modules:
|
||||
module.clearState()
|
||||
return self
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class Contiguous(Module):
|
||||
|
||||
def updateOutput(self, input):
|
||||
@ -11,7 +12,6 @@ class Contiguous(Module):
|
||||
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
if not gradOutput.is_contiguous():
|
||||
self.gradInput.resize_as_(gradOutput).copy_(gradOutput)
|
||||
@ -19,4 +19,3 @@ class Contiguous(Module):
|
||||
self.gradInput.set_(gradOutput)
|
||||
|
||||
return self.gradInput
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class Copy(Module):
|
||||
|
||||
def __init__(self, intype, outtype, dontCast=False):
|
||||
@ -13,15 +14,12 @@ class Copy(Module):
|
||||
self.output.resize_(input.size()).copy_(input)
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
self.gradInput.resize_(gradOutput.size()).copy_(gradOutput)
|
||||
return self.gradInput
|
||||
|
||||
|
||||
def type(self, type=None, tensorCache=None):
|
||||
if type and self.dontCast:
|
||||
return self
|
||||
return self
|
||||
|
||||
return super(Copy, self).type(self, type, tensorCache)
|
||||
|
||||
|
@ -3,6 +3,7 @@ import torch
|
||||
from .Module import Module
|
||||
from .utils import clear
|
||||
|
||||
|
||||
class Cosine(Module):
|
||||
|
||||
def __init__(self, inputSize, outputSize):
|
||||
@ -22,7 +23,7 @@ class Cosine(Module):
|
||||
if stdv is not None:
|
||||
stdv = stdv * math.sqrt(3)
|
||||
else:
|
||||
stdv = 1./math.sqrt(self.weight.size(0))
|
||||
stdv = 1. / math.sqrt(self.weight.size(0))
|
||||
self.weight.uniform_(-stdv, stdv)
|
||||
|
||||
def updateOutput(self, input):
|
||||
@ -32,9 +33,9 @@ class Cosine(Module):
|
||||
outputSize = self.weight.size(0)
|
||||
|
||||
if self._weightNorm is None:
|
||||
self._weightNorm = self.weight.new()
|
||||
self._weightNorm = self.weight.new()
|
||||
if self._inputNorm is None:
|
||||
self._inputNorm = self.weight.new()
|
||||
self._inputNorm = self.weight.new()
|
||||
|
||||
# y_j = (w_j * x) / ( || w_j || * || x || )
|
||||
|
||||
@ -53,12 +54,11 @@ class Cosine(Module):
|
||||
self.output.div_(self._inputNorm.expand_as(self.output))
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
assert input.dim() == 2
|
||||
|
||||
if self.gradInput is None:
|
||||
return
|
||||
return
|
||||
|
||||
inputSize = self.weight.size(1)
|
||||
outputSize = self.weight.size(0)
|
||||
@ -72,15 +72,15 @@ class Cosine(Module):
|
||||
nelement = self.gradInput.nelement()
|
||||
self.gradInput.resize_as_(input)
|
||||
if self.gradInput.nelement() != nelement:
|
||||
self.gradInput.zero_()
|
||||
self.gradInput.zero_()
|
||||
|
||||
inputNorm = self._inputNorm.expand_as(input)
|
||||
weightNorm = self._weightNorm.view(1, outputSize).expand_as(gradOutput)
|
||||
|
||||
if self._gradOutput is None:
|
||||
self._gradOutput = gradOutput.new()
|
||||
self._gradOutput = gradOutput.new()
|
||||
if self._sum is None:
|
||||
self._sum = input.new()
|
||||
self._sum = input.new()
|
||||
|
||||
self.gradInput.copy_(input).div_(inputNorm)
|
||||
self._gradOutput.resize_as_(gradOutput).copy_(gradOutput)
|
||||
@ -107,13 +107,13 @@ class Cosine(Module):
|
||||
"""
|
||||
|
||||
if self._weight is None:
|
||||
self._weight = self.weight.new()
|
||||
self._weight = self.weight.new()
|
||||
if self._sum is None:
|
||||
self._sum = input.new()
|
||||
self._sum = input.new()
|
||||
|
||||
self._weight.resize_as_(self.weight).copy_(self.weight)
|
||||
if self._gradOutput is None:
|
||||
self._gradOutput = gradOutput.new()
|
||||
self._gradOutput = gradOutput.new()
|
||||
self._gradOutput.resize_as_(gradOutput).copy_(gradOutput)
|
||||
self._gradOutput.mul_(self.output)
|
||||
torch.sum(self._gradOutput, 0, out=self._sum)
|
||||
@ -131,25 +131,23 @@ class Cosine(Module):
|
||||
|
||||
def type(self, type=None, tensorCache=None):
|
||||
if type is not None:
|
||||
# prevent premature memory allocations
|
||||
self._input = None
|
||||
self._weight = None
|
||||
self._inputNorm = None
|
||||
self._weightNorm = None
|
||||
self._gradOutput = None
|
||||
self._sum = None
|
||||
# prevent premature memory allocations
|
||||
self._input = None
|
||||
self._weight = None
|
||||
self._inputNorm = None
|
||||
self._weightNorm = None
|
||||
self._gradOutput = None
|
||||
self._sum = None
|
||||
|
||||
return super(Cosine, self).type(type, tensorCache)
|
||||
|
||||
|
||||
def clearState(self):
|
||||
clear(self, [
|
||||
'_input',
|
||||
'_weight',
|
||||
'_gradOutput',
|
||||
'_sum',
|
||||
'_inputNorm',
|
||||
'_weightNorm',
|
||||
'_input',
|
||||
'_weight',
|
||||
'_gradOutput',
|
||||
'_sum',
|
||||
'_inputNorm',
|
||||
'_weightNorm',
|
||||
])
|
||||
return super(Cosine, self).clearState()
|
||||
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
from .Module import Module
|
||||
from .utils import clear
|
||||
|
||||
|
||||
class CosineDistance(Module):
|
||||
|
||||
def __init__(self, ):
|
||||
@ -11,39 +12,38 @@ class CosineDistance(Module):
|
||||
self._input1 = None
|
||||
self._input2 = None
|
||||
self.buffer = None
|
||||
self.w1 = None
|
||||
self.w1 = None
|
||||
self.w22 = None
|
||||
self.w = None
|
||||
self.w = None
|
||||
self.w32 = None
|
||||
self.ones = None
|
||||
|
||||
def _makeContiguous(self, input1, input2):
|
||||
if not input1.is_contiguous():
|
||||
if self._input1 is None:
|
||||
self._input1 = input1.new()
|
||||
self._input1.resize_as_(input1).copy_(input1)
|
||||
input1 = self._input1
|
||||
if self._input1 is None:
|
||||
self._input1 = input1.new()
|
||||
self._input1.resize_as_(input1).copy_(input1)
|
||||
input1 = self._input1
|
||||
|
||||
if not input2.is_contiguous():
|
||||
if self._input2 is None:
|
||||
self._input2 = input2.new()
|
||||
self._input2.resize_as_(input2).copy_(input2)
|
||||
input2 = self._input2
|
||||
if self._input2 is None:
|
||||
self._input2 = input2.new()
|
||||
self._input2.resize_as_(input2).copy_(input2)
|
||||
input2 = self._input2
|
||||
|
||||
return input1, input2
|
||||
|
||||
|
||||
def updateOutput(self, input):
|
||||
input1, input2 = input[0], input[1]
|
||||
input1, input2 = self._makeContiguous(input1, input2)
|
||||
|
||||
if self.buffer is None:
|
||||
self.buffer = input1.new()
|
||||
self.w1 = input1.new()
|
||||
self.w22 = input1.new()
|
||||
self.w = input1.new()
|
||||
self.w32 = input1.new()
|
||||
self.ones = input1.new()
|
||||
self.buffer = input1.new()
|
||||
self.w1 = input1.new()
|
||||
self.w22 = input1.new()
|
||||
self.w = input1.new()
|
||||
self.w32 = input1.new()
|
||||
self.ones = input1.new()
|
||||
|
||||
torch.mul(input1, input2, out=self.buffer)
|
||||
torch.sum(self.buffer, 1, out=self.w1)
|
||||
@ -65,18 +65,17 @@ class CosineDistance(Module):
|
||||
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
v1 = input[0]
|
||||
v2 = input[1]
|
||||
v1 = input[0]
|
||||
v2 = input[1]
|
||||
v1, v2 = self._makeContiguous(v1, v2)
|
||||
|
||||
if len(self.gradInput) != 2:
|
||||
if self.gradInput[0] is None:
|
||||
self.gradInput[0] = v1.new()
|
||||
if self.gradInput[1] is None:
|
||||
self.gradInput[1] = v1.new()
|
||||
self.gradInput = self.gradInput[:2]
|
||||
if self.gradInput[0] is None:
|
||||
self.gradInput[0] = v1.new()
|
||||
if self.gradInput[1] is None:
|
||||
self.gradInput[1] = v1.new()
|
||||
self.gradInput = self.gradInput[:2]
|
||||
|
||||
gw1 = self.gradInput[0]
|
||||
gw2 = self.gradInput[1]
|
||||
@ -97,15 +96,13 @@ class CosineDistance(Module):
|
||||
|
||||
return self.gradInput
|
||||
|
||||
|
||||
def clearState(self):
|
||||
clear(self, [
|
||||
'buffer',
|
||||
'w1',
|
||||
'w22',
|
||||
'w',
|
||||
'w32',
|
||||
'ones',
|
||||
'buffer',
|
||||
'w1',
|
||||
'w22',
|
||||
'w',
|
||||
'w32',
|
||||
'ones',
|
||||
])
|
||||
return super(CosineDistance, self).clearState()
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Criterion import Criterion
|
||||
|
||||
|
||||
class CosineEmbeddingCriterion(Criterion):
|
||||
|
||||
def __init__(self, margin=0, sizeAverage=True):
|
||||
@ -9,23 +10,22 @@ class CosineEmbeddingCriterion(Criterion):
|
||||
self.sizeAverage = sizeAverage
|
||||
self.gradInput = [torch.Tensor(), torch.Tensor()]
|
||||
self.buffer = None
|
||||
self.w1 = None
|
||||
self.w1 = None
|
||||
self.w22 = None
|
||||
self.w = None
|
||||
self.w = None
|
||||
self.w32 = None
|
||||
self._outputs = None
|
||||
self._idx = None
|
||||
|
||||
|
||||
def updateOutput(self, input, y):
|
||||
input1, input2 = input[0], input[1]
|
||||
|
||||
# keep backward compatibility
|
||||
if self.buffer is None:
|
||||
self.buffer = input1.new()
|
||||
self.w1 = input1.new()
|
||||
self.w1 = input1.new()
|
||||
self.w22 = input1.new()
|
||||
self.w = input1.new()
|
||||
self.w = input1.new()
|
||||
self.w32 = input1.new()
|
||||
self._outputs = input1.new()
|
||||
|
||||
@ -64,14 +64,13 @@ class CosineEmbeddingCriterion(Criterion):
|
||||
self.output = self._outputs.sum()
|
||||
|
||||
if self.sizeAverage:
|
||||
self.output = self.output / y.size(0)
|
||||
self.output = self.output / y.size(0)
|
||||
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, y):
|
||||
v1 = input[0]
|
||||
v2 = input[1]
|
||||
v1 = input[0]
|
||||
v2 = input[1]
|
||||
|
||||
gw1 = self.gradInput[0]
|
||||
gw2 = self.gradInput[1]
|
||||
@ -98,22 +97,21 @@ class CosineEmbeddingCriterion(Criterion):
|
||||
gw2[self._idx] = gw2[self._idx].mul_(-1)
|
||||
|
||||
if self.sizeAverage:
|
||||
gw1.div_(y.size(0))
|
||||
gw2.div_(y.size(0))
|
||||
gw1.div_(y.size(0))
|
||||
gw2.div_(y.size(0))
|
||||
|
||||
return self.gradInput
|
||||
|
||||
def type(self, type=None, tensorCache=None):
|
||||
if not type:
|
||||
return self._type
|
||||
return self._type
|
||||
|
||||
self._idx = None
|
||||
super(CosineEmbeddingCriterion, self).type(type, tensorCache)
|
||||
# comparison operators behave differently from cuda/c implementations
|
||||
if type == 'torch.cuda.FloatTensor':
|
||||
self._idx = torch.cuda.ByteTensor()
|
||||
self._idx = torch.cuda.ByteTensor()
|
||||
else:
|
||||
self._idx = torch.ByteTensor()
|
||||
self._idx = torch.ByteTensor()
|
||||
|
||||
return self
|
||||
|
||||
|
@ -3,6 +3,7 @@ from .Module import Module
|
||||
from .utils import recursiveType
|
||||
import torch._thnn
|
||||
|
||||
|
||||
class Criterion(object):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class CriterionTable(Module):
|
||||
|
||||
def __init__(self, criterion):
|
||||
@ -15,4 +16,3 @@ class CriterionTable(Module):
|
||||
def updateGradInput(self, input, grad_output):
|
||||
self.criterion.updateGradInput(*input)
|
||||
return self.gradInput
|
||||
|
||||
|
@ -3,6 +3,7 @@ from .Criterion import Criterion
|
||||
from .LogSoftMax import LogSoftMax
|
||||
from .ClassNLLCriterion import ClassNLLCriterion
|
||||
|
||||
|
||||
class CrossEntropyCriterion(Criterion):
|
||||
|
||||
def __init__(self, weights=None):
|
||||
@ -26,4 +27,3 @@ class CrossEntropyCriterion(Criterion):
|
||||
self.lsm.updateGradInput(input, self.nll.gradInput)
|
||||
self.gradInput = self.lsm.gradInput.view(size)
|
||||
return self.gradInput
|
||||
|
||||
|
@ -14,18 +14,19 @@ import math
|
||||
import torch
|
||||
from .Concat import Concat
|
||||
|
||||
|
||||
class DepthConcat(Concat):
|
||||
|
||||
def windowNarrow(self, output, currentOutput, offset):
|
||||
outputWindow = output.narrow(self.dimension, offset, currentOutput.size(self.dimension))
|
||||
for dim in range(len(self.outputSize)):
|
||||
currentSize = currentOutput.size(dim)
|
||||
if dim != self.dimension and self.outputSize[dim] != currentSize:
|
||||
# 5x5 vs 3x3 -> start = [(5-3)/2] + 1 = 2 (1 pad each side)
|
||||
# 9x9 vs 5x5 -> start = [(9-5)/2] + 1 = 3 (2 pad each side)
|
||||
# 9x9 vs 4x4 -> start = [(9-4)/2] + 1 = 3.5 (2 pad, 3 pad)
|
||||
start = int(math.floor(((self.outputSize[dim] - currentSize) / 2)))
|
||||
outputWindow = outputWindow.narrow(dim, start, currentSize)
|
||||
currentSize = currentOutput.size(dim)
|
||||
if dim != self.dimension and self.outputSize[dim] != currentSize:
|
||||
# 5x5 vs 3x3 -> start = [(5-3)/2] + 1 = 2 (1 pad each side)
|
||||
# 9x9 vs 5x5 -> start = [(9-5)/2] + 1 = 3 (2 pad each side)
|
||||
# 9x9 vs 4x4 -> start = [(9-4)/2] + 1 = 3.5 (2 pad, 3 pad)
|
||||
start = int(math.floor(((self.outputSize[dim] - currentSize) / 2)))
|
||||
outputWindow = outputWindow.narrow(dim, start, currentSize)
|
||||
return outputWindow
|
||||
|
||||
def updateOutput(self, input):
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Criterion import Criterion
|
||||
|
||||
|
||||
class DistKLDivCriterion(Criterion):
|
||||
|
||||
def __init__(self, sizeAverage=True):
|
||||
@ -11,7 +12,7 @@ class DistKLDivCriterion(Criterion):
|
||||
def updateOutput(self, input, target):
|
||||
assert input.is_same_size(target)
|
||||
if self.output_tensor is None:
|
||||
self.output_tensor = input.new(1)
|
||||
self.output_tensor = input.new(1)
|
||||
self._backend.DistKLDivCriterion_updateOutput(
|
||||
self._backend.library_state,
|
||||
input,
|
||||
@ -32,4 +33,3 @@ class DistKLDivCriterion(Criterion):
|
||||
self.sizeAverage
|
||||
)
|
||||
return self.gradInput
|
||||
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
from .Module import Module
|
||||
from .utils import clear
|
||||
|
||||
|
||||
class DotProduct(Module):
|
||||
|
||||
def __init__(self):
|
||||
@ -13,7 +14,7 @@ class DotProduct(Module):
|
||||
input1, input2 = input[0], input[1]
|
||||
|
||||
if self.buffer is None:
|
||||
self.buffer = input1.new()
|
||||
self.buffer = input1.new()
|
||||
|
||||
torch.mul(input1, input2, out=self.buffer)
|
||||
torch.sum(self.buffer, 1, out=self.output)
|
||||
@ -26,11 +27,11 @@ class DotProduct(Module):
|
||||
not_batch = False
|
||||
|
||||
if len(self.gradInput) != 2:
|
||||
if self.gradInput[0] is None:
|
||||
self.gradInput[0] = input[0].new()
|
||||
if self.gradInput[1] is None:
|
||||
self.gradInput[1] = input[1].new()
|
||||
self.gradInput = self.gradInput[:2]
|
||||
if self.gradInput[0] is None:
|
||||
self.gradInput[0] = input[0].new()
|
||||
if self.gradInput[1] is None:
|
||||
self.gradInput[1] = input[1].new()
|
||||
self.gradInput = self.gradInput[:2]
|
||||
|
||||
gw1 = self.gradInput[0]
|
||||
gw2 = self.gradInput[1]
|
||||
@ -46,4 +47,3 @@ class DotProduct(Module):
|
||||
def clearState(self):
|
||||
clear(self, 'buffer')
|
||||
return super(DotProduct, self).clearState()
|
||||
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
from .Module import Module
|
||||
from .utils import clear
|
||||
|
||||
|
||||
class Dropout(Module):
|
||||
|
||||
def __init__(self, p=0.5, inplace=False):
|
||||
@ -19,8 +20,8 @@ class Dropout(Module):
|
||||
|
||||
if self.p > 0 and self.train:
|
||||
self.noise.resize_as_(input)
|
||||
self.noise.bernoulli_(1-self.p)
|
||||
self.noise.div_(1-self.p)
|
||||
self.noise.bernoulli_(1 - self.p)
|
||||
self.noise.div_(1 - self.p)
|
||||
self.output.mul_(self.noise)
|
||||
|
||||
return self.output
|
||||
@ -32,7 +33,7 @@ class Dropout(Module):
|
||||
self.gradInput.resize_as_(gradOutput).copy_(gradOutput)
|
||||
|
||||
if self.p > 0 and self.train:
|
||||
self.gradInput.mul_(self.noise) # simply mask the gradients with the noise vector
|
||||
self.gradInput.mul_(self.noise) # simply mask the gradients with the noise vector
|
||||
|
||||
return self.gradInput
|
||||
|
||||
@ -45,4 +46,3 @@ class Dropout(Module):
|
||||
def clearState(self):
|
||||
clear(self, 'noise')
|
||||
return super(Dropout, self).clearState()
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class ELU(Module):
|
||||
"""
|
||||
Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter
|
||||
@ -39,4 +40,3 @@ class ELU(Module):
|
||||
|
||||
def __repr__(self):
|
||||
return '{}(alpha={:.3f})'.format(str(type(self)), self.alpha)
|
||||
|
||||
|
@ -3,6 +3,7 @@ import torch
|
||||
from .Module import Module
|
||||
from .utils import clear
|
||||
|
||||
|
||||
class Euclidean(Module):
|
||||
|
||||
def __init__(self, inputSize, outputSize):
|
||||
@ -18,11 +19,11 @@ class Euclidean(Module):
|
||||
self.fastBackward = True
|
||||
self.reset()
|
||||
|
||||
self._input = None
|
||||
self._weight = None
|
||||
self._expand = None
|
||||
self._input = None
|
||||
self._weight = None
|
||||
self._expand = None
|
||||
self._expand2 = None
|
||||
self._repeat = None
|
||||
self._repeat = None
|
||||
self._repeat2 = None
|
||||
self._div = None
|
||||
self._output = None
|
||||
@ -32,32 +33,32 @@ class Euclidean(Module):
|
||||
|
||||
def reset(self, stdv=None):
|
||||
if stdv is not None:
|
||||
stdv = stdv * math.sqrt(3)
|
||||
stdv = stdv * math.sqrt(3)
|
||||
else:
|
||||
stdv = 1./math.sqrt(self.weight.size(0))
|
||||
stdv = 1. / math.sqrt(self.weight.size(0))
|
||||
|
||||
self.weight.uniform_(-stdv, stdv)
|
||||
|
||||
def _view(self, res, src, *args):
|
||||
if src.is_contiguous():
|
||||
res.set_(src.view(*args))
|
||||
res.set_(src.view(*args))
|
||||
else:
|
||||
res.set_(src.contiguous().view(*args))
|
||||
res.set_(src.contiguous().view(*args))
|
||||
|
||||
def updateOutput(self, input):
|
||||
# lazy initialize buffers
|
||||
if self._input is None:
|
||||
self._input = input.new()
|
||||
self._input = input.new()
|
||||
if self._weight is None:
|
||||
self._weight = self.weight.new()
|
||||
self._weight = self.weight.new()
|
||||
if self._expand is None:
|
||||
self._expand = self.output.new()
|
||||
self._expand = self.output.new()
|
||||
if self._expand2 is None:
|
||||
self._expand2 = self.output.new()
|
||||
self._expand2 = self.output.new()
|
||||
if self._repeat is None:
|
||||
self._repeat = self.output.new()
|
||||
self._repeat = self.output.new()
|
||||
if self._repeat2 is None:
|
||||
self._repeat2 = self.output.new()
|
||||
self._repeat2 = self.output.new()
|
||||
|
||||
inputSize, outputSize = self.weight.size(0), self.weight.size(1)
|
||||
|
||||
@ -88,19 +89,19 @@ class Euclidean(Module):
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
if self.gradInput is None:
|
||||
return
|
||||
return
|
||||
|
||||
if self._div is None:
|
||||
self._div = input.new()
|
||||
self._div = input.new()
|
||||
if self._output is None:
|
||||
self._output = self.output.new()
|
||||
self._output = self.output.new()
|
||||
if self._gradOutput is None:
|
||||
self._gradOutput = input.new()
|
||||
self._gradOutput = input.new()
|
||||
if self._expand3 is None:
|
||||
self._expand3 = input.new()
|
||||
self._expand3 = input.new()
|
||||
|
||||
if not self.fastBackward:
|
||||
self.updateOutput(input)
|
||||
self.updateOutput(input)
|
||||
|
||||
inputSize, outputSize = self.weight.size(0), self.weight.size(1)
|
||||
|
||||
@ -126,13 +127,11 @@ class Euclidean(Module):
|
||||
else:
|
||||
torch.mul(self._repeat, self._expand3, out=self._repeat2)
|
||||
|
||||
|
||||
torch.sum(self._repeat2, 2, out=self.gradInput)
|
||||
self.gradInput.resize_as_(input)
|
||||
|
||||
return self.gradInput
|
||||
|
||||
|
||||
def accGradParameters(self, input, gradOutput, scale=1):
|
||||
inputSize, outputSize = self.weight.size(0), self.weight.size(1)
|
||||
|
||||
@ -144,32 +143,30 @@ class Euclidean(Module):
|
||||
# assumes a preceding call to updateGradInput
|
||||
assert input.dim() == 2
|
||||
if self._sum is None:
|
||||
self._sum = input.new()
|
||||
self._sum = input.new()
|
||||
torch.sum(self._repeat2, 0, out=self._sum)
|
||||
self._sum.resize_(inputSize, outputSize)
|
||||
self.gradWeight.add_(-scale, self._sum)
|
||||
|
||||
def type(self, type=None, tensorCache=None):
|
||||
if type:
|
||||
# prevent premature memory allocations
|
||||
self.clearState()
|
||||
# prevent premature memory allocations
|
||||
self.clearState()
|
||||
|
||||
return super(Euclidean, self).type(type, tensorCache)
|
||||
|
||||
|
||||
def clearState(self):
|
||||
clear(self, [
|
||||
'_input',
|
||||
'_output',
|
||||
'_gradOutput',
|
||||
'_weight',
|
||||
'_div',
|
||||
'_sum',
|
||||
'_expand',
|
||||
'_expand2',
|
||||
'_expand3',
|
||||
'_repeat',
|
||||
'_repeat2',
|
||||
'_input',
|
||||
'_output',
|
||||
'_gradOutput',
|
||||
'_weight',
|
||||
'_div',
|
||||
'_sum',
|
||||
'_expand',
|
||||
'_expand2',
|
||||
'_expand3',
|
||||
'_repeat',
|
||||
'_repeat2',
|
||||
])
|
||||
return super(Euclidean, self).clearState()
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class Exp(Module):
|
||||
|
||||
def updateOutput(self, input):
|
||||
@ -8,4 +9,3 @@ class Exp(Module):
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
return torch.mul(self.output, gradOutput, out=self.gradInput)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class FlattenTable(Module):
|
||||
|
||||
def __init__(self):
|
||||
@ -59,7 +60,6 @@ class FlattenTable(Module):
|
||||
|
||||
return self.output
|
||||
|
||||
|
||||
def updateGradInput(self, input, gradOutput):
|
||||
assert isinstance(input, list)
|
||||
assert isinstance(gradOutput, list)
|
||||
@ -69,11 +69,10 @@ class FlattenTable(Module):
|
||||
|
||||
# However, we should check that the gradInput is valid:
|
||||
if not self._checkMapping(gradOutput, self.gradInput, self.input_map):
|
||||
self.gradInput = self._inverseFlatten(gradOutput, self.input_map)
|
||||
self.gradInput = self._inverseFlatten(gradOutput, self.input_map)
|
||||
|
||||
return self.gradInput
|
||||
|
||||
|
||||
def type(self, type=None, tensorCache=None):
|
||||
if not type:
|
||||
return self._type
|
||||
@ -81,8 +80,6 @@ class FlattenTable(Module):
|
||||
# conversions. Just force the tables to be empty.
|
||||
self.clearState()
|
||||
|
||||
|
||||
def clearState(self):
|
||||
self.input_map = []
|
||||
return super(FlattenTable, self).clearState()
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class GradientReversal(Module):
|
||||
|
||||
def __init__(self, lambd=1):
|
||||
@ -19,4 +20,3 @@ class GradientReversal(Module):
|
||||
self.gradInput.copy_(gradOutput)
|
||||
self.gradInput.mul_(-self.lambd)
|
||||
return self.gradInput
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .Module import Module
|
||||
|
||||
|
||||
class HardShrink(Module):
|
||||
|
||||
def __init__(self, lambd=0.5):
|
||||
@ -26,4 +27,3 @@ class HardShrink(Module):
|
||||
self.lambd
|
||||
)
|
||||
return self.gradInput
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user