diff --git a/TensorMath.lua b/TensorMath.lua index 7dddb00afb78..b5e95929912b 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -30,7 +30,13 @@ for _,tensortype in ipairs({'ByteTensor', 'ceil', 'floor', 'abs', - 'sign' + 'sign', + 'lt', + 'gt', + 'le', + 'ge', + 'eq', + 'ne' }) do local torchfunc = torch[tensortype].torch[func] diff --git a/TensorMathWrap.lua b/TensorMathWrap.lua index 7f2f6cb0172e..ff82fdf73ec4 100644 --- a/TensorMathWrap.lua +++ b/TensorMathWrap.lua @@ -660,6 +660,18 @@ static void THTensor_random1__(THTensor *self, long b) {name='charoption', default="X", invisible=true}} ) + for _,name in pairs({'lt','gt','le','ge','eq','ne'}) do + interface:wrap(name, + cname(name .. 'Value'), + {{name='ByteTensor',default=true, returned=true}, + {name=Tensor}, + {name=real}}, + cname(name .. 'Tensor'), + {{name='ByteTensor',default=true, returned=true}, + {name=Tensor}, + {name=Tensor}}) + end + if Tensor == 'FloatTensor' or Tensor == 'DoubleTensor' then interface:wrap("mean", diff --git a/dok/maths.dok b/dok/maths.dok index 187331b48122..c2314f044e06 100644 --- a/dok/maths.dok +++ b/dok/maths.dok @@ -11,6 +11,7 @@ categories: * [[#torch.matrixwide.dok|matrix-wide operations]] like [[#torch.trace|trace]] and [[#torch.norm|norm]]. * [[#torch.conv.dok|Convolution and cross-correlation]] operations like [[#torch.conv2|conv2]]. * [[#torch.linalg.dok|Basic linear algebra operations]] like [[#torch.eig|eigen value/vector calculation]], [[#torch.svd|singular value decomposition (svd)]] and [[#torch.gesv|linear system solution]]. + * [[#torch.logical.dok|Logical Operations on Tensors]]. By default, all operations allocate a new tensor to return the result. However, all functions also support passing the resulting(s) @@ -802,3 +803,161 @@ u,s,v = torch.svd(a) +====== Logical Operations on Tensors ====== +{{anchor:torch.logical.dok}} + +These functions implement logical comparison operators that take a +tensor as input and another tensor or a number as the comparison +target. They return a ''ByteTensor'' in which each element is 0 or 1 +indicating if the comparison for the corresponding element was +''false'' or ''true'' respectively. + +===== torch.lt(a, b) ===== +{{anchor:torch.lt}} + +Implements %%<%% operator comparing each element in ''a'' with ''b'' +(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''. + +===== torch.le(a, b) ===== +{{anchor:torch.lt}} + +Implements %%<=%% operator comparing each element in ''a'' with ''b'' +(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''. + +===== torch.gt(a, b) ===== +{{anchor:torch.lt}} + +Implements %%>%% operator comparing each element in ''a'' with ''b'' +(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''. + +===== torch.ge(a, b) ===== +{{anchor:torch.lt}} + +Implements %%>=%% operator comparing each element in ''a'' with ''b'' +(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''. + +===== torch.eq(a, b) ===== +{{anchor:torch.lt}} + +Implements %%==%% operator comparing each element in ''a'' with ''b'' +(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''. + +===== torch.ne(a, b) ===== +{{anchor:torch.lt}} + +Implements %%!=%% operator comparing each element in ''a'' with ''b'' +(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''. + + + + +> a = torch.rand(10) +> b = torch.rand(10) +> =a + 0.5694 + 0.5264 + 0.3041 + 0.4159 + 0.1677 + 0.7964 + 0.0257 + 0.2093 + 0.6564 + 0.0740 +[torch.DoubleTensor of dimension 10] + +> =b + 0.2950 + 0.4867 + 0.9133 + 0.1291 + 0.1811 + 0.3921 + 0.7750 + 0.3259 + 0.2263 + 0.1737 +[torch.DoubleTensor of dimension 10] + +> =torch.lt(a,b) + 0 + 0 + 1 + 0 + 1 + 0 + 1 + 1 + 0 + 1 +[torch.ByteTensor of dimension 10] + +> return torch.eq(a,b) +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +[torch.ByteTensor of dimension 10] + +> return torch.ne(a,b) + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 +[torch.ByteTensor of dimension 10] + +> return torch.gt(a,b) + 1 + 1 + 0 + 1 + 0 + 1 + 0 + 0 + 1 + 0 +[torch.ByteTensor of dimension 10] + +> a[torch.gt(a,b)] = 10 +> =a + 10.0000 + 10.0000 + 0.3041 + 10.0000 + 0.1677 + 10.0000 + 0.0257 + 0.2093 + 10.0000 + 0.0740 +[torch.DoubleTensor of dimension 10] + +> a[torch.gt(a,1)] = -1 +> =a +-1.0000 +-1.0000 + 0.3041 +-1.0000 + 0.1677 +-1.0000 + 0.0257 + 0.2093 +-1.0000 + 0.0740 +[torch.DoubleTensor of dimension 10] + + + \ No newline at end of file diff --git a/generic/Tensor.c b/generic/Tensor.c index 575d4c532164..30b025ef96d4 100644 --- a/generic/Tensor.c +++ b/generic/Tensor.c @@ -470,6 +470,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); THLongStorage *idx = NULL; + THByteTensor *mask; if(lua_isnumber(L, 2)) { @@ -556,6 +557,22 @@ static int torch_Tensor_(__newindex__)(lua_State *L) THStorage_(set)(tensor->storage, index, value); lua_pushboolean(L, 1); } + else if((mask = luaT_toudata(L, 2, torch_ByteTensor_id))) + { + THTensor *vals; + if (lua_isnumber(L, 3)) + { + THTensor_(maskedFill)(tensor, mask, (real)(luaL_checknumber(L,3))); + } + else if((vals = luaT_toudata(L, 3, torch_Tensor_id))) + { + THTensor_(maskedCopy)(tensor, mask, vals); + } + else + { + luaL_error(L,"number or tensor expected"); + } + } else lua_pushboolean(L, 0); diff --git a/lib/TH/THTensorApply.h b/lib/TH/THTensorApply.h index 761623c43f9c..f5250885c402 100644 --- a/lib/TH/THTensorApply.h +++ b/lib/TH/THTensorApply.h @@ -9,7 +9,7 @@ TYPE2 *TENSOR2##_data = NULL; \ long *TENSOR2##_counter = NULL; \ long TENSOR2##_stride = 0, TENSOR2##_size = 0, TENSOR2##_dim = 0, TENSOR2##_i, TENSOR2##_n; \ - TYPE2 *TENSOR3##_data = NULL; \ + TYPE3 *TENSOR3##_data = NULL; \ long *TENSOR3##_counter = NULL; \ long TENSOR3##_stride = 0, TENSOR3##_size = 0, TENSOR3##_dim = 0, TENSOR3##_i, TENSOR3##_n; \ int TH_TENSOR_APPLY_hasFinished = 0; \ diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index e66d6ba9cd2b..ab80e5cbc028 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -14,6 +14,36 @@ void THTensor_(zero)(THTensor *r_) THVector_(fill)(r__data, 0, r__size); break;); } +void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, real value) +{ + TH_TENSOR_APPLY2(real, tensor, unsigned char, mask, + if (*mask_data > 1) THError("Mask tensor can take 0 and 1 values only"); + else if (*mask_data == 1) *tensor_data = value;); +} + +void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src ) +{ + THTensor *srct = THTensor_(newContiguous)(src); + real *src_data = srct->storage->data; + long cntr = 0; + long nelem = THTensor_(nElement)(srct); + TH_TENSOR_APPLY2(real, tensor, unsigned char, mask, + if (*mask_data > 1) + { + THError("Mask tensor can take 0 and 1 values only"); + } + else if (*mask_data == 1) + { + *tensor_data = *src_data; + src_data++; + cntr++; + if (cntr > nelem) + THError("Number of elements of src != mask"); + }); + if (cntr != nelem) + THError("Number of elements of src != mask"); +} + accreal THTensor_(dot)(THTensor *tensor, THTensor *src) { accreal sum = 0; @@ -847,6 +877,29 @@ void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension) } } +#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP) \ + void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, real value) \ + { \ + THByteTensor_rawResize(r_, t->nDimension, t->size, NULL); \ + THByteTensor_zero(r_); \ + TH_TENSOR_APPLY2(unsigned char, r_, real, t, \ + if (*t_data OP value) *r__data = 1;); \ + } \ + void THTensor_(NAME##Tensor)(THByteTensor *r_, THTensor *ta, THTensor *tb) \ + { \ + THByteTensor_rawResize(r_, ta->nDimension, ta->size, NULL); \ + THByteTensor_zero(r_); \ + TH_TENSOR_APPLY3(unsigned char, r_, real, ta, real, tb, \ + if(*ta_data OP *tb_data) *r__data = 1;); \ + } \ + +TENSOR_IMPLEMENT_LOGICAL(lt,<) +TENSOR_IMPLEMENT_LOGICAL(gt,>) +TENSOR_IMPLEMENT_LOGICAL(le,<=) +TENSOR_IMPLEMENT_LOGICAL(ge,>=) +TENSOR_IMPLEMENT_LOGICAL(eq,==) +TENSOR_IMPLEMENT_LOGICAL(ne,!=) + /* floating point only now */ #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h index ba0b9913b06d..ec8e8be2579f 100644 --- a/lib/TH/generic/THTensorMath.h +++ b/lib/TH/generic/THTensorMath.h @@ -5,6 +5,9 @@ TH_API void THTensor_(fill)(THTensor *r_, real value); TH_API void THTensor_(zero)(THTensor *r_); +TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, real value); +TH_API void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src); + TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src); TH_API real THTensor_(minall)(THTensor *t); @@ -50,6 +53,20 @@ TH_API void THTensor_(tril)(THTensor *r_, THTensor *t, long k); TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, long k); TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension); +TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, real value); +TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, real value); +TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, real value); +TH_API void THTensor_(geValue)(THByteTensor *r_, THTensor* t, real value); +TH_API void THTensor_(neValue)(THByteTensor *r_, THTensor* t, real value); +TH_API void THTensor_(eqValue)(THByteTensor *r_, THTensor* t, real value); + +TH_API void THTensor_(ltTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(leTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(gtTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(geTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(neTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(eqTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); + #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) TH_API void THTensor_(log)(THTensor *r_, THTensor *t); diff --git a/test/test.lua b/test/test.lua index 4abc18a61724..917b8362fd0b 100644 --- a/test/test.lua +++ b/test/test.lua @@ -344,6 +344,22 @@ function torchtest.conv3() mytester:asserteq(maxdiff(immfc[1],imfc),0,'torch.conv3') end +function torchtest.logical() + local x = torch.rand(100,100)*2-1; + local xx = x:clone() + + local xgt = torch.gt(x,1) + local xlt = torch.lt(x,1) + + local xeq = torch.eq(x,1) + local xne = torch.ne(x,1) + + local neqs = xgt+xlt + local all = neqs + xeq + mytester:asserteq(neqs:sumall(), xne:sumall(), 'torch.logical') + mytester:asserteq(x:nElement(),all:double():sumall() , 'torch.logical') +end + function torch.test() math.randomseed(os.time()) mytester = torch.Tester()