mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Merge branch 'logical' into newpack
Conflicts: pkg/torch/TensorMathWrap.lua
This commit is contained in:
@ -30,7 +30,13 @@ for _,tensortype in ipairs({'ByteTensor',
|
||||
'ceil',
|
||||
'floor',
|
||||
'abs',
|
||||
'sign'
|
||||
'sign',
|
||||
'lt',
|
||||
'gt',
|
||||
'le',
|
||||
'ge',
|
||||
'eq',
|
||||
'ne'
|
||||
}) do
|
||||
|
||||
local torchfunc = torch[tensortype].torch[func]
|
||||
|
@ -660,6 +660,18 @@ static void THTensor_random1__(THTensor *self, long b)
|
||||
{name='charoption', default="X", invisible=true}}
|
||||
)
|
||||
|
||||
for _,name in pairs({'lt','gt','le','ge','eq','ne'}) do
|
||||
interface:wrap(name,
|
||||
cname(name .. 'Value'),
|
||||
{{name='ByteTensor',default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=real}},
|
||||
cname(name .. 'Tensor'),
|
||||
{{name='ByteTensor',default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=Tensor}})
|
||||
end
|
||||
|
||||
if Tensor == 'FloatTensor' or Tensor == 'DoubleTensor' then
|
||||
|
||||
interface:wrap("mean",
|
||||
|
159
dok/maths.dok
159
dok/maths.dok
@ -11,6 +11,7 @@ categories:
|
||||
* [[#torch.matrixwide.dok|matrix-wide operations]] like [[#torch.trace|trace]] and [[#torch.norm|norm]].
|
||||
* [[#torch.conv.dok|Convolution and cross-correlation]] operations like [[#torch.conv2|conv2]].
|
||||
* [[#torch.linalg.dok|Basic linear algebra operations]] like [[#torch.eig|eigen value/vector calculation]], [[#torch.svd|singular value decomposition (svd)]] and [[#torch.gesv|linear system solution]].
|
||||
* [[#torch.logical.dok|Logical Operations on Tensors]].
|
||||
|
||||
By default, all operations allocate a new tensor to return the
|
||||
result. However, all functions also support passing the resulting(s)
|
||||
@ -802,3 +803,161 @@ u,s,v = torch.svd(a)
|
||||
|
||||
</file>
|
||||
|
||||
====== Logical Operations on Tensors ======
|
||||
{{anchor:torch.logical.dok}}
|
||||
|
||||
These functions implement logical comparison operators that take a
|
||||
tensor as input and another tensor or a number as the comparison
|
||||
target. They return a ''ByteTensor'' in which each element is 0 or 1
|
||||
indicating if the comparison for the corresponding element was
|
||||
''false'' or ''true'' respectively.
|
||||
|
||||
===== torch.lt(a, b) =====
|
||||
{{anchor:torch.lt}}
|
||||
|
||||
Implements %%<%% operator comparing each element in ''a'' with ''b''
|
||||
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
|
||||
|
||||
===== torch.le(a, b) =====
|
||||
{{anchor:torch.lt}}
|
||||
|
||||
Implements %%<=%% operator comparing each element in ''a'' with ''b''
|
||||
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
|
||||
|
||||
===== torch.gt(a, b) =====
|
||||
{{anchor:torch.lt}}
|
||||
|
||||
Implements %%>%% operator comparing each element in ''a'' with ''b''
|
||||
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
|
||||
|
||||
===== torch.ge(a, b) =====
|
||||
{{anchor:torch.lt}}
|
||||
|
||||
Implements %%>=%% operator comparing each element in ''a'' with ''b''
|
||||
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
|
||||
|
||||
===== torch.eq(a, b) =====
|
||||
{{anchor:torch.lt}}
|
||||
|
||||
Implements %%==%% operator comparing each element in ''a'' with ''b''
|
||||
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
|
||||
|
||||
===== torch.ne(a, b) =====
|
||||
{{anchor:torch.lt}}
|
||||
|
||||
Implements %%!=%% operator comparing each element in ''a'' with ''b''
|
||||
(if ''b'' is a number) or each element in ''a'' with corresponding element in ''b''.
|
||||
|
||||
|
||||
<file lua>
|
||||
|
||||
> a = torch.rand(10)
|
||||
> b = torch.rand(10)
|
||||
> =a
|
||||
0.5694
|
||||
0.5264
|
||||
0.3041
|
||||
0.4159
|
||||
0.1677
|
||||
0.7964
|
||||
0.0257
|
||||
0.2093
|
||||
0.6564
|
||||
0.0740
|
||||
[torch.DoubleTensor of dimension 10]
|
||||
|
||||
> =b
|
||||
0.2950
|
||||
0.4867
|
||||
0.9133
|
||||
0.1291
|
||||
0.1811
|
||||
0.3921
|
||||
0.7750
|
||||
0.3259
|
||||
0.2263
|
||||
0.1737
|
||||
[torch.DoubleTensor of dimension 10]
|
||||
|
||||
> =torch.lt(a,b)
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
[torch.ByteTensor of dimension 10]
|
||||
|
||||
> return torch.eq(a,b)
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
[torch.ByteTensor of dimension 10]
|
||||
|
||||
> return torch.ne(a,b)
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
[torch.ByteTensor of dimension 10]
|
||||
|
||||
> return torch.gt(a,b)
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
[torch.ByteTensor of dimension 10]
|
||||
|
||||
> a[torch.gt(a,b)] = 10
|
||||
> =a
|
||||
10.0000
|
||||
10.0000
|
||||
0.3041
|
||||
10.0000
|
||||
0.1677
|
||||
10.0000
|
||||
0.0257
|
||||
0.2093
|
||||
10.0000
|
||||
0.0740
|
||||
[torch.DoubleTensor of dimension 10]
|
||||
|
||||
> a[torch.gt(a,1)] = -1
|
||||
> =a
|
||||
-1.0000
|
||||
-1.0000
|
||||
0.3041
|
||||
-1.0000
|
||||
0.1677
|
||||
-1.0000
|
||||
0.0257
|
||||
0.2093
|
||||
-1.0000
|
||||
0.0740
|
||||
[torch.DoubleTensor of dimension 10]
|
||||
|
||||
|
||||
</file>
|
@ -470,6 +470,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THLongStorage *idx = NULL;
|
||||
THByteTensor *mask;
|
||||
|
||||
if(lua_isnumber(L, 2))
|
||||
{
|
||||
@ -556,6 +557,22 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
|
||||
THStorage_(set)(tensor->storage, index, value);
|
||||
lua_pushboolean(L, 1);
|
||||
}
|
||||
else if((mask = luaT_toudata(L, 2, torch_ByteTensor_id)))
|
||||
{
|
||||
THTensor *vals;
|
||||
if (lua_isnumber(L, 3))
|
||||
{
|
||||
THTensor_(maskedFill)(tensor, mask, (real)(luaL_checknumber(L,3)));
|
||||
}
|
||||
else if((vals = luaT_toudata(L, 3, torch_Tensor_id)))
|
||||
{
|
||||
THTensor_(maskedCopy)(tensor, mask, vals);
|
||||
}
|
||||
else
|
||||
{
|
||||
luaL_error(L,"number or tensor expected");
|
||||
}
|
||||
}
|
||||
else
|
||||
lua_pushboolean(L, 0);
|
||||
|
||||
|
@ -9,7 +9,7 @@
|
||||
TYPE2 *TENSOR2##_data = NULL; \
|
||||
long *TENSOR2##_counter = NULL; \
|
||||
long TENSOR2##_stride = 0, TENSOR2##_size = 0, TENSOR2##_dim = 0, TENSOR2##_i, TENSOR2##_n; \
|
||||
TYPE2 *TENSOR3##_data = NULL; \
|
||||
TYPE3 *TENSOR3##_data = NULL; \
|
||||
long *TENSOR3##_counter = NULL; \
|
||||
long TENSOR3##_stride = 0, TENSOR3##_size = 0, TENSOR3##_dim = 0, TENSOR3##_i, TENSOR3##_n; \
|
||||
int TH_TENSOR_APPLY_hasFinished = 0; \
|
||||
|
@ -14,6 +14,36 @@ void THTensor_(zero)(THTensor *r_)
|
||||
THVector_(fill)(r__data, 0, r__size); break;);
|
||||
}
|
||||
|
||||
void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, real value)
|
||||
{
|
||||
TH_TENSOR_APPLY2(real, tensor, unsigned char, mask,
|
||||
if (*mask_data > 1) THError("Mask tensor can take 0 and 1 values only");
|
||||
else if (*mask_data == 1) *tensor_data = value;);
|
||||
}
|
||||
|
||||
void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src )
|
||||
{
|
||||
THTensor *srct = THTensor_(newContiguous)(src);
|
||||
real *src_data = srct->storage->data;
|
||||
long cntr = 0;
|
||||
long nelem = THTensor_(nElement)(srct);
|
||||
TH_TENSOR_APPLY2(real, tensor, unsigned char, mask,
|
||||
if (*mask_data > 1)
|
||||
{
|
||||
THError("Mask tensor can take 0 and 1 values only");
|
||||
}
|
||||
else if (*mask_data == 1)
|
||||
{
|
||||
*tensor_data = *src_data;
|
||||
src_data++;
|
||||
cntr++;
|
||||
if (cntr > nelem)
|
||||
THError("Number of elements of src != mask");
|
||||
});
|
||||
if (cntr != nelem)
|
||||
THError("Number of elements of src != mask");
|
||||
}
|
||||
|
||||
accreal THTensor_(dot)(THTensor *tensor, THTensor *src)
|
||||
{
|
||||
accreal sum = 0;
|
||||
@ -847,6 +877,29 @@ void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension)
|
||||
}
|
||||
}
|
||||
|
||||
#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP) \
|
||||
void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, real value) \
|
||||
{ \
|
||||
THByteTensor_rawResize(r_, t->nDimension, t->size, NULL); \
|
||||
THByteTensor_zero(r_); \
|
||||
TH_TENSOR_APPLY2(unsigned char, r_, real, t, \
|
||||
if (*t_data OP value) *r__data = 1;); \
|
||||
} \
|
||||
void THTensor_(NAME##Tensor)(THByteTensor *r_, THTensor *ta, THTensor *tb) \
|
||||
{ \
|
||||
THByteTensor_rawResize(r_, ta->nDimension, ta->size, NULL); \
|
||||
THByteTensor_zero(r_); \
|
||||
TH_TENSOR_APPLY3(unsigned char, r_, real, ta, real, tb, \
|
||||
if(*ta_data OP *tb_data) *r__data = 1;); \
|
||||
} \
|
||||
|
||||
TENSOR_IMPLEMENT_LOGICAL(lt,<)
|
||||
TENSOR_IMPLEMENT_LOGICAL(gt,>)
|
||||
TENSOR_IMPLEMENT_LOGICAL(le,<=)
|
||||
TENSOR_IMPLEMENT_LOGICAL(ge,>=)
|
||||
TENSOR_IMPLEMENT_LOGICAL(eq,==)
|
||||
TENSOR_IMPLEMENT_LOGICAL(ne,!=)
|
||||
|
||||
/* floating point only now */
|
||||
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
|
||||
|
@ -5,6 +5,9 @@
|
||||
TH_API void THTensor_(fill)(THTensor *r_, real value);
|
||||
TH_API void THTensor_(zero)(THTensor *r_);
|
||||
|
||||
TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, real value);
|
||||
TH_API void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src);
|
||||
|
||||
TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src);
|
||||
|
||||
TH_API real THTensor_(minall)(THTensor *t);
|
||||
@ -50,6 +53,20 @@ TH_API void THTensor_(tril)(THTensor *r_, THTensor *t, long k);
|
||||
TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, long k);
|
||||
TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension);
|
||||
|
||||
TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, real value);
|
||||
TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, real value);
|
||||
TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, real value);
|
||||
TH_API void THTensor_(geValue)(THByteTensor *r_, THTensor* t, real value);
|
||||
TH_API void THTensor_(neValue)(THByteTensor *r_, THTensor* t, real value);
|
||||
TH_API void THTensor_(eqValue)(THByteTensor *r_, THTensor* t, real value);
|
||||
|
||||
TH_API void THTensor_(ltTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
|
||||
TH_API void THTensor_(leTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
|
||||
TH_API void THTensor_(gtTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
|
||||
TH_API void THTensor_(geTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
|
||||
TH_API void THTensor_(neTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
|
||||
TH_API void THTensor_(eqTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
|
||||
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
|
||||
|
||||
TH_API void THTensor_(log)(THTensor *r_, THTensor *t);
|
||||
|
@ -344,6 +344,22 @@ function torchtest.conv3()
|
||||
mytester:asserteq(maxdiff(immfc[1],imfc),0,'torch.conv3')
|
||||
end
|
||||
|
||||
function torchtest.logical()
|
||||
local x = torch.rand(100,100)*2-1;
|
||||
local xx = x:clone()
|
||||
|
||||
local xgt = torch.gt(x,1)
|
||||
local xlt = torch.lt(x,1)
|
||||
|
||||
local xeq = torch.eq(x,1)
|
||||
local xne = torch.ne(x,1)
|
||||
|
||||
local neqs = xgt+xlt
|
||||
local all = neqs + xeq
|
||||
mytester:asserteq(neqs:sumall(), xne:sumall(), 'torch.logical')
|
||||
mytester:asserteq(x:nElement(),all:double():sumall() , 'torch.logical')
|
||||
end
|
||||
|
||||
function torch.test()
|
||||
math.randomseed(os.time())
|
||||
mytester = torch.Tester()
|
||||
|
Reference in New Issue
Block a user