Merge branch 'logical' into newpack

Conflicts:
	pkg/torch/TensorMathWrap.lua
This commit is contained in:
Ronan Collobert
2012-02-03 10:12:30 +01:00
parent 477587f566
commit fc16a68f48
8 changed files with 282 additions and 2 deletions

View File

@ -30,7 +30,13 @@ for _,tensortype in ipairs({'ByteTensor',
'ceil',
'floor',
'abs',
'sign'
'sign',
'lt',
'gt',
'le',
'ge',
'eq',
'ne'
}) do
local torchfunc = torch[tensortype].torch[func]

View File

@ -660,6 +660,18 @@ static void THTensor_random1__(THTensor *self, long b)
{name='charoption', default="X", invisible=true}}
)
for _,name in pairs({'lt','gt','le','ge','eq','ne'}) do
interface:wrap(name,
cname(name .. 'Value'),
{{name='ByteTensor',default=true, returned=true},
{name=Tensor},
{name=real}},
cname(name .. 'Tensor'),
{{name='ByteTensor',default=true, returned=true},
{name=Tensor},
{name=Tensor}})
end
if Tensor == 'FloatTensor' or Tensor == 'DoubleTensor' then
interface:wrap("mean",

View File

@ -11,6 +11,7 @@ categories:
* [[#torch.matrixwide.dok|matrix-wide operations]] like [[#torch.trace|trace]] and [[#torch.norm|norm]].
* [[#torch.conv.dok|Convolution and cross-correlation]] operations like [[#torch.conv2|conv2]].
* [[#torch.linalg.dok|Basic linear algebra operations]] like [[#torch.eig|eigen value/vector calculation]], [[#torch.svd|singular value decomposition (svd)]] and [[#torch.gesv|linear system solution]].
* [[#torch.logical.dok|Logical Operations on Tensors]].
By default, all operations allocate a new tensor to return the
result. However, all functions also support passing the resulting(s)
@ -802,3 +803,161 @@ u,s,v = torch.svd(a)
</file>
====== Logical Operations on Tensors ======
{{anchor:torch.logical.dok}}
These functions implement logical comparison operators that take a
tensor as input and another tensor or a number as the comparison
target. They return a ''ByteTensor'' in which each element is 0 or 1
indicating if the comparison for the corresponding element was
''false'' or ''true'' respectively.
===== torch.lt(a, b) =====
{{anchor:torch.lt}}
Implements %%<%% operator comparing each element in ''a'' with ''b''
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
===== torch.le(a, b) =====
{{anchor:torch.lt}}
Implements %%<=%% operator comparing each element in ''a'' with ''b''
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
===== torch.gt(a, b) =====
{{anchor:torch.lt}}
Implements %%>%% operator comparing each element in ''a'' with ''b''
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
===== torch.ge(a, b) =====
{{anchor:torch.lt}}
Implements %%>=%% operator comparing each element in ''a'' with ''b''
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
===== torch.eq(a, b) =====
{{anchor:torch.lt}}
Implements %%==%% operator comparing each element in ''a'' with ''b''
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
===== torch.ne(a, b) =====
{{anchor:torch.lt}}
Implements %%!=%% operator comparing each element in ''a'' with ''b''
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
<file lua>
> a = torch.rand(10)
> b = torch.rand(10)
> =a
0.5694
0.5264
0.3041
0.4159
0.1677
0.7964
0.0257
0.2093
0.6564
0.0740
[torch.DoubleTensor of dimension 10]
> =b
0.2950
0.4867
0.9133
0.1291
0.1811
0.3921
0.7750
0.3259
0.2263
0.1737
[torch.DoubleTensor of dimension 10]
> =torch.lt(a,b)
0
0
1
0
1
0
1
1
0
1
[torch.ByteTensor of dimension 10]
> return torch.eq(a,b)
0
0
0
0
0
0
0
0
0
0
[torch.ByteTensor of dimension 10]
> return torch.ne(a,b)
1
1
1
1
1
1
1
1
1
1
[torch.ByteTensor of dimension 10]
> return torch.gt(a,b)
1
1
0
1
0
1
0
0
1
0
[torch.ByteTensor of dimension 10]
> a[torch.gt(a,b)] = 10
> =a
10.0000
10.0000
0.3041
10.0000
0.1677
10.0000
0.0257
0.2093
10.0000
0.0740
[torch.DoubleTensor of dimension 10]
> a[torch.gt(a,1)] = -1
> =a
-1.0000
-1.0000
0.3041
-1.0000
0.1677
-1.0000
0.0257
0.2093
-1.0000
0.0740
[torch.DoubleTensor of dimension 10]
</file>

View File

@ -470,6 +470,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
{
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
THLongStorage *idx = NULL;
THByteTensor *mask;
if(lua_isnumber(L, 2))
{
@ -556,6 +557,22 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
THStorage_(set)(tensor->storage, index, value);
lua_pushboolean(L, 1);
}
else if((mask = luaT_toudata(L, 2, torch_ByteTensor_id)))
{
THTensor *vals;
if (lua_isnumber(L, 3))
{
THTensor_(maskedFill)(tensor, mask, (real)(luaL_checknumber(L,3)));
}
else if((vals = luaT_toudata(L, 3, torch_Tensor_id)))
{
THTensor_(maskedCopy)(tensor, mask, vals);
}
else
{
luaL_error(L,"number or tensor expected");
}
}
else
lua_pushboolean(L, 0);

View File

@ -9,7 +9,7 @@
TYPE2 *TENSOR2##_data = NULL; \
long *TENSOR2##_counter = NULL; \
long TENSOR2##_stride = 0, TENSOR2##_size = 0, TENSOR2##_dim = 0, TENSOR2##_i, TENSOR2##_n; \
TYPE2 *TENSOR3##_data = NULL; \
TYPE3 *TENSOR3##_data = NULL; \
long *TENSOR3##_counter = NULL; \
long TENSOR3##_stride = 0, TENSOR3##_size = 0, TENSOR3##_dim = 0, TENSOR3##_i, TENSOR3##_n; \
int TH_TENSOR_APPLY_hasFinished = 0; \

View File

@ -14,6 +14,36 @@ void THTensor_(zero)(THTensor *r_)
THVector_(fill)(r__data, 0, r__size); break;);
}
void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, real value)
{
TH_TENSOR_APPLY2(real, tensor, unsigned char, mask,
if (*mask_data > 1) THError("Mask tensor can take 0 and 1 values only");
else if (*mask_data == 1) *tensor_data = value;);
}
void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src )
{
THTensor *srct = THTensor_(newContiguous)(src);
real *src_data = srct->storage->data;
long cntr = 0;
long nelem = THTensor_(nElement)(srct);
TH_TENSOR_APPLY2(real, tensor, unsigned char, mask,
if (*mask_data > 1)
{
THError("Mask tensor can take 0 and 1 values only");
}
else if (*mask_data == 1)
{
*tensor_data = *src_data;
src_data++;
cntr++;
if (cntr > nelem)
THError("Number of elements of src != mask");
});
if (cntr != nelem)
THError("Number of elements of src != mask");
}
accreal THTensor_(dot)(THTensor *tensor, THTensor *src)
{
accreal sum = 0;
@ -847,6 +877,29 @@ void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension)
}
}
#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP) \
void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, real value) \
{ \
THByteTensor_rawResize(r_, t->nDimension, t->size, NULL); \
THByteTensor_zero(r_); \
TH_TENSOR_APPLY2(unsigned char, r_, real, t, \
if (*t_data OP value) *r__data = 1;); \
} \
void THTensor_(NAME##Tensor)(THByteTensor *r_, THTensor *ta, THTensor *tb) \
{ \
THByteTensor_rawResize(r_, ta->nDimension, ta->size, NULL); \
THByteTensor_zero(r_); \
TH_TENSOR_APPLY3(unsigned char, r_, real, ta, real, tb, \
if(*ta_data OP *tb_data) *r__data = 1;); \
} \
TENSOR_IMPLEMENT_LOGICAL(lt,<)
TENSOR_IMPLEMENT_LOGICAL(gt,>)
TENSOR_IMPLEMENT_LOGICAL(le,<=)
TENSOR_IMPLEMENT_LOGICAL(ge,>=)
TENSOR_IMPLEMENT_LOGICAL(eq,==)
TENSOR_IMPLEMENT_LOGICAL(ne,!=)
/* floating point only now */
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)

View File

@ -5,6 +5,9 @@
TH_API void THTensor_(fill)(THTensor *r_, real value);
TH_API void THTensor_(zero)(THTensor *r_);
TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, real value);
TH_API void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src);
TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src);
TH_API real THTensor_(minall)(THTensor *t);
@ -50,6 +53,20 @@ TH_API void THTensor_(tril)(THTensor *r_, THTensor *t, long k);
TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, long k);
TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension);
TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, real value);
TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, real value);
TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, real value);
TH_API void THTensor_(geValue)(THByteTensor *r_, THTensor* t, real value);
TH_API void THTensor_(neValue)(THByteTensor *r_, THTensor* t, real value);
TH_API void THTensor_(eqValue)(THByteTensor *r_, THTensor* t, real value);
TH_API void THTensor_(ltTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
TH_API void THTensor_(leTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
TH_API void THTensor_(gtTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
TH_API void THTensor_(geTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
TH_API void THTensor_(neTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
TH_API void THTensor_(eqTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
TH_API void THTensor_(log)(THTensor *r_, THTensor *t);

View File

@ -344,6 +344,22 @@ function torchtest.conv3()
mytester:asserteq(maxdiff(immfc[1],imfc),0,'torch.conv3')
end
function torchtest.logical()
local x = torch.rand(100,100)*2-1;
local xx = x:clone()
local xgt = torch.gt(x,1)
local xlt = torch.lt(x,1)
local xeq = torch.eq(x,1)
local xne = torch.ne(x,1)
local neqs = xgt+xlt
local all = neqs + xeq
mytester:asserteq(neqs:sumall(), xne:sumall(), 'torch.logical')
mytester:asserteq(x:nElement(),all:double():sumall() , 'torch.logical')
end
function torch.test()
math.randomseed(os.time())
mytester = torch.Tester()