Allow support for negative dimension argument for all functions

This commit is contained in:
albanD
2017-03-26 14:25:53 +01:00
committed by soumith
parent e7f5220dfa
commit f0c7124420
13 changed files with 394 additions and 179 deletions

View File

@ -154,10 +154,11 @@ class build_ext(setuptools.command.build_ext.build_ext):
from tools.cwrap.plugins.KwargsPlugin import KwargsPlugin from tools.cwrap.plugins.KwargsPlugin import KwargsPlugin
from tools.cwrap.plugins.NullableArguments import NullableArguments from tools.cwrap.plugins.NullableArguments import NullableArguments
from tools.cwrap.plugins.CuDNNPlugin import CuDNNPlugin from tools.cwrap.plugins.CuDNNPlugin import CuDNNPlugin
from tools.cwrap.plugins.WrapDim import WrapDim
thp_plugin = THPPlugin() thp_plugin = THPPlugin()
cwrap('torch/csrc/generic/TensorMethods.cwrap', plugins=[ cwrap('torch/csrc/generic/TensorMethods.cwrap', plugins=[
BoolOption(), thp_plugin, AutoGPU(condition='IS_CUDA'), BoolOption(), thp_plugin, AutoGPU(condition='IS_CUDA'),
ArgcountSortPlugin(), KwargsPlugin() ArgcountSortPlugin(), KwargsPlugin(), WrapDim()
]) ])
cwrap('torch/csrc/cudnn/cuDNN.cwrap', plugins=[ cwrap('torch/csrc/cudnn/cuDNN.cwrap', plugins=[
CuDNNPlugin(), NullableArguments() CuDNNPlugin(), NullableArguments()

View File

@ -6,6 +6,7 @@ import torch
import unittest import unittest
from copy import deepcopy from copy import deepcopy
from collections import OrderedDict from collections import OrderedDict
from itertools import product
from torch.autograd import gradcheck from torch.autograd import gradcheck
from common import TestCase, run_tests from common import TestCase, run_tests
@ -944,11 +945,12 @@ def gather_variable(shape, index_dim, max_indices):
return Variable(index, requires_grad=False) return Variable(index, requires_grad=False)
def prod_zeros(dim_size): def prod_zeros(dim_size, dim_select):
assert len(dim_select) == 2
result = torch.randn(dim_size, dim_size, dim_size) result = torch.randn(dim_size, dim_size, dim_size)
result[0, 1] = 0 result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_()
result[2, 3] = 0 result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_()
result[4, 3] = 0 result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_()
return Variable(result, requires_grad=True) return Variable(result, requires_grad=True)
@ -974,8 +976,8 @@ function_tests = [
(DivConstant, (3.14, True), (torch.rand(L, L) + 1e-1,), 'by_tensor'), (DivConstant, (3.14, True), (torch.rand(L, L) + 1e-1,), 'by_tensor'),
(PowConstant, (3.14,), (torch.rand(L, L),)), (PowConstant, (3.14,), (torch.rand(L, L),)),
(PowConstant, (3.14, True), (torch.rand(L, L),), 'tensor_power'), (PowConstant, (3.14, True), (torch.rand(L, L),), 'tensor_power'),
(Transpose, (0, 1), (torch.rand(L, L),)), (Transpose, (0, 1), (torch.rand(L, L),), '2d', [0, 1]),
(Transpose, (2, 0), (torch.rand(S, S, S),), '3d'), (Transpose, (2, 0), (torch.rand(S, S, S),), '3d', [0, 1]),
(Permute, ((0, 4, 3, 5, 1, 2),), ((1, 2, 3, 4, 5, 6),)), (Permute, ((0, 4, 3, 5, 1, 2),), ((1, 2, 3, 4, 5, 6),)),
(Index, ((1, 2),), (torch.rand(S, S, S),)), (Index, ((1, 2),), (torch.rand(S, S, S),)),
(Index, (slice(0, 3),), (torch.rand(S, S, S),), 'slice'), (Index, (slice(0, 3),), (torch.rand(S, S, S),), 'slice'),
@ -1017,14 +1019,18 @@ function_tests = [
(CmaxConstant, (0.5,), ((S, S, S),)), (CmaxConstant, (0.5,), ((S, S, S),)),
(CminConstant, (0.5,), ((S, S, S),)), (CminConstant, (0.5,), ((S, S, S),)),
(Mean, (), ((S, S, S),)), (Mean, (), ((S, S, S),)),
(Mean, (1,), ((S, S, S),), 'dim'), (Mean, (1,), ((S, S, S),), 'dim', [0]),
(Sum, (), ((S, S, S),)), (Sum, (), ((S, S, S),)),
(Sum, (1,), ((S, S, S),), 'dim'), (Sum, (1,), ((S, S, S),), 'dim', [0]),
(Prod, (), ((S, S, S),)), (Prod, (), ((S, S, S),)),
(Prod, (), (prod_zeros(S),), 'zeros'), (Prod, (), (prod_zeros(S, [0, 1]),), 'zerosdim2'),
(Prod, (), (prod_zeros(S, [0, 2]),), 'zerosdim1'),
(Prod, (), (prod_zeros(S, [1, 2]),), 'zerosdim0'),
(Prod, (), (prod_single_zero(S),), 'single_zero'), (Prod, (), (prod_single_zero(S),), 'single_zero'),
(Prod, (1,), ((S, S, S),), 'dim'), (Prod, (1,), ((S, S, S),), 'dim', [0]),
(Prod, (1,), (prod_zeros(S),), 'zeros_dim'), (Prod, (1,), (prod_zeros(S, [0, 1]),), 'zeros_dim2', [0]),
(Prod, (1,), (prod_zeros(S, [0, 2]),), 'zeros_dim1', [0]),
(Prod, (1,), (prod_zeros(S, [1, 2]),), 'zeros_dim0', [0]),
(Addmm, (), ((S, M), (S, S), (S, M)),), (Addmm, (), ((S, M), (S, S), (S, M)),),
(Addmm, (0.1, 1), ((S, M), (S, S), (S, M)), 'coef'), (Addmm, (0.1, 1), ((S, M), (S, S), (S, M)), 'coef'),
(Addbmm, (), ((S, M), (S, S, S), (S, S, M)),), (Addbmm, (), ((S, M), (S, S, S), (S, S, M)),),
@ -1042,17 +1048,17 @@ function_tests = [
(Cumsum, (1,), ((S, S, S),), 'dim1'), (Cumsum, (1,), ((S, S, S),), 'dim1'),
(Cumsum, (0,), ((S,),), '1d'), (Cumsum, (0,), ((S,),), '1d'),
(Min, (), ((S, S, S),),), (Min, (), ((S, S, S),),),
(Max, (0,), ((S, S, S),), 'dim'), (Max, (1,), ((S, S, S),), 'dim', [0]),
(Min, (0,), ((S, S, S),), 'dim'), (Min, (1,), ((S, S, S),), 'dim', [0]),
(Mode, (0,), ((S, S, S),),), (Mode, (1,), ((S, S, S),), 'dim', [0]),
(Kthvalue, (2, 0), ((S, S, S),),), (Kthvalue, (2, 0), ((S, S, S),),),
(Median, (0,), ((S, S, S),),), (Median, (0,), ((S, S, S),),),
(Norm, (1.5,), (torch.rand(S, S, S),), '1_5'), (Norm, (1.5,), (torch.rand(S, S, S),), '1_5'),
(Norm, (), ((S, S, S),), '2'), (Norm, (), ((S, S, S),), '2'),
(Norm, (3,), ((S, S, S),), '3'), (Norm, (3,), ((S, S, S),), '3'),
(Norm, (1.5, 0), (torch.rand(S, S, S),), '1_5_dim'), (Norm, (1.5, 1), (torch.rand(S, S, S),), '1_5_dim', [1]),
(Norm, (2, 0), ((S, S, S),), '2_dim'), (Norm, (2, 1), ((S, S, S),), '2_dim', [1]),
(Norm, (3, 0), ((S, S, S),), '3_dim'), (Norm, (3, 1), ((S, S, S),), '3_dim', [1]),
(Addcmul, (), ((S, S), (S, S), (S, S))), (Addcmul, (), ((S, S), (S, S), (S, S))),
(Addcmul, (0.6,), ((S, S), (S, S), (S, S)), 'scale'), (Addcmul, (0.6,), ((S, S), (S, S), (S, S)), 'scale'),
(Addcdiv, (), ((S, S), (S, S), torch.rand(S, S) + 5e-2)), (Addcdiv, (), ((S, S), (S, S), torch.rand(S, S) + 5e-2)),
@ -1062,10 +1068,12 @@ function_tests = [
(IndexFill, (0, 2), ((S, S), index_variable(2, S))), (IndexFill, (0, 2), ((S, S), index_variable(2, S))),
(IndexSelect, (0,), ((S, S), index_variable(2, S))), (IndexSelect, (0,), ((S, S), index_variable(2, S))),
(Gather, (0,), ((M, S), gather_variable((S, S), 1, M))), (Gather, (0,), ((M, S), gather_variable((S, S), 1, M))),
(Gather, (1,), ((M, S), gather_variable((M, S // 2), 0, S)), 'dim1'), (Gather, (1,), ((M, S), gather_variable((M, S // 2), 0, S)), 'dim1', [0]),
(Scatter, (0,), ((M, S), gather_variable((S, S), 1, M), (S, S))), (Scatter, (0,), ((M, S), gather_variable((S, S), 1, M), (S, S))),
(Scatter, (1,), ((M, S), gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1'), (Scatter, (1,), ((M, S), gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]),
(Concat, (0,), ((1, S, S), (2, S, S), (3, S, S))), (Concat, (0,), ((1, S, S), (2, S, S), (3, S, S))),
(Concat, (-1,), ((S, S, 1), (S, S, 2), (S, S, 3)), 'negdim-1'),
(Concat, (-2,), ((S, 1, S), (S, 2, S), (S, 3, S)), 'negdim-2'),
(Resize, (S * S, S), ((S, S, S),)), (Resize, (S * S, S), ((S, S, S),)),
(Diag, (), ((S, S),), '2d'), (Diag, (), ((S, S),), '2d'),
(Diag, (), ((S,),), '1d'), (Diag, (), ((S,),), '1d'),
@ -1078,19 +1086,20 @@ function_tests = [
(Cross, (1,), ((S, 3, S), (S, 3, S)), 'dim'), (Cross, (1,), ((S, 3, S), (S, 3, S)), 'dim'),
(Clone, (), ((S, M, S),)), (Clone, (), ((S, M, S),)),
(Squeeze, (), ((S, 1, M, 1),)), (Squeeze, (), ((S, 1, M, 1),)),
(Squeeze, (1,), ((S, 1, M, 1),), 'dim'), (Squeeze, (1,), ((S, 1, M, 1),), 'dim', [0]),
(Unsqueeze, (0,), ((S, M, S),), '0'), (Unsqueeze, (0,), ((S, M, S),), '0'),
(Unsqueeze, (1,), ((S, M, S),), '1'), (Unsqueeze, (1,), ((S, M, S),), '1', [0]),
(Unsqueeze, (2,), ((S, M, S),), '2', [0]),
# (MaskedCopy, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False), (S, S),)), # (MaskedCopy, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False), (S, S),)),
(MaskedFill, (10,), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))), (MaskedFill, (10,), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
(MaskedSelect, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))), (MaskedSelect, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
(Sort, (), ((S, M, S),)), (Sort, (), ((S, M, S),)),
(Sort, (1,), ((S, M, S),), 'dim'), (Sort, (1,), ((S, M, S),), 'dim', [0]),
(Sort, (1, True), ((S, M, S),), 'dim_desc'), (Sort, (1, True), ((S, M, S),), 'dim_desc', [0]),
(Topk, (3,), ((S, M, S),)), (Topk, (3,), ((S, M, S),)),
(Topk, (3, 1), ((S, M, S),), 'dim'), (Topk, (3, 1), ((S, M, S),), 'dim', [1]),
(Topk, (3, 1, True), ((S, M, S),), 'dim_desc'), (Topk, (3, 1, True), ((S, M, S),), 'dim_desc', [1]),
(Topk, (3, 1, True, True), ((S, M, S),), 'dim_desc_sort'), (Topk, (3, 1, True, True), ((S, M, S),), 'dim_desc_sort', [1]),
] ]
@ -1105,7 +1114,7 @@ method_tests = [
('div', (S, S, S), (3.14,), 'constant'), ('div', (S, S, S), (3.14,), 'constant'),
('pow', (S, S, S), ((S, S, S),)), ('pow', (S, S, S), ((S, S, S),)),
('pow', (S, S, S), (3.14,), 'constant'), ('pow', (S, S, S), (3.14,), 'constant'),
('transpose', (1, 2, 3), (1, 2)), ('transpose', (1, 2, 3), (1, 2), 'dim', [0, 1]),
('t', (1, 2), ()), ('t', (1, 2), ()),
('view', (S, S, S), (S * S, S),), ('view', (S, S, S), (S * S, S),),
('view_as', (S, S, S), ((S * S, S),)), ('view_as', (S, S, S), ((S * S, S),)),
@ -1140,20 +1149,22 @@ method_tests = [
('remainder', (S, S, S), (1.5,)), ('remainder', (S, S, S), (1.5,)),
('lerp', (S, S, S), ((S, S, S), 0.4)), ('lerp', (S, S, S), ((S, S, S), 0.4)),
('max', (S, S, S), ()), ('max', (S, S, S), ()),
('max', (S, S, S), (1,), 'dim', [0]),
('max', (S, S, S), ((S, S, S),), 'elementwise'), ('max', (S, S, S), ((S, S, S),), 'elementwise'),
('min', (S, S, S), ()), ('min', (S, S, S), ()),
('min', (S, S, S), (1,), 'dim', [0]),
('min', (S, S, S), ((S, S, S),), 'elementwise'), ('min', (S, S, S), ((S, S, S),), 'elementwise'),
('mean', (S, S, S), ()), ('mean', (S, S, S), ()),
('mean', (S, S, S), (1,), 'dim'), ('mean', (S, S, S), (1,), 'dim', [0]),
('sum', (S, S, S), ()), ('sum', (S, S, S), ()),
('sum', (S, S, S), (1,), 'dim'), ('sum', (S, S, S), (1,), 'dim', [0]),
('prod', (S, S, S), ()), ('prod', (S, S, S), ()),
('prod', (S, S, S), (1,), 'dim'), ('prod', (S, S, S), (1,), 'dim', [0]),
('var', (S, S, S), ()), ('var', (S, S, S), ()),
('var', (S, S, S), (1,), 'dim'), ('var', (S, S, S), (1,), 'dim', [0]),
('std', (S, S, S), ()), ('std', (S, S, S), ()),
('std', (S, S, S), (1,), 'dim'), ('std', (S, S, S), (1,), 'dim', [0]),
('renorm', (S, S, S), (2, 1, 0.5)), ('renorm', (S, S, S), (2, 1, 0.5), 'dim', [1]),
('renorm', (S, S, S), (1, 2, 3), 'norm_1'), ('renorm', (S, S, S), (1, 2, 3), 'norm_1'),
('repeat', (S, S, S, S), (2, 3, 1, 4)), ('repeat', (S, S, S, S), (2, 3, 1, 4)),
('cumsum', (S, S, S), (1,)), ('cumsum', (S, S, S), (1,)),
@ -1174,10 +1185,10 @@ method_tests = [
('addcdiv', (S, S), ((S, S), (S, S))), ('addcdiv', (S, S), ((S, S), (S, S))),
('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale'), ('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale'),
('norm', (S, S, S), (2,)), ('norm', (S, S, S), (2,)),
('norm', (S, S, S), (2, 1), 'dim'), ('norm', (S, S, S), (2, 1), 'dim', [1]),
('dist', (S, S, S), ((S, S, S),)), ('dist', (S, S, S), ((S, S, S),)),
('dist', (S, S, S), ((S, S, S), 4), '4'), ('dist', (S, S, S), ((S, S, S), 4), '4'),
('index_select', (S, S, S), (0, index_variable(2, S))), ('index_select', (S, S, S), (0, index_variable(2, S)), 'dim', [0]),
('diag', (M, M), (), '2d'), ('diag', (M, M), (), '2d'),
('diag', (M,), (), '1d'), ('diag', (M,), (), '1d'),
('tril', (M, M), ()), ('tril', (M, M), ()),
@ -1199,14 +1210,14 @@ method_tests = [
('lt', (S, S, S), (0,), 'scalar'), ('lt', (S, S, S), (0,), 'scalar'),
('le', (S, S, S), (0,), 'scalar'), ('le', (S, S, S), (0,), 'scalar'),
('permute', (1, 2, 3, 4), (0, 2, 3, 1)), ('permute', (1, 2, 3, 4), (0, 2, 3, 1)),
('select', (S, S, S), (1, 2)), ('select', (S, S, S), (1, 2), 'dim', [0]),
('narrow', (S, S, S), (1, 2, 2)), ('narrow', (S, S, S), (1, 2, 2), 'dim', [0]),
('squeeze', (S, 1, S, 1), ()), ('squeeze', (S, 1, S, 1), ()),
('squeeze', (S, 1, S, 1), (1,), '1_dim'), ('squeeze', (S, 1, S, 1), (1,), '1_dim', [0]),
('squeeze', (S, 1, S, 1), (2,), 'not_1_dim'), ('squeeze', (S, 1, S, 1), (2,), 'not_1_dim', [0]),
('unsqueeze', (S, S, S), (0,), 'first'), ('unsqueeze', (S, S, S), (0,), 'first', [0]),
('unsqueeze', (S, S, S), (1,), 'middle'), ('unsqueeze', (S, S, S), (1,), 'middle', [0]),
('unsqueeze', (S, S, S), (3,), 'last'), ('unsqueeze', (S, S, S), (3,), 'last', [0]),
('masked_select', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False),)), ('masked_select', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False),)),
('masked_fill_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), 10)), ('masked_fill_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), 10)),
('masked_copy_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M))), ('masked_copy_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M))),
@ -1252,41 +1263,50 @@ ignore_inplace = set((
for test in function_tests: for test in function_tests:
cls, constructor_args, call_args = test[:3] cls, constructor_args, call_args = test[:3]
test_name = 'test_{}Function'.format(cls.__name__) basic_test_name = 'test_{}Function'.format(cls.__name__)
if len(test) == 4: if len(test) >= 4:
test_name += '_' + test[3] basic_test_name += '_' + test[3]
def do_test(self, cls=cls, constructor_args=constructor_args, dim_args_idx = test[4] if len(test) == 5 else []
call_args=call_args, test_name=test_name):
input = create_input(call_args)
self.assertEqual(gradcheck(cls(*constructor_args), input, eps=1e-6, atol=PRECISION), True)
if test_name not in ignore_inplace and issubclass(cls, InplaceFunction): for dim_perm in product([-1, 1], repeat=len(dim_args_idx)):
output = cls(*constructor_args)(*input) test_name = basic_test_name
if not isinstance(output, tuple): new_constructor_args = [arg * dim_perm[dim_args_idx.index(i)] if i in dim_args_idx else arg
output = (output,) for i, arg in enumerate(constructor_args)]
inplace_input = deepcopy(input) test_name = basic_test_name + ''.join('_neg' + str(i) for i, idx in enumerate(dim_perm) if idx < 0)
inplace_input_copy = tuple(i + 0 for i in inplace_input) new_constructor_args = tuple(new_constructor_args)
fn = cls(*constructor_args, inplace=True)
inplace_output = fn(*inplace_input_copy)
if not isinstance(inplace_output, tuple):
inplace_output = (inplace_output,)
self.assertEqual(inplace_output, output)
# Check that gradient is the same
for inp_i, i in zip(inplace_input, input):
if inp_i.grad is not None:
inp_i.grad.data.zero_()
if i.grad is not None:
i.grad.data.zero_()
for io, o in zip(inplace_output, output):
grad = torch.randn(*io.size()).double()
io.backward(grad)
o.backward(grad)
for inp_i, i in zip(inplace_input, input):
self.assertEqual(inp_i.grad, i.grad)
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name def do_test(self, cls=cls, constructor_args=new_constructor_args,
setattr(TestAutograd, test_name, do_test) call_args=call_args, test_name=test_name):
input = create_input(call_args)
self.assertEqual(gradcheck(cls(*constructor_args), input, eps=1e-6, atol=PRECISION), True)
if test_name not in ignore_inplace and issubclass(cls, InplaceFunction):
output = cls(*constructor_args)(*input)
if not isinstance(output, tuple):
output = (output,)
inplace_input = deepcopy(input)
inplace_input_copy = tuple(i + 0 for i in inplace_input)
fn = cls(*constructor_args, inplace=True)
inplace_output = fn(*inplace_input_copy)
if not isinstance(inplace_output, tuple):
inplace_output = (inplace_output,)
self.assertEqual(inplace_output, output)
# Check that gradient is the same
for inp_i, i in zip(inplace_input, input):
if inp_i.grad is not None:
inp_i.grad.data.zero_()
if i.grad is not None:
i.grad.data.zero_()
for io, o in zip(inplace_output, output):
grad = torch.randn(*io.size()).double()
io.backward(grad)
o.backward(grad)
for inp_i, i in zip(inplace_input, input):
self.assertEqual(inp_i.grad, i.grad)
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
setattr(TestAutograd, test_name, do_test)
EXCLUDE_FUNCTIONAL = { EXCLUDE_FUNCTIONAL = {
@ -1298,42 +1318,50 @@ EXCLUDE_FUNCTIONAL = {
} }
for test in method_tests: for test in method_tests:
name, self_size, args = test[:3] name, self_size, args = test[:3]
test_name = 'test_' + name + ('_' + test[3] if len(test) == 4 else '') basic_test_name = 'test_' + name + ('_' + test[3] if len(test) >= 4 else '')
def do_test(self, name=name, self_size=self_size, args=args, test_name=test_name): dim_args_idx = test[4] if len(test) == 5 else []
def check(name):
self_variable = create_input((self_size,), requires_grad=False)[0]
args_variable = create_input(args, requires_grad=False)
self_tensor = deepcopy(self_variable.data)
args_tensor = deepcopy(unpack_variables(args_variable))
output_variable = getattr(self_variable, name)(*args_variable)
output_tensor = getattr(self_tensor, name)(*args_tensor)
if not torch.is_tensor(output_tensor) and not isinstance(output_tensor, tuple):
output_tensor = torch.DoubleTensor((output_tensor,))
self.assertEqual(unpack_variables(output_variable), output_tensor)
# TODO: check that both have changed after adding all inplace ops
# functional interface tests for dim_perm in product([-1, 1], repeat=len(dim_args_idx)):
if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL: test_name = basic_test_name
f_args_variable = (self_variable,) + args_variable new_args = [arg * dim_perm[dim_args_idx.index(i)] if i in dim_args_idx else arg for i, arg in enumerate(args)]
f_args_tensor = (self_tensor,) + args_tensor test_name = basic_test_name + ''.join('_neg' + str(i) for i, idx in enumerate(dim_perm) if idx < 0)
output_variable = getattr(torch, name)(*f_args_variable) new_args = tuple(new_args)
output_tensor = getattr(torch, name)(*f_args_tensor)
def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_name):
def check(name):
self_variable = create_input((self_size,), requires_grad=False)[0]
args_variable = create_input(args, requires_grad=False)
self_tensor = deepcopy(self_variable.data)
args_tensor = deepcopy(unpack_variables(args_variable))
output_variable = getattr(self_variable, name)(*args_variable)
output_tensor = getattr(self_tensor, name)(*args_tensor)
if not torch.is_tensor(output_tensor) and not isinstance(output_tensor, tuple): if not torch.is_tensor(output_tensor) and not isinstance(output_tensor, tuple):
output_tensor = torch.DoubleTensor((output_tensor,)) output_tensor = torch.DoubleTensor((output_tensor,))
self.assertEqual(unpack_variables(output_variable), output_tensor) self.assertEqual(unpack_variables(output_variable), output_tensor)
# TODO: check that both have changed after adding all inplace ops
check(name) # functional interface tests
inplace_name = name + '_' if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
if hasattr(Variable(torch.ones(1)), inplace_name): f_args_variable = (self_variable,) + args_variable
try: f_args_tensor = (self_tensor,) + args_tensor
check(inplace_name) output_variable = getattr(torch, name)(*f_args_variable)
except Exception as e: output_tensor = getattr(torch, name)(*f_args_tensor)
if 'only supports scalar' not in e.args[0]: if not torch.is_tensor(output_tensor) and not isinstance(output_tensor, tuple):
raise output_tensor = torch.DoubleTensor((output_tensor,))
self.assertEqual(unpack_variables(output_variable), output_tensor)
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name check(name)
setattr(TestAutograd, test_name, do_test) inplace_name = name + '_'
if hasattr(Variable(torch.ones(1)), inplace_name):
try:
check(inplace_name)
except Exception as e:
if 'only supports scalar' not in e.args[0]:
raise
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
setattr(TestAutograd, test_name, do_test)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -155,12 +155,15 @@ tests = [
('fmod', small_3d, lambda t: [small_3d_positive(t)], 'tensor'), ('fmod', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
('chunk', medium_2d, lambda t: [4],), ('chunk', medium_2d, lambda t: [4],),
('chunk', medium_2d, lambda t: [4, 1], 'dim'), ('chunk', medium_2d, lambda t: [4, 1], 'dim'),
('chunk', medium_2d, lambda t: [4, -2], 'neg_dim'),
('clamp', medium_2d_scaled, lambda t: [-1, 5],), ('clamp', medium_2d_scaled, lambda t: [-1, 5],),
('clone', medium_2d, lambda t: [],), ('clone', medium_2d, lambda t: [],),
('contiguous', medium_2d, lambda t: [],), ('contiguous', medium_2d, lambda t: [],),
('cross', new_t(M, 3, M), lambda t: [new_t(M, 3, M)(t)],), ('cross', new_t(M, 3, M), lambda t: [new_t(M, 3, M)(t)],),
('cumprod', small_3d, lambda t: [1],), ('cumprod', small_3d, lambda t: [1],),
('cumprod', small_3d, lambda t: [-1], 'neg_dim'),
('cumsum', small_3d, lambda t: [1],), ('cumsum', small_3d, lambda t: [1],),
('cumsum', small_3d, lambda t: [-1], 'neg_dim'),
('dim', small_3d, lambda t: [],), ('dim', small_3d, lambda t: [],),
('dist', small_2d, lambda t: [small_2d(t)],), ('dist', small_2d, lambda t: [small_2d(t)],),
('dist', small_2d, lambda t: [small_2d(t), 3], '3_norm'), ('dist', small_2d, lambda t: [small_2d(t), 3], '3_norm'),
@ -188,52 +191,72 @@ tests = [
# TODO: positive case # TODO: positive case
('kthvalue', small_3d_unique, lambda t: [3],), ('kthvalue', small_3d_unique, lambda t: [3],),
('kthvalue', small_3d_unique, lambda t: [3, 1], 'dim'), ('kthvalue', small_3d_unique, lambda t: [3, 1], 'dim'),
('kthvalue', small_3d_unique, lambda t: [3, -1], 'neg_dim'),
('lerp', small_3d, lambda t: [small_3d(t), 0.3],), ('lerp', small_3d, lambda t: [small_3d(t), 0.3],),
('max', small_3d_unique, lambda t: [],), ('max', small_3d_unique, lambda t: [],),
('max', small_3d_unique, lambda t: [1], 'dim'), ('max', small_3d_unique, lambda t: [1], 'dim'),
('max', small_3d_unique, lambda t: [-1], 'neg_dim'),
('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise'), ('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
('min', small_3d_unique, lambda t: [],), ('min', small_3d_unique, lambda t: [],),
('min', small_3d_unique, lambda t: [1], 'dim'), ('min', small_3d_unique, lambda t: [1], 'dim'),
('min', small_3d_unique, lambda t: [-1], 'neg_dim'),
('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise'), ('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
('mean', small_3d, lambda t: [],), ('mean', small_3d, lambda t: [],),
('mean', small_3d, lambda t: [-1], 'neg_dim'),
('mean', small_3d, lambda t: [1], 'dim'), ('mean', small_3d, lambda t: [1], 'dim'),
('mode', small_3d, lambda t: [],), ('mode', small_3d, lambda t: [],),
('mode', small_3d, lambda t: [1], 'dim'), ('mode', small_3d, lambda t: [1], 'dim'),
('mode', small_3d, lambda t: [-1], 'neg_dim'),
('remainder', small_3d, lambda t: [3], 'value'), ('remainder', small_3d, lambda t: [3], 'value'),
('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor'), ('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
('std', small_3d, lambda t: [],), ('std', small_3d, lambda t: [],),
('std', small_3d, lambda t: [1], 'dim'), ('std', small_3d, lambda t: [1], 'dim'),
('std', small_3d, lambda t: [-1], 'neg_dim'),
('var', small_3d, lambda t: [],), ('var', small_3d, lambda t: [],),
('var', small_3d, lambda t: [1], 'dim'), ('var', small_3d, lambda t: [1], 'dim'),
('var', small_3d, lambda t: [-1], 'neg_dim'),
('ndimension', small_3d, lambda t: [],), ('ndimension', small_3d, lambda t: [],),
('nelement', small_3d, lambda t: [],), ('nelement', small_3d, lambda t: [],),
('numel', small_3d, lambda t: [],), ('numel', small_3d, lambda t: [],),
('narrow', small_3d, lambda t: [1, 3, 2],), ('narrow', small_3d, lambda t: [1, 3, 2],),
('narrow', small_3d, lambda t: [-1, 3, 2], 'neg_dim'),
('nonzero', small_3d, lambda t: [],), ('nonzero', small_3d, lambda t: [],),
('norm', small_3d, lambda t: [],), ('norm', small_3d, lambda t: [],),
('norm', small_3d, lambda t: [3], '3_norm'), ('norm', small_3d, lambda t: [3], '3_norm'),
('norm', small_3d, lambda t: [3, 0], '3_norm_dim'), ('norm', small_3d, lambda t: [3, 0], '3_norm_dim'),
('norm', small_3d, lambda t: [3, -2], '3_norm_neg_dim'),
('ones', small_3d, lambda t: [1, 2, 3, 4, 5],), ('ones', small_3d, lambda t: [1, 2, 3, 4, 5],),
('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0],), ('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0],),
('prod', small_2d_oneish, lambda t: [],), ('prod', small_2d_oneish, lambda t: [],),
('prod', small_3d, lambda t: [1], 'dim'), ('prod', small_3d, lambda t: [1], 'dim'),
('prod', small_3d, lambda t: [-1], 'neg_dim'),
('sum', small_2d, lambda t: [],), ('sum', small_2d, lambda t: [],),
('sum', small_3d, lambda t: [1], 'dim'), ('sum', small_3d, lambda t: [1], 'dim'),
('sum', small_3d, lambda t: [-1], 'neg_dim'),
('renorm', small_3d, lambda t: [2, 1, 1], '2_norm'), ('renorm', small_3d, lambda t: [2, 1, 1], '2_norm'),
('renorm', small_3d, lambda t: [2, -1, 1], '2_norm_neg_dim'),
('renorm', small_3d, lambda t: [1.5, 1, 1], '1_5_norm'), ('renorm', small_3d, lambda t: [1.5, 1, 1], '1_5_norm'),
('repeat', small_2d, lambda t: [2, 2, 2],), ('repeat', small_2d, lambda t: [2, 2, 2],),
('size', new_t(1, 2, 3, 4), lambda t: [],), ('size', new_t(1, 2, 3, 4), lambda t: [],),
('size', new_t(1, 2, 3, 4), lambda t: [1], 'dim'),
('size', new_t(1, 2, 3, 4), lambda t: [-2], 'neg_dim'),
('sort', small_3d_unique, lambda t: [],), ('sort', small_3d_unique, lambda t: [],),
('sort', small_3d_unique, lambda t: [1], 'dim'), ('sort', small_3d_unique, lambda t: [1], 'dim'),
('sort', small_3d_unique, lambda t: [-1], 'neg_dim'),
('sort', small_3d_unique, lambda t: [1, True], 'dim_descending'), ('sort', small_3d_unique, lambda t: [1, True], 'dim_descending'),
('sort', small_3d_unique, lambda t: [-1, True], 'neg_dim_descending'),
('split', small_3d, lambda t: [2],), ('split', small_3d, lambda t: [2],),
('split', small_3d, lambda t: [2, 1], 'dim'), ('split', small_3d, lambda t: [2, 1], 'dim'),
('split', small_3d, lambda t: [2, -3], 'neg_dim'),
('squeeze', new_t(1, 2, 1, 4), lambda t: [],), ('squeeze', new_t(1, 2, 1, 4), lambda t: [],),
('squeeze', new_t(1, 2, 1, 4), lambda t: [2], 'dim'), ('squeeze', new_t(1, 2, 1, 4), lambda t: [2], 'dim'),
('squeeze', new_t(1, 2, 1, 4), lambda t: [-2], 'neg_dim'),
('t', new_t(1, 2), lambda t: [],), ('t', new_t(1, 2), lambda t: [],),
('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2],), ('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2],),
('transpose', new_t(1, 2, 3, 4), lambda t: [-1, -2], 'neg_dim'),
('to_list', small_3d, lambda t: [],), ('to_list', small_3d, lambda t: [],),
('topk', small_3d, lambda t: [2, 1, False, True], 'dim_sort'), ('topk', small_3d, lambda t: [2, 1, False, True], 'dim_sort'),
('topk', small_3d, lambda t: [2, -1, False, True], 'neg_dim_sort'),
('topk', small_3d, lambda t: [2, 1, True, True], 'dim_desc_sort'), ('topk', small_3d, lambda t: [2, 1, True, True], 'dim_desc_sort'),
('trace', medium_2d, lambda t: [],), ('trace', medium_2d, lambda t: [],),
('tril', medium_2d, lambda t: [],), ('tril', medium_2d, lambda t: [],),
@ -243,6 +266,7 @@ tests = [
('triu', medium_2d, lambda t: [2], 'positive'), ('triu', medium_2d, lambda t: [2], 'positive'),
('triu', medium_2d, lambda t: [-2], 'negative'), ('triu', medium_2d, lambda t: [-2], 'negative'),
('unsqueeze', new_t(2, 3, 4), lambda t: [2],), ('unsqueeze', new_t(2, 3, 4), lambda t: [2],),
('unsqueeze', new_t(2, 3, 4), lambda t: [-2], 'neg_dim'),
('view', small_3d, lambda t: [100, 10],), ('view', small_3d, lambda t: [100, 10],),
('view_as', small_3d, lambda t: [t(100, 10)],), ('view_as', small_3d, lambda t: [t(100, 10)],),
('zero', small_3d, lambda t: [],), ('zero', small_3d, lambda t: [],),
@ -467,6 +491,9 @@ class TestCuda(TestCase):
def test_scatter_cpu_dim(self): def test_scatter_cpu_dim(self):
self._test_scatter(torch.randn(4, 4), dim=1) self._test_scatter(torch.randn(4, 4), dim=1)
def test_scatter_cpu_neg_dim(self):
self._test_scatter(torch.randn(4, 4), dim=-2)
def test_scatter_cpu_sizes(self): def test_scatter_cpu_sizes(self):
self._test_scatter(torch.randn(6, 4), chunk_sizes=(2, 4)) self._test_scatter(torch.randn(6, 4), chunk_sizes=(2, 4))
@ -476,6 +503,9 @@ class TestCuda(TestCase):
def test_scatter_gpu_dim(self): def test_scatter_gpu_dim(self):
self._test_scatter(torch.randn(4, 4).cuda(), dim=1) self._test_scatter(torch.randn(4, 4).cuda(), dim=1)
def test_scatter_gpu_neg_dim(self):
self._test_scatter(torch.randn(4, 4).cuda(), dim=-2)
def test_scatter_gpu_sizes(self): def test_scatter_gpu_sizes(self):
self._test_scatter(torch.randn(6, 4).cuda(), chunk_sizes=(2, 4)) self._test_scatter(torch.randn(6, 4).cuda(), chunk_sizes=(2, 4))

View File

@ -2,6 +2,7 @@ import sys
import os import os
import math import math
import random import random
import copy
import torch import torch
import torch.cuda import torch.cuda
import tempfile import tempfile
@ -3132,23 +3133,107 @@ class TestTorch(TestCase):
self.assertIsInstance(x[:-1], torch.Size) self.assertIsInstance(x[:-1], torch.Size)
self.assertIsInstance(x + x, torch.Size) self.assertIsInstance(x + x, torch.Size)
def test_transpose_neg(self): # Functions to test negative dimension wrapping
x = torch.randn(10, 20, 30) METHOD = 1
ndim = 3 INPLACE_METHOD = 2
FUNCTIONAL = 4
DIM_ARG = None
for i, j in combinations(range(ndim), 2):
a = x.transpose(i, j)
b = x.transpose(i - ndim, j - ndim)
self.assertEqual(a, b)
a = torch.transpose(x, i, j) def make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim=0):
b = torch.transpose(x, i - ndim, j - ndim) def neg_dim_test(self):
self.assertEqual(a, b) if isinstance(tensor_arg, list):
assert METHOD not in types and INPLACE_METHOD not in types
x = [torch.randn(arg) for arg in tensor_arg]
ndim = len(tensor_arg[-1])
else:
x = torch.randn(*tensor_arg)
ndim = len(tensor_arg)
ndim += extra_dim
a = x.clone() n_dim_to_test = sum(map(lambda e: e is DIM_ARG, arg_constr()))
x.transpose_(i, j)
x.transpose_(i - ndim, j - ndim) for dims_val in combinations(range(ndim), n_dim_to_test):
self.assertEqual(a, x) arg = arg_constr()
arg_neg = copy.deepcopy(arg)
idx = 0
for i, v in enumerate(arg):
if v is DIM_ARG:
arg[i] = dims_val[idx]
arg_neg[i] = dims_val[idx] - ndim
idx += 1
if METHOD in types:
a = getattr(x, name)(*arg)
b = getattr(x, name)(*arg_neg)
self.assertEqual(a, b)
if INPLACE_METHOD in types:
a = x.clone()
getattr(a, name + '_')(*arg)
b = x.clone()
getattr(b, name + '_')(*arg_neg)
self.assertEqual(a, b)
if FUNCTIONAL in types:
a = getattr(torch, name)(x, *arg)
b = getattr(torch, name)(x, *arg_neg)
self.assertEqual(a, b)
return neg_dim_test
def idx_tensor(size, max_val):
return torch.LongTensor(*size).random_(0, max_val - 1)
neg_dim_tests = [
('narrow', (10, 20, 30), lambda: [DIM_ARG, 0, 5], [METHOD]),
('transpose', (10, 20, 30), lambda: [DIM_ARG, DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
('size', (10, 20, 30), lambda: [DIM_ARG], [METHOD]),
('cat', [(2, 3, 4), (2, 3, 4)], lambda: [DIM_ARG], [FUNCTIONAL]),
('chunk', (10, 20, 30), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
('gather', (10, 20), lambda: [DIM_ARG, idx_tensor((10, 20), 10)], [METHOD, FUNCTIONAL]),
('index_select', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10)], [METHOD, FUNCTIONAL]),
('split', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
('squeeze', (10, 1, 20, 1), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
('stack', [(2, 3, 4), (2, 3, 4)], lambda: [DIM_ARG], [FUNCTIONAL]),
('unbind', (2, 3, 4), lambda: [DIM_ARG], [FUNCTIONAL]),
('unsqueeze', (10, 20), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL], 1),
('cumprod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('cumsum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('mean', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('median', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('mode', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('norm', (10, 20), lambda: [2, DIM_ARG], [METHOD, FUNCTIONAL]),
('prod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('std', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('sum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('var', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('kthvalue', (10, 20), lambda: [3, DIM_ARG], [METHOD, FUNCTIONAL]),
('max', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('min', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('sort', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
('topk', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
('renorm', (10, 20), lambda: [2, DIM_ARG, 1], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
('index_add', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
('index_copy', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
('index_fill', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), 12], [INPLACE_METHOD]),
('scatter', (10, 10), lambda: [DIM_ARG, idx_tensor((10, 10), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
('select', (10, 20), lambda: [DIM_ARG, 3], [METHOD]),
('unfold', (10, 20), lambda: [DIM_ARG, 5, 2], [METHOD]),
]
for decl in neg_dim_tests:
if len(decl) == 4:
name, tensor_arg, arg_constr, types = decl
extra_dim = 0
elif len(decl) == 5:
name, tensor_arg, arg_constr, types, extra_dim = decl
test_name = 'test_' + name + '_neg_dim'
assert not hasattr(TestTorch, test_name), "Duplicated test name: " + test_name
setattr(TestTorch, test_name, make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim))
if __name__ == '__main__': if __name__ == '__main__':
run_tests() run_tests()

View File

@ -198,7 +198,7 @@ class cwrap(object):
arguments = self.get_assign_args(arguments) arguments = self.get_assign_args(arguments)
for arg, unpack in zip(arguments, arg_unpack): for arg, unpack in zip(arguments, arg_unpack):
if arg['type'] == 'CONSTANT': if arg['type'] == 'CONSTANT':
call_arg.append(str(arg['name'])) call_arg.append(unpack)
else: else:
var_name = "arg_" + str(arg.get('assign_name', arg['name'])) var_name = "arg_" + str(arg.get('assign_name', arg['name']))
res = self.ARG_ASSIGN_TEMPLATE.substitute( res = self.ARG_ASSIGN_TEMPLATE.substitute(

View File

@ -0,0 +1,40 @@
from . import CWrapPlugin
from string import Template
class WrapDim(CWrapPlugin):
NDIM_TEMPLATE = Template(
"""${arg_tensor}->nDimension""")
CODE_TEMPLATE = Template(
"""THPUtils_assert(${arg_dim} >= -(${ndim}) && ${arg_dim} < (${ndim}),
"dimension out of range (expected to be in range of [%d, %d], but got %d)",
-(${ndim}), (${ndim})-1, ${arg_dim});
if (${arg_dim} < 0) ${arg_dim} += (${ndim});""")
def initialize(self, cwrap):
self.cwrap = cwrap
def process_option_code_template(self, template, option):
new_code = []
for i, arg in enumerate(option['arguments']):
if 'wrap_dim' not in arg:
continue
params = arg.get('wrap_dim').split("+")
arg_tensor = params[0]
arg_tensor = "arg_" + arg_tensor
arg_dim = "arg_" + arg.get('assign_name', arg['name'])
params[0] = self.NDIM_TEMPLATE.substitute(arg_tensor=arg_tensor)
ndim = "+".join(params)
new_code.append(self.CODE_TEMPLATE.substitute(
arg_dim=arg_dim,
ndim=ndim))
new_code.append("")
template = new_code + template
return template

View File

@ -65,3 +65,4 @@ from .GILRelease import GILRelease
from .AutoGPU import AutoGPU from .AutoGPU import AutoGPU
from .CuDNNPlugin import CuDNNPlugin from .CuDNNPlugin import CuDNNPlugin
from .GenericNN import GenericNN from .GenericNN import GenericNN
from .WrapDim import WrapDim

View File

@ -4297,6 +4297,8 @@ specified position.
The returned tensor shares the same underlying data with this tensor. The returned tensor shares the same underlying data with this tensor.
A negative dim value can be used and will correspond to :math:`dim + input.dim() + 1`
Args: Args:
input (Tensor): the input `Tensor` input (Tensor): the input `Tensor`
dim (int): The index at which to insert the singleton dimension dim (int): The index at which to insert the singleton dimension

View File

@ -60,8 +60,9 @@ class Prod(_DimReduceFunction):
return grad_input return grad_input
else: else:
input, output = self.saved_tensors input, output = self.saved_tensors
dim = self.dim if self.dim >= 0 else self.dim + input.dim()
zero_mask = input == 0 zero_mask = input == 0
slice_zero_count = zero_mask.sum(self.dim) slice_zero_count = zero_mask.sum(dim)
total_zeros = slice_zero_count.sum() total_zeros = slice_zero_count.sum()
grad_input = grad_output.mul(output).expand_as(input).div(input) grad_input = grad_output.mul(output).expand_as(input).div(input)
if total_zeros == 0: if total_zeros == 0:
@ -71,9 +72,13 @@ class Prod(_DimReduceFunction):
grad_input[some_zeros] = 0 grad_input[some_zeros] = 0
single_zero_idx = slice_zero_count.eq(1).nonzero() single_zero_idx = slice_zero_count.eq(1).nonzero()
if len(single_zero_idx) == 0:
return grad_input
for idx in single_zero_idx: for idx in single_zero_idx:
idx_tuple = tuple(idx.cpu()) idx_tuple = tuple(idx.cpu())
input_idx_tuple = idx_tuple[:self.dim] + (slice(0, None),) + idx_tuple[self.dim + 1:] input_idx_tuple = idx_tuple[:dim] + (slice(0, None),) + idx_tuple[dim + 1:]
# slice_mask and input_copy are 1D # slice_mask and input_copy are 1D
slice_mask = zero_mask[input_idx_tuple] slice_mask = zero_mask[input_idx_tuple]
@ -81,7 +86,7 @@ class Prod(_DimReduceFunction):
zero_idx = slice_mask.nonzero()[0, 0] zero_idx = slice_mask.nonzero()[0, 0]
input_copy[zero_idx] = 1. input_copy[zero_idx] = 1.
grad_idx_tuple = idx_tuple[:self.dim] + (zero_idx,) + idx_tuple[self.dim + 1:] grad_idx_tuple = idx_tuple[:dim] + (zero_idx,) + idx_tuple[dim + 1:]
grad_input[grad_idx_tuple] = grad_output[idx_tuple] * input_copy.prod() grad_input[grad_idx_tuple] = grad_output[idx_tuple] * input_copy.prod()
return grad_input return grad_input

View File

@ -661,10 +661,12 @@ class Variable(_C._VariableBase):
return Transpose(dim1, dim2)(self) return Transpose(dim1, dim2)(self)
def select(self, dim, _index): def select(self, dim, _index):
dim = dim if dim >= 0 else dim + self.dim()
index = tuple(slice(None, None) for _ in range(dim)) + (_index,) index = tuple(slice(None, None) for _ in range(dim)) + (_index,)
return Index(index)(self) return Index(index)(self)
def narrow(self, dim, start_index, length): def narrow(self, dim, start_index, length):
dim = dim if dim >= 0 else dim + self.dim()
index = tuple(slice(None, None) for _ in range(dim)) + \ index = tuple(slice(None, None) for _ in range(dim)) + \
(slice(start_index, start_index + length),) (slice(start_index, start_index + length),)

View File

@ -191,6 +191,12 @@ static PyObject * THPTensor_(select)(THPTensor *self, PyObject *args)
return NULL; return NULL;
int ndim = THTensor_(nDimension)(LIBRARY_STATE self->cdata); int ndim = THTensor_(nDimension)(LIBRARY_STATE self->cdata);
THPUtils_assert(dim >= -(ndim) && dim < (ndim),
"dimension out of range (expected to be in range of [%d, %d], but got %d)",
-(ndim), (ndim)-1, dim);
if (dim<0) dim += ndim;
if(ndim > 1) { if(ndim > 1) {
THTensorPtr selected = THTensor_(newWithTensor)(LIBRARY_STATE self->cdata); THTensorPtr selected = THTensor_(newWithTensor)(LIBRARY_STATE self->cdata);
THTensor_(select)(LIBRARY_STATE selected.get(), NULL, dim, idx); THTensor_(select)(LIBRARY_STATE selected.get(), NULL, dim, idx);
@ -358,50 +364,34 @@ PyObject * THPTensor_(stride)(PyObject *self, PyObject *args, PyObject *kwargs)
- THBoolTensor* mask - THBoolTensor* mask
]] ]]
#if IS_CUDA
THTensor* THTensor_(transpose_neg)(THCState* state, THTensor *self, THTensor *src, int dim0, int dim1)
#else
THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int dim1)
#endif
{
int ndim = self->nDimension;
if (dim0 < 0)
dim0 += ndim;
if (dim1 < 0)
dim1 += ndim;
if (src != NULL) {
THTensor_(transpose)(LIBRARY_STATE self, src, dim0, dim1);
return NULL;
} else {
return THTensor_(newTranspose)(LIBRARY_STATE self, dim0, dim1);
}
}
[[ [[
name: transpose name: transpose
with_stateless: True with_stateless: True
cname: transpose_neg cname: newTranspose
cpu_half: True cpu_half: True
auto_gpu: False auto_gpu: False
return: THTensor* return: THTensor*
arguments: arguments:
- THTensor* self - THTensor* self
- CONSTANT NULL - arg: long dim0
- long dim0 wrap_dim: self
- long dim1 - arg: long dim1
wrap_dim: self
]] ]]
[[ [[
name: transpose_ name: transpose_
cname: transpose_neg cname: transpose
cpu_half: True cpu_half: True
auto_gpu: False auto_gpu: False
return: self return: self
arguments: arguments:
- THTensor* self - THTensor* self
- THTensor* self - THTensor* self
- long dim0 - arg: long dim0
- long dim1 wrap_dim: self
- arg: long dim1
wrap_dim: self
]] ]]
[[ [[
@ -449,7 +439,8 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
- arg: THTensor* result - arg: THTensor* result
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -465,7 +456,8 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
arguments: arguments:
- THTensor* self - THTensor* self
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -479,7 +471,8 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
- arg: THTensor* result - arg: THTensor* result
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self+1
]] ]]
[[ [[
@ -491,7 +484,8 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
arguments: arguments:
- THTensor* self - THTensor* self
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self+1
]] ]]
[[ [[
@ -550,7 +544,8 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
- arg: THTensor* result - arg: THTensor* result
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
- THIndexTensor* index - THIndexTensor* index
]] ]]
@ -561,7 +556,8 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
return: argument 0 return: argument 0
arguments: arguments:
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
- THIndexTensor* index - THIndexTensor* index
- THTensor* source - THTensor* source
]] ]]
@ -573,7 +569,8 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
return: argument 0 return: argument 0
arguments: arguments:
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
- THIndexTensor* index - THIndexTensor* index
- THTensor* source - THTensor* source
]] ]]
@ -585,7 +582,8 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
return: argument 0 return: argument 0
arguments: arguments:
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
- THIndexTensor* index - THIndexTensor* index
- real value - real value
]] ]]
@ -599,7 +597,8 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
- arg: THTensor* result - arg: THTensor* result
output: True output: True
- THTensor* self - THTensor* self
- long dimension - arg: long dimension
wrap_dim: self
- long start - long start
- long length - long length
]] ]]
@ -613,7 +612,8 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
- arg: THTensor* result - arg: THTensor* result
output: True output: True
- THTensor* self - THTensor* self
- long dimension - arg: long dimension
wrap_dim: self
- long size - long size
- long step - long step
]] ]]
@ -672,13 +672,15 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
- cname: scatter - cname: scatter
arguments: arguments:
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
- THIndexTensor* index - THIndexTensor* index
- THTensor* src - THTensor* src
- cname: scatterFill - cname: scatterFill
arguments: arguments:
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
- THIndexTensor* index - THIndexTensor* index
- real value - real value
]] ]]
@ -694,7 +696,8 @@ THTensor* THTensor_(transpose_neg)(THTensor *self, THTensor *src, int dim0, int
- arg: THTensor* result - arg: THTensor* result
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
- THIndexTensor* index - THIndexTensor* index
]] ]]

View File

@ -416,7 +416,8 @@
- arg: THIndexTensor* min_indices - arg: THIndexTensor* min_indices
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -442,7 +443,8 @@
- arg: THIndexTensor* max_indices - arg: THIndexTensor* max_indices
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -467,7 +469,8 @@
output: True output: True
- THTensor* self - THTensor* self
- long k - long k
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -489,7 +492,8 @@
- arg: THIndexTensor* indices - arg: THIndexTensor* indices
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -512,7 +516,8 @@
- arg: THIndexTensor* indices - arg: THIndexTensor* indices
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -535,7 +540,8 @@
- arg: THIndexTensor* indices - arg: THIndexTensor* indices
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
- CONSTANT false - CONSTANT false
- arguments: - arguments:
- arg: THTensor* values - arg: THTensor* values
@ -543,7 +549,8 @@
- arg: THIndexTensor* indices - arg: THIndexTensor* indices
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
- bool descending - bool descending
]] ]]
@ -571,7 +578,8 @@
output: True output: True
- THTensor* self - THTensor* self
- long k - long k
- long dim - arg: long dim
wrap_dim: self
- arg: bool largest - arg: bool largest
default: "true" default: "true"
- arg: bool sorted - arg: bool sorted

View File

@ -475,7 +475,8 @@
- arg: THTensor* destination - arg: THTensor* destination
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -493,7 +494,8 @@
- arg: THTensor* destination - arg: THTensor* destination
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
- CONSTANT false - CONSTANT false
]] ]]
@ -512,7 +514,8 @@
- arg: THTensor* destination - arg: THTensor* destination
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
- CONSTANT false - CONSTANT false
]] ]]
@ -534,7 +537,8 @@
output: True output: True
- THTensor* self - THTensor* self
- real p - real p
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -549,7 +553,8 @@
output: True output: True
- THTensor* self - THTensor* self
- real p - real p
- long dim - arg: long dim
wrap_dim: self
- real maxnorm - real maxnorm
]] ]]
@ -563,7 +568,8 @@
- THTensor* self - THTensor* self
- THTensor* self - THTensor* self
- real p - real p
- long dim - arg: long dim
wrap_dim: self
- real maxnorm - real maxnorm
]] ]]
@ -813,7 +819,8 @@
- arg: THTensor* result - arg: THTensor* result
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -830,7 +837,8 @@
- arg: THTensor* result - arg: THTensor* result
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -841,7 +849,8 @@
- arg: THTensor* result - arg: THTensor* result
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[
@ -852,7 +861,8 @@
- arg: THTensor* result - arg: THTensor* result
output: True output: True
- THTensor* self - THTensor* self
- long dim - arg: long dim
wrap_dim: self
]] ]]
[[ [[