mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove THTensor::_dim, temporarily remove THTensor_nDimension. (#9895)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9895 The primary goal here was to remove THTensor::_dim, which isn't part of the API moving forward. Instead, we provide 3 options for getting the dimensionality (this is temporary although non-trivial to remove!): ``` nDimension corresponds to the "true" ATen dimension. TODO: implement. nDimensionLegacyNoScalars correpsonds to the ATen dimension, except scalars are viewed as 1-dimensional tensors. nDimensionLegacyAll corresponds to the ATen dimension, except scalars are viewed as 1-dimensional tensors and tensors with a dimension of size zero are collapsed to 0-dimensional tensors. ``` So in this patch, nDimension -> nDimensionLegacyNoScalars and _dim/_nDimension goes to nDimensionLegacyAll. These are just codemods. Pull Request resolved: https://github.com/pytorch/pytorch/pull/9835 Reviewed By: ezyang Differential Revision: D8999338 Pulled By: gchanan fbshipit-source-id: a4d676ac728f6f36ca09604a41e888d545ae9311
This commit is contained in:
committed by
Facebook Github Bot
parent
bc66d98248
commit
1af1b0c2a5
@ -52,13 +52,6 @@ struct THTensor
|
||||
return storage_->unsafe_data<T>() + storage_offset_;
|
||||
}
|
||||
|
||||
// [NOTE: _dim() vs dim()]
|
||||
// _dim() returns the "old" TH dimension view where no dimensions represents an empty tensor.
|
||||
// dim() returns the ATen view of the dimensionality, i.e. 0-sized dimensions are supported.
|
||||
inline int64_t _dim() const {
|
||||
return is_empty() ? 0 : dim();
|
||||
}
|
||||
|
||||
inline int64_t dim() const {
|
||||
return sizes_.size();
|
||||
}
|
||||
@ -159,6 +152,31 @@ inline void THTensor_setIsZeroDim(THTensor *tensor, bool is_zero_dim) {
|
||||
tensor->is_zero_dim_ = is_zero_dim;
|
||||
}
|
||||
|
||||
// [NOTE: nDimension vs nDimensionLegacyNoScalars vs nDimensionLegacyAll]
|
||||
// nDimension corresponds to the "true" ATen dimension. TODO: implement.
|
||||
// nDimensionLegacyNoScalars correpsonds to the ATen dimension, except scalars are viewed as 1-dimensional tensors.
|
||||
// nDimensionLegacyAll corresponds to the ATen dimension, except scalars are viewed as 1-dimensional tensors
|
||||
// and tensors with a dimension of size zero are collapsed to 0-dimensional tensors.
|
||||
//
|
||||
// Eventually, everything should go through nDimension or tensor->dim().
|
||||
inline int THTensor_nDimensionLegacyNoScalars(const THTensor* tensor) {
|
||||
if (THTensor_isZeroDim(tensor)) {
|
||||
return 1;
|
||||
} else {
|
||||
return tensor->dim();
|
||||
}
|
||||
}
|
||||
|
||||
inline int THTensor_nDimensionLegacyAll(const THTensor* tensor) {
|
||||
if (tensor->is_empty()) {
|
||||
return 0;
|
||||
} else if (THTensor_isZeroDim(tensor)) {
|
||||
return 1;
|
||||
} else {
|
||||
return tensor->dim();
|
||||
}
|
||||
}
|
||||
|
||||
TH_API void THTensor_free(THTensor *self);
|
||||
TH_CPP_API at::optional<std::vector<int64_t>> THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride,
|
||||
at::IntList newshape);
|
||||
|
@ -46,7 +46,7 @@
|
||||
TENSOR##_data = THTensor_getStoragePtr(TENSOR)->data<TYPE>()+TENSOR->storage_offset(); \
|
||||
TENSOR##_size = 1; \
|
||||
TENSOR##_stride = 1; \
|
||||
for(TENSOR##_i = TENSOR->_dim()-1; TENSOR##_i >= 0; TENSOR##_i--) { \
|
||||
for(TENSOR##_i = THTensor_nDimensionLegacyAll(TENSOR)-1; TENSOR##_i >= 0; TENSOR##_i--) { \
|
||||
if(TENSOR->size(TENSOR##_i) != 1) { \
|
||||
if(TENSOR->stride(TENSOR##_i) == TENSOR##_size && TENSOR##_i != DIM) \
|
||||
TENSOR##_size *= TENSOR->size(TENSOR##_i); \
|
||||
@ -59,7 +59,7 @@
|
||||
if (!TENSOR##_contiguous) { \
|
||||
/* Find the dimension of contiguous sections */ \
|
||||
TENSOR##_dim = 1; \
|
||||
for(TENSOR##_i = TENSOR->_dim()-2; TENSOR##_i >= 0; TENSOR##_i--) \
|
||||
for(TENSOR##_i = THTensor_nDimensionLegacyAll(TENSOR)-2; TENSOR##_i >= 0; TENSOR##_i--) \
|
||||
{ \
|
||||
if(TENSOR->stride(TENSOR##_i) != TENSOR->stride(TENSOR##_i+1) * TENSOR->size(TENSOR##_i+1) || TENSOR##_i == DIM || TENSOR##_i+1 == DIM) \
|
||||
TENSOR##_dim++; \
|
||||
@ -69,19 +69,19 @@
|
||||
TENSOR##_sizes = TENSOR##_counter + TENSOR##_dim; \
|
||||
TENSOR##_strides = TENSOR##_counter + 2*TENSOR##_dim; \
|
||||
TH_TENSOR_dim_index = TENSOR##_dim-1; \
|
||||
TENSOR##_dimOffset = (DIM == TENSOR->_dim()-1) ? &TENSOR##_i : &TENSOR##_counter[DIM]; \
|
||||
TENSOR##_sizes[TH_TENSOR_dim_index] = TENSOR->size(TENSOR->_dim()-1); \
|
||||
TENSOR##_strides[TH_TENSOR_dim_index] = TENSOR->stride(TENSOR->_dim()-1); \
|
||||
TENSOR##_dimOffset = (DIM == THTensor_nDimensionLegacyAll(TENSOR)-1) ? &TENSOR##_i : &TENSOR##_counter[DIM]; \
|
||||
TENSOR##_sizes[TH_TENSOR_dim_index] = TENSOR->size(THTensor_nDimensionLegacyAll(TENSOR)-1); \
|
||||
TENSOR##_strides[TH_TENSOR_dim_index] = TENSOR->stride(THTensor_nDimensionLegacyAll(TENSOR)-1); \
|
||||
/* TENSOR##_counter tracks where we are in the storage. The offset into the */ \
|
||||
/* storage is given by storage_offset + (i * j), where i is the stride */ \
|
||||
/* vector and j is tensor_counter vector. This sets the starting position for the loop. */ \
|
||||
for(TENSOR##_i = TENSOR##_dim-1; TENSOR##_i >= 0; --TENSOR##_i) { \
|
||||
TENSOR##_counter[TENSOR##_i] = 0; \
|
||||
} \
|
||||
for(TENSOR##_i = TENSOR->_dim()-2; TENSOR##_i >= 0; --TENSOR##_i) { \
|
||||
for(TENSOR##_i = THTensor_nDimensionLegacyAll(TENSOR)-2; TENSOR##_i >= 0; --TENSOR##_i) { \
|
||||
if (TENSOR->stride(TENSOR##_i) == TENSOR->stride(TENSOR##_i+1) * TENSOR->size(TENSOR##_i+1) && TENSOR##_i != DIM && TENSOR##_i+1 != DIM) { \
|
||||
TENSOR##_sizes[TH_TENSOR_dim_index] = TENSOR->size(TENSOR##_i) * TENSOR##_sizes[TH_TENSOR_dim_index]; \
|
||||
if (DIM != TENSOR->_dim()-1 && TENSOR##_i < DIM) \
|
||||
if (DIM != THTensor_nDimensionLegacyAll(TENSOR)-1 && TENSOR##_i < DIM) \
|
||||
TENSOR##_dimOffset--; \
|
||||
} else { \
|
||||
--TH_TENSOR_dim_index; \
|
||||
|
@ -146,7 +146,7 @@
|
||||
int TH_TENSOR_DIM_APPLY_i; \
|
||||
\
|
||||
if( (DIMENSION < 0) || (DIMENSION >= TENSOR1->dim()) ) \
|
||||
THError("invalid dimension %d (expected to be 0 <= dim < %d)", DIMENSION, TENSOR1->_dim()); \
|
||||
THError("invalid dimension %d (expected to be 0 <= dim < %d)", DIMENSION, THTensor_nDimensionLegacyAll(TENSOR1)); \
|
||||
if( TENSOR1->dim() != TENSOR2->dim() ) { \
|
||||
AT_ERROR("inconsistent tensor size, expected ", #TENSOR1, " ", TENSOR1->sizes(), " and ", #TENSOR2, " ", TENSOR2->sizes(), " to have the same number of dimensions"); \
|
||||
} \
|
||||
@ -266,25 +266,25 @@
|
||||
int TH_TENSOR_DIM_APPLY_hasFinished = 0; \
|
||||
int TH_TENSOR_DIM_APPLY_i; \
|
||||
\
|
||||
if( (DIMENSION < 0) || (DIMENSION >= TENSOR->_dim()) ) \
|
||||
if( (DIMENSION < 0) || (DIMENSION >= THTensor_nDimensionLegacyAll(TENSOR)) ) \
|
||||
THError("invalid dimension"); \
|
||||
\
|
||||
TENSOR##_data = THTensor_getStoragePtr(TENSOR)->data<TYPE>()+(TENSOR)->storage_offset(); \
|
||||
TENSOR##_stride = (TENSOR)->stride(DIMENSION); \
|
||||
TENSOR##_size = TENSOR->size(DIMENSION); \
|
||||
/* Counter stores the indices into the Tensor at any time */ \
|
||||
TH_TENSOR_DIM_APPLY_counter = (int64_t*)THAlloc(sizeof(int64_t)*(TENSOR->_dim())); \
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR->_dim(); TH_TENSOR_DIM_APPLY_i++) \
|
||||
TH_TENSOR_DIM_APPLY_counter = (int64_t*)THAlloc(sizeof(int64_t)*(THTensor_nDimensionLegacyAll(TENSOR))); \
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < THTensor_nDimensionLegacyAll(TENSOR); TH_TENSOR_DIM_APPLY_i++) \
|
||||
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
|
||||
\
|
||||
while(!TH_TENSOR_DIM_APPLY_hasFinished) \
|
||||
{ \
|
||||
CODE \
|
||||
\
|
||||
if(TENSOR->_dim() == 1) \
|
||||
if(THTensor_nDimensionLegacyAll(TENSOR) == 1) \
|
||||
break; \
|
||||
\
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR->_dim(); TH_TENSOR_DIM_APPLY_i++) \
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < THTensor_nDimensionLegacyAll(TENSOR); TH_TENSOR_DIM_APPLY_i++) \
|
||||
{ \
|
||||
/* Check if the index is equal to DIMENSION. We don't need to update the */ \
|
||||
/* offset if this is the case, and can consider the next index. However, */ \
|
||||
@ -292,7 +292,7 @@
|
||||
/* we have parsed the entire tensor and can exit */ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == TENSOR->_dim()-1) \
|
||||
if(TH_TENSOR_DIM_APPLY_i == THTensor_nDimensionLegacyAll(TENSOR)-1) \
|
||||
{ \
|
||||
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
@ -307,7 +307,7 @@
|
||||
if(TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR->size(TH_TENSOR_DIM_APPLY_i)) \
|
||||
{ \
|
||||
/* Handled TENSOR_size(dim) iterations for DIM_APPLY_i. If this is the last dimension, exit */ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == TENSOR->_dim()-1) \
|
||||
if(TH_TENSOR_DIM_APPLY_i == THTensor_nDimensionLegacyAll(TENSOR)-1) \
|
||||
{ \
|
||||
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
|
@ -15,27 +15,27 @@ ptrdiff_t THTensor_(storageOffset)(const THTensor *self)
|
||||
return self->storage_offset();
|
||||
}
|
||||
|
||||
int THTensor_(nDimension)(const THTensor *self)
|
||||
int THTensor_(nDimensionLegacyNoScalars)(const THTensor *self)
|
||||
{
|
||||
return self->dim();
|
||||
return THTensor_nDimensionLegacyNoScalars(self);
|
||||
}
|
||||
|
||||
int THTensor_(_nDimension)(const THTensor *self)
|
||||
int THTensor_(nDimensionLegacyAll)(const THTensor *self)
|
||||
{
|
||||
return self->_dim();
|
||||
return THTensor_nDimensionLegacyAll(self);
|
||||
}
|
||||
|
||||
int64_t THTensor_(size)(const THTensor *self, int dim)
|
||||
{
|
||||
THArgCheck((dim >= 0) && (dim < self->dim()), 2, "dimension %d out of range of %dD tensor",
|
||||
dim+TH_INDEX_BASE, THTensor_(nDimension)(self));
|
||||
dim+TH_INDEX_BASE, THTensor_(nDimensionLegacyNoScalars)(self));
|
||||
return self->size(dim);
|
||||
}
|
||||
|
||||
int64_t THTensor_(stride)(const THTensor *self, int dim)
|
||||
{
|
||||
THArgCheck((dim >= 0) && (dim < self->dim()), 2, "dimension %d out of range of %dD tensor",
|
||||
dim+TH_INDEX_BASE, THTensor_(nDimension)(self));
|
||||
dim+TH_INDEX_BASE, THTensor_(nDimensionLegacyNoScalars)(self));
|
||||
return self->stride(dim);
|
||||
}
|
||||
|
||||
@ -397,7 +397,7 @@ void THTensor_(select)(THTensor *self, THTensor *src, int dimension, int64_t sli
|
||||
src = self;
|
||||
|
||||
#ifndef USE_TH_SIZE_ZERO_DIM
|
||||
THArgCheck(src->_dim() > 1, 1, "cannot select on a vector");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(src) > 1, 1, "cannot select on a vector");
|
||||
#else
|
||||
#ifndef USE_TH_SCALAR
|
||||
THArgCheck(src->dim() > 1, 1, "cannot select on a vector");
|
||||
@ -575,7 +575,7 @@ int THTensor_(isTransposed)(const THTensor *self)
|
||||
int64_t size_max_stride = 1;
|
||||
int64_t z = 1;
|
||||
int d;
|
||||
for (d = 0; d < self->_dim(); ++d) {
|
||||
for (d = 0; d < THTensor_nDimensionLegacyAll(self); ++d) {
|
||||
if (self->stride(d) == 0 && self->size(d) != 1)
|
||||
return 0;
|
||||
if (self->stride(d) > max_stride) {
|
||||
@ -611,10 +611,10 @@ int THTensor_(isContiguous)(const THTensor *self)
|
||||
int THTensor_(isSize)(const THTensor *self, const THLongStorage *dims)
|
||||
{
|
||||
int d;
|
||||
if (self->_dim() != dims->size)
|
||||
if (THTensor_nDimensionLegacyAll(self) != dims->size)
|
||||
return 0;
|
||||
|
||||
for(d = 0; d < self->_dim(); ++d)
|
||||
for(d = 0; d < THTensor_nDimensionLegacyAll(self); ++d)
|
||||
{
|
||||
if(self->size(d) != THLongStorage_data(dims)[d])
|
||||
return 0;
|
||||
@ -641,10 +641,10 @@ int THTensor_(isSetTo)(const THTensor *self, const THTensor* src)
|
||||
return 0;
|
||||
if (THTensor_getStoragePtr(self) == THTensor_getStoragePtr(src) &&
|
||||
self->storage_offset() == src->storage_offset() &&
|
||||
self->_dim() == src->_dim())
|
||||
THTensor_nDimensionLegacyAll(self) == THTensor_nDimensionLegacyAll(src))
|
||||
{
|
||||
int d;
|
||||
for (d = 0; d < self->_dim(); ++d)
|
||||
for (d = 0; d < THTensor_nDimensionLegacyAll(self); ++d)
|
||||
{
|
||||
if (self->size(d) != src->size(d) || self->stride(d) != src->stride(d))
|
||||
return 0;
|
||||
@ -656,13 +656,13 @@ int THTensor_(isSetTo)(const THTensor *self, const THTensor* src)
|
||||
|
||||
ptrdiff_t THTensor_(nElement)(const THTensor *self)
|
||||
{
|
||||
if(self->_dim() == 0)
|
||||
if(THTensor_nDimensionLegacyAll(self) == 0)
|
||||
return 0;
|
||||
else
|
||||
{
|
||||
ptrdiff_t nElement = 1;
|
||||
int d;
|
||||
for(d = 0; d < self->_dim(); d++)
|
||||
for(d = 0; d < THTensor_nDimensionLegacyAll(self); d++)
|
||||
nElement *= self->size(d);
|
||||
return nElement;
|
||||
}
|
||||
@ -790,56 +790,56 @@ void THTensor_(resizeNd)(THTensor *self, int nDimension, int64_t *size, int64_t
|
||||
|
||||
void THTensor_(set1d)(THTensor *tensor, int64_t x0, real value)
|
||||
{
|
||||
THArgCheck(tensor->_dim() == 1, 1, "tensor must have one dimension");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 1, 1, "tensor must have one dimension");
|
||||
THArgCheck( (x0 >= 0) && (x0 < tensor->size(0)), 2, "out of range");
|
||||
THStorage_(set)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0), value);
|
||||
}
|
||||
|
||||
real THTensor_(get1d)(const THTensor *tensor, int64_t x0)
|
||||
{
|
||||
THArgCheck(tensor->_dim() == 1, 1, "tensor must have one dimension");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 1, 1, "tensor must have one dimension");
|
||||
THArgCheck( (x0 >= 0) && (x0 < tensor->size(0)), 2, "out of range");
|
||||
return THStorage_(get)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0));
|
||||
}
|
||||
|
||||
void THTensor_(set2d)(THTensor *tensor, int64_t x0, int64_t x1, real value)
|
||||
{
|
||||
THArgCheck(tensor->_dim() == 2, 1, "tensor must have two dimensions");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 2, 1, "tensor must have two dimensions");
|
||||
THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)), 2, "out of range");
|
||||
THStorage_(set)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1), value);
|
||||
}
|
||||
|
||||
real THTensor_(get2d)(const THTensor *tensor, int64_t x0, int64_t x1)
|
||||
{
|
||||
THArgCheck(tensor->_dim() == 2, 1, "tensor must have two dimensions");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 2, 1, "tensor must have two dimensions");
|
||||
THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)), 2, "out of range");
|
||||
return THStorage_(get)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1));
|
||||
}
|
||||
|
||||
void THTensor_(set3d)(THTensor *tensor, int64_t x0, int64_t x1, int64_t x2, real value)
|
||||
{
|
||||
THArgCheck(tensor->_dim() == 3, 1, "tensor must have three dimensions");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 3, 1, "tensor must have three dimensions");
|
||||
THArgCheck( (x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)), 2, "out of range");
|
||||
THStorage_(set)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2), value);
|
||||
}
|
||||
|
||||
real THTensor_(get3d)(const THTensor *tensor, int64_t x0, int64_t x1, int64_t x2)
|
||||
{
|
||||
THArgCheck(tensor->_dim() == 3, 1, "tensor must have three dimensions");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 3, 1, "tensor must have three dimensions");
|
||||
THArgCheck( (x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)), 2, "out of range");
|
||||
return THStorage_(get)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2));
|
||||
}
|
||||
|
||||
void THTensor_(set4d)(THTensor *tensor, int64_t x0, int64_t x1, int64_t x2, int64_t x3, real value)
|
||||
{
|
||||
THArgCheck(tensor->_dim() == 4, 1, "tensor must have four dimensions");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 4, 1, "tensor must have four dimensions");
|
||||
THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)) && (x3 >= 0) && (x3 < tensor->size(3)), 2, "out of range");
|
||||
THStorage_(set)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2)+x3*tensor->stride(3), value);
|
||||
}
|
||||
|
||||
real THTensor_(get4d)(const THTensor *tensor, int64_t x0, int64_t x1, int64_t x2, int64_t x3)
|
||||
{
|
||||
THArgCheck(tensor->_dim() == 4, 1, "tensor must have four dimensions");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 4, 1, "tensor must have four dimensions");
|
||||
THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)) && (x3 >= 0) && (x3 < tensor->size(3)), 2, "out of range");
|
||||
return THStorage_(get)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2)+x3*tensor->stride(3));
|
||||
}
|
||||
@ -853,10 +853,10 @@ THDescBuff THTensor_(desc)(const THTensor *tensor) {
|
||||
n += snprintf(str, L-n, "torch." _stringify(x) "Tensor of size ");
|
||||
#undef _stringify
|
||||
int i;
|
||||
for(i = 0; i < tensor->_dim(); i++) {
|
||||
for(i = 0; i < THTensor_nDimensionLegacyAll(tensor); i++) {
|
||||
if(n >= L) break;
|
||||
n += snprintf(str+n, L-n, "%" PRId64, tensor->size(i));
|
||||
if(i < tensor->_dim()-1) {
|
||||
if(i < THTensor_nDimensionLegacyAll(tensor)-1) {
|
||||
n += snprintf(str+n, L-n, "x");
|
||||
}
|
||||
}
|
||||
|
@ -24,9 +24,9 @@ typedef struct THTensor THTensor;
|
||||
TH_API THStorage* THTensor_(storage)(const THTensor *self);
|
||||
TH_API ptrdiff_t THTensor_(storageOffset)(const THTensor *self);
|
||||
|
||||
// See [NOTE: _dim() vs dim()]; _nDimension corresponds to _dim(), nDimension corresponds to dim().
|
||||
TH_API int THTensor_(nDimension)(const THTensor *self);
|
||||
TH_API int THTensor_(_nDimension)(const THTensor *self);
|
||||
// See [NOTE: nDimension vs nDimensionLegacyNoScalars vs nDimensionLegacyAll]
|
||||
TH_API int THTensor_(nDimensionLegacyNoScalars)(const THTensor *self);
|
||||
TH_API int THTensor_(nDimensionLegacyAll)(const THTensor *self);
|
||||
TH_API int64_t THTensor_(size)(const THTensor *self, int dim);
|
||||
TH_API int64_t THTensor_(stride)(const THTensor *self, int dim);
|
||||
TH_API THLongStorage *THTensor_(newSizeOf)(THTensor *self);
|
||||
|
@ -119,7 +119,7 @@
|
||||
#define TH_TENSOR_DIM_APPLY3_SIZE_SCATTER(TENSOR1, TENSOR2, TENSOR3, DIMENSION) \
|
||||
{ \
|
||||
int shape_check_flag = 0; \
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->_dim(); TH_TENSOR_DIM_APPLY_i++) \
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < THTensor_nDimensionLegacyAll(TENSOR1); TH_TENSOR_DIM_APPLY_i++) \
|
||||
{ \
|
||||
int64_t TENSOR3##_dim_size = TENSOR3->size(TH_TENSOR_DIM_APPLY_i); \
|
||||
if (TH_TENSOR_DIM_APPLY_i != DIMENSION) { \
|
||||
|
@ -1367,7 +1367,7 @@ void THTensor_(conv2Dmap)(THTensor *r_, real beta, real alpha, THTensor *t_, THT
|
||||
|
||||
AT_CHECK(!t_->is_empty() && t_->dim() == 3, "input: non-empty 3D Tensor expected, got size: ", t_->sizes());
|
||||
AT_CHECK(!k_->is_empty() && k_->dim() == 3, "kernel: non-empty 3D Tensor expected, got size: ", k_->sizes());
|
||||
THArgCheck(map->_dim() == 2 , 4, "map: 2D Tensor expected");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(map) == 2 , 4, "map: 2D Tensor expected");
|
||||
THArgCheck(srow >= 1, 6, "Stride should be a positive integer");
|
||||
THArgCheck(scol >= 1, 7, "Stride should be a positive integer");
|
||||
|
||||
@ -1880,7 +1880,7 @@ void THTensor_(conv3Dmap)(THTensor *r_, real beta, real alpha, THTensor *t_, THT
|
||||
|
||||
AT_CHECK(!t_->is_empty() && t_->dim() == 4, "input: non-empty 4D Tensor expected, got size: ", t_->sizes());
|
||||
AT_CHECK(!k_->is_empty() && k_->dim() == 4, "kernel: non-empty 4D Tensor expected, got size: ", k_->sizes());
|
||||
THArgCheck(map->_dim() == 2 , 4, "map: 2D Tensor expected");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(map) == 2 , 4, "map: 2D Tensor expected");
|
||||
THArgCheck(srow >= 1, 6, "Stride should be a positive integer");
|
||||
THArgCheck(scol >= 1, 7, "Stride should be a positive integer");
|
||||
THArgCheck(*vf == 'V' || *vf == 'F', 8, "type of convolution can 'V' or 'F'");
|
||||
|
@ -17,7 +17,7 @@ int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) {
|
||||
const int MIN_SZ = 60 * 60;
|
||||
return THTensor_(isContiguous)(tensor) &&
|
||||
!src->is_empty() &&
|
||||
THTensor_(nDimension)(src) == 2 &&
|
||||
THTensor_(nDimensionLegacyNoScalars)(src) == 2 &&
|
||||
THTensor_(stride)(src, 0) == 1 &&
|
||||
THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) &&
|
||||
THTensor_(nElement)(tensor) >= MIN_SZ;
|
||||
|
@ -150,9 +150,9 @@ void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTens
|
||||
real *tensor_data, *src_data;
|
||||
|
||||
#ifndef USE_TH_SIZE_ZERO_DIM
|
||||
THArgCheck(index->_dim() <= 1, 3, "Index is supposed to be an empty tensor or a vector");
|
||||
THArgCheck(dim < src->_dim(), 4, "Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
|
||||
THArgCheck(src->_dim() > 0, 2, "Source tensor is empty");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(index) <= 1, 3, "Index is supposed to be an empty tensor or a vector");
|
||||
THArgCheck(dim < THTensor_nDimensionLegacyAll(src), 4, "Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(src) > 0, 2, "Source tensor is empty");
|
||||
#else
|
||||
THArgCheck(index->dim() == 1, 3, "Index is supposed to be 1-dimensional");
|
||||
THArgCheck(dim < src->dim(), 4, "Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
|
||||
@ -261,7 +261,7 @@ void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTens
|
||||
static ptrdiff_t THTensor_(dataOffset)(THTensor* tensor, ptrdiff_t linearIndex) {
|
||||
auto size = tensor->sizes();
|
||||
auto stride = tensor->strides();
|
||||
int nDim = tensor->_dim();
|
||||
int nDim = THTensor_nDimensionLegacyAll(tensor);
|
||||
ptrdiff_t dataOffset = 0;
|
||||
for (int i = nDim - 1; i >= 0; i--) {
|
||||
dataOffset += (linearIndex % size[i]) * stride[i];
|
||||
@ -355,8 +355,8 @@ void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTenso
|
||||
|
||||
numel = THLongTensor_nElement(index);
|
||||
#ifndef USE_TH_SIZE_ZERO_DIM
|
||||
THArgCheck(index->_dim() == 1, 3, "Index is supposed to be a vector");
|
||||
THArgCheck(dim < src->_dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(index) == 1, 3, "Index is supposed to be a vector");
|
||||
THArgCheck(dim < THTensor_nDimensionLegacyAll(src), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
|
||||
#else
|
||||
THArgCheck(index->dim() == 1, 3, "Index is supposed to be a vector");
|
||||
THArgCheck(dim < src->dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
|
||||
@ -401,8 +401,8 @@ void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, real v
|
||||
|
||||
numel = THLongTensor_nElement(index);
|
||||
#ifndef USE_TH_SIZE_ZERO_DIM
|
||||
THArgCheck(index->_dim() == 1, 3, "Index is supposed to be a vector");
|
||||
THArgCheck(dim < tensor->_dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(index) == 1, 3, "Index is supposed to be a vector");
|
||||
THArgCheck(dim < THTensor_nDimensionLegacyAll(tensor), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
|
||||
#else
|
||||
THArgCheck(index->dim() == 1, 3, "Index is supposed to be a vector");
|
||||
THArgCheck(dim < tensor->dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
|
||||
@ -432,11 +432,11 @@ void THTensor_(gather)(THTensor *tensor, THTensor *src, int dim, THLongTensor *i
|
||||
{
|
||||
int64_t elems_per_row, i, idx;
|
||||
|
||||
THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(src), 4,
|
||||
THArgCheck(THLongTensor_nDimensionLegacyNoScalars(index) == THTensor_(nDimensionLegacyNoScalars)(src), 4,
|
||||
"Index tensor must have same dimensions as input tensor");
|
||||
THArgCheck(dim >= 0 && dim < THTensor_(nDimension)(tensor), 3,
|
||||
THArgCheck(dim >= 0 && dim < THTensor_(nDimensionLegacyNoScalars)(tensor), 3,
|
||||
"Index dimension is out of bounds");
|
||||
THArgCheck(THTensor_(nDimension)(src) == THTensor_(nDimension)(tensor), 2,
|
||||
THArgCheck(THTensor_(nDimensionLegacyNoScalars)(src) == THTensor_(nDimensionLegacyNoScalars)(tensor), 2,
|
||||
"Input tensor must have same dimensions as output tensor");
|
||||
|
||||
elems_per_row = THLongTensor_size(index, dim);
|
||||
@ -460,16 +460,16 @@ void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor
|
||||
int64_t elems_per_row, i, idx;
|
||||
|
||||
#ifndef USE_TH_SIZE_ZERO_DIM
|
||||
THArgCheck(dim < THTensor_(_nDimension)(tensor), 2, "Index dimension is out of bounds");
|
||||
THArgCheck(THLongTensor__nDimension(index) == THTensor_(_nDimension)(tensor), 3,
|
||||
THArgCheck(dim < THTensor_(nDimensionLegacyAll)(tensor), 2, "Index dimension is out of bounds");
|
||||
THArgCheck(THLongTensor_nDimensionLegacyAll(index) == THTensor_(nDimensionLegacyAll)(tensor), 3,
|
||||
"Index tensor must have same dimensions as output tensor");
|
||||
THArgCheck(THTensor_(_nDimension)(src) == THTensor_(_nDimension)(tensor), 4,
|
||||
THArgCheck(THTensor_(nDimensionLegacyAll)(src) == THTensor_(nDimensionLegacyAll)(tensor), 4,
|
||||
"Input tensor must have same dimensions as output tensor");
|
||||
#else
|
||||
THArgCheck(dim < THTensor_(nDimension)(tensor), 2, "Index dimension is out of bounds");
|
||||
THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(tensor), 3,
|
||||
THArgCheck(dim < THTensor_(nDimensionLegacyNoScalars)(tensor), 2, "Index dimension is out of bounds");
|
||||
THArgCheck(THLongTensor_nDimensionLegacyNoScalars(index) == THTensor_(nDimensionLegacyNoScalars)(tensor), 3,
|
||||
"Index tensor must have same dimensions as output tensor");
|
||||
THArgCheck(THTensor_(nDimension)(src) == THTensor_(nDimension)(tensor), 4,
|
||||
THArgCheck(THTensor_(nDimensionLegacyNoScalars)(src) == THTensor_(nDimensionLegacyNoScalars)(tensor), 4,
|
||||
"Input tensor must have same dimensions as output tensor");
|
||||
#endif
|
||||
|
||||
@ -493,10 +493,10 @@ void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTen
|
||||
{
|
||||
int64_t elems_per_row, i, idx;
|
||||
|
||||
THArgCheck(dim < THTensor_(nDimension)(tensor), 2, "Index dimension is out of bounds");
|
||||
THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(tensor), 3,
|
||||
THArgCheck(dim < THTensor_(nDimensionLegacyNoScalars)(tensor), 2, "Index dimension is out of bounds");
|
||||
THArgCheck(THLongTensor_nDimensionLegacyNoScalars(index) == THTensor_(nDimensionLegacyNoScalars)(tensor), 3,
|
||||
"Index tensor must have same dimensions as output tensor");
|
||||
THArgCheck(THTensor_(nDimension)(src) == THTensor_(nDimension)(tensor), 4,
|
||||
THArgCheck(THTensor_(nDimensionLegacyNoScalars)(src) == THTensor_(nDimensionLegacyNoScalars)(tensor), 4,
|
||||
"Input tensor must have same dimensions as output tensor");
|
||||
|
||||
elems_per_row = THLongTensor_size(index, dim);
|
||||
@ -519,8 +519,8 @@ void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, real
|
||||
{
|
||||
int64_t elems_per_row, i, idx;
|
||||
|
||||
THArgCheck(dim < THTensor_(_nDimension)(tensor), 2, "Index dimension is out of bounds");
|
||||
THArgCheck(THLongTensor__nDimension(index) == THTensor_(_nDimension)(tensor), 3,
|
||||
THArgCheck(dim < THTensor_(nDimensionLegacyAll)(tensor), 2, "Index dimension is out of bounds");
|
||||
THArgCheck(THLongTensor_nDimensionLegacyAll(index) == THTensor_(nDimensionLegacyAll)(tensor), 3,
|
||||
"Index tensor must have same dimensions as output tensor");
|
||||
|
||||
elems_per_row = THLongTensor_size(index, dim);
|
||||
@ -558,7 +558,7 @@ real THTensor_(minall)(THTensor *tensor)
|
||||
real theMin;
|
||||
real value;
|
||||
|
||||
THArgCheck(tensor->_dim() > 0, 1, "tensor must have one dimension");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(tensor) > 0, 1, "tensor must have one dimension");
|
||||
theMin = THTensor_(data)(tensor)[0];
|
||||
TH_TENSOR_APPLY(real, tensor,
|
||||
value = *tensor_data;
|
||||
@ -576,7 +576,7 @@ real THTensor_(maxall)(THTensor *tensor)
|
||||
real theMax;
|
||||
real value;
|
||||
|
||||
THArgCheck(tensor->_dim() > 0, 1, "tensor must have one dimension");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(tensor) > 0, 1, "tensor must have one dimension");
|
||||
theMax = THTensor_(data)(tensor)[0];
|
||||
TH_TENSOR_APPLY(real, tensor,
|
||||
value = *tensor_data;
|
||||
|
@ -161,16 +161,16 @@ void THTensor_(trtrs)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a,
|
||||
int free_b = 0;
|
||||
if (a == NULL) a = ra_;
|
||||
if (b == NULL) b = rb_;
|
||||
THArgCheck(a->_dim() == 2, 2, "A should have 2 dimensions, but has %d",
|
||||
a->_dim());
|
||||
THArgCheck(b->_dim() == 1 || b->_dim() == 2, 1, "B should have 1 or 2 "
|
||||
"dimensions, but has %d", b->_dim());
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 2, "A should have 2 dimensions, but has %d",
|
||||
THTensor_nDimensionLegacyAll(a));
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(b) == 1 || THTensor_nDimensionLegacyAll(b) == 2, 1, "B should have 1 or 2 "
|
||||
"dimensions, but has %d", THTensor_nDimensionLegacyAll(b));
|
||||
THArgCheck(a->size(0) == a->size(1), 2, "A should be square, but is %ldx%ld",
|
||||
a->size(0), a->size(1));
|
||||
THArgCheck(a->size(0) == b->size(0), 2, "A,B size incompatible - A has %ld "
|
||||
"rows, B has %ld", a->size(0), b->size(0));
|
||||
|
||||
if (b->_dim() == 1) {
|
||||
if (THTensor_nDimensionLegacyAll(b) == 1) {
|
||||
b = THTensor_(newWithStorage2d)(THTensor_getStoragePtr(b), b->storage_offset(), b->size(0),
|
||||
b->stride(0), 1, 0);
|
||||
free_b = 1;
|
||||
@ -220,7 +220,7 @@ void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a)
|
||||
THArgCheck(a->size(0) == b->size(0), 2, "A,B size incompatible - A has %ld "
|
||||
"rows, B has %ld", a->size(0), b->size(0));
|
||||
|
||||
if (b->_dim() == 1) {
|
||||
if (THTensor_nDimensionLegacyAll(b) == 1) {
|
||||
b = THTensor_(newWithStorage2d)(THTensor_getStoragePtr(b), b->storage_offset(), b->size(0),
|
||||
b->stride(0), 1, 0);
|
||||
free_b = 1;
|
||||
@ -498,7 +498,7 @@ void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra
|
||||
void THTensor_(getri)(THTensor *ra_, THTensor *a)
|
||||
{
|
||||
if (a == NULL) a = ra_;
|
||||
THArgCheck(a->_dim() == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(a->size(0) == a->size(1), 1, "A should be square");
|
||||
|
||||
int m, n, lda, info, lwork;
|
||||
@ -541,7 +541,7 @@ void THTensor_(getri)(THTensor *ra_, THTensor *a)
|
||||
|
||||
void THTensor_(clearUpLoTriangle)(THTensor *a, const char *uplo)
|
||||
{
|
||||
THArgCheck(a->_dim() == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(a->size(0) == a->size(1), 1, "A should be square");
|
||||
|
||||
int n = a->size(0);
|
||||
@ -574,7 +574,7 @@ void THTensor_(clearUpLoTriangle)(THTensor *a, const char *uplo)
|
||||
|
||||
void THTensor_(copyUpLoTriangle)(THTensor *a, const char *uplo)
|
||||
{
|
||||
THArgCheck(a->_dim() == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(a->size(0) == a->size(1), 1, "A should be square");
|
||||
|
||||
int n = a->size(0);
|
||||
@ -608,7 +608,7 @@ void THTensor_(copyUpLoTriangle)(THTensor *a, const char *uplo)
|
||||
void THTensor_(potrf)(THTensor *ra_, THTensor *a, const char *uplo)
|
||||
{
|
||||
if (a == NULL) a = ra_;
|
||||
THArgCheck(a->_dim() == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(a->size(0) == a->size(1), 1, "A should be square");
|
||||
|
||||
int n, lda, info;
|
||||
@ -634,16 +634,16 @@ void THTensor_(potrs)(THTensor *rb_, THTensor *b, THTensor *a, const char *uplo)
|
||||
int free_b = 0;
|
||||
if (b == NULL) b = rb_;
|
||||
|
||||
THArgCheck(a->_dim() == 2, 2, "A should have 2 dimensions, but has %d",
|
||||
a->_dim());
|
||||
THArgCheck(b->_dim() == 1 || b->_dim() == 2, 1, "B should have 1 or 2 "
|
||||
"dimensions, but has %d", b->_dim());
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 2, "A should have 2 dimensions, but has %d",
|
||||
THTensor_nDimensionLegacyAll(a));
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(b) == 1 || THTensor_nDimensionLegacyAll(b) == 2, 1, "B should have 1 or 2 "
|
||||
"dimensions, but has %d", THTensor_nDimensionLegacyAll(b));
|
||||
THArgCheck(a->size(0) == a->size(1), 2, "A should be square, but is %ldx%ld",
|
||||
a->size(0), a->size(1));
|
||||
THArgCheck(a->size(0) == b->size(0), 2, "A,B size incompatible - A has %ld "
|
||||
"rows, B has %ld", a->size(0), b->size(0));
|
||||
|
||||
if (b->_dim() == 1) {
|
||||
if (THTensor_nDimensionLegacyAll(b) == 1) {
|
||||
b = THTensor_(newWithStorage2d)(THTensor_getStoragePtr(b), b->storage_offset(), b->size(0),
|
||||
b->stride(0), 1, 0);
|
||||
free_b = 1;
|
||||
@ -680,7 +680,7 @@ void THTensor_(potrs)(THTensor *rb_, THTensor *b, THTensor *a, const char *uplo)
|
||||
void THTensor_(potri)(THTensor *ra_, THTensor *a, const char *uplo)
|
||||
{
|
||||
if (a == NULL) a = ra_;
|
||||
THArgCheck(a->_dim() == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(a->size(0) == a->size(1), 1, "A should be square");
|
||||
|
||||
int n, lda, info;
|
||||
@ -718,7 +718,7 @@ void THTensor_(potri)(THTensor *ra_, THTensor *a, const char *uplo)
|
||||
The algorithm terminates when the pivot <= tol.
|
||||
*/
|
||||
void THTensor_(pstrf)(THTensor *ra_, THIntTensor *rpiv_, THTensor *a, const char *uplo, real tol) {
|
||||
THArgCheck(a->_dim() == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(a->size(0) == a->size(1), 1, "A should be square");
|
||||
|
||||
int n = a->size(0);
|
||||
@ -861,7 +861,7 @@ void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a)
|
||||
void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau)
|
||||
{
|
||||
if (a == NULL) a = ra_;
|
||||
THArgCheck(a->_dim() == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional");
|
||||
|
||||
THTensor *ra__ = NULL;
|
||||
ra__ = THTensor_(cloneColumnMajor)(ra_, a);
|
||||
@ -914,7 +914,7 @@ void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau)
|
||||
void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, const char *side, const char *trans)
|
||||
{
|
||||
if (a == NULL) a = ra_;
|
||||
THArgCheck(a->_dim() == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional");
|
||||
|
||||
THTensor *ra__ = NULL;
|
||||
ra__ = THTensor_(cloneColumnMajor)(ra_, c);
|
||||
@ -958,7 +958,7 @@ void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, co
|
||||
|
||||
void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinfo_, int pivot, THTensor *a)
|
||||
{
|
||||
AT_CHECK(THTensor_(nDimension)(a) == 3, "expected 3D tensor, got size: ", a->sizes());
|
||||
AT_CHECK(THTensor_(nDimensionLegacyNoScalars)(a) == 3, "expected 3D tensor, got size: ", a->sizes());
|
||||
if (!pivot) {
|
||||
THError("btrifact without pivoting is not implemented on the CPU");
|
||||
}
|
||||
@ -1033,10 +1033,10 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf
|
||||
|
||||
void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor *pivots)
|
||||
{
|
||||
AT_CHECK(!atf->is_empty() && THTensor_(nDimension)(atf) == 3, "expected non-empty 3D tensor, got size: ",
|
||||
AT_CHECK(!atf->is_empty() && THTensor_(nDimensionLegacyNoScalars)(atf) == 3, "expected non-empty 3D tensor, got size: ",
|
||||
atf->sizes());
|
||||
AT_CHECK(!b->is_empty() && (THTensor_(nDimension)(b) == 3 ||
|
||||
THTensor_(nDimension)(b) == 2), "expected non-empty 2D or 3D tensor, got size: ", b->sizes());
|
||||
AT_CHECK(!b->is_empty() && (THTensor_(nDimensionLegacyNoScalars)(b) == 3 ||
|
||||
THTensor_(nDimensionLegacyNoScalars)(b) == 2), "expected non-empty 2D or 3D tensor, got size: ", b->sizes());
|
||||
THArgCheck(THTensor_(size)(atf, 0) ==
|
||||
THTensor_(size)(b, 0), 3, "number of batches must be equal");
|
||||
THArgCheck(THTensor_(size)(atf, 1) ==
|
||||
@ -1051,7 +1051,7 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor
|
||||
|
||||
int64_t num_batches = atf->size(0);
|
||||
int64_t n = atf->size(1);
|
||||
int nrhs = rb_->_dim() > 2 ? rb_->size(2) : 1;
|
||||
int nrhs = THTensor_nDimensionLegacyAll(rb_) > 2 ? rb_->size(2) : 1;
|
||||
|
||||
int lda, ldb;
|
||||
THTensor *atf_;
|
||||
@ -1077,7 +1077,7 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor
|
||||
// correct ordering of B
|
||||
if (rb_->stride(1) == 1) {
|
||||
// column ordered
|
||||
if (rb_->_dim() == 2 || rb_->size(2) == 1) {
|
||||
if (THTensor_nDimensionLegacyAll(rb_) == 2 || rb_->size(2) == 1) {
|
||||
ldb = n;
|
||||
} else {
|
||||
ldb = rb_->stride(2);
|
||||
@ -1085,7 +1085,7 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor
|
||||
rb__ = rb_;
|
||||
} else {
|
||||
// make column ordered
|
||||
if (rb_->_dim() > 2) {
|
||||
if (THTensor_nDimensionLegacyAll(rb_) > 2) {
|
||||
THTensor *transp_r_ = THTensor_(newTranspose)(rb_, 1, 2);
|
||||
rb__ = THTensor_(newClone)(transp_r_);
|
||||
THTensor_(free)(transp_r_);
|
||||
|
@ -1120,8 +1120,8 @@ void THTensor_(addbmm)(THTensor *result, real beta, THTensor *t, real alpha, THT
|
||||
{
|
||||
int64_t batch;
|
||||
|
||||
THArgCheck(THTensor_(nDimension)(batch1) == 3, 1, "expected 3D tensor");
|
||||
THArgCheck(THTensor_(nDimension)(batch2) == 3, 2, "expected 3D tensor");
|
||||
THArgCheck(THTensor_(nDimensionLegacyNoScalars)(batch1) == 3, 1, "expected 3D tensor");
|
||||
THArgCheck(THTensor_(nDimensionLegacyNoScalars)(batch2) == 3, 2, "expected 3D tensor");
|
||||
THArgCheck(THTensor_(size)(batch1, 0) == THTensor_(size)(batch2, 0), 2,
|
||||
"equal number of batches expected, got %d, %d",
|
||||
THTensor_(size)(batch1, 0), THTensor_(size)(batch2, 0));
|
||||
|
@ -8,8 +8,8 @@ void THTensor_(baddbmm)(THTensor *result, real beta, THTensor *t, real alpha, TH
|
||||
{
|
||||
int64_t batch;
|
||||
|
||||
THArgCheck(THTensor_(nDimension)(batch1) == 3, 1, "expected 3D tensor, got %dD", THTensor_(nDimension)(batch1));
|
||||
THArgCheck(THTensor_(nDimension)(batch2) == 3, 2, "expected 3D tensor, got %dD", THTensor_(nDimension)(batch2));
|
||||
THArgCheck(THTensor_(nDimensionLegacyNoScalars)(batch1) == 3, 1, "expected 3D tensor, got %dD", THTensor_(nDimensionLegacyNoScalars)(batch1));
|
||||
THArgCheck(THTensor_(nDimensionLegacyNoScalars)(batch2) == 3, 2, "expected 3D tensor, got %dD", THTensor_(nDimensionLegacyNoScalars)(batch2));
|
||||
THArgCheck(THTensor_(size)(batch1, 0) == THTensor_(size)(batch2, 0), 2,
|
||||
"equal number of batches expected, got %d, %d",
|
||||
THTensor_(size)(batch1, 0), THTensor_(size)(batch2, 0));
|
||||
@ -65,8 +65,8 @@ ptrdiff_t THTensor_(numel)(THTensor *t)
|
||||
void THTensor_(preserveReduceDimSemantics)(
|
||||
THTensor *r_, int in_dims, int reduce_dimension, int keepdim) {
|
||||
if (r_ && !keepdim &&
|
||||
THTensor_(_nDimension)(r_) == in_dims - 1 &&
|
||||
THTensor_(_nDimension)(r_) != 0) {
|
||||
THTensor_(nDimensionLegacyAll)(r_) == in_dims - 1 &&
|
||||
THTensor_(nDimensionLegacyAll)(r_) != 0) {
|
||||
THTensor_(unsqueeze1d)(r_, r_, reduce_dimension);
|
||||
}
|
||||
}
|
||||
@ -75,10 +75,10 @@ void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
|
||||
{
|
||||
THLongStorage *dim;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
int in_dims = THTensor_(_nDimension)(t);
|
||||
int in_dims = THTensor_(nDimensionLegacyAll)(t);
|
||||
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
|
||||
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
|
||||
dim = THTensor_(newSizeOf)(t);
|
||||
@ -112,7 +112,7 @@ void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
|
||||
*indices__data = theIndex;
|
||||
*values__data = theMax;);
|
||||
} else {
|
||||
if (THTensor_(_nDimension)(t) > 1) {
|
||||
if (THTensor_(nDimensionLegacyAll)(t) > 1) {
|
||||
THTensor *t0 = THTensor_(newSelect)(t, dimension, 0);
|
||||
THTensor_(copy)(values_, t0);
|
||||
THTensor_(free)(t0);
|
||||
@ -159,10 +159,10 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
|
||||
{
|
||||
THLongStorage *dim;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
int in_dims = THTensor_(_nDimension)(t);
|
||||
int in_dims = THTensor_(nDimensionLegacyAll)(t);
|
||||
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
|
||||
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
|
||||
dim = THTensor_(newSizeOf)(t);
|
||||
@ -196,7 +196,7 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
|
||||
*indices__data = theIndex;
|
||||
*values__data = theMax;);
|
||||
} else {
|
||||
if (THTensor_(_nDimension)(t) > 1) {
|
||||
if (THTensor_(nDimensionLegacyAll)(t) > 1) {
|
||||
THTensor *t0 = THTensor_(newSelect)(t, dimension, 0);
|
||||
THTensor_(copy)(values_, t0);
|
||||
THTensor_(free)(t0);
|
||||
@ -243,10 +243,10 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim)
|
||||
{
|
||||
THLongStorage *dim;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim);
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
|
||||
dim = THTensor_(newSizeOf)(t);
|
||||
THLongStorage_set(dim, dimension, 1);
|
||||
THTensor_(resize)(r_, dim, NULL);
|
||||
@ -264,7 +264,7 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim)
|
||||
if(r_Contig && (tp != rp)){
|
||||
ptrdiff_t iter = 0;
|
||||
ptrdiff_t r_Size = THTensor_(nElement)(r_);
|
||||
int r_Dim = r_->_dim();
|
||||
int r_Dim = THTensor_nDimensionLegacyAll(r_);
|
||||
#pragma omp parallel for if ( r_Size > HYPER_TH_OMP_OVERHEAD_THRESHOLD)
|
||||
for (iter = 0; iter < r_Size; iter++) {
|
||||
int j;
|
||||
@ -323,10 +323,10 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim)
|
||||
{
|
||||
THLongStorage *dim;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim);
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
|
||||
dim = THTensor_(newSizeOf)(t);
|
||||
THLongStorage_set(dim, dimension, 1);
|
||||
THTensor_(resize)(r_, dim, NULL);
|
||||
@ -344,7 +344,7 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim)
|
||||
if(r_Contig && (tp != rp)){
|
||||
ptrdiff_t iter = 0;
|
||||
ptrdiff_t r_Size = THTensor_(nElement)(r_);
|
||||
int r_Dim = r_->_dim();
|
||||
int r_Dim = THTensor_nDimensionLegacyAll(r_);
|
||||
#pragma omp parallel for if ( r_Size > HYPER_TH_OMP_OVERHEAD_THRESHOLD)
|
||||
for (iter = 0; iter < r_Size; iter++) {
|
||||
int j;
|
||||
@ -401,7 +401,7 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim)
|
||||
|
||||
void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension)
|
||||
{
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyNoScalars)(t), 2, "dimension %d out of range",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
THTensor_(resizeAs)(r_, t);
|
||||
@ -418,7 +418,7 @@ void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension)
|
||||
|
||||
void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension)
|
||||
{
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyNoScalars)(t), 2, "dimension %d out of range",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
THTensor_(resizeAs)(r_, t);
|
||||
@ -458,7 +458,7 @@ accreal THTensor_(trace)(THTensor *t)
|
||||
int64_t i = 0;
|
||||
int64_t t_stride_0, t_stride_1, t_diag_size;
|
||||
|
||||
THArgCheck(THTensor_(_nDimension)(t) == 2, 1, "expected a matrix");
|
||||
THArgCheck(THTensor_(nDimensionLegacyAll)(t) == 2, 1, "expected a matrix");
|
||||
|
||||
t_stride_0 = THTensor_(stride)(t, 0);
|
||||
t_stride_1 = THTensor_(stride)(t, 1);
|
||||
@ -476,11 +476,11 @@ void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension)
|
||||
{
|
||||
int i;
|
||||
|
||||
if(THTensor_(nDimension)(a) != THTensor_(nDimension)(b))
|
||||
if(THTensor_(nDimensionLegacyNoScalars)(a) != THTensor_(nDimensionLegacyNoScalars)(b))
|
||||
THError("inconsistent tensor dimension %dD, %dD",
|
||||
THTensor_(nDimension)(a), THTensor_(nDimension)(b));
|
||||
THTensor_(nDimensionLegacyNoScalars)(a), THTensor_(nDimensionLegacyNoScalars)(b));
|
||||
|
||||
for(i = 0; i < THTensor_(nDimension)(a); i++)
|
||||
for(i = 0; i < THTensor_(nDimensionLegacyNoScalars)(a); i++)
|
||||
{
|
||||
if(THTensor_(size)(a, i) != THTensor_(size)(b, i)) {
|
||||
THDescBuff ba = THTensor_(sizeDesc)(a);
|
||||
@ -491,7 +491,7 @@ void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension)
|
||||
|
||||
if(dimension < 0)
|
||||
{
|
||||
for(i = 0; i < THTensor_(nDimension)(a); i++)
|
||||
for(i = 0; i < THTensor_(nDimensionLegacyNoScalars)(a); i++)
|
||||
{
|
||||
if(THTensor_(size)(a, i) == 3)
|
||||
{
|
||||
@ -505,7 +505,7 @@ void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension)
|
||||
}
|
||||
}
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(a), 3, "dimension %d out of range",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyNoScalars)(a), 3, "dimension %d out of range",
|
||||
dimension + TH_INDEX_BASE);
|
||||
THArgCheck(THTensor_(size)(a, dimension) == 3, 3, "dimension %d does not have size 3",
|
||||
dimension + TH_INDEX_BASE);
|
||||
@ -560,9 +560,9 @@ void THTensor_(diag)(THTensor *r_, THTensor *t, int k)
|
||||
#ifndef USE_TH_SIZE_ZERO_DIM
|
||||
AT_ASSERT(!t->is_empty())
|
||||
#endif
|
||||
THArgCheck(THTensor_(nDimension)(t) == 1 || THTensor_(nDimension)(t) == 2, 1, "matrix or a vector expected");
|
||||
THArgCheck(THTensor_(nDimensionLegacyNoScalars)(t) == 1 || THTensor_(nDimensionLegacyNoScalars)(t) == 2, 1, "matrix or a vector expected");
|
||||
|
||||
if(THTensor_(nDimension)(t) == 1)
|
||||
if(THTensor_(nDimensionLegacyNoScalars)(t) == 1)
|
||||
{
|
||||
real *t_data = THTensor_(data)(t);
|
||||
int64_t t_stride_0 = THTensor_(stride)(t, 0);
|
||||
@ -900,7 +900,7 @@ static void THTensor_(quicksortdescend)(real *arr, int64_t *idx, int64_t element
|
||||
|
||||
void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder)
|
||||
{
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "invalid dimension %d",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyNoScalars)(t), 2, "invalid dimension %d",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
THTensor_(resizeAs)(rt_, t);
|
||||
@ -1031,7 +1031,7 @@ static void THTensor_(quickselect)(real *arr, int64_t *idx, int64_t k, int64_t e
|
||||
|
||||
real THTensor_(medianall)(THTensor *tensor)
|
||||
{
|
||||
THArgCheck(tensor->_dim() > 0, 1, "tensor must have one dimension");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(tensor) > 0, 1, "tensor must have one dimension");
|
||||
|
||||
real theMedian;
|
||||
ptrdiff_t numel;
|
||||
@ -1063,9 +1063,9 @@ void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
|
||||
int64_t *tempi__data;
|
||||
int64_t t_size_dim;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "dimension out of range");
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 3, "dimension out of range");
|
||||
|
||||
int in_dims = THTensor_(_nDimension)(t);
|
||||
int in_dims = THTensor_(nDimensionLegacyAll)(t);
|
||||
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
|
||||
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
|
||||
dim = THTensor_(newSizeOf)(t);
|
||||
@ -1131,10 +1131,10 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t,
|
||||
int64_t *tempi__data;
|
||||
int64_t t_size_dim;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "dimension out of range");
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 3, "dimension out of range");
|
||||
THArgCheck(k > 0 && k <= t->size(dimension), 2, "selected index out of range");
|
||||
|
||||
int in_dims = THTensor_(_nDimension)(t);
|
||||
int in_dims = THTensor_(nDimensionLegacyAll)(t);
|
||||
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
|
||||
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
|
||||
dim = THTensor_(newSizeOf)(t);
|
||||
@ -1176,7 +1176,7 @@ void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, i
|
||||
{
|
||||
int64_t t_size_dim, k;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "dimension out of range");
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 3, "dimension out of range");
|
||||
|
||||
t_size_dim = THTensor_(size)(t, dimension);
|
||||
k = (t_size_dim-1) >> 1; /* take middle or one-before-middle element */
|
||||
@ -1187,9 +1187,9 @@ void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, i
|
||||
void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, int dim, int dir, int sorted)
|
||||
{
|
||||
#ifndef USE_TH_SIZE_ZERO_DIM
|
||||
int numDims = THTensor_(_nDimension)(t);
|
||||
int numDims = THTensor_(nDimensionLegacyAll)(t);
|
||||
#else
|
||||
int numDims = THTensor_(nDimension)(t);
|
||||
int numDims = THTensor_(nDimensionLegacyNoScalars)(t);
|
||||
#endif
|
||||
THArgCheck(dim >= 0 && dim < numDims, 3, "dim not in range");
|
||||
|
||||
@ -1267,7 +1267,7 @@ void THTensor_(tril)(THTensor *r_, THTensor *t, int64_t k)
|
||||
real *t_data, *r__data;
|
||||
int64_t r, c;
|
||||
|
||||
THArgCheck(THTensor_(_nDimension)(t) == 2, 1, "expected a matrix");
|
||||
THArgCheck(THTensor_(nDimensionLegacyAll)(t) == 2, 1, "expected a matrix");
|
||||
|
||||
THTensor_(resizeAs)(r_, t);
|
||||
|
||||
@ -1298,7 +1298,7 @@ void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k)
|
||||
real *t_data, *r__data;
|
||||
int64_t r, c;
|
||||
|
||||
THArgCheck(THTensor_(_nDimension)(t) == 2, 1, "expected a matrix");
|
||||
THArgCheck(THTensor_(nDimensionLegacyAll)(t) == 2, 1, "expected a matrix");
|
||||
|
||||
THTensor_(resizeAs)(r_, t);
|
||||
|
||||
@ -1659,10 +1659,10 @@ void THTensor_(logicalAnd)(THTensor *r_, THTensor *t, int dimension, int keepdim
|
||||
{
|
||||
THLongStorage *dim;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim);
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
|
||||
dim = THTensor_(newSizeOf)(t);
|
||||
THLongStorage_set(dim, dimension, 1);
|
||||
THTensor_(resize)(r_, dim, NULL);
|
||||
@ -1680,7 +1680,7 @@ void THTensor_(logicalAnd)(THTensor *r_, THTensor *t, int dimension, int keepdim
|
||||
if(r_Contig && (tp != rp)){
|
||||
ptrdiff_t iter = 0;
|
||||
ptrdiff_t r_Size = THTensor_(nElement)(r_);
|
||||
int r_Dim = r_->_dim();
|
||||
int r_Dim = THTensor_nDimensionLegacyAll(r_);
|
||||
#pragma omp parallel for if ( r_Size > TH_OMP_OVERHEAD_THRESHOLD)
|
||||
for (iter = 0; iter < r_Size; iter++) {
|
||||
int j;
|
||||
@ -1739,10 +1739,10 @@ void THTensor_(logicalAny)(THTensor *r_, THTensor *t, int dimension, int keepdim
|
||||
{
|
||||
THLongStorage *dim;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim);
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
|
||||
dim = THTensor_(newSizeOf)(t);
|
||||
THLongStorage_set(dim, dimension, 1);
|
||||
THTensor_(resize)(r_, dim, NULL);
|
||||
@ -1760,7 +1760,7 @@ void THTensor_(logicalAny)(THTensor *r_, THTensor *t, int dimension, int keepdim
|
||||
if(r_Contig && (tp != rp)){
|
||||
ptrdiff_t iter = 0;
|
||||
ptrdiff_t r_Size = THTensor_(nElement)(r_);
|
||||
int r_Dim = r_->_dim();
|
||||
int r_Dim = THTensor_nDimensionLegacyAll(r_);
|
||||
#pragma omp parallel for if ( r_Size > TH_OMP_OVERHEAD_THRESHOLD)
|
||||
for (iter = 0; iter < r_Size; iter++) {
|
||||
int j;
|
||||
@ -1883,7 +1883,7 @@ void THTensor_(lerp)(THTensor *r_, THTensor *a, THTensor *b, real weight)
|
||||
|
||||
void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension, int keepdim)
|
||||
{
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "invalid dimension %d",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "invalid dimension %d",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
THTensor_(sum)(r_, t, dimension, keepdim);
|
||||
@ -1894,10 +1894,10 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
|
||||
{
|
||||
THLongStorage *dim;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "invalid dimension %d",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 3, "invalid dimension %d",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim);
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
|
||||
dim = THTensor_(newSizeOf)(t);
|
||||
THLongStorage_set(dim, dimension, 1);
|
||||
THTensor_(resize)(r_, dim, NULL);
|
||||
@ -1938,10 +1938,10 @@ void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
|
||||
{
|
||||
THLongStorage *dim;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "invalid dimension %d",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 3, "invalid dimension %d",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim);
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
|
||||
dim = THTensor_(newSizeOf)(t);
|
||||
THLongStorage_set(dim, dimension, 1);
|
||||
THTensor_(resize)(r_, dim, NULL);
|
||||
@ -1982,10 +1982,10 @@ void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int k
|
||||
{
|
||||
THLongStorage *dim;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "invalid dimension %d",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 3, "invalid dimension %d",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim);
|
||||
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
|
||||
dim = THTensor_(newSizeOf)(t);
|
||||
THLongStorage_set(dim, dimension, 1);
|
||||
THTensor_(resize)(r_, dim, NULL);
|
||||
@ -2054,11 +2054,11 @@ void THTensor_(renorm)(THTensor *res, THTensor *src, real value, int dimension,
|
||||
{
|
||||
THTensor *rowR, *rowS;
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(src), 3, "invalid dimension %d",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyNoScalars)(src), 3, "invalid dimension %d",
|
||||
dimension + TH_INDEX_BASE);
|
||||
THArgCheck(value > 0, 2, "non-positive-norm not supported");
|
||||
THArgCheck(THTensor_(nDimension)(src) > 1, 1, "need at least 2 dimensions, got %d dimensions",
|
||||
THTensor_(nDimension)(src));
|
||||
THArgCheck(THTensor_(nDimensionLegacyNoScalars)(src) > 1, 1, "need at least 2 dimensions, got %d dimensions",
|
||||
THTensor_(nDimensionLegacyNoScalars)(src));
|
||||
|
||||
rowR = THTensor_(new)();
|
||||
rowS = THTensor_(new)();
|
||||
@ -2208,10 +2208,10 @@ void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, real minv
|
||||
|
||||
void THTensor_(bhistc)(THTensor *hist, THTensor *tensor, int64_t nbins, real minvalue, real maxvalue)
|
||||
{
|
||||
THArgCheck(THTensor_(_nDimension)(tensor) < 3, 2, "invalid dimension %d, the input must be a 2d tensor", THTensor_(_nDimension)(tensor));
|
||||
THArgCheck(THTensor_(nDimensionLegacyAll)(tensor) < 3, 2, "invalid dimension %d, the input must be a 2d tensor", THTensor_(nDimensionLegacyAll)(tensor));
|
||||
|
||||
int dimension = 1;
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(tensor), 2, "invalid dimension %d",
|
||||
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(tensor), 2, "invalid dimension %d",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
real minval;
|
||||
|
@ -357,7 +357,7 @@ void THTensor_(multinomialAliasDraw)(THLongTensor *self, THGenerator *_generator
|
||||
void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(_generator->mutex);
|
||||
int64_t start_dim = THTensor_(_nDimension)(prob_dist);
|
||||
int64_t start_dim = THTensor_(nDimensionLegacyAll)(prob_dist);
|
||||
int64_t n_dist;
|
||||
int64_t n_categories;
|
||||
THDoubleTensor* cum_dist;
|
||||
|
@ -190,11 +190,11 @@ bool THC_pointwiseApply1(THCState* state,
|
||||
TensorTypeA* a,
|
||||
const Op& op,
|
||||
TensorArgType aType = ReadWrite) {
|
||||
if (THCTensor__nDimension(state, a) > MAX_CUTORCH_DIMS) {
|
||||
if (THCTensor_nDimensionLegacyAll(state, a) > MAX_CUTORCH_DIMS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (THCTensor__nDimension(state, a) == 0) {
|
||||
if (THCTensor_nDimensionLegacyAll(state, a) == 0) {
|
||||
// Zero-dim tensor; do nothing
|
||||
return true;
|
||||
}
|
||||
@ -333,12 +333,12 @@ bool THC_pointwiseApply2(THCState* state,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (THCTensor__nDimension(state, a) > MAX_CUTORCH_DIMS ||
|
||||
THCTensor__nDimension(state, b) > MAX_CUTORCH_DIMS) {
|
||||
if (THCTensor_nDimensionLegacyAll(state, a) > MAX_CUTORCH_DIMS ||
|
||||
THCTensor_nDimensionLegacyAll(state, b) > MAX_CUTORCH_DIMS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (THCTensor__nDimension(state, a) == 0) {
|
||||
if (THCTensor_nDimensionLegacyAll(state, a) == 0) {
|
||||
// Zero-dim tensor; do nothing
|
||||
return true;
|
||||
}
|
||||
@ -527,13 +527,13 @@ bool THC_pointwiseApply3(THCState* state,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (THCTensor__nDimension(state, a) > MAX_CUTORCH_DIMS ||
|
||||
THCTensor__nDimension(state, b) > MAX_CUTORCH_DIMS ||
|
||||
THCTensor__nDimension(state, c) > MAX_CUTORCH_DIMS) {
|
||||
if (THCTensor_nDimensionLegacyAll(state, a) > MAX_CUTORCH_DIMS ||
|
||||
THCTensor_nDimensionLegacyAll(state, b) > MAX_CUTORCH_DIMS ||
|
||||
THCTensor_nDimensionLegacyAll(state, c) > MAX_CUTORCH_DIMS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (THCTensor__nDimension(state, a) == 0) {
|
||||
if (THCTensor_nDimensionLegacyAll(state, a) == 0) {
|
||||
// Zero-dim tensor; do nothing
|
||||
return true;
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ template <typename T, int NewDim,
|
||||
typename IndexT, template <typename U> class PtrTraits>
|
||||
THCDeviceTensor<T, NewDim, IndexT, PtrTraits>
|
||||
toDeviceTensorCast(THCState* state, THCudaTensor* t) {
|
||||
switch (THCudaTensor__nDimension(state, t)) {
|
||||
switch (THCudaTensor_nDimensionLegacyAll(state, t)) {
|
||||
SWITCH_UNROLL_CUDA_CAST_FACTORY(1);
|
||||
SWITCH_UNROLL_CUDA_CAST_FACTORY(2);
|
||||
SWITCH_UNROLL_CUDA_CAST_FACTORY(3);
|
||||
|
@ -48,7 +48,7 @@ template <typename T, int Dim,
|
||||
typename IndexT, template <typename U> class PtrTraits>
|
||||
THCDeviceTensor<T, Dim, IndexT, PtrTraits>
|
||||
toDeviceTensor(THCState* state, THCTensor* t) {
|
||||
if (Dim != THCTensor__nDimension(state, t)) {
|
||||
if (Dim != THCTensor_nDimensionLegacyAll(state, t)) {
|
||||
THError("THCudaTensor dimension mismatch");
|
||||
}
|
||||
// Determine the maximum offset into the tensor achievable; `IndexT`
|
||||
|
@ -414,12 +414,12 @@ bool THC_reduceDim(THCState* state,
|
||||
int64_t reductionStride = THCTensor_stride(state, in, dim);
|
||||
ptrdiff_t outElements = inElements / reductionSize;
|
||||
|
||||
if (THCTensor__nDimension(state, out) > MAX_CUTORCH_DIMS ||
|
||||
THCTensor__nDimension(state, in) > MAX_CUTORCH_DIMS) {
|
||||
if (THCTensor_nDimensionLegacyAll(state, out) > MAX_CUTORCH_DIMS ||
|
||||
THCTensor_nDimensionLegacyAll(state, in) > MAX_CUTORCH_DIMS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (THCTensor__nDimension(state, in) == 0) {
|
||||
if (THCTensor_nDimensionLegacyAll(state, in) == 0) {
|
||||
// Zero-dim tensor; do nothing
|
||||
return true;
|
||||
}
|
||||
@ -483,7 +483,7 @@ bool THC_reduceDim(THCState* state,
|
||||
|
||||
// Preserve noncontiguities by unsqueezing out if necessary
|
||||
THCTensor_preserveReduceDimSemantics(
|
||||
state, out, THCTensor__nDimension(state, in), dim, keepdim);
|
||||
state, out, THCTensor_nDimensionLegacyAll(state, in), dim, keepdim);
|
||||
|
||||
// Resize out
|
||||
THLongStorage* sizes = THCTensor_newSizeOf(state, in);
|
||||
|
@ -232,11 +232,11 @@ bool THC_reduceAll(THCState* state,
|
||||
int outOnDevice) {
|
||||
ptrdiff_t inElements = THCTensor_nElement(state, in);
|
||||
|
||||
if (THCTensor__nDimension(state, in) > MAX_CUTORCH_DIMS) {
|
||||
if (THCTensor_nDimensionLegacyAll(state, in) > MAX_CUTORCH_DIMS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (THCTensor__nDimension(state, in) == 0) {
|
||||
if (THCTensor_nDimensionLegacyAll(state, in) == 0) {
|
||||
// Zero-dim tensor; do nothing
|
||||
*out = init;
|
||||
return true;
|
||||
|
@ -7,7 +7,7 @@
|
||||
#define MAX_GRID_SIZE 65535LL
|
||||
|
||||
void THCCheckTensorDims(THCState* state, THCudaTensor* tensor, int arg) {
|
||||
int64_t dims = THCudaTensor__nDimension(state, tensor);
|
||||
int64_t dims = THCudaTensor_nDimensionLegacyAll(state, tensor);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, arg, CUTORCH_DIM_WARNING);
|
||||
}
|
||||
|
||||
|
@ -9,12 +9,12 @@
|
||||
|
||||
#include "THCTensorInfo.cuh"
|
||||
|
||||
int THCTensor_nDimension(THCState *state, const THCTensor *self) {
|
||||
int THCTensor_nDimensionLegacyNoScalars(THCState *state, const THCTensor *self) {
|
||||
return self->dim();
|
||||
}
|
||||
|
||||
int THCTensor__nDimension(THCState *state, const THCTensor *self) {
|
||||
return self->_dim();
|
||||
int THCTensor_nDimensionLegacyAll(THCState *state, const THCTensor *self) {
|
||||
return THTensor_nDimensionLegacyAll(self);
|
||||
}
|
||||
|
||||
int64_t THCTensor_size(THCState *state, const THCTensor *self, int dim) {
|
||||
@ -281,13 +281,13 @@ bool THCTensor_allContiguous(THCState *state, THCTensor **inputs, int numInputs)
|
||||
}
|
||||
|
||||
ptrdiff_t THCTensor_nElement(THCState *state, const THCTensor *self) {
|
||||
if(self->_dim() == 0)
|
||||
if(THTensor_nDimensionLegacyAll(self) == 0)
|
||||
return 0;
|
||||
else
|
||||
{
|
||||
ptrdiff_t nElement = 1;
|
||||
int d;
|
||||
for(d = 0; d < self->_dim(); d++)
|
||||
for(d = 0; d < THTensor_nDimensionLegacyAll(self); d++)
|
||||
nElement *= self->size(d);
|
||||
return nElement;
|
||||
}
|
||||
@ -326,7 +326,7 @@ bool THCTensor_canUse32BitIndexMath(THCState* state, const THCTensor* t, ptrdiff
|
||||
ptrdiff_t offset = 0;
|
||||
ptrdiff_t linearId = elements - 1;
|
||||
|
||||
for (int i = THCTensor__nDimension(state, t) - 1; i >= 0; --i) {
|
||||
for (int i = THCTensor_nDimensionLegacyAll(state, t) - 1; i >= 0; --i) {
|
||||
ptrdiff_t curDimIndex =
|
||||
linearId % THCTensor_size(state, t, i);
|
||||
ptrdiff_t curDimOffset = curDimIndex *
|
||||
@ -361,7 +361,7 @@ bool THCTensor_all32BitIndexable(THCState* state, THCTensor** inputs, int numInp
|
||||
/* the contiguity guarantees of the resize semantics. */ \
|
||||
void THCTensor_preserveReduceDimSemantics(THCState *state, THCTensor *tensor,
|
||||
int in_dims, int64_t dimension, int keepdim) {
|
||||
int out_dims = THCTensor__nDimension(state, tensor);
|
||||
int out_dims = THCTensor_nDimensionLegacyAll(state, tensor);
|
||||
if (out_dims > 0 && !keepdim && out_dims == in_dims - 1) {
|
||||
THCTensor_unsqueeze1d(state, tensor, tensor, dimension);
|
||||
}
|
||||
@ -402,7 +402,7 @@ bool THCTensor_maybeOverlappingIndices(THCState* state, const THCTensor* t) {
|
||||
/* Extract size/stride arrays; only consider size >1 dims. */
|
||||
SizeAndStride info[MAX_CUTORCH_DIMS];
|
||||
|
||||
int dims = THCTensor__nDimension(state, t);
|
||||
int dims = THCTensor_nDimensionLegacyAll(state, t);
|
||||
int nonSize1Dims = 0;
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
int64_t size = THCTensor_size(state, t, i);
|
||||
|
@ -10,9 +10,9 @@
|
||||
#include <atomic>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
// See [NOTE: _dim() vs dim()]; _nDimension corresponds to _dim(), nDimension corresponds to dim().
|
||||
THC_API int THCTensor_nDimension(THCState *state, const THCTensor *self);
|
||||
THC_API int THCTensor__nDimension(THCState *state, const THCTensor *self);
|
||||
// See [NOTE: nDimension vs nDimensionLegacyNoScalars vs nDimensionLegacyAll]
|
||||
THC_API int THCTensor_nDimensionLegacyNoScalars(THCState *state, const THCTensor *self);
|
||||
THC_API int THCTensor_nDimensionLegacyAll(THCState *state, const THCTensor *self);
|
||||
|
||||
THC_API int64_t THCTensor_size(THCState *state, const THCTensor *self, int dim);
|
||||
THC_API int64_t THCTensor_stride(THCState *state, const THCTensor *self, int dim);
|
||||
|
@ -32,7 +32,7 @@ void THC_copyTensor(THCState* state, THCTensor* dst, THCTensor* src) {
|
||||
THCTensor_nElement(state, src),
|
||||
2, "sizes do not match");
|
||||
|
||||
if (THCTensor__nDimension(state, dst) == 0) {
|
||||
if (THCTensor_nDimensionLegacyAll(state, dst) == 0) {
|
||||
// Zero-dim tensor; copy nothing
|
||||
return;
|
||||
}
|
||||
|
@ -296,7 +296,7 @@ __global__ void THCTensor_kernel_varOuterDim(T *tgt, T *src_, unsigned num_orows
|
||||
|
||||
template<typename TensorTypeK, typename T, typename AccT, bool apply_sqrt>
|
||||
__host__ void THCTensor_varOuterDim(THCState *state, TensorTypeK *tgt, TensorTypeK *src, int64_t dimension, int flag) {
|
||||
unsigned ndim = THCTensor__nDimension(state, src);
|
||||
unsigned ndim = THCTensor_nDimensionLegacyAll(state, src);
|
||||
// Treat all outer dimensions (i.e. dim < dimension) as one.
|
||||
unsigned num_orows = 1;
|
||||
for (int64_t dim = 0; dim < dimension; dim++) {
|
||||
@ -442,7 +442,7 @@ __global__ void THCTensor_kernel_varInnermostDim(T *tgt, T *src_, unsigned num_r
|
||||
|
||||
template<typename TensorTypeK, typename T, typename AccT, bool apply_sqrt>
|
||||
__host__ void THCTensor_varInnermostDim(THCState *state, TensorTypeK *tgt, TensorTypeK *src, int flag) {
|
||||
unsigned ndim = THCTensor__nDimension(state, src);
|
||||
unsigned ndim = THCTensor_nDimensionLegacyAll(state, src);
|
||||
// Treat all outer dimensions as a single dimension.
|
||||
unsigned num_rows = 1;
|
||||
for (unsigned dim = 0; dim < ndim - 1; dim++) {
|
||||
@ -515,7 +515,7 @@ THC_transformReduceOuterDimIndex(THCState *state,
|
||||
int64_t rdim,
|
||||
const thrust::pair<ScalarTypeK, ScalarTypeIndex>& init,
|
||||
BinaryFunction binary_op) {
|
||||
unsigned ndim = THCTensor__nDimension(state, src);
|
||||
unsigned ndim = THCTensor_nDimensionLegacyAll(state, src);
|
||||
unsigned num_orows = 1;
|
||||
for (int64_t dim = 0; dim < rdim; dim++) {
|
||||
num_orows *= THCTensor_size(state, src, dim);
|
||||
@ -618,7 +618,7 @@ THC_transformReduceInnermostDimIndex(THCState *state,
|
||||
TensorTypeK *src,
|
||||
const thrust::pair<ScalarTypeK, ScalarTypeIndex>& init,
|
||||
BinaryFunction binary_op) {
|
||||
unsigned ndim = THCTensor__nDimension(state, src);
|
||||
unsigned ndim = THCTensor_nDimensionLegacyAll(state, src);
|
||||
unsigned num_rows = 1;
|
||||
for (unsigned dim = 0; dim < ndim - 1; dim++) {
|
||||
num_rows *= THCTensor_size(state, src, dim);
|
||||
@ -654,13 +654,13 @@ THC_reduceDimIndex(THCState *state,
|
||||
BinaryFunction binary_op)
|
||||
{
|
||||
THArgCheck(dimension >= 0 &&
|
||||
dimension < THCTensor__nDimension(state, src),
|
||||
dimension < THCTensor_nDimensionLegacyAll(state, src),
|
||||
3, "dimension out of range");
|
||||
|
||||
|
||||
// Unsqueeze tgt1_/tgt_2 if necessary so that their contiguity traits
|
||||
// are preserved if they are the same size as the correct reduction output.
|
||||
int src_dims = THCTensor__nDimension(state, src);
|
||||
int src_dims = THCTensor_nDimensionLegacyAll(state, src);
|
||||
THCTensor_preserveReduceDimSemantics(
|
||||
state, tgt1_, src_dims, dimension, keepdim);
|
||||
THCTensor_preserveReduceDimSemantics(
|
||||
@ -676,7 +676,7 @@ THC_reduceDimIndex(THCState *state,
|
||||
TensorTypeIndex *tgt2 = (TensorTypeIndex*)THCTensor_newContiguous<ScalarTypeIndex>(state, tgt2_);
|
||||
src = (TensorTypeK*)THCTensor_newContiguous<ScalarTypeK>(state, src);
|
||||
|
||||
if (dimension == THCTensor__nDimension(state, src) - 1) {
|
||||
if (dimension == THCTensor_nDimensionLegacyAll(state, src) - 1) {
|
||||
THC_transformReduceInnermostDimIndex(state, tgt1, tgt2, src, init, binary_op);
|
||||
} else {
|
||||
THC_transformReduceOuterDimIndex(state, tgt1, tgt2, src, dimension, init, binary_op);
|
||||
|
@ -3,7 +3,7 @@
|
||||
void THCudaLongTensor_fillSliceWithIndex(THCState* state,
|
||||
THCudaLongTensor* t,
|
||||
int dim) {
|
||||
int64_t dims = THCudaLongTensor_nDimension(state, t);
|
||||
int64_t dims = THCudaLongTensor_nDimensionLegacyNoScalars(state, t);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
|
||||
ptrdiff_t inElements = THCudaLongTensor_nElement(state, t);
|
||||
|
@ -60,7 +60,7 @@ getTensorInfo(THCState* state, TensorType* t) {
|
||||
IndexType sz[MAX_CUTORCH_DIMS];
|
||||
IndexType st[MAX_CUTORCH_DIMS];
|
||||
|
||||
int dims = THCTensor_nDimension(state, t);
|
||||
int dims = THCTensor_nDimensionLegacyNoScalars(state, t);
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
sz[i] = THCTensor_size(state, t, i);
|
||||
st[i] = THCTensor_stride(state, t, i);
|
||||
|
@ -13,14 +13,14 @@ ptrdiff_t THCTensor_(storageOffset)(THCState *state, const THCTensor *self)
|
||||
return self->storage_offset();
|
||||
}
|
||||
|
||||
int THCTensor_(nDimension)(THCState *state, const THCTensor *self)
|
||||
int THCTensor_(nDimensionLegacyNoScalars)(THCState *state, const THCTensor *self)
|
||||
{
|
||||
return THCTensor_nDimension(state, self);
|
||||
return THCTensor_nDimensionLegacyNoScalars(state, self);
|
||||
}
|
||||
|
||||
int THCTensor_(_nDimension)(THCState *state, const THCTensor *self)
|
||||
int THCTensor_(nDimensionLegacyAll)(THCState *state, const THCTensor *self)
|
||||
{
|
||||
return THCTensor__nDimension(state, self);
|
||||
return THCTensor_nDimensionLegacyAll(state, self);
|
||||
}
|
||||
|
||||
int64_t THCTensor_(size)(THCState *state, const THCTensor *self, int dim)
|
||||
@ -236,7 +236,7 @@ THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLongStorage
|
||||
// Collapses the first two dimensions of a tensor.
|
||||
// Assumes the input tensor is contiguous.
|
||||
THCTensor *THCTensor_(newFoldBatchDim)(THCState *state, THCTensor *input) {
|
||||
int in_dims = THCTensor_(_nDimension)(state, input);
|
||||
int in_dims = THCTensor_(nDimensionLegacyAll)(state, input);
|
||||
THArgCheck(in_dims >= 2, 1, "Tensor needs to have at least two dimensions");
|
||||
THArgCheck(THCTensor_(isContiguous)(state, input), 1,
|
||||
"Tensor must be contiguous");
|
||||
@ -391,7 +391,7 @@ void THCTensor_(select)(THCState *state, THCTensor *self, THCTensor *src, int di
|
||||
src = self;
|
||||
|
||||
#ifndef USE_TH_SIZE_ZERO_DIM
|
||||
THArgCheck(src->_dim() > 1, 1, "cannot select on a vector");
|
||||
THArgCheck(THTensor_nDimensionLegacyAll(src) > 1, 1, "cannot select on a vector");
|
||||
#else
|
||||
#ifndef USE_TH_SCALAR
|
||||
THArgCheck(src->dim() > 1, 1, "cannot select on a vector");
|
||||
|
@ -21,9 +21,9 @@ typedef struct THCTensor THCTensor;
|
||||
THC_API THCStorage* THCTensor_(storage)(THCState *state, const THCTensor *self);
|
||||
THC_API ptrdiff_t THCTensor_(storageOffset)(THCState *state, const THCTensor *self);
|
||||
|
||||
// See [NOTE: _dim() vs dim()]; _nDimension corresponds to _dim(), nDimension corresponds to dim().
|
||||
THC_API int THCTensor_(nDimension)(THCState *state, const THCTensor *self);
|
||||
THC_API int THCTensor_(_nDimension)(THCState *state, const THCTensor *self);
|
||||
// See [NOTE: nDimension vs nDimensionLegacyNoScalars vs nDimensionLegacyAll]
|
||||
THC_API int THCTensor_(nDimensionLegacyNoScalars)(THCState *state, const THCTensor *self);
|
||||
THC_API int THCTensor_(nDimensionLegacyAll)(THCState *state, const THCTensor *self);
|
||||
|
||||
THC_API int64_t THCTensor_(size)(THCState *state, const THCTensor *self, int dim);
|
||||
THC_API int64_t THCTensor_(stride)(THCState *state, const THCTensor *self, int dim);
|
||||
|
@ -9,10 +9,10 @@ static ptrdiff_t THCTensor_(getSliceSize)(THCState *state, THCTensor *dst,
|
||||
THCudaLongTensor *index,
|
||||
THCTensor *src)
|
||||
{
|
||||
int dstDims = THCTensor_(nDimension)(state, dst);
|
||||
int srcDims = (src == nullptr) ? dstDims : THCTensor_(nDimension)(state, src);
|
||||
int dstDims = THCTensor_(nDimensionLegacyNoScalars)(state, dst);
|
||||
int srcDims = (src == nullptr) ? dstDims : THCTensor_(nDimensionLegacyNoScalars)(state, src);
|
||||
|
||||
THArgCheck(THCudaLongTensor_nDimension(state, index) == 1, 4,
|
||||
THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, index) == 1, 4,
|
||||
"expecting vector of indices");
|
||||
THArgCheck(dim >= 0 && dim < dstDims, 2, "Indexing dim is out of bounds");
|
||||
|
||||
@ -97,11 +97,11 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, dst, src));
|
||||
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, indices));
|
||||
|
||||
int dims = THCTensor_(nDimension)(state, dst);
|
||||
int dims = THCTensor_(nDimensionLegacyNoScalars)(state, dst);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
dims = THCTensor_(nDimension)(state, src);
|
||||
dims = THCTensor_(nDimensionLegacyNoScalars)(state, src);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 5, CUTORCH_DIM_WARNING);
|
||||
dims = THCudaLongTensor_nDimension(state, indices);
|
||||
dims = THCudaLongTensor_nDimensionLegacyNoScalars(state, indices);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
|
||||
|
||||
// The `src` is partitioned into two parts:
|
||||
@ -222,9 +222,9 @@ void THCTensor_(take)(THCState *state, THCTensor *dst, THCTensor *src, THCudaLon
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, dst, src));
|
||||
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index));
|
||||
|
||||
THArgCheck(THCTensor_(nDimension)(state, src) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
THArgCheck(THCTensor_(nDimension)(state, dst) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
THArgCheck(THCudaLongTensor_nDimension(state, index) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, src) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, dst) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, index) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
THArgCheck(!(THCTensor_(numel)(state, src) == 0 && THCudaLongTensor_numel(state, index) != 0), 2,
|
||||
"tried to take from an empty tensor");
|
||||
|
||||
@ -255,9 +255,9 @@ void THCTensor_(put)(THCState *state, THCTensor *dst, THCudaLongTensor *index, T
|
||||
THArgCheck(THCTensor_(nElement)(state, src) == numIndices,
|
||||
3, "src should have the same number of elements as index");
|
||||
|
||||
THArgCheck(THCTensor_(nDimension)(state, dst) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
THArgCheck(THCTensor_(nDimension)(state, src) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
THArgCheck(THCudaLongTensor_nDimension(state, index) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, dst) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, src) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, index) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
|
||||
if (numIndices == 0) {
|
||||
return;
|
||||
@ -286,11 +286,11 @@ void THCTensor_(indexAdd)(THCState *state, THCTensor *dst, int dim, THCudaLongTe
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, dst, src));
|
||||
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, indices));
|
||||
|
||||
int dims = THCTensor_(nDimension)(state, dst);
|
||||
int dims = THCTensor_(nDimensionLegacyNoScalars)(state, dst);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
dims = THCTensor_(nDimension)(state, src);
|
||||
dims = THCTensor_(nDimensionLegacyNoScalars)(state, src);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 5, CUTORCH_DIM_WARNING);
|
||||
dims = THCudaLongTensor_nDimension(state, indices);
|
||||
dims = THCudaLongTensor_nDimensionLegacyNoScalars(state, indices);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
|
||||
|
||||
// The `src` is partitioned into two parts:
|
||||
@ -409,9 +409,9 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT
|
||||
{
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, dst));
|
||||
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, indices));
|
||||
int dims = THCTensor_(nDimension)(state, dst);
|
||||
int dims = THCTensor_(nDimensionLegacyNoScalars)(state, dst);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
dims = THCudaLongTensor_nDimension(state, indices);
|
||||
dims = THCudaLongTensor_nDimensionLegacyNoScalars(state, indices);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
|
||||
|
||||
// The `src` is partitioned into two parts:
|
||||
@ -518,19 +518,19 @@ void THCTensor_(indexSelect)(THCState *state, THCTensor *dst, THCTensor *src, in
|
||||
{
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, dst, src, indices));
|
||||
|
||||
int dims = THCTensor_(nDimension)(state, dst);
|
||||
int dims = THCTensor_(nDimensionLegacyNoScalars)(state, dst);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
dims = THCTensor_(nDimension)(state, src);
|
||||
dims = THCTensor_(nDimensionLegacyNoScalars)(state, src);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING);
|
||||
dims = THCudaLongTensor_nDimension(state, indices);
|
||||
dims = THCudaLongTensor_nDimensionLegacyNoScalars(state, indices);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 5, CUTORCH_DIM_WARNING);
|
||||
|
||||
ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices);
|
||||
|
||||
int srcDims = THCTensor_(nDimension)(state, src);
|
||||
int srcDims = THCTensor_(nDimensionLegacyNoScalars)(state, src);
|
||||
cudaStream_t stream = THCState_getCurrentStream(state);
|
||||
|
||||
THArgCheck(THCudaLongTensor_nDimension(state, indices) <= 1, 3,
|
||||
THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, indices) <= 1, 3,
|
||||
"Index is supposed to be an empty tensor or a vector");
|
||||
THArgCheck(dim < srcDims, 4, "Indexing dim is out of bounds");
|
||||
THArgCheck(srcDims > 0, 2, "Source tensor is empty");
|
||||
|
@ -272,7 +272,7 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
|
||||
self = THCTensor_(newContiguous)(state, self);
|
||||
thrust::device_ptr<real> self_data(THCTensor_(data)(state, self));
|
||||
|
||||
int num_dim = THCTensor_(nDimension)(state, self);
|
||||
int num_dim = THCTensor_(nDimensionLegacyNoScalars)(state, self);
|
||||
int64_t N = THCTensor_(nElement)(state, self);
|
||||
|
||||
THCudaLongTensor_resize2d(state, tensor, N, num_dim);
|
||||
@ -329,7 +329,7 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
|
||||
|
||||
void THCTensor_(diag)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k){
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
|
||||
int nDimension = THCTensor_(nDimension)(state, src_);
|
||||
int nDimension = THCTensor_(nDimensionLegacyNoScalars)(state, src_);
|
||||
#ifndef USE_TH_SIZE_ZERO_DIM
|
||||
AT_ASSERT(!src_->is_empty());
|
||||
#endif
|
||||
@ -392,7 +392,7 @@ void THCTensor_(eye)(THCState *state, THCTensor *self_, int64_t n, int64_t m)
|
||||
|
||||
accreal THCTensor_(trace)(THCState *state, THCTensor *src_) {
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, src_));
|
||||
THArgCheck((src_->_dim() == 2), 1, "expected a matrix");
|
||||
THArgCheck((THTensor_nDimensionLegacyAll(src_) == 2), 1, "expected a matrix");
|
||||
THCTensor *diag = THCTensor_(new)(state);
|
||||
THCTensor_(diag)(state, diag, src_, 0);
|
||||
accreal trace = THCTensor_(sumall)(state, diag);
|
||||
|
@ -409,9 +409,9 @@ THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
|
||||
real alpha, THCTensor *batch1, THCTensor *batch2) {
|
||||
#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2));
|
||||
THArgCheck(THCTensor_(nDimension)(state, t) == 2, 4, "expected 2D tensor");
|
||||
THArgCheck(THCTensor_(nDimension)(state, batch1) == 3, 6, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(nDimension)(state, batch2) == 3, 7, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, t) == 2, 4, "expected 2D tensor");
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, batch1) == 3, 6, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, batch2) == 3, 7, "expected 3D tensor");
|
||||
|
||||
int64_t batchnum = THCTensor_(size)(state, batch1, 0);
|
||||
int64_t m1d1 = THCTensor_(size)(state, batch1, 1);
|
||||
@ -474,9 +474,9 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
|
||||
real alpha, THCTensor *batch1, THCTensor *batch2) {
|
||||
#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2));
|
||||
THArgCheck(THCTensor_(nDimension)(state, t) == 3, 4, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(nDimension)(state, batch1) == 3, 6, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(nDimension)(state, batch2) == 3, 7, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, t) == 3, 4, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, batch1) == 3, 6, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, batch2) == 3, 7, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(size)(state, t, 0) == THCTensor_(size)(state, batch1, 0), 6,
|
||||
"equal number of batches expected");
|
||||
THArgCheck(THCTensor_(size)(state, t, 0) == THCTensor_(size)(state, batch2, 0), 7,
|
||||
@ -740,7 +740,7 @@ THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTens
|
||||
{
|
||||
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
|
||||
THAssert(THCTensor_(checkGPU)(state, 2, ra_, a));
|
||||
THArgCheck(THCTensor_(nDimension)(state, a) == 3, 3, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, a) == 3, 3, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(size)(state, a, 1) ==
|
||||
THCTensor_(size)(state, a, 2), 3, "matrices must be square");
|
||||
|
||||
@ -845,9 +845,9 @@ THC_API void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b
|
||||
{
|
||||
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
|
||||
THAssert(THCTensor_(checkGPU)(state, 3, rb_, atf, b));
|
||||
THArgCheck(THCTensor_(_nDimension)(state, atf) == 3, 3, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(_nDimension)(state, b) == 3 ||
|
||||
THCTensor_(_nDimension)(state, b) == 2, 4, "expected 2D or 3D tensor");
|
||||
THArgCheck(THCTensor_(nDimensionLegacyAll)(state, atf) == 3, 3, "expected 3D tensor");
|
||||
THArgCheck(THCTensor_(nDimensionLegacyAll)(state, b) == 3 ||
|
||||
THCTensor_(nDimensionLegacyAll)(state, b) == 2, 4, "expected 2D or 3D tensor");
|
||||
THArgCheck(THCTensor_(size)(state, atf, 0) ==
|
||||
THCTensor_(size)(state, b, 0), 3, "number of batches must be equal");
|
||||
THArgCheck(THCTensor_(size)(state, atf, 1) ==
|
||||
@ -862,7 +862,7 @@ THC_API void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b
|
||||
|
||||
|
||||
int n = atf->size(1);
|
||||
int nrhs = rb_->_dim() > 2 ? rb_->size(2) : 1;
|
||||
int nrhs = THTensor_nDimensionLegacyAll(rb_) > 2 ? rb_->size(2) : 1;
|
||||
THCTensor *atf_;
|
||||
THCTensor *rb__;
|
||||
int lda, ldb;
|
||||
@ -887,7 +887,7 @@ THC_API void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b
|
||||
// correct ordering of B
|
||||
if (rb_->stride(1) == 1) {
|
||||
// column ordered
|
||||
if (rb_->_dim() == 2 || rb_->size(2) == 1) {
|
||||
if (THTensor_nDimensionLegacyAll(rb_) == 2 || rb_->size(2) == 1) {
|
||||
ldb = n;
|
||||
} else {
|
||||
ldb = rb_->stride(2);
|
||||
@ -895,7 +895,7 @@ THC_API void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b
|
||||
rb__ = rb_;
|
||||
} else {
|
||||
// make column ordered
|
||||
if (rb_->_dim() > 2) {
|
||||
if (THTensor_nDimensionLegacyAll(rb_) > 2) {
|
||||
THCTensor *transp_r_ = THCTensor_(newTranspose)(state, rb_, 1, 2);
|
||||
rb__ = THCTensor_(newClone)(state, transp_r_);
|
||||
THCTensor_(free)(state, transp_r_);
|
||||
|
@ -114,9 +114,9 @@ THCTensor_(cross)(THCState *state, THCTensor *self, THCTensor *x, THCTensor *y,
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self, x, y));
|
||||
|
||||
int i;
|
||||
int nd = THCTensor_(nDimension)(state, x);
|
||||
int nd = THCTensor_(nDimensionLegacyNoScalars)(state, x);
|
||||
ptrdiff_t nelem = THCTensor_(nElement)(state, x);
|
||||
THArgCheck(nd == THCTensor_(nDimension)(state, y), 1, "tensors must have same number of dimensions");
|
||||
THArgCheck(nd == THCTensor_(nDimensionLegacyNoScalars)(state, y), 1, "tensors must have same number of dimensions");
|
||||
for (i = 0; i < nd; i++) {
|
||||
THArgCheck(THCTensor_(size)(state, x, i) == THCTensor_(size)(state, y, i), 1, "dimension %i of x and y does not match", i);
|
||||
if (dimension < 0 && THCTensor_(size)(state, x, i) == 3) {
|
||||
|
@ -63,9 +63,9 @@ THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value,
|
||||
THCTensor *data = THCTensor_(newClone)(state, src_);
|
||||
int64_t numel = THCTensor_(nElement)(state, data);
|
||||
|
||||
THArgCheck(dimension >= 0 && dimension < THCTensor_(nDimension)(state, src), 3, "invalid dimension");
|
||||
THArgCheck(dimension >= 0 && dimension < THCTensor_(nDimensionLegacyNoScalars)(state, src), 3, "invalid dimension");
|
||||
THArgCheck(THCNumerics<real>::gt(value, scalar_cast<real>(0)), 2, "non-positive-norm not supported");
|
||||
THArgCheck(THCTensor_(nDimension)(state, src) > 1, 1, "need at least 2 dimensions");
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, src) > 1, 1, "need at least 2 dimensions");
|
||||
|
||||
if (numel > 0) {
|
||||
ptrdiff_t size = numel / data->size(0);
|
||||
@ -94,7 +94,7 @@ THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, int dimension
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
|
||||
|
||||
THCTensor_preserveReduceDimSemantics(
|
||||
state, self_, THCTensor_(_nDimension)(state, src), dimension, keepdim);
|
||||
state, self_, THCTensor_(nDimensionLegacyAll)(state, src), dimension, keepdim);
|
||||
THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
|
||||
THLongStorage_set(dim, dimension, 1);
|
||||
THCTensor_(resize)(state, self_, dim, NULL);
|
||||
@ -103,7 +103,7 @@ THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, int dimension
|
||||
THCTensor *self = THCTensor_(newContiguous)(state, self_);
|
||||
src = THCTensor_(newContiguous)(state, src);
|
||||
|
||||
if (dimension == THCTensor_(_nDimension)(state, src) - 1) {
|
||||
if (dimension == THCTensor_(nDimensionLegacyAll)(state, src) - 1) {
|
||||
THCTensor_varInnermostDim<THCTensor, real, accreal, true>(state, self, src, biased);
|
||||
} else {
|
||||
THCTensor_varOuterDim<THCTensor, real, accreal, true>(state, self, src, dimension, biased);
|
||||
@ -123,7 +123,7 @@ THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, int dimension
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
|
||||
|
||||
THCTensor_preserveReduceDimSemantics(
|
||||
state, self_, THCTensor_(_nDimension)(state, src), dimension, keepdim);
|
||||
state, self_, THCTensor_(nDimensionLegacyAll)(state, src), dimension, keepdim);
|
||||
THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
|
||||
THLongStorage_set(dim, dimension, 1);
|
||||
THCTensor_(resize)(state, self_, dim, NULL);
|
||||
@ -132,7 +132,7 @@ THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, int dimension
|
||||
THCTensor *self = THCTensor_(newContiguous)(state, self_);
|
||||
src = THCTensor_(newContiguous)(state, src);
|
||||
|
||||
if (dimension == THCTensor_(_nDimension)(state, src) - 1) {
|
||||
if (dimension == THCTensor_(nDimensionLegacyAll)(state, src) - 1) {
|
||||
THCTensor_varInnermostDim<THCTensor, real, accreal, false>(state, self, src, biased);
|
||||
} else {
|
||||
THCTensor_varOuterDim<THCTensor, real, accreal, false>(state, self, src, dimension, biased);
|
||||
|
@ -28,7 +28,7 @@ __host__ void THCTensor_(scanOuterDim)(THCState *state, THCTensor *tgt,
|
||||
THCTensor *src, int dimension,
|
||||
real init, BinaryOp binary_op)
|
||||
{
|
||||
unsigned ndim = THCTensor_(_nDimension)(state, src);
|
||||
unsigned ndim = THCTensor_(nDimensionLegacyAll)(state, src);
|
||||
// Treat all outer dimensions (i.e. dim < dimension) as one.
|
||||
unsigned num_orows = 1;
|
||||
for (int dim = 0; dim < dimension; dim++) {
|
||||
@ -57,7 +57,7 @@ __host__ void THCTensor_(scanInnermostDim)(THCState *state, THCTensor *tgt,
|
||||
THCTensor *src, real init,
|
||||
BinaryFunction binary_op)
|
||||
{
|
||||
unsigned ndim = THCTensor_(_nDimension)(state, src);
|
||||
unsigned ndim = THCTensor_(nDimensionLegacyAll)(state, src);
|
||||
// Treat all outer dimensions as a single dimension.
|
||||
unsigned num_rows = 1;
|
||||
for (unsigned dim = 0; dim < ndim - 1; dim++) {
|
||||
@ -79,7 +79,7 @@ void THCTensor_(scanDim)(THCState *state, THCTensor *self_, THCTensor *src,
|
||||
int dimension, real init, BinaryFunction binary_op)
|
||||
{
|
||||
// "init" must be the identity element for binary_op
|
||||
int ndim = THCTensor_(nDimension)(state, src);
|
||||
int ndim = THCTensor_(nDimensionLegacyNoScalars)(state, src);
|
||||
THArgCheck(dimension >= 0 && dimension < ndim, 3, "dimension %d out of range",
|
||||
dimension + TH_INDEX_BASE);
|
||||
|
||||
|
@ -20,7 +20,7 @@ THC_API void THCTensor_(calculateMode)(THCState *state,
|
||||
data += THLongStorage_data(position)[i] * THCTensor_(stride)(state, input, i);
|
||||
}
|
||||
|
||||
int64_t nElement = THCTensor_(size)(state, input, THCTensor_(_nDimension)(state, input) - 1);
|
||||
int64_t nElement = THCTensor_(size)(state, input, THCTensor_(nDimensionLegacyAll)(state, input) - 1);
|
||||
THCThrustAllocator thrustAlloc(state);
|
||||
|
||||
// Wrap input data, sortBuffer, in Thrust device vectors
|
||||
@ -137,7 +137,7 @@ THC_API void THCTensor_(dimApplyMode)(THCState *state,
|
||||
int dimension,
|
||||
THLongStorage *position,
|
||||
int curDim) {
|
||||
int64_t ndim = THCTensor_(_nDimension)(state, input);
|
||||
int64_t ndim = THCTensor_(nDimensionLegacyAll)(state, input);
|
||||
|
||||
// Because we have transposed the Tensor, the data for the dimension we are mode'ing along
|
||||
// is always in the innermost dimension
|
||||
@ -172,7 +172,7 @@ THC_API void THCTensor_(mode)(THCState *state,
|
||||
THAssert(THCTensor_(checkGPU)(state, 1, values));
|
||||
|
||||
// Verify they are asking for a valid dimension
|
||||
ndim = THCTensor_(_nDimension)(state, input);
|
||||
ndim = THCTensor_(nDimensionLegacyAll)(state, input);
|
||||
THArgCheck(dimension >= 0 && dimension < ndim, 4, "Dimension of out bounds");
|
||||
|
||||
sliceSize = THCTensor_(size)(state, input, dimension);
|
||||
|
@ -109,7 +109,7 @@ THC_API void THCTensor_(cauchy)(THCState* state, THCTensor *self_, double median
|
||||
|
||||
void THCTensor_(renormRows)(struct THCState* state,
|
||||
THCTensor* t) {
|
||||
THAssert(THCTensor_(_nDimension)(state, t) == 2);
|
||||
THAssert(THCTensor_(nDimensionLegacyAll)(state, t) == 2);
|
||||
int64_t rows = THCTensor_(size)(state, t, 0);
|
||||
int64_t cols = THCTensor_(size)(state, t, 1);
|
||||
|
||||
@ -137,7 +137,7 @@ THC_API void THCTensor_(multinomial)(struct THCState *state,
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self, prob_dist));
|
||||
THCGenerator* gen = THCRandom_getGenerator(state);
|
||||
|
||||
int inputSize = THCTensor_(_nDimension)(state, prob_dist);
|
||||
int inputSize = THCTensor_(nDimensionLegacyAll)(state, prob_dist);
|
||||
THArgCheck(inputSize > 0 && inputSize <= 2, 2,
|
||||
"prob_dist must be 1 or 2 dim");
|
||||
|
||||
|
@ -12,25 +12,25 @@ void THCTensor_(gather)(THCState* state, THCTensor *tensor,
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src));
|
||||
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index));
|
||||
|
||||
THArgCheck(THCudaLongTensor_nDimension(state, index) == THCTensor_(nDimension)(state, src), 4,
|
||||
THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, index) == THCTensor_(nDimensionLegacyNoScalars)(state, src), 4,
|
||||
"Index tensor must have same dimensions as input tensor");
|
||||
THLongStorage *indexSize = THCudaLongTensor_newSizeOf(state, index);
|
||||
THArgCheck(THCTensor_(isSize)(state, tensor, indexSize), 4,
|
||||
"Index tensor must have the same size as output tensor.");
|
||||
THLongStorage_free(indexSize);
|
||||
THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 3,
|
||||
THArgCheck(dim >= 0 && dim < THCTensor_(nDimensionLegacyNoScalars)(state, tensor), 3,
|
||||
"Index dimension is out of bounds");
|
||||
THArgCheck(THCTensor_(nDimension)(state, src) == THCTensor_(nDimension)(state, tensor), 2,
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, src) == THCTensor_(nDimensionLegacyNoScalars)(state, tensor), 2,
|
||||
"Input tensor must have same dimensions as output tensor");
|
||||
|
||||
for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) {
|
||||
for (int d = 0; d < THCTensor_(nDimensionLegacyNoScalars)(state, tensor); d++) {
|
||||
if (d != dim) {
|
||||
THArgCheck(THCTensor_(size)(state, tensor, d) == THCTensor_(size)(state, src, d), 2,
|
||||
"Input tensor must have same size as output tensor apart from the specified dimension");
|
||||
}
|
||||
}
|
||||
|
||||
THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS,
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, tensor) <= MAX_CUTORCH_DIMS,
|
||||
1, CUTORCH_DIM_WARNING);
|
||||
|
||||
|
||||
@ -109,14 +109,14 @@ void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLong
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src));
|
||||
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index));
|
||||
|
||||
THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 2,
|
||||
THArgCheck(dim >= 0 && dim < THCTensor_(nDimensionLegacyNoScalars)(state, tensor), 2,
|
||||
"Index dimension is out of bounds");
|
||||
THArgCheck(THCudaLongTensor_nDimension(state, index) == THCTensor_(nDimension)(state, src), 3,
|
||||
THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, index) == THCTensor_(nDimensionLegacyNoScalars)(state, src), 3,
|
||||
"Index tensor must have same dimensions as input tensor");
|
||||
THArgCheck(THCTensor_(nDimension)(state, src) == THCTensor_(nDimension)(state, tensor), 4,
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, src) == THCTensor_(nDimensionLegacyNoScalars)(state, tensor), 4,
|
||||
"Input tensor must have same dimensions as output tensor");
|
||||
|
||||
for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) {
|
||||
for (int d = 0; d < THCTensor_(nDimensionLegacyNoScalars)(state, tensor); d++) {
|
||||
int64_t indexSizeD = THCudaLongTensor_size(state, index, d);
|
||||
if (d != dim) {
|
||||
THArgCheck(indexSizeD <= THCTensor_(size)(state, tensor, d), 3,
|
||||
@ -128,7 +128,7 @@ void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLong
|
||||
THCudaLongTensor_sizeDesc(state, index).str, THCTensor_(sizeDesc)(state, src).str);
|
||||
}
|
||||
|
||||
THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS,
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, tensor) <= MAX_CUTORCH_DIMS,
|
||||
1, CUTORCH_DIM_WARNING);
|
||||
|
||||
const ptrdiff_t totalElements = THCudaLongTensor_nElement(state, index);
|
||||
@ -201,14 +201,14 @@ void THCTensor_(scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaL
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src));
|
||||
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index));
|
||||
|
||||
THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 2,
|
||||
THArgCheck(dim >= 0 && dim < THCTensor_(nDimensionLegacyNoScalars)(state, tensor), 2,
|
||||
"Index dimension is out of bounds");
|
||||
THArgCheck(THCudaLongTensor_nDimension(state, index) == THCTensor_(nDimension)(state, src), 3,
|
||||
THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, index) == THCTensor_(nDimensionLegacyNoScalars)(state, src), 3,
|
||||
"Index tensor must have same dimensions as input tensor");
|
||||
THArgCheck(THCTensor_(nDimension)(state, src) == THCTensor_(nDimension)(state, tensor), 4,
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, src) == THCTensor_(nDimensionLegacyNoScalars)(state, tensor), 4,
|
||||
"Input tensor must have same dimensions as output tensor");
|
||||
|
||||
for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) {
|
||||
for (int d = 0; d < THCTensor_(nDimensionLegacyNoScalars)(state, tensor); d++) {
|
||||
int64_t indexSizeD = THCudaLongTensor_size(state, index, d);
|
||||
if (d != dim) {
|
||||
THArgCheck(indexSizeD <= THCTensor_(size)(state, tensor, d), 3,
|
||||
@ -220,7 +220,7 @@ void THCTensor_(scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaL
|
||||
THCudaLongTensor_sizeDesc(state, index).str, THCTensor_(sizeDesc)(state, src).str);
|
||||
}
|
||||
|
||||
THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS,
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, tensor) <= MAX_CUTORCH_DIMS,
|
||||
1, CUTORCH_DIM_WARNING);
|
||||
|
||||
const ptrdiff_t totalElements = THCudaLongTensor_nElement(state, index);
|
||||
@ -295,13 +295,13 @@ THCTensor_(scatterFill)(THCState* state, THCTensor *tensor,
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, tensor));
|
||||
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index));
|
||||
|
||||
THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 2,
|
||||
THArgCheck(dim >= 0 && dim < THCTensor_(nDimensionLegacyNoScalars)(state, tensor), 2,
|
||||
"Index dimension is out of bounds");
|
||||
THArgCheck(THCudaLongTensor_nDimension(state, index) ==
|
||||
THCTensor_(nDimension)(state, tensor), 3,
|
||||
THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, index) ==
|
||||
THCTensor_(nDimensionLegacyNoScalars)(state, tensor), 3,
|
||||
"Index tensor must have same dimensions as output tensor");
|
||||
|
||||
for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) {
|
||||
for (int d = 0; d < THCTensor_(nDimensionLegacyNoScalars)(state, tensor); d++) {
|
||||
if (d != dim) {
|
||||
THArgCheck(THCTensor_(size)(state, tensor, d) ==
|
||||
THCudaLongTensor_size(state, index, d), 4,
|
||||
@ -309,7 +309,7 @@ THCTensor_(scatterFill)(THCState* state, THCTensor *tensor,
|
||||
}
|
||||
}
|
||||
|
||||
THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS,
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, tensor) <= MAX_CUTORCH_DIMS,
|
||||
1, CUTORCH_DIM_WARNING);
|
||||
|
||||
const ptrdiff_t totalElements = THCudaLongTensor_nElement(state, index);
|
||||
|
@ -13,9 +13,9 @@ THC_API void THCTensor_(sortKeyValueInplace)(THCState* state,
|
||||
THArgCheck(THCTensor_(isSize)(state, key, valueSize), 2,
|
||||
"Key tensor must have same size as value tensor");
|
||||
THLongStorage_free(valueSize);
|
||||
int dims = THCudaLongTensor_nDimension(state, value);
|
||||
int dims = THCudaLongTensor_nDimensionLegacyNoScalars(state, value);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING);
|
||||
dims = THCTensor_(nDimension)(state, key);
|
||||
dims = THCTensor_(nDimensionLegacyNoScalars)(state, key);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
|
||||
ptrdiff_t inElements = THCTensor_(nElement)(state, key);
|
||||
@ -158,7 +158,7 @@ void THCTensor_(sortViaThrust)(THCState* state,
|
||||
THCudaLongTensor* indices,
|
||||
THCTensor* input,
|
||||
int dim, bool dir) {
|
||||
int nDims = THCTensor_(_nDimension)(state, input);
|
||||
int nDims = THCTensor_(nDimensionLegacyAll)(state, input);
|
||||
|
||||
ptrdiff_t totalElements = THCTensor_(nElement)(state, input);
|
||||
int64_t sliceSize = THCTensor_(size)(state, input, dim);
|
||||
@ -283,11 +283,11 @@ THC_API void THCTensor_(sort)(THCState* state,
|
||||
int dim, int order) {
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, sorted, input));
|
||||
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, indices));
|
||||
int64_t dims = THCTensor_(nDimension)(state, sorted);
|
||||
int64_t dims = THCTensor_(nDimensionLegacyNoScalars)(state, sorted);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
dims = THCTensor_(nDimension)(state, input);
|
||||
dims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
|
||||
dims = THCudaLongTensor_nDimension(state, indices);
|
||||
dims = THCudaLongTensor_nDimensionLegacyNoScalars(state, indices);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING);
|
||||
|
||||
// Make sure sufficient output space is allocated
|
||||
|
@ -9,10 +9,10 @@ THC_API void THCTensor_(topk)(THCState* state,
|
||||
int64_t k, int dim, int dir, int sorted) {
|
||||
THAssert(topK != NULL && indices != NULL && input_ != NULL);
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, topK, indices, input_));
|
||||
THArgCheck(THCTensor_(nDimension)(state, topK) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
int64_t dims = THCudaLongTensor_nDimension(state, indices);
|
||||
THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, topK) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
|
||||
int64_t dims = THCudaLongTensor_nDimensionLegacyNoScalars(state, indices);
|
||||
THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING);
|
||||
int numDims = THCTensor_(nDimension)(state, input_);
|
||||
int numDims = THCTensor_(nDimensionLegacyNoScalars)(state, input_);
|
||||
THArgCheck(numDims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
|
||||
|
||||
THArgCheck(dim >= 0 && dim < numDims, 6, "dim not in range");
|
||||
|
@ -61,7 +61,7 @@ inline int GET_BLOCKS(const int N)
|
||||
}
|
||||
|
||||
#define THCUNN_check_dim_size(STATE, T, DIM, DIM_SIZE, SIZE) \
|
||||
if (THCTensor_(nDimension)(STATE, T) != DIM || \
|
||||
if (THCTensor_(nDimensionLegacyNoScalars)(STATE, T) != DIM || \
|
||||
THCTensor_(size)(STATE, T, DIM_SIZE) != SIZE) { \
|
||||
THCDescBuff s1 = THCTensor_(sizeDesc)(state, T); \
|
||||
THError("Need " #T " of dimension %d and " #T ".size[%d] == %d" \
|
||||
@ -69,7 +69,7 @@ inline int GET_BLOCKS(const int N)
|
||||
}
|
||||
|
||||
#define THCUNN_check_dim_size_indices(STATE, T, DIM, DIM_SIZE, SIZE) \
|
||||
if (THCIndexTensor_(nDimension)(STATE, T) != DIM || \
|
||||
if (THCIndexTensor_(nDimensionLegacyNoScalars)(STATE, T) != DIM || \
|
||||
THCIndexTensor_(size)(STATE, T, DIM_SIZE) != SIZE) { \
|
||||
THCDescBuff s1 = THCIndexTensor_(sizeDesc)(state, T); \
|
||||
THError("Need " #T " of dimension %d and " #T ".size[%d] == %d" \
|
||||
|
@ -11,7 +11,7 @@ static THCDeviceTensor<real, Dim> THNN_(devicetensor)(THCState *state, THCTensor
|
||||
return THCDeviceTensor<real, Dim>();
|
||||
}
|
||||
|
||||
int inDim = THCTensor__nDimension(state, t);
|
||||
int inDim = THCTensor_nDimensionLegacyAll(state, t);
|
||||
if (inDim == Dim) {
|
||||
return toDeviceTensor<real, Dim>(state, t);
|
||||
}
|
||||
|
@ -11,11 +11,11 @@ void THNN_(ClassNLLCriterion_updateOutput)(
|
||||
THCTensor *weights,
|
||||
THCTensor *total_weight,
|
||||
int64_t ignore_index) {
|
||||
if (THCIndexTensor_(nDimension)(state, target) > 1) {
|
||||
if (THCIndexTensor_(nDimensionLegacyNoScalars)(state, target) > 1) {
|
||||
THError("multi-target not supported");
|
||||
}
|
||||
|
||||
int n_dims = THCTensor_(nDimension)(state, input);
|
||||
int n_dims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
int n_classes = THCTensor_(size)(state, input, n_dims - 1);
|
||||
ignore_index -= TH_INDEX_BASE;
|
||||
|
||||
@ -80,7 +80,7 @@ void THNN_(ClassNLLCriterion_updateOutput)(
|
||||
real *output_data = THCTensor_(data)(state, output);
|
||||
real *total_weight_data = THCTensor_(data)(state, total_weight);
|
||||
|
||||
if (THCTensor_(nDimension)(state, input) == 1) {
|
||||
if (THCTensor_(nDimensionLegacyNoScalars)(state, input) == 1) {
|
||||
cunn_ClassNLLCriterion_updateOutput_kernel1<real>
|
||||
<<<1, 1, 0, THCState_getCurrentStream(state)>>>(
|
||||
output_data,
|
||||
@ -93,7 +93,7 @@ void THNN_(ClassNLLCriterion_updateOutput)(
|
||||
ignore_index
|
||||
);
|
||||
|
||||
} else if (THCTensor_(nDimension)(state, input) == 2) {
|
||||
} else if (THCTensor_(nDimensionLegacyNoScalars)(state, input) == 2) {
|
||||
cunn_ClassNLLCriterion_updateOutput_kernel<real, accreal>
|
||||
<<<1, NTHREADS, 0, THCState_getCurrentStream(state)>>>(
|
||||
output_data,
|
||||
@ -127,11 +127,11 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
|
||||
THCTensor *weights,
|
||||
THCTensor *total_weight,
|
||||
int64_t ignore_index) {
|
||||
if (THCIndexTensor_(nDimension)(state, target) > 1) {
|
||||
if (THCIndexTensor_(nDimensionLegacyNoScalars)(state, target) > 1) {
|
||||
THError("multi-target not supported");
|
||||
}
|
||||
|
||||
int n_dims = THCTensor_(nDimension)(state, input);
|
||||
int n_dims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
int n_classes = THCTensor_(size)(state, input, n_dims - 1);
|
||||
|
||||
THCTensor_(resizeAs)(state, gradInput, input);
|
||||
@ -197,7 +197,7 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
|
||||
THCIndex_t *target_data = THCIndexTensor_(data)(state, target);
|
||||
real *total_weight_data = THCTensor_(data)(state, total_weight);
|
||||
|
||||
if (THCTensor_(nDimension)(state, input) == 1) {
|
||||
if (THCTensor_(nDimensionLegacyNoScalars)(state, input) == 1) {
|
||||
cunn_ClassNLLCriterion_updateGradInput_kernel1<real>
|
||||
<<<1, 1, 0, THCState_getCurrentStream(state)>>>(
|
||||
gradInput_data,
|
||||
|
@ -17,7 +17,7 @@ static inline void THNN_(Col2Im_shapeCheck)(
|
||||
THArgCheck(dW > 0 && dH > 0, 8,
|
||||
"dilation should be greater than zero, but got dH: %d dW: %d", dH, dW);
|
||||
|
||||
int64_t ndim = THCTensor_(nDimension)(state, input);
|
||||
int64_t ndim = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
THCUNN_argCheck(state, !input->is_empty() && (ndim == 2 || ndim == 3), 2, input,
|
||||
"Expected non-empty 2D or 3D input tensor, but got input of shape %s");
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
// [batch dim][feature dim][opt dim 1][opt dim 2]
|
||||
THCDeviceTensor<real, 4>
|
||||
THNN_(FeatureLPPooling_upcast)(THCState* state, THCTensor* t, bool batchMode) {
|
||||
int inputDim = THCTensor_(_nDimension)(state, t);
|
||||
int inputDim = THCTensor_(nDimensionLegacyAll)(state, t);
|
||||
|
||||
if (inputDim == 1) {
|
||||
// [feature dim]
|
||||
@ -58,7 +58,7 @@ THNN_(FeatureLPPooling_resizeForOutput)(THCState* state,
|
||||
bool batchMode,
|
||||
int width,
|
||||
int stride) {
|
||||
int inputDim = THCTensor_(_nDimension)(state, input);
|
||||
int inputDim = THCTensor_(nDimensionLegacyAll)(state, input);
|
||||
THAssert(inputDim >= 1 && inputDim <= 4);
|
||||
|
||||
int64_t outSize =
|
||||
@ -109,7 +109,7 @@ void
|
||||
THNN_(FeatureLPPooling_resize)(THCState* state,
|
||||
THCTensor* toResize,
|
||||
THCTensor* src) {
|
||||
int inputDim = THCTensor_(_nDimension)(state, src);
|
||||
int inputDim = THCTensor_(nDimensionLegacyAll)(state, src);
|
||||
THAssert(inputDim >= 1 && inputDim <= 4);
|
||||
|
||||
if (inputDim == 1) {
|
||||
@ -149,7 +149,7 @@ void THNN_(FeatureLPPooling_updateOutput)(THCState* state,
|
||||
bool batchMode) {
|
||||
THCUNN_assertSameGPU(state, 2, inputTH, outputTH);
|
||||
|
||||
int inputDim = THCTensor_(_nDimension)(state, inputTH);
|
||||
int inputDim = THCTensor_(nDimensionLegacyAll)(state, inputTH);
|
||||
|
||||
if (batchMode) {
|
||||
THArgCheck(inputDim >= 2 && inputDim <= 4, 2,
|
||||
@ -207,7 +207,7 @@ void THNN_(FeatureLPPooling_updateGradInput)(THCState* state,
|
||||
"input tensor must fit into 32-bit index math");
|
||||
THCUNN_assertSameGPU(state, 4, gradOutputTH, inputTH, outputTH, gradInputTH);
|
||||
|
||||
int inputDim = THCTensor_(_nDimension)(state, inputTH);
|
||||
int inputDim = THCTensor_(nDimensionLegacyAll)(state, inputTH);
|
||||
|
||||
if (batchMode) {
|
||||
THArgCheck(inputDim >= 2 && inputDim <= 4, 2,
|
||||
|
@ -18,10 +18,10 @@ void THNN_(FusedRNNAssertSizes)(THCState *state, int factor, int count, ...)
|
||||
THCTensor_(nElement)(state, hidden),
|
||||
3, "Input and Hidden tensor sizes should be the same.");
|
||||
|
||||
THAssertMsg(THCTensor__nDimension(state, input) <= MAX_CUTORCH_DIMS,
|
||||
THAssertMsg(THCTensor_nDimensionLegacyAll(state, input) <= MAX_CUTORCH_DIMS,
|
||||
"Tensor dimension is too large.");
|
||||
|
||||
THAssertMsg(THCTensor__nDimension(state, hidden) <= MAX_CUTORCH_DIMS,
|
||||
THAssertMsg(THCTensor_nDimensionLegacyAll(state, hidden) <= MAX_CUTORCH_DIMS,
|
||||
"Tensor dimension is too large.");
|
||||
|
||||
for (int arg=2; arg < count; ++arg){
|
||||
@ -29,7 +29,7 @@ void THNN_(FusedRNNAssertSizes)(THCState *state, int factor, int count, ...)
|
||||
THArgCheck(THCTensor_(nElement)(state, input) ==
|
||||
THCTensor_(nElement)(state, tens)*factor,
|
||||
3, "A pointwise tensor was not the right size, should have 1/%u the elements of input/hidden tensor.", arg, factor);
|
||||
THAssertMsg(THCTensor__nDimension(state, tens) <= MAX_CUTORCH_DIMS,
|
||||
THAssertMsg(THCTensor_nDimensionLegacyAll(state, tens) <= MAX_CUTORCH_DIMS,
|
||||
"Tensor dimension is too large.");
|
||||
}
|
||||
|
||||
@ -42,13 +42,13 @@ int THNN_(minIndexType)(THCState *state, int count, ...)
|
||||
va_start(list, count);
|
||||
|
||||
THCTensor* tens = va_arg(list, THCTensor*);
|
||||
int startDim = THCTensor__nDimension(state, tens);
|
||||
int startDim = THCTensor_nDimensionLegacyAll(state, tens);
|
||||
bool canCollapse = THCTensor_(isContiguous)(state,tens);
|
||||
|
||||
for (int arg=1; arg < count; ++arg){
|
||||
tens = va_arg(list, THCTensor*);
|
||||
canCollapse = canCollapse && THCTensor_(isContiguous)(state, tens);
|
||||
if(THCTensor__nDimension(state, tens) != startDim){
|
||||
if(THCTensor_nDimensionLegacyAll(state, tens) != startDim){
|
||||
va_end(list);
|
||||
return -1;
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ static inline void THNN_(Im2Col_shapeCheck)(
|
||||
THArgCheck(sW > 0 && sH > 0, 10,
|
||||
"stride should be greater than zero, but got sH: %d sW: %d", sH, sW);
|
||||
|
||||
int64_t ndim = THCTensor_(nDimension)(state, input);
|
||||
int64_t ndim = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
THCUNN_argCheck(state, !input->is_empty() && (ndim == 3 || ndim == 4), 2, input,
|
||||
"Expected non-empty 3D or 4D input tensor, but got input of shape %s");
|
||||
|
||||
|
@ -6,8 +6,8 @@ static bool THNN_(checkKeysValues)(THCState *state, THCudaLongTensor* keys,
|
||||
THCTensor* values)
|
||||
{
|
||||
return THCudaLongTensor_size(state, keys, 0) == THCTensor_(nElement)(state, values)
|
||||
&& THCTensor_(_nDimension)(state, values) == 1
|
||||
&& THCudaLongTensor__nDimension(state, keys) == 1;
|
||||
&& THCTensor_(nDimensionLegacyAll)(state, values) == 1
|
||||
&& THCudaLongTensor_nDimensionLegacyAll(state, keys) == 1;
|
||||
}
|
||||
|
||||
void THNN_(IndexLinear_updateOutput)(
|
||||
|
@ -22,8 +22,8 @@ void THNN_(LookupTable_accGradParameters)(
|
||||
THError("Tensors must be contiguous");
|
||||
}
|
||||
|
||||
int nDim = THCIndexTensor_(_nDimension)(state, input);
|
||||
if (THCIndexTensor_(_nDimension)(state, input) != 1 && THCIndexTensor_(_nDimension)(state, input) != 2) {
|
||||
int nDim = THCIndexTensor_(nDimensionLegacyAll)(state, input);
|
||||
if (THCIndexTensor_(nDimensionLegacyAll)(state, input) != 1 && THCIndexTensor_(nDimensionLegacyAll)(state, input) != 2) {
|
||||
THCDescBuff s1 = THCIndexTensor_(sizeDesc)(state, input);
|
||||
THError("input must be a vector or matrix, but is of shape: %s", s1.str);
|
||||
}
|
||||
@ -170,7 +170,7 @@ void THNN_(LookupTable_renorm)(
|
||||
THError("Tensors must be contiguous");
|
||||
}
|
||||
|
||||
if (THCIndexTensor_(_nDimension)(state, idx) != 1) {
|
||||
if (THCIndexTensor_(nDimensionLegacyAll)(state, idx) != 1) {
|
||||
THError("idx must be a vector");
|
||||
}
|
||||
|
||||
|
@ -88,8 +88,8 @@ void THNN_(LookupTableBag_accGradParameters)(
|
||||
bag_size_data = THCIndexTensor_(data)(state, bag_size);
|
||||
}
|
||||
|
||||
int nDim = THCIndexTensor_(_nDimension)(state, input);
|
||||
if (THCIndexTensor_(_nDimension)(state, input) != 1 && THCIndexTensor_(_nDimension)(state, input) != 2) {
|
||||
int nDim = THCIndexTensor_(nDimensionLegacyAll)(state, input);
|
||||
if (THCIndexTensor_(nDimensionLegacyAll)(state, input) != 1 && THCIndexTensor_(nDimensionLegacyAll)(state, input) != 2) {
|
||||
THCDescBuff s1 = THCIndexTensor_(sizeDesc)(state, input);
|
||||
THError("input must be a vector or matrix, but is of shape: %s", s1.str);
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ void THNN_(PReLU_updateOutput)(
|
||||
}
|
||||
else
|
||||
{
|
||||
int ndim = THCTensor_(_nDimension)(state, input);
|
||||
int ndim = THCTensor_(nDimensionLegacyAll)(state, input);
|
||||
input = THCTensor_(newContiguous)(state, input);
|
||||
|
||||
int n = THCTensor_(nElement)(state, input);
|
||||
@ -64,7 +64,7 @@ void THNN_(PReLU_updateGradInput)(
|
||||
}
|
||||
else
|
||||
{
|
||||
int ndim = THCTensor_(_nDimension)(state, input);
|
||||
int ndim = THCTensor_(nDimensionLegacyAll)(state, input);
|
||||
input = THCTensor_(newContiguous)(state, input);
|
||||
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
|
||||
|
||||
@ -119,7 +119,7 @@ void THNN_(PReLU_accGradParameters)(
|
||||
}
|
||||
else
|
||||
{
|
||||
int ndim = THCTensor_(_nDimension)(state, input);
|
||||
int ndim = THCTensor_(nDimensionLegacyAll)(state, input);
|
||||
|
||||
if (ndim == 1)
|
||||
{
|
||||
|
@ -4,17 +4,17 @@
|
||||
|
||||
static bool THNN_(checkInput)(THCTensor* t)
|
||||
{
|
||||
return !t->is_empty() && t->_dim() == 2 && t->size(1) == 3;
|
||||
return !t->is_empty() && THTensor_nDimensionLegacyAll(t) == 2 && t->size(1) == 3;
|
||||
}
|
||||
|
||||
static bool THNN_(checkSize2D)(THCTensor* t, int64_t size0, int64_t size1)
|
||||
{
|
||||
return !t->is_empty() && t->_dim() == 2 && t->size(0) == size0 && t->size(1) == size1;
|
||||
return !t->is_empty() && THTensor_nDimensionLegacyAll(t) == 2 && t->size(0) == size0 && t->size(1) == size1;
|
||||
}
|
||||
|
||||
static bool THNN_(checkSize1D)(THCTensor* t, int64_t size0)
|
||||
{
|
||||
return !t->is_empty() && t->_dim() == 1 && t->size(0) == size0;
|
||||
return !t->is_empty() && THTensor_nDimensionLegacyAll(t) == 1 && t->size(0) == size0;
|
||||
}
|
||||
|
||||
static inline void THNN_(copyCudaFloatingType)(THCState *state, THCudaIntTensor *buf, THCTensor *t) {
|
||||
@ -41,7 +41,7 @@ void THNN_(SparseLinear_updateOutput)(
|
||||
int64_t inDim = THCTensor_(size)(state, weight, 1);
|
||||
|
||||
THArgCheck(THNN_(checkInput)(input), 2, "input size must be nnz x 3");
|
||||
AT_CHECK(!output->is_empty() && THCTensor_(nDimension)(state, output) == 2,
|
||||
AT_CHECK(!output->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, output) == 2,
|
||||
"output must be batchsize x outputsize, got size: ", output->sizes());
|
||||
THArgCheck(THNN_(checkSize1D)(bias, outDim), 5, "bias size wrong");
|
||||
|
||||
|
@ -8,10 +8,10 @@ void THNN_(SpatialClassNLLCriterion_shapeCheck)(
|
||||
THCIndexTensor *target,
|
||||
THCTensor *weights)
|
||||
{
|
||||
AT_CHECK(!target->is_empty() && THCIndexTensor_(nDimension)(state, target) == 3, 1,
|
||||
AT_CHECK(!target->is_empty() && THCIndexTensor_(nDimensionLegacyNoScalars)(state, target) == 3, 1,
|
||||
"only batches of spatial targets supported (non-empty 3D tensors)" \
|
||||
" but got targets of size: : ", target->sizes());
|
||||
AT_CHECK(!input->is_empty() && THCTensor_(nDimension)(state, input) == 4, 2,
|
||||
AT_CHECK(!input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4, 2,
|
||||
"only batches of spatial inputs supported (non-empty 4D tensors), " \
|
||||
"but got input of size: ", input->sizes());
|
||||
if (THCTensor_(size)(state, input, 0) != THCIndexTensor_(size)(state, target, 0) ||
|
||||
@ -33,7 +33,7 @@ static void THNN_(SpatialClassNLLCriterion_gradOutput_no_reduce_shapeCheck)(
|
||||
THCTensor *gradOutput,
|
||||
THCIndexTensor *target)
|
||||
{
|
||||
AT_CHECK(!gradOutput->is_empty() && THCTensor_(nDimension)(state, gradOutput) == 3, 2,
|
||||
AT_CHECK(!gradOutput->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, gradOutput) == 3, 2,
|
||||
"Expected non-empty dimension 3 but got gradOutput of size: ", gradOutput->sizes());
|
||||
if (THCTensor_(size)(state, gradOutput, 0) != THCIndexTensor_(size)(state, target, 0) ||
|
||||
THCTensor_(size)(state, gradOutput, 1) != THCIndexTensor_(size)(state, target, 1) ||
|
||||
|
@ -16,8 +16,8 @@ void THNN_(SpatialDepthwiseConvolution_updateOutput)(
|
||||
THCUNN_assertSameGPU(state, 3, input, output, weight);
|
||||
|
||||
// Only handle 4D Input Tensors for now
|
||||
THAssert(!input->is_empty() && THCTensor_(nDimension)(state, input) == 4);
|
||||
THAssert(!weight->is_empty() && THCTensor_(nDimension)(state, weight) == 4);
|
||||
THAssert(!input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4);
|
||||
THAssert(!weight->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, weight) == 4);
|
||||
|
||||
// We assume that the input and weight Tensors are shaped properly by
|
||||
// the caller, so we verify that here to some extent
|
||||
@ -107,9 +107,9 @@ void THNN_(SpatialDepthwiseConvolution_updateGradInput)(
|
||||
THCUNN_assertSameGPU(state, 3, gradOutput, gradInput, weight);
|
||||
|
||||
// Only handle 4D Input Tensors for now
|
||||
THAssert(!input->is_empty() && THCTensor_(nDimension)(state, input) == 4);
|
||||
THAssert(!weight->is_empty() && THCTensor_(nDimension)(state, weight) == 4);
|
||||
THAssert(!gradOutput->is_empty() && THCTensor_(nDimension)(state, gradOutput) == 4);
|
||||
THAssert(!input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4);
|
||||
THAssert(!weight->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, weight) == 4);
|
||||
THAssert(!gradOutput->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, gradOutput) == 4);
|
||||
|
||||
// Minimal shape checking, as above
|
||||
// Same # of elements in batch
|
||||
@ -204,9 +204,9 @@ void THNN_(SpatialDepthwiseConvolution_accGradParameters)(
|
||||
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradWeight);
|
||||
|
||||
// Only handle 4D Input Tensors for now
|
||||
THAssert(!input->is_empty() && THCTensor_(nDimension)(state, input) == 4);
|
||||
THAssert(!gradOutput->is_empty() && THCTensor_(nDimension)(state, gradOutput) == 4);
|
||||
THAssert(!gradWeight->is_empty() && THCTensor_(nDimension)(state, gradWeight) == 4);
|
||||
THAssert(!input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4);
|
||||
THAssert(!gradOutput->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, gradOutput) == 4);
|
||||
THAssert(!gradWeight->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, gradWeight) == 4);
|
||||
|
||||
// Minimal shape checking as above
|
||||
// Same # of elements in batch
|
||||
|
@ -180,7 +180,7 @@ void THNN_(SpatialDilatedMaxPooling_updateGradInput)(
|
||||
int64_t nInputCols, nInputRows, nInputPlane, batchSize;
|
||||
int64_t nOutputCols, nOutputRows;
|
||||
|
||||
if (input->_dim() == 3) {
|
||||
if (THTensor_nDimensionLegacyAll(input) == 3) {
|
||||
nInputCols = input->size(2);
|
||||
nInputRows = input->size(1);
|
||||
nInputPlane = input->size(0);
|
||||
|
@ -16,7 +16,7 @@ void THNN_(SpatialFractionalMaxPooling_updateOutput)(
|
||||
int dimw = 2;
|
||||
int64_t numBatch = 1;
|
||||
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 3 || numInputDims == 4), 2, input,
|
||||
"non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s");
|
||||
|
||||
@ -106,7 +106,7 @@ void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
|
||||
int dimh = 1;
|
||||
int dimw = 2;
|
||||
|
||||
int64_t numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int64_t numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
if (numInputDims == 4) {
|
||||
dimh++;
|
||||
dimw++;
|
||||
|
@ -7,9 +7,9 @@ static inline void THNN_(SpatialGridSamplerBilinear_shapeCheck)(
|
||||
THCTensor *input,
|
||||
THCTensor *grid,
|
||||
THCTensor *gradOutput) {
|
||||
THCUNN_argCheck(state, !input->is_empty() && THCTensor_(nDimension)(state, input) == 4, 2, input,
|
||||
THCUNN_argCheck(state, !input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4, 2, input,
|
||||
"non-empty 4D input tensor expected but got: %s");
|
||||
THCUNN_argCheck(state, !grid->is_empty() && THCTensor_(nDimension)(state, grid) == 4, 2, grid,
|
||||
THCUNN_argCheck(state, !grid->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, grid) == 4, 2, grid,
|
||||
"4D grid tensor expected but got: %s");
|
||||
|
||||
int64_t nbatch = THCTensor_(size)(state, input, 0);
|
||||
|
@ -15,7 +15,7 @@ void THNN_(SpatialReflectionPadding_updateOutput)(THCState *state,
|
||||
int dimw = 2;
|
||||
int numBatch = 1;
|
||||
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 3 || numInputDims == 4), 2, input,
|
||||
"non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s")
|
||||
|
||||
@ -91,7 +91,7 @@ void THNN_(SpatialReflectionPadding_updateGradInput)(
|
||||
int dimh = 1;
|
||||
int dimw = 2;
|
||||
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
if (numInputDims == 4) {
|
||||
planeDim++;
|
||||
dimh++;
|
||||
|
@ -16,7 +16,7 @@ void THNN_(SpatialReplicationPadding_updateOutput)(
|
||||
int dimw = 2;
|
||||
int numBatch = 1;
|
||||
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 3 || numInputDims == 4), 2, input,
|
||||
"non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s")
|
||||
|
||||
@ -81,7 +81,7 @@ void THNN_(SpatialReplicationPadding_updateGradInput)(
|
||||
int dimh = 1;
|
||||
int dimw = 2;
|
||||
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
if (numInputDims == 4) {
|
||||
planeDim++;
|
||||
dimh++;
|
||||
|
@ -16,7 +16,7 @@ static inline void THNN_(SpatialUpSamplingNearest_shapeCheck)
|
||||
" but got input (H: %d, W: %d) output (H: %d, W: %d)",
|
||||
inputHeight, inputWidth, outputHeight, outputWidth);
|
||||
if (input != NULL) {
|
||||
THCUNN_argCheck(state, input->_dim() == 4, 2, input,
|
||||
THCUNN_argCheck(state, THTensor_nDimensionLegacyAll(input) == 4, 2, input,
|
||||
"4D input tensor expected but got: %s");
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,7 @@ void THNN_(TemporalReflectionPadding_updateOutput)(THCState *state,
|
||||
int dimw = 1;
|
||||
int numBatch = 1;
|
||||
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 2 || numInputDims == 3), 2, input,
|
||||
"non-empty 2D or 3D (batch mode) tensor expected for input, but got: %s")
|
||||
|
||||
@ -79,7 +79,7 @@ void THNN_(TemporalReflectionPadding_updateGradInput)(
|
||||
int planeDim = 0;
|
||||
int dimw = 1;
|
||||
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
if (numInputDims == 3) {
|
||||
planeDim++;
|
||||
dimw++;
|
||||
|
@ -14,7 +14,7 @@ void THNN_(TemporalReplicationPadding_updateOutput)(
|
||||
int dimw = 1;
|
||||
int numBatch = 1;
|
||||
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 2 || numInputDims == 3), 2, input,
|
||||
"2D or 3D (batch mode) tensor expected for input, but got: %s")
|
||||
|
||||
@ -74,7 +74,7 @@ void THNN_(TemporalReplicationPadding_updateGradInput)(
|
||||
int planeDim = 0;
|
||||
int dimw = 1;
|
||||
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
if (numInputDims == 3) {
|
||||
planeDim++;
|
||||
dimw++;
|
||||
|
@ -15,7 +15,7 @@ static inline void THNN_(TemporalUpSamplingNearest_shapeCheck)
|
||||
" but got input (W: %d) output (W: %d)",
|
||||
inputWidth, outputWidth);
|
||||
if (input != NULL) {
|
||||
THCUNN_argCheck(state, input->_dim() == 3, 2, input,
|
||||
THCUNN_argCheck(state, THTensor_nDimensionLegacyAll(input) == 3, 2, input,
|
||||
"3D input tensor expected but got: %s");
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,7 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)(
|
||||
dimw++;
|
||||
}
|
||||
|
||||
if (!input->is_empty() && THCTensor_(nDimension)(state, input) == 4)
|
||||
if (!input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4)
|
||||
{
|
||||
THArgCheck(input->size(dimw) >= kW && input->size(dimh) >= kH
|
||||
&& input->size(dimt) >= kT, 2,
|
||||
@ -45,7 +45,7 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)(
|
||||
inputHeight = THCTensor_(size)(state, input, 2);
|
||||
inputWidth = THCTensor_(size)(state, input, 3);
|
||||
}
|
||||
else if (!input->is_empty() && THCTensor_(nDimension)(state, input) == 5)
|
||||
else if (!input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5)
|
||||
{
|
||||
THArgCheck(input->size(dimw) >= kW && input->size(dimh) >= kH
|
||||
&& input->size(dimt) >= kT, 2,
|
||||
@ -128,7 +128,7 @@ void THNN_(VolumetricAveragePooling_updateOutput)(
|
||||
int dimh = 2;
|
||||
int dimw = 3;
|
||||
|
||||
int fiveDimensionalInput = THCTensor_(nDimension)(state, input) == 5;
|
||||
int fiveDimensionalInput = THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5;
|
||||
if (fiveDimensionalInput)
|
||||
{
|
||||
dimt++;
|
||||
@ -284,7 +284,7 @@ void THNN_(VolumetricAveragePooling_updateGradInput)(
|
||||
int outputHeight;
|
||||
int outputWidth;
|
||||
|
||||
int fiveDimensionalInput = THCTensor_(nDimension)(state, input) == 5;
|
||||
int fiveDimensionalInput = THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5;
|
||||
if (!fiveDimensionalInput) /* 4D */
|
||||
{
|
||||
batchSize = 1;
|
||||
|
@ -51,7 +51,7 @@ static inline void THNN_(VolumetricDilatedMaxPooling_shapeCheck)(
|
||||
dimw++;
|
||||
}
|
||||
|
||||
if (THCTensor_(nDimension)(state, input) == 4)
|
||||
if (THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4)
|
||||
{
|
||||
/* sizes */
|
||||
inputSlices = THCTensor_(size)(state, input, 0);
|
||||
@ -59,7 +59,7 @@ static inline void THNN_(VolumetricDilatedMaxPooling_shapeCheck)(
|
||||
inputHeight = THCTensor_(size)(state, input, 2);
|
||||
inputWidth = THCTensor_(size)(state, input, 3);
|
||||
}
|
||||
else if (THCTensor_(nDimension)(state, input) == 5)
|
||||
else if (THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5)
|
||||
{
|
||||
/* sizes */
|
||||
inputSlices = THCTensor_(size)(state, input, 1);
|
||||
@ -142,7 +142,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
|
||||
int dimh = 2;
|
||||
int dimw = 3;
|
||||
|
||||
int fiveDimensionalInput = THCTensor_(nDimension)(state, input) == 5;
|
||||
int fiveDimensionalInput = THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5;
|
||||
|
||||
if (fiveDimensionalInput)
|
||||
{
|
||||
@ -157,7 +157,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
|
||||
dT, dW, dH, padT, padW, padH,
|
||||
dilationT, dilationW, dilationH, ceilMode);
|
||||
|
||||
if (THCTensor_(nDimension)(state, input) == 4)
|
||||
if (THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4)
|
||||
{
|
||||
/* sizes */
|
||||
batchSize = 1;
|
||||
@ -316,7 +316,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
|
||||
int outputTime, outputHeight, outputWidth;
|
||||
int inputTime, inputHeight, inputWidth;
|
||||
|
||||
int fiveDimensionalInput = THCTensor_(nDimension)(state, input) == 5;
|
||||
int fiveDimensionalInput = THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5;
|
||||
|
||||
THCUNN_assertSameGPU(state, 4, input, indices, gradOutput, gradInput);
|
||||
THNN_(VolumetricDilatedMaxPooling_shapeCheck)(
|
||||
|
@ -17,7 +17,7 @@ void THNN_(VolumetricFractionalMaxPooling_updateOutput)(
|
||||
int dimt = 3;
|
||||
int64_t numBatch = 1;
|
||||
|
||||
int64_t numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int64_t numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 4 || numInputDims == 5), 2, input,
|
||||
"non-empty 4D or 5D (batch mode) tensor expected for input, but got: %s");
|
||||
|
||||
@ -113,7 +113,7 @@ void THNN_(VolumetricFractionalMaxPooling_updateGradInput)(
|
||||
int dimw = 2;
|
||||
int dimt = 3;
|
||||
|
||||
int64_t numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int64_t numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
if (numInputDims == 5) {
|
||||
dimh++;
|
||||
dimw++;
|
||||
|
@ -7,9 +7,9 @@ static inline void THNN_(VolumetricGridSamplerBilinear_shapeCheck)(
|
||||
THCTensor *input,
|
||||
THCTensor *grid,
|
||||
THCTensor *gradOutput) {
|
||||
THCUNN_argCheck(state, !input->is_empty() && THCTensor_(nDimension)(state, input) == 5, 2, input,
|
||||
THCUNN_argCheck(state, !input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5, 2, input,
|
||||
"non-empty 5D input tensor expected but got: %s");
|
||||
THCUNN_argCheck(state, !grid->is_empty() && THCTensor_(nDimension)(state, grid) == 5, 2, grid,
|
||||
THCUNN_argCheck(state, !grid->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, grid) == 5, 2, grid,
|
||||
"non-empty 5D grid tensor expected but got: %s");
|
||||
|
||||
int64_t nbatch = THCTensor_(size)(state, input, 0);
|
||||
|
@ -24,11 +24,11 @@ static inline void THNN_(VolumetricMaxUnpooling_shapeCheck)(
|
||||
"stride should be greater than zero, but got dT: %d dH: %d dW: %d",
|
||||
dT, dH, dW);
|
||||
|
||||
if (THCTensor_(nDimension)(state, input) == 4)
|
||||
if (THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4)
|
||||
{
|
||||
inputSlices = THCTensor_(size)(state, input, 0);
|
||||
}
|
||||
else if (THCTensor_(nDimension)(state, input) == 5)
|
||||
else if (THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5)
|
||||
{
|
||||
inputSlices = THCTensor_(size)(state, input, 1);
|
||||
}
|
||||
@ -83,8 +83,8 @@ void THNN_(VolumetricMaxUnpooling_updateOutput)(
|
||||
dT, dW, dH, padT, padW, padH);
|
||||
THCUNN_assertSameGPU(state, 3, input, indices, output);
|
||||
|
||||
int fiveDimensionalInput = THCTensor_(nDimension)(state, input) == 5;
|
||||
if (THCTensor_(nDimension)(state, input) == 4)
|
||||
int fiveDimensionalInput = THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5;
|
||||
if (THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4)
|
||||
{
|
||||
/* sizes */
|
||||
batchSize = 1;
|
||||
@ -192,7 +192,7 @@ void THNN_(VolumetricMaxUnpooling_updateGradInput)(
|
||||
dT, dW, dH, padT, padW, padH);
|
||||
THCUNN_assertSameGPU(state, 4, input, indices, gradOutput, gradInput);
|
||||
|
||||
int fiveDimensionalInput = THCTensor_(nDimension)(state, input) == 5;
|
||||
int fiveDimensionalInput = THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5;
|
||||
if (!fiveDimensionalInput) /* 4D */
|
||||
{
|
||||
batchSize = 1;
|
||||
|
@ -11,7 +11,7 @@ static inline void THNN_(VolumetricReplicationPadding_shapeCheck)(
|
||||
int pfront, int pback) {
|
||||
THArgCheck(THCTensor_canUse32BitIndexMath(state, input), 2,
|
||||
"input tensor must fit into 32-bit index math");
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
|
||||
THCUNN_argCheck(state, !input->is_empty() && (numInputDims == 4 || numInputDims == 5), 2, input,
|
||||
"non-empty 4D or 5D (batch mode) tensor expected for input, but got: %s");
|
||||
@ -75,7 +75,7 @@ void THNN_(VolumetricReplicationPadding_updateOutput)(
|
||||
int dimw = 3;
|
||||
int numBatch = 1;
|
||||
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
|
||||
if (numInputDims == 5) {
|
||||
numBatch = THCTensor_(size)(state, input, 0);
|
||||
@ -137,7 +137,7 @@ void THNN_(VolumetricReplicationPadding_updateGradInput)(
|
||||
int dimh = 2;
|
||||
int dimw = 3;
|
||||
|
||||
int numInputDims = THCTensor_(nDimension)(state, input);
|
||||
int numInputDims = THCTensor_(nDimensionLegacyNoScalars)(state, input);
|
||||
if (numInputDims == 5) {
|
||||
planeDim++;
|
||||
dimd++;
|
||||
|
@ -16,7 +16,7 @@ static inline void THNN_(VolumetricUpSamplingNearest_shapeCheck)
|
||||
" but got input (D: %d, H: %d, W: %d) output (D: %d, H: %d, W: %d)",
|
||||
inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth);
|
||||
if (input != NULL) {
|
||||
THCUNN_argCheck(state, input->_dim() == 5, 2, input,
|
||||
THCUNN_argCheck(state, THTensor_nDimensionLegacyAll(input) == 5, 2, input,
|
||||
"5D input tensor expected but got: %s");
|
||||
}
|
||||
|
||||
|
@ -13,14 +13,14 @@ void THNN_(ClassNLLCriterion_updateOutput)(
|
||||
int64_t ignore_index)
|
||||
{
|
||||
THTensor_(resize1d)(total_weight, 1);
|
||||
int n_dims = THTensor_(_nDimension)(input);
|
||||
int n_dims = THTensor_(nDimensionLegacyAll)(input);
|
||||
int n_classes = THTensor_(size)(input, n_dims - 1);
|
||||
ignore_index -= TH_INDEX_BASE;
|
||||
|
||||
if (THIndexTensor_(_nDimension)(target) > 1) {
|
||||
if (THIndexTensor_(nDimensionLegacyAll)(target) > 1) {
|
||||
THError("multi-target not supported");
|
||||
}
|
||||
if (THTensor_(_nDimension)(input) > 2) {
|
||||
if (THTensor_(nDimensionLegacyAll)(input) > 2) {
|
||||
THError("input tensor should be 1D or 2D");
|
||||
}
|
||||
if (weights && THTensor_(nElement)(weights) != n_classes) {
|
||||
@ -73,14 +73,14 @@ void THNN_(ClassNLLCriterion_updateOutput)(
|
||||
|
||||
output_data[0] = total_weight_data[0] = 0.0;
|
||||
|
||||
if (THTensor_(_nDimension)(input) == 1) {
|
||||
if (THTensor_(nDimensionLegacyAll)(input) == 1) {
|
||||
int cur_target = target_data[0] - TH_INDEX_BASE;
|
||||
if (cur_target != ignore_index) {
|
||||
THAssert(cur_target >= 0 && cur_target < n_classes);
|
||||
total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f;
|
||||
output_data[0] = -input_data[cur_target] * total_weight_data[0];
|
||||
}
|
||||
} else if (THTensor_(_nDimension)(input) == 2) {
|
||||
} else if (THTensor_(nDimensionLegacyAll)(input) == 2) {
|
||||
int batch_size = THTensor_(size)(input, 0);
|
||||
THAssert(THIndexTensor_(size)(target, 0) == batch_size);
|
||||
|
||||
@ -124,7 +124,7 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
|
||||
THTensor_(resizeAs)(gradInput, input);
|
||||
THTensor_(zero)(gradInput);
|
||||
|
||||
int n_dims = THTensor_(_nDimension)(input);
|
||||
int n_dims = THTensor_(nDimensionLegacyAll)(input);
|
||||
int n_classes = THTensor_(size)(input, n_dims - 1);
|
||||
ignore_index -= TH_INDEX_BASE;
|
||||
|
||||
@ -132,11 +132,11 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
|
||||
THError("gradInput must be contiguous");
|
||||
}
|
||||
|
||||
if (THIndexTensor_(_nDimension)(target) > 1) {
|
||||
if (THIndexTensor_(nDimensionLegacyAll)(target) > 1) {
|
||||
THError("multi-target not supported");
|
||||
}
|
||||
|
||||
if (THTensor_(_nDimension)(input) > 2) {
|
||||
if (THTensor_(nDimensionLegacyAll)(input) > 2) {
|
||||
THError("input tensor should be 1D or 2D");
|
||||
}
|
||||
|
||||
@ -177,7 +177,7 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
|
||||
|
||||
real gradOutput_value = THTensor_(get1d)(gradOutput, 0);
|
||||
|
||||
if (THTensor_(_nDimension)(input) == 1) {
|
||||
if (THTensor_(nDimensionLegacyAll)(input) == 1) {
|
||||
int cur_target = target_data[0] - TH_INDEX_BASE;
|
||||
if (cur_target != ignore_index) {
|
||||
THAssert(cur_target >= 0 && cur_target < n_classes);
|
||||
@ -187,7 +187,7 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
|
||||
gradInput_data[cur_target] *= gradOutput_value;
|
||||
}
|
||||
|
||||
} else if (THTensor_(_nDimension)(input) == 2) {
|
||||
} else if (THTensor_(nDimensionLegacyAll)(input) == 2) {
|
||||
int batch_size = THTensor_(size)(input, 0);
|
||||
THAssert(THIndexTensor_(size)(target, 0) == batch_size);
|
||||
|
||||
|
@ -124,7 +124,7 @@ static inline void THNN_(Col2Im_shapeCheck)(
|
||||
THArgCheck(dW > 0 && dH > 0, 8,
|
||||
"dilation should be greater than zero, but got dH: %d dW: %d", dH, dW);
|
||||
|
||||
int64_t ndim = THTensor_(nDimension)(input);
|
||||
int64_t ndim = THTensor_(nDimensionLegacyNoScalars)(input);
|
||||
THNN_ARGCHECK(!input->is_empty() && (ndim == 2 || ndim == 3), 2, input,
|
||||
"Expected non-empty 2D or 3D input tensor, but got input of shape %s");
|
||||
|
||||
|
@ -39,7 +39,7 @@ static inline size_t flpOutputSize(FEATURE_LP_SIZE_TYPE inputSize,
|
||||
|
||||
FeatureLPPoolingSizes
|
||||
THNN_(FeatureLPPooling_upcastCPU)(THTensor* t, bool batchMode) {
|
||||
int dim = THTensor_(_nDimension)(t);
|
||||
int dim = THTensor_(nDimensionLegacyAll)(t);
|
||||
|
||||
// Upcast to [batch dim][feature dim][opt dim 1][opt dim 2]
|
||||
FeatureLPPoolingSizes s;
|
||||
@ -99,7 +99,7 @@ THNN_(FeatureLPPooling_resizeForOutputCPU)(THTensor* toResize,
|
||||
bool batchMode,
|
||||
int width,
|
||||
int stride) {
|
||||
int inputDim = THTensor_(_nDimension)(input);
|
||||
int inputDim = THTensor_(nDimensionLegacyAll)(input);
|
||||
THAssert(inputDim >= 1 && inputDim <= 4);
|
||||
|
||||
int64_t outSize =
|
||||
@ -147,7 +147,7 @@ THNN_(FeatureLPPooling_resizeForOutputCPU)(THTensor* toResize,
|
||||
void
|
||||
THNN_(FeatureLPPooling_resizeCPU)(THTensor* toResize,
|
||||
THTensor* src) {
|
||||
int inputDim = THTensor_(_nDimension)(src);
|
||||
int inputDim = THTensor_(nDimensionLegacyAll)(src);
|
||||
THAssert(inputDim >= 1 && inputDim <= 4);
|
||||
|
||||
if (inputDim == 1) {
|
||||
@ -183,7 +183,7 @@ THNN_(FeatureLPPooling_updateOutput)(
|
||||
int width,
|
||||
int stride,
|
||||
bool batchMode) {
|
||||
int inputDim = THTensor_(_nDimension)(input);
|
||||
int inputDim = THTensor_(nDimensionLegacyAll)(input);
|
||||
|
||||
if (batchMode) {
|
||||
THArgCheck(inputDim >= 2 && inputDim <= 4, 2,
|
||||
@ -261,7 +261,7 @@ THNN_(FeatureLPPooling_updateGradInput)(
|
||||
int width,
|
||||
int stride,
|
||||
bool batchMode) {
|
||||
int inputDim = THTensor_(_nDimension)(input);
|
||||
int inputDim = THTensor_(nDimensionLegacyAll)(input);
|
||||
|
||||
if (batchMode) {
|
||||
THArgCheck(inputDim >= 2 && inputDim <= 4, 3,
|
||||
|
@ -17,7 +17,7 @@ void THNN_(HardTanh_updateOutput)(
|
||||
else
|
||||
THTensor_(resizeAs)(output, input);
|
||||
|
||||
if (input->_dim() == 1 || !THTensor_(isContiguous)(input) || !THTensor_(isContiguous)(output))
|
||||
if (THTensor_nDimensionLegacyAll(input) == 1 || !THTensor_(isContiguous)(input) || !THTensor_(isContiguous)(output))
|
||||
{
|
||||
if (inplace)
|
||||
{
|
||||
@ -88,7 +88,7 @@ void THNN_(HardTanh_updateGradInput)(
|
||||
else
|
||||
THTensor_(resizeAs)(gradInput, input);
|
||||
|
||||
if (input->_dim() == 1 ||
|
||||
if (THTensor_nDimensionLegacyAll(input) == 1 ||
|
||||
!THTensor_(isContiguous)(input) ||
|
||||
!THTensor_(isContiguous)(gradOutput) ||
|
||||
!THTensor_(isContiguous)(gradInput))
|
||||
|
@ -16,7 +16,7 @@ static inline void THNN_(Im2Col_shapeCheck)(
|
||||
THArgCheck(sW > 0 && sH > 0, 10,
|
||||
"stride should be greater than zero, but got sH: %d sW: %d", sH, sW);
|
||||
|
||||
int64_t ndim = THTensor_(nDimension)(input);
|
||||
int64_t ndim = THTensor_(nDimensionLegacyNoScalars)(input);
|
||||
THNN_ARGCHECK(!input->is_empty() && (ndim == 3 || ndim == 4), 2, input,
|
||||
"Expected non-empty 3D or 4D input tensor, but got input of shape %s");
|
||||
|
||||
|
@ -24,8 +24,8 @@
|
||||
static bool THNN_(checkKeysValues)(THLongTensor* keys, THTensor* values)
|
||||
{
|
||||
return THLongTensor_size(keys, 0) == THTensor_(nElement)(values)
|
||||
&& THTensor_(_nDimension)(values) == 1
|
||||
&& THLongTensor__nDimension(keys) == 1;
|
||||
&& THTensor_(nDimensionLegacyAll)(values) == 1
|
||||
&& THLongTensor_nDimensionLegacyAll(keys) == 1;
|
||||
}
|
||||
|
||||
void THNN_(IndexLinear_updateOutput)(
|
||||
|
@ -23,7 +23,7 @@ void THNN_(Linear_updateOutput)(
|
||||
THTensor *bias,
|
||||
THTensor *addBuffer)
|
||||
{
|
||||
int64_t dim = THTensor_(_nDimension)(input);
|
||||
int64_t dim = THTensor_(nDimensionLegacyAll)(input);
|
||||
if (dim == 1) {
|
||||
THTensor_(resize1d)(output,THTensor_(size)(weight,0));
|
||||
if (bias) {
|
||||
@ -66,7 +66,7 @@ void THNN_(Linear_updateGradInput)(
|
||||
THTensor_(zero)(gradInput);
|
||||
}
|
||||
|
||||
int64_t dim = THTensor_(_nDimension)(input);
|
||||
int64_t dim = THTensor_(nDimensionLegacyAll)(input);
|
||||
if (dim == 1) {
|
||||
THTensor *tweight = THTensor_(new)();
|
||||
THTensor_(transpose)(tweight,weight,0,1);
|
||||
@ -92,7 +92,7 @@ void THNN_(Linear_accGradParameters)(
|
||||
accreal scale_)
|
||||
{
|
||||
real scale = TH_CONVERT_ACCREAL_TO_REAL(scale_);
|
||||
int64_t dim = THTensor_(_nDimension)(input);
|
||||
int64_t dim = THTensor_(nDimensionLegacyAll)(input);
|
||||
if (dim == 1) {
|
||||
THTensor_(addr)(gradWeight,1,gradWeight,scale,gradOutput,input);
|
||||
if (bias) {
|
||||
|
@ -48,7 +48,7 @@ void THNN_(LookupTable_accGradParameters)(
|
||||
THError("gradWeight must be contiguous");
|
||||
if (!THIndexTensor_(isContiguous)(input))
|
||||
THError("input must be contiguous");
|
||||
if (input->is_empty() || (THIndexTensor_(nDimension)(input) != 1 && THIndexTensor_(nDimension)(input) != 2)) {
|
||||
if (input->is_empty() || (THIndexTensor_(nDimensionLegacyNoScalars)(input) != 1 && THIndexTensor_(nDimensionLegacyNoScalars)(input) != 2)) {
|
||||
THDescBuff s1 = THIndexTensor_(sizeDesc)(input);
|
||||
THError("input must be a non-empty vector or matrix, but is of shape: %s", s1.str);
|
||||
}
|
||||
@ -173,7 +173,7 @@ void THNN_(LookupTable_renorm)(
|
||||
THError("weight must be contiguous");
|
||||
if (!THIndexTensor_(isContiguous)(idx))
|
||||
THError("input must be contiguous");
|
||||
if (idx->is_empty() || THIndexTensor_(nDimension)(idx) != 1)
|
||||
if (idx->is_empty() || THIndexTensor_(nDimensionLegacyNoScalars)(idx) != 1)
|
||||
THError("idx must be a non-empty vector");
|
||||
if (normType <= 0)
|
||||
THError("non-positive-norm not supported");
|
||||
|
@ -25,7 +25,7 @@ void THNN_(PReLU_updateOutput)(
|
||||
input = THTensor_(newContiguous)(input);
|
||||
int64_t bs = 1, ks = 1;
|
||||
{
|
||||
int64_t input_ndim = THTensor_(_nDimension)(input);
|
||||
int64_t input_ndim = THTensor_(nDimensionLegacyAll)(input);
|
||||
if (input->size(input_ndim > 1) != nOutputPlane)
|
||||
THError("Wrong number of input planes. Expected %d but got %d.", nOutputPlane, input->size(input_ndim > 1));
|
||||
|
||||
@ -90,7 +90,7 @@ void THNN_(PReLU_updateGradInput)(
|
||||
|
||||
int64_t bs = 1, ks = 1;
|
||||
{
|
||||
int64_t input_ndim = THTensor_(_nDimension)(input);
|
||||
int64_t input_ndim = THTensor_(nDimensionLegacyAll)(input);
|
||||
if (input->size(input_ndim > 1) != nOutputPlane)
|
||||
THError("Wrong number of input planes. Expected %d but got %d.", nOutputPlane, input->size(input_ndim > 1));
|
||||
|
||||
@ -161,7 +161,7 @@ void THNN_(PReLU_accGradParameters)(
|
||||
weight = THTensor_(newContiguous)(weight);
|
||||
int64_t bs = 1, ks = 1;
|
||||
{
|
||||
int64_t input_ndim = THTensor_(_nDimension)(input);
|
||||
int64_t input_ndim = THTensor_(nDimensionLegacyAll)(input);
|
||||
if (input->size(input_ndim > 1) != nOutputPlane)
|
||||
THError("Wrong number of input planes. Expected %d but got %d.", nOutputPlane, input->size(input_ndim > 1));
|
||||
|
||||
|
@ -3,13 +3,13 @@
|
||||
#else
|
||||
|
||||
#define INITIAL_CHECK \
|
||||
THArgCheck(THIndexTensor_(_nDimension)(target) == 3, 3, \
|
||||
THArgCheck(THIndexTensor_(nDimensionLegacyAll)(target) == 3, 3, \
|
||||
"only batches of spatial targets supported (3D tensors)" \
|
||||
" but got targets of dimension: %d", \
|
||||
THIndexTensor_(_nDimension)(target)); \
|
||||
THArgCheck(THTensor_(_nDimension)(input) == 4, 2, \
|
||||
THIndexTensor_(nDimensionLegacyAll)(target)); \
|
||||
THArgCheck(THTensor_(nDimensionLegacyAll)(input) == 4, 2, \
|
||||
"only batches of spatial inputs supported (4D tensors), " \
|
||||
"but got input of dimension: %d", THTensor_(_nDimension)(input)); \
|
||||
"but got input of dimension: %d", THTensor_(nDimensionLegacyAll)(input)); \
|
||||
if (weights && THTensor_(nElement)(weights) != THTensor_(size)(input, 1)) { \
|
||||
THError("weight tensor should be defined either for all or no classes"); \
|
||||
} \
|
||||
@ -28,10 +28,10 @@
|
||||
}
|
||||
|
||||
#define GRADOUTPUT_SHAPE_CHECK \
|
||||
THArgCheck(THTensor_(_nDimension)(gradOutput) == 3, 3, \
|
||||
THArgCheck(THTensor_(nDimensionLegacyAll)(gradOutput) == 3, 3, \
|
||||
"gradOutput must have same dimension as target (3)" \
|
||||
" but got dimension: %d", \
|
||||
THTensor_(_nDimension)(gradOutput)); \
|
||||
THTensor_(nDimensionLegacyAll)(gradOutput)); \
|
||||
{ \
|
||||
int64_t gradOutput0 = THTensor_(size)(gradOutput, 0); \
|
||||
int64_t gradOutput1 = THTensor_(size)(gradOutput, 1); \
|
||||
|
@ -102,7 +102,7 @@ void THNN_(SpatialFractionalMaxPooling_updateOutput)(
|
||||
int heightDim = 1;
|
||||
int widthDim = 2;
|
||||
|
||||
int64_t numInputDims = THTensor_(nDimension)(input);
|
||||
int64_t numInputDims = THTensor_(nDimensionLegacyNoScalars)(input);
|
||||
THNN_ARGCHECK(!input->is_empty() && (numInputDims == 3 || numInputDims == 4), 2, input,
|
||||
"non-empty 3D or 4D (batch mode) tensor expected for input, but got: %s");
|
||||
|
||||
@ -202,7 +202,7 @@ void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
|
||||
int heightDim = 1;
|
||||
int widthDim = 2;
|
||||
|
||||
int64_t numInputDims = THTensor_(nDimension)(input);
|
||||
int64_t numInputDims = THTensor_(nDimensionLegacyNoScalars)(input);
|
||||
if (numInputDims == 4) {
|
||||
numBatch = THTensor_(size)(input, 0);
|
||||
planeDim = 1;
|
||||
|
@ -15,7 +15,7 @@ static inline void THNN_(SpatialUpSamplingNearest_shapeCheck)
|
||||
" but got input (H: %d, W: %d) output (H: %d, W: %d)",
|
||||
inputHeight, inputWidth, outputHeight, outputWidth);
|
||||
if (input != NULL) {
|
||||
THNN_ARGCHECK(input->_dim() == 4, 2, input,
|
||||
THNN_ARGCHECK(THTensor_nDimensionLegacyAll(input) == 4, 2, input,
|
||||
"4D input tensor expected but got: %s");
|
||||
}
|
||||
|
||||
|
@ -22,7 +22,7 @@ void THNN_(Sqrt_updateGradInput)(
|
||||
THNN_CHECK_SHAPE(output, gradOutput);
|
||||
THTensor_(resizeAs)(gradInput, input);
|
||||
|
||||
if (output->_dim() == 1 ||
|
||||
if (THTensor_nDimensionLegacyAll(output) == 1 ||
|
||||
!THTensor_(isContiguous)(output) ||
|
||||
!THTensor_(isContiguous)(gradOutput) ||
|
||||
!THTensor_(isContiguous)(gradInput))
|
||||
|
@ -9,7 +9,7 @@ void THNN_(Square_updateOutput)(
|
||||
{
|
||||
THTensor_(resizeAs)(output, input);
|
||||
|
||||
if (input->_dim() == 1 || !THTensor_(isContiguous)(input) || !THTensor_(isContiguous)(output))
|
||||
if (THTensor_nDimensionLegacyAll(input) == 1 || !THTensor_(isContiguous)(input) || !THTensor_(isContiguous)(output))
|
||||
{
|
||||
TH_TENSOR_APPLY2(real, output, real, input,
|
||||
*output_data = (*input_data) * (*input_data);
|
||||
@ -35,7 +35,7 @@ void THNN_(Square_updateGradInput)(
|
||||
THNN_CHECK_SHAPE(input, gradOutput);
|
||||
THTensor_(resizeAs)(gradInput, input);
|
||||
|
||||
if (input->_dim() == 1 ||
|
||||
if (THTensor_nDimensionLegacyAll(input) == 1 ||
|
||||
!THTensor_(isContiguous)(input) ||
|
||||
!THTensor_(isContiguous)(gradOutput) ||
|
||||
!THTensor_(isContiguous)(gradInput))
|
||||
|
@ -19,7 +19,7 @@ void THNN_(Tanh_updateGradInput)(
|
||||
THNN_CHECK_SHAPE(output, gradOutput);
|
||||
THTensor_(resizeAs)(gradInput, output);
|
||||
|
||||
if (output->_dim() == 1 ||
|
||||
if (THTensor_nDimensionLegacyAll(output) == 1 ||
|
||||
!THTensor_(isContiguous)(output) ||
|
||||
!THTensor_(isContiguous)(gradOutput) ||
|
||||
!THTensor_(isContiguous)(gradInput))
|
||||
|
@ -13,7 +13,7 @@ static inline void THNN_(TemporalUpSamplingNearest_shapeCheck)
|
||||
" but got input (W: %d) output (W: %d)",
|
||||
inputWidth, outputWidth);
|
||||
if (input != NULL) {
|
||||
THNN_ARGCHECK(input->_dim() == 3, 2, input,
|
||||
THNN_ARGCHECK(THTensor_nDimensionLegacyAll(input) == 3, 2, input,
|
||||
"3D input tensor expected but got: %s");
|
||||
}
|
||||
|
||||
|
@ -114,7 +114,7 @@ void THNN_(VolumetricFractionalMaxPooling_updateOutput)(
|
||||
int widthDim = 2;
|
||||
int timeDim = 3;
|
||||
|
||||
int64_t numInputDims = THTensor_(nDimension)(input);
|
||||
int64_t numInputDims = THTensor_(nDimensionLegacyNoScalars)(input);
|
||||
THNN_ARGCHECK(!input->is_empty() && (numInputDims == 4 || numInputDims == 5), 2, input,
|
||||
"non-empty 4D or 5D (batch mode) tensor expected for input, but got: %s");
|
||||
|
||||
@ -224,7 +224,7 @@ void THNN_(VolumetricFractionalMaxPooling_updateGradInput)(
|
||||
int widthDim = 2;
|
||||
int timeDim = 3;
|
||||
|
||||
int64_t numInputDims = THTensor_(nDimension)(input);
|
||||
int64_t numInputDims = THTensor_(nDimensionLegacyNoScalars)(input);
|
||||
if (numInputDims == 5) {
|
||||
numBatch = THTensor_(size)(input, 0);
|
||||
planeDim = 1;
|
||||
|
@ -15,7 +15,7 @@ static inline void THNN_(VolumetricUpSamplingNearest_shapeCheck)
|
||||
" but got input (D: %d, H: %d, W: %d) output (D: %d, H: %d, W: %d)",
|
||||
inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth);
|
||||
if (input != NULL) {
|
||||
THNN_ARGCHECK(input->_dim() == 5, 2, input,
|
||||
THNN_ARGCHECK(THTensor_nDimensionLegacyAll(input) == 5, 2, input,
|
||||
"5D input tensor expected but got: %s");
|
||||
}
|
||||
|
||||
|
@ -44,7 +44,7 @@
|
||||
}
|
||||
|
||||
#define THNN_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
|
||||
if (THTensor_(nDimension)(T) != DIM || \
|
||||
if (THTensor_(nDimensionLegacyNoScalars)(T) != DIM || \
|
||||
THTensor_(size)(T, DIM_SIZE) != SIZE) { \
|
||||
THDescBuff s1 = THTensor_(sizeDesc)(T); \
|
||||
THError("Need " #T " of dimension %d and " #T ".size[%d] == %d" \
|
||||
@ -52,7 +52,7 @@
|
||||
}
|
||||
|
||||
#define THNN_CHECK_DIM_SIZE_INDICES(T, DIM, DIM_SIZE, SIZE) \
|
||||
if (THIndexTensor_(nDimension)(T) != DIM || \
|
||||
if (THIndexTensor_(nDimensionLegacyNoScalars)(T) != DIM || \
|
||||
THIndexTensor_(size)(T, DIM_SIZE) != SIZE) { \
|
||||
THDescBuff s1 = THIndexTensor_(sizeDesc)(T); \
|
||||
THError("Need " #T " of dimension %d and " #T ".size[%d] == %d" \
|
||||
|
@ -5,9 +5,9 @@ from string import Template
|
||||
class AssertNDim(CWrapPlugin):
|
||||
|
||||
PRE_CODE_TEMPLATE = Template(
|
||||
"""if(THTensor_(nDimension)(LIBRARY_STATE ${arg_op}) != ${dim_value}) {
|
||||
"""if(THTensor_(nDimensionLegacyNoScalars)(LIBRARY_STATE ${arg_op}) != ${dim_value}) {
|
||||
THError("Expected argument %s to have %d dimension(s), but has %d",
|
||||
"${op}", ${dim_value}, THTensor_(nDimension)(LIBRARY_STATE ${arg_op}));
|
||||
"${op}", ${dim_value}, THTensor_(nDimensionLegacyNoScalars)(LIBRARY_STATE ${arg_op}));
|
||||
}
|
||||
""")
|
||||
|
||||
|
@ -37,19 +37,19 @@ ptrdiff_t THDTensor_(storageOffset)(const THDTensor *self) {
|
||||
return self->storageOffset;
|
||||
}
|
||||
|
||||
int THDTensor_(nDimension)(const THDTensor *self) {
|
||||
int THDTensor_(nDimensionLegacyNoScalars)(const THDTensor *self) {
|
||||
return self->nDimension;
|
||||
}
|
||||
|
||||
int64_t THDTensor_(size)(const THDTensor *self, int dim) {
|
||||
THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "dimension %d out of range of %dD tensor",
|
||||
dim+1, THDTensor_(nDimension)(self));
|
||||
dim+1, THDTensor_(nDimensionLegacyNoScalars)(self));
|
||||
return self->size[dim];
|
||||
}
|
||||
|
||||
int64_t THDTensor_(stride)(const THDTensor *self, int dim) {
|
||||
THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "dimension %d out of range of %dD tensor", dim+1,
|
||||
THDTensor_(nDimension)(self));
|
||||
THDTensor_(nDimensionLegacyNoScalars)(self));
|
||||
return self->stride[dim];
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,7 @@ THD_API THDDescBuff THDTensor_(sizeDesc)(const THDTensor *tensor);
|
||||
/**** access methods ****/
|
||||
THD_API THDStorage* THDTensor_(storage)(const THDTensor *self);
|
||||
THD_API ptrdiff_t THDTensor_(storageOffset)(const THDTensor *self);
|
||||
THD_API int THDTensor_(nDimension)(const THDTensor *self);
|
||||
THD_API int THDTensor_(nDimensionLegacyNoScalars)(const THDTensor *self);
|
||||
THD_API int64_t THDTensor_(size)(const THDTensor *self, int dim);
|
||||
THD_API int64_t THDTensor_(stride)(const THDTensor *self, int dim);
|
||||
THD_API THLongStorage *THDTensor_(newSizeOf)(THDTensor *self);
|
||||
|
@ -8,7 +8,7 @@ using namespace master;
|
||||
|
||||
void THDTensor_(gather)(THDTensor *self, THDTensor *src, int dim, THDLongTensor *index) {
|
||||
THArgCheck(dim < self->nDimension, 2, "Index dimension is out of bounds");
|
||||
THArgCheck(THDLongTensor_nDimension(index) == self->nDimension, 3,
|
||||
THArgCheck(THDLongTensor_nDimensionLegacyNoScalars(index) == self->nDimension, 3,
|
||||
"Index tensor must have same dimensions as output tensor");
|
||||
THArgCheck(src->nDimension == self->nDimension, 4,
|
||||
"Input tensor must have same dimensions as output tensor");
|
||||
@ -27,7 +27,7 @@ void THDTensor_(gather)(THDTensor *self, THDTensor *src, int dim, THDLongTensor
|
||||
|
||||
void THDTensor_(scatter)(THDTensor *self, int dim, THDLongTensor *index, THDTensor *src) {
|
||||
THArgCheck(dim < self->nDimension, 2, "Index dimension is out of bounds");
|
||||
THArgCheck(THDLongTensor_nDimension(index) == self->nDimension, 3,
|
||||
THArgCheck(THDLongTensor_nDimensionLegacyNoScalars(index) == self->nDimension, 3,
|
||||
"Index tensor must have same dimensions as output tensor");
|
||||
THArgCheck(src->nDimension == self->nDimension, 4,
|
||||
"Input tensor must have same dimensions as output tensor");
|
||||
@ -46,7 +46,7 @@ void THDTensor_(scatter)(THDTensor *self, int dim, THDLongTensor *index, THDTens
|
||||
|
||||
void THDTensor_(scatterFill)(THDTensor *self, int dim, THDLongTensor *index, real val) {
|
||||
THArgCheck(dim < self->nDimension, 2, "Index dimension is out of bounds");
|
||||
THArgCheck(THDLongTensor_nDimension(index) == self->nDimension, 3,
|
||||
THArgCheck(THDLongTensor_nDimensionLegacyNoScalars(index) == self->nDimension, 3,
|
||||
"Index tensor must have same dimensions as output tensor");
|
||||
|
||||
masterCommandChannel->sendMessage(
|
||||
@ -315,10 +315,10 @@ ptrdiff_t THDTensor_(numel)(THDTensor *t) {
|
||||
}
|
||||
|
||||
void THDTensor_(diag)(THDTensor *r_, THDTensor *t, int k) {
|
||||
THArgCheck(THDTensor_(nDimension)(t) == 1 || THDTensor_(nDimension)(t) == 2,
|
||||
THArgCheck(THDTensor_(nDimensionLegacyNoScalars)(t) == 1 || THDTensor_(nDimensionLegacyNoScalars)(t) == 2,
|
||||
1, "matrix or a vector expected");
|
||||
|
||||
if (THDTensor_(nDimension)(t) == 1) {
|
||||
if (THDTensor_(nDimensionLegacyNoScalars)(t) == 1) {
|
||||
int64_t t_size = THDTensor_(size)(t, 0);
|
||||
int64_t sz = t_size + (k >= 0 ? k : -k);
|
||||
|
||||
@ -388,7 +388,7 @@ void THDTensor_(reshape)(THDTensor *r_, THDTensor *t, THLongStorage *size) {
|
||||
void THDTensor_(sort)(THDTensor *rt_, THDLongTensor *ri_,
|
||||
THDTensor *t, int dimension,
|
||||
int descendingOrder) {
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimension)(t),
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimensionLegacyNoScalars)(t),
|
||||
2, "invalid dimension %d", dimension + TH_INDEX_BASE);
|
||||
|
||||
THDTensor_(resizeAs)(rt_, t);
|
||||
@ -409,7 +409,7 @@ void THDTensor_(sort)(THDTensor *rt_, THDLongTensor *ri_,
|
||||
void THDTensor_(topk)(THDTensor *rt_, THDLongTensor *ri_,
|
||||
THDTensor *t, int64_t k, int dim,
|
||||
int dir, int sorted) {
|
||||
int numDims = THDTensor_(nDimension)(t);
|
||||
int numDims = THDTensor_(nDimensionLegacyNoScalars)(t);
|
||||
THArgCheck(dim >= 0 && dim < numDims, 3, "dim not in range");
|
||||
|
||||
int64_t sliceSize = THDTensor_(size)(t, dim);
|
||||
@ -428,7 +428,7 @@ void THDTensor_(topk)(THDTensor *rt_, THDLongTensor *ri_,
|
||||
}
|
||||
|
||||
void THDTensor_(tril)(THDTensor *r_, THDTensor *t, int64_t k) {
|
||||
THArgCheck(THDTensor_(nDimension)(t) == 2, 1, "expected a matrix");
|
||||
THArgCheck(THDTensor_(nDimensionLegacyNoScalars)(t) == 2, 1, "expected a matrix");
|
||||
|
||||
THDTensor_(resizeAs)(r_, t);
|
||||
|
||||
@ -439,7 +439,7 @@ void THDTensor_(tril)(THDTensor *r_, THDTensor *t, int64_t k) {
|
||||
}
|
||||
|
||||
void THDTensor_(triu)(THDTensor *r_, THDTensor *t, int64_t k) {
|
||||
THArgCheck(THDTensor_(nDimension)(t) == 2, 1, "expected a matrix");
|
||||
THArgCheck(THDTensor_(nDimensionLegacyNoScalars)(t) == 2, 1, "expected a matrix");
|
||||
|
||||
THDTensor_(resizeAs)(r_, t);
|
||||
|
||||
@ -650,7 +650,7 @@ void THDTensor_(lerp)(THDTensor *r_, THDTensor *a, THDTensor *b, real weight) {
|
||||
}
|
||||
|
||||
void THDTensor_(mean)(THDTensor *r_, THDTensor *t, int dimension, int keepdim) {
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimension)(t), 2,
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimensionLegacyNoScalars)(t), 2,
|
||||
"invalid dimension %d", dimension + TH_INDEX_BASE);
|
||||
|
||||
THLongStorage *dim = THDTensor_(newSizeOf)(t);
|
||||
@ -669,7 +669,7 @@ void THDTensor_(mean)(THDTensor *r_, THDTensor *t, int dimension, int keepdim) {
|
||||
}
|
||||
|
||||
void THDTensor_(std)(THDTensor *r_, THDTensor *t, int dimension, int biased, int keepdim) {
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimension)(t), 3,
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimensionLegacyNoScalars)(t), 3,
|
||||
"invalid dimension %d", dimension + TH_INDEX_BASE);
|
||||
|
||||
THLongStorage *dim = THDTensor_(newSizeOf)(t);
|
||||
@ -688,7 +688,7 @@ void THDTensor_(std)(THDTensor *r_, THDTensor *t, int dimension, int biased, int
|
||||
}
|
||||
|
||||
void THDTensor_(var)(THDTensor *r_, THDTensor *t, int dimension, int biased, int keepdim) {
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimension)(t), 3,
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimensionLegacyNoScalars)(t), 3,
|
||||
"invalid dimension %d", dimension + TH_INDEX_BASE);
|
||||
|
||||
THLongStorage *dim = THDTensor_(newSizeOf)(t);
|
||||
@ -707,7 +707,7 @@ void THDTensor_(var)(THDTensor *r_, THDTensor *t, int dimension, int biased, int
|
||||
}
|
||||
|
||||
void THDTensor_(norm)(THDTensor *r_, THDTensor *t, real value, int dimension, int keepdim) {
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimension)(t), 3,
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimensionLegacyNoScalars)(t), 3,
|
||||
"invalid dimension %d", dimension + TH_INDEX_BASE);
|
||||
|
||||
THLongStorage *dim = THDTensor_(newSizeOf)(t);
|
||||
@ -736,12 +736,12 @@ accreal THDTensor_(normall)(THDTensor *tensor, real value) {
|
||||
|
||||
void THDTensor_(renorm)(THDTensor *res, THDTensor *src, real value,
|
||||
int dimension, real maxnorm) {
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimension)(src), 3,
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimensionLegacyNoScalars)(src), 3,
|
||||
"invalid dimension %d", dimension + TH_INDEX_BASE);
|
||||
THArgCheck(value > 0, 2, "non-positive-norm not supported");
|
||||
THArgCheck(THDTensor_(nDimension)(src) > 1, 1,
|
||||
THArgCheck(THDTensor_(nDimensionLegacyNoScalars)(src) > 1, 1,
|
||||
"need at least 2 dimensions, got %d dimensions",
|
||||
THDTensor_(nDimension)(src));
|
||||
THDTensor_(nDimensionLegacyNoScalars)(src));
|
||||
|
||||
THDTensor_(resizeAs)(res, src);
|
||||
|
||||
@ -846,12 +846,12 @@ void THDTensor_(histc)(THDTensor *hist, THDTensor *tensor, int64_t nbins,
|
||||
|
||||
void THDTensor_(bhistc)(THDTensor *hist, THDTensor *tensor, int64_t nbins,
|
||||
real minvalue, real maxvalue) {
|
||||
THArgCheck(THDTensor_(nDimension)(tensor) < 3, 2,
|
||||
THArgCheck(THDTensor_(nDimensionLegacyNoScalars)(tensor) < 3, 2,
|
||||
"invalid dimension %d, the input must be a 2d tensor",
|
||||
THDTensor_(nDimension)(tensor));
|
||||
THDTensor_(nDimensionLegacyNoScalars)(tensor));
|
||||
|
||||
int dimension = 1;
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimension)(tensor), 2,
|
||||
THArgCheck(dimension >= 0 && dimension < THDTensor_(nDimensionLegacyNoScalars)(tensor), 2,
|
||||
"invalid dimension %d", dimension + TH_INDEX_BASE);
|
||||
|
||||
THDTensor_(resize2d)(hist, tensor->size[0], nbins);
|
||||
|
@ -88,7 +88,7 @@ void THDTensor_(logNormal)(THDTensor *self, THDGenerator *_generator, double mea
|
||||
void THDTensor_(multinomial)(THDLongTensor *self, THDGenerator *_generator,
|
||||
THDTensor *prob_dist, int n_sample,
|
||||
int with_replacement) {
|
||||
int start_dim = THDTensor_(nDimension)(prob_dist);
|
||||
int start_dim = THDTensor_(nDimensionLegacyNoScalars)(prob_dist);
|
||||
if (start_dim == 1) {
|
||||
THDTensor_(resize2d)(prob_dist, 1, THDTensor_(size)(prob_dist, 0));
|
||||
}
|
||||
|
Reference in New Issue
Block a user