Make caffe2::Tensor::dims() return an IntList instead of a const vector& (#12180)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12180

I had to fix a lot of call sites, because a lot of places assume that
you can actually get a const vector&, and if the internal representation
of sizes in a tensor is NOT a vector, it's not possible to fulfill
this API contract.

Framework changes:
- I deleted TensorImpl::dims(); caffe2::Tensor::dims() just forwards to
  sizes() now.
- De-templatized SetDims; now it is an explicit list of ArrayRef and
  variadic overloads.  This makes implicit conversions work again,
  so I don't need to explicitly list the std::vector cases too.
  - As a knock-on effect, this causes Reset() to accept at::IntList as well as
    const std::vector<int64_t>&
- Edited variadic overloads of SetDims to all forward to the underlying
  arbitrary-dim implementation, reducing code duplication. (It's probably
  marginally less efficient in the new world.)
- Replace Tensor constructor accepting const std::vector<int64_t>& with at::IntList
- Make MKLTensor accept ArrayRef along with vector in constructor and
  Reset (unfortunately, no implicit conversions here, since it's templated on
  index type.)
- There are a few other places, like cudnn, where I changed functions
  that previously took const std::vector<int64_t>& to take at::IntList
  instead.

Classification of call site changes:
- 'const std::vector<int64_t>& x_dims = x.dims()' ==>
  'at::IntList x_dims = x.dims()'
- 'std::vector<int64_t> x_dims = x.dims()' ==>
  'std::vector<int64_t> x_dims = x.dims().vec()' (we need a copy!)
  Usually this is because we're about to mutably modify the vector
  to compute some new dimension.  However, it also very commonly occurs in the
  form: 'x_dims_ = x.dims()' because we frequently cache sizes in operators.
- Instead of constructing std::vector<int64_t>{blah, blah}, construct an
  at::IntList directly

ArrayRef changes:
- cbegin()/cend() iterators, they operate the same aas begin()/end() because
  everything on ArrayRef is const.
- Moved operator<< into ArrayRef.h, so that it's always available when
  working with ArrayRef.  I also templated it, so it now works on an
  ArrayRef of any type.
- Add operator== overload for ArrayRef, and also add variants to permit
  comparison of ArrayRef with std::vector, a very common operation.
  (The non-templated version of operator== can get these automatically
  via implicit conversion, but with templates C++ refuses to do
  any explicit conversions.)

I'm planning to audit all dims() call sites to make sure they don't
expect 'auto x = t.dims()' to give you an x whose lifetime can validly
outlive the tensor.

I opted not to do a dims() to sizes() rename, because dims() also matches
the protobufs accessor.  Bad news!

Reviewed By: jerryzh168

Differential Revision: D10111759

fbshipit-source-id: a2a81dc4b92c22ad4b3b8ef4077a7e97b6479452
This commit is contained in:
Edward Yang
2018-10-05 15:45:03 -07:00
committed by Facebook Github Bot
parent f9fb37ca79
commit 54d9823d00
110 changed files with 362 additions and 299 deletions

View File

@ -108,6 +108,15 @@ class ArrayRef final {
return Data + Length;
}
// These are actually the same as iterator, since ArrayRef only
// gives you const iterators.
constexpr const_iterator cbegin() const {
return Data;
}
constexpr const_iterator cend() const {
return Data + Length;
}
constexpr reverse_iterator rbegin() const {
return reverse_iterator(end());
}
@ -209,4 +218,53 @@ class ArrayRef final {
/// @}
};
template <typename T>
std::ostream& operator<<(std::ostream & out, ArrayRef<T> list) {
int i = 0;
out << "[";
for(auto e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "]";
return out;
}
// WARNING: Template instantiation will NOT be willing to do an implicit
// conversions to get you to an at::ArrayRef, which is why we need so
// many overloads.
template <typename T>
bool operator==(at::ArrayRef<T> a1, at::ArrayRef<T> a2) {
return a1.equals(a2);
}
template <typename T>
bool operator!=(at::ArrayRef<T> a1, at::ArrayRef<T> a2) {
return !a1.equals(a2);
}
template <typename T>
bool operator==(std::vector<T> a1, at::ArrayRef<T> a2) {
return at::ArrayRef<T>(a1).equals(a2);
}
template <typename T>
bool operator!=(std::vector<T> a1, at::ArrayRef<T> a2) {
return !at::ArrayRef<T>(a1).equals(a2);
}
template <typename T>
bool operator==(at::ArrayRef<T> a1, std::vector<T> a2) {
return a1.equals(at::ArrayRef<T>(a2));
}
template <typename T>
bool operator!=(at::ArrayRef<T> a1, std::vector<T> a2) {
return !a1.equals(at::ArrayRef<T>(a2));
}
using IntList = ArrayRef<int64_t>;
} // namespace at

View File

@ -28,18 +28,6 @@ private:
std::ios saved;
};
std::ostream& operator<<(std::ostream & out, IntList list) {
int i = 0;
out << "[";
for(auto e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "]";
return out;
}
std::ostream& operator<<(std::ostream & out, Backend b) {
return out << toString(b);
}

View File

@ -8,7 +8,6 @@
namespace at {
CAFFE2_API std::ostream& operator<<(std::ostream& out, IntList list);
CAFFE2_API std::ostream& operator<<(std::ostream& out, Backend b);
CAFFE2_API std::ostream& operator<<(std::ostream& out, const Type& t);
CAFFE2_API std::ostream& print(

View File

@ -189,7 +189,6 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
}
class Tensor;
typedef ArrayRef<int64_t> IntList;
typedef ArrayRef<Tensor> TensorList;
inline std::ostream& operator<<(

View File

@ -1015,6 +1015,19 @@ inline size_t capacity_in_bytes(const SmallVector<T, N>& X) {
return X.capacity_in_bytes();
}
template <typename T, unsigned N>
std::ostream& operator<<(std::ostream & out, const SmallVector<T, N>& list) {
int i = 0;
out << "[";
for(auto e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "]";
return out;
}
} // end namespace at
namespace std {

View File

@ -447,7 +447,7 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
data_type_ = caffe2::TypeMeta();
return;
}
Resize(src.dims());
Resize(src.sizes());
if (numel() > 0) {
if (data_type_.copy()) {
CAFFE_ENFORCE(
@ -810,21 +810,11 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
return static_cast<T*>(raw_mutable_data(caffe2::TypeMeta::Make<T>()));
}
/**
* Returns the dimensions of the tensor as a vector.
*/
inline const std::vector<int64_t>& dims() const {
// TODO: This method will no longer work if we change the
// internal representation of dims(). That's BAD. Let's get
// people to stop using this.
return sizes_;
}
private:
template <
typename T,
typename = typename std::enable_if<std::is_integral<T>::value>::type>
bool SetDims(const std::vector<T>& src) {
bool SetDimsTemplate(at::ArrayRef<T> src) {
auto old_numel = numel_;
sizes_.resize(src.size());
int64_t new_numel = 1;
@ -837,58 +827,36 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
return numel_ != old_numel;
}
bool SetDims() {
auto old_numel = numel_;
sizes_.resize(0);
update_to_contiguous_strides();
numel_ = 1;
return numel_ != old_numel;
bool SetDims(at::ArrayRef<int64_t> s) {
return SetDimsTemplate(s);
}
bool SetDims(at::ArrayRef<int> s) {
return SetDimsTemplate(s);
}
bool SetDims(at::ArrayRef<size_t> s) {
return SetDimsTemplate(s);
}
bool SetDims() {
return SetDims(at::IntList{});
}
// TODO(jiayq): maybe rewrite the following functions with initializer list.
// NVCC does not play well with initializer lists last time, but worth
// another shot.
bool SetDims(const int64_t d0) {
auto old_numel = numel_;
sizes_.resize(1);
sizes_[0] = d0;
update_to_contiguous_strides();
numel_ = d0;
return numel_ != old_numel;
return SetDims(at::IntList{d0});
}
bool SetDims(const int64_t d0, const int64_t d1) {
auto old_numel = numel_;
sizes_.resize(2);
sizes_[0] = d0;
sizes_[1] = d1;
update_to_contiguous_strides();
numel_ = d0 * d1;
return numel_ != old_numel;
return SetDims(at::IntList{d0, d1});
}
bool SetDims(const int64_t d0, const int64_t d1, const int64_t d2) {
auto old_numel = numel_;
sizes_.resize(3);
sizes_[0] = d0;
sizes_[1] = d1;
sizes_[2] = d2;
update_to_contiguous_strides();
numel_ = d0 * d1 * d2;
return numel_ != old_numel;
return SetDims(at::IntList{d0, d1, d2});
}
bool
SetDims(const int64_t d0, const int64_t d1, const int64_t d2, const int64_t d3) {
auto old_numel = numel_;
sizes_.resize(4);
sizes_[0] = d0;
sizes_[1] = d1;
sizes_[2] = d2;
sizes_[3] = d3;
update_to_contiguous_strides();
numel_ = d0 * d1 * d2 * d3;
return numel_ != old_numel;
bool SetDims(const int64_t d0, const int64_t d1, const int64_t d2, const int64_t d3) {
return SetDims(at::IntList{d0, d1, d2, d3});
}
inline void update_to_contiguous_strides() {

View File

@ -161,7 +161,7 @@ bool TensorRTOp::RunOnDevice() {
size_t N = 0;
for (int i = 0; i < InputSize(); ++i) {
const auto& input_tensor = Input(i);
const auto& tensor_dims = input_tensor.dims();
const auto tensor_dims = input_tensor.dims();
CAFFE_ENFORCE(!tensor_dims.empty(), "Input tensor cannot be empty");
if (i == 0) {
N = tensor_dims.front();
@ -198,7 +198,7 @@ bool TensorRTOp::RunOnDevice() {
// input, check input dimensions
const auto& input_tensor = Input(input_idx++);
const float* input_data = input_tensor.data<float>();
const auto& tensor_dims = input_tensor.dims();
const auto tensor_dims = input_tensor.dims();
auto chw = CheckDims(dims, tensor_dims);
bindings.push_back((void*)(input_data + offset * chw));
} else {

View File

@ -47,14 +47,14 @@ class C10_EXPORT QTensor {
* Explained here: https://arxiv.org/abs/1606.06160
*/
explicit QTensor(
const std::vector<int>& dims,
at::ArrayRef<int> dims,
const unsigned char precision,
const bool signbit = false)
: precision_(precision), signed_(signbit) {
Resize(dims);
}
void Resize(std::vector<int> dim_source) {
void Resize(at::ArrayRef<int> dim_source) {
if (dims_ != dim_source) {
size_t source_size = std::accumulate(
dim_source.begin(), dim_source.end(), 1, std::multiplies<int>());
@ -62,7 +62,7 @@ class C10_EXPORT QTensor {
data_ptr_.clear();
capacity_ = 0;
}
dims_ = dim_source;
dims_ = dim_source.vec();
size_ = source_size;
}
}
@ -145,7 +145,7 @@ class C10_EXPORT QTensor {
return precision_;
}
inline const vector<int>& dims() const {
inline at::ArrayRef<int> dims() const {
return dims_;
}

View File

@ -86,7 +86,7 @@ vector<int64_t> GetTensorInfo(
CHECK(tc->unsafeGetTensorImpl()->storage().unsafeGetStorageImpl());
*capacity = tc->storage().capacity();
tc->ExtractDeviceOption(device);
return tc->dims();
return tc->dims().vec();
}
// since we only have one tensor, probably need to remove this at some point?

View File

@ -44,8 +44,7 @@ class CAFFE2_API Tensor final {
* Note that the actual data allocation is not going to be carried out until
* the first time mutable_data() is called.
*/
explicit Tensor(const vector<int64_t>& dims, DeviceType type)
: Tensor(Storage(type)) {
explicit Tensor(at::IntList dims, DeviceType type) : Tensor(Storage(type)) {
// TODO: here, we create a Storage
// and immediately discard it in Resize() since
// reset_tensor will be true and FreeMemory will be called,
@ -342,8 +341,8 @@ class CAFFE2_API Tensor final {
return impl_->numel() * itemsize();
}
inline const vector<int64_t>& dims() const {
return impl_.get()->dims();
inline at::IntList dims() const {
return impl_.get()->sizes();
}
inline int64_t size_from_dim(int k) const {

View File

@ -216,7 +216,7 @@ class MaxPoolRTCOp final : public ConvPoolOpBase<CUDAContext> {
stride_w(),
pad_t(),
pad_l());
input_dims_ = X.dims();
input_dims_ = X.dims().vec();
}
// Carry out the pooling computation.
func_.Launch(CAFFE_GET_BLOCKS(Y->size()), 1, 1, CAFFE_CUDA_NUM_THREADS,
@ -269,7 +269,7 @@ class MaxPoolGradientRTCOp final : public ConvPoolOpBase<CUDAContext> {
stride_w(),
pad_t(),
pad_l());
input_dims_ = X.dims();
input_dims_ = X.dims().vec();
}
func_.Launch(CAFFE_GET_BLOCKS(X.size()), 1, 1, CAFFE_CUDA_NUM_THREADS, 1, 1,
0, context_.cuda_stream(),

View File

@ -125,7 +125,7 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
"IDEEP fallback op currently does not support non-TensorCPU "
"output type who needs copying.");
const auto& src = local_output_blobs_[i]->template Get<TensorCPU>();
auto src_dims = src.dims();
auto src_dims = src.dims().vec();
if (src.template IsType<float>() &&
src.dims().size() != 0 && src.size_from_dim(0) != 0 &&
base_op_->type() != "Python") {

View File

@ -37,26 +37,26 @@
error); \
} while (0)
#define CHECK_INPUT_FILTER_DIMS(X, filter, condition) \
do { \
if (cached_input_dims_ != X.dims() || \
cached_filter_dims_ != filter.dims()) { \
cached_input_dims_ = X.dims(); \
cached_filter_dims_ = filter.dims(); \
condition = true; \
} else { \
condition = false; \
} \
#define CHECK_INPUT_FILTER_DIMS(X, filter, condition) \
do { \
if (at::IntList(cached_input_dims_) != X.dims() || \
at::IntList(cached_filter_dims_) != filter.dims()) { \
cached_input_dims_ = X.dims().vec(); \
cached_filter_dims_ = filter.dims().vec(); \
condition = true; \
} else { \
condition = false; \
} \
} while (0)
#define CHECK_INPUT_DIMS(X, condition) \
do { \
if (cached_input_dims_ != X.dims()) { \
cached_input_dims_ = X.dims(); \
condition = true; \
} else { \
condition = false; \
} \
#define CHECK_INPUT_DIMS(X, condition) \
do { \
if (at::IntList(cached_input_dims_) != X.dims()) { \
cached_input_dims_ = X.dims().vec(); \
condition = true; \
} else { \
condition = false; \
} \
} while (0)
// All caffe2 mkl related headers

View File

@ -45,7 +45,7 @@ class MKLConcatOp final : public MKLOperator<T> {
bool dims_changed = (input_size_cache_.size() != nInputs);
for (int i = 0; i < nInputs && !dims_changed; ++i) {
dims_changed = (input_size_cache_[i] != Input(i).dims());
dims_changed = (at::IntList(input_size_cache_[i]) != Input(i).dims());
}
if (dims_changed || c10::FLAGS_caffe2_mkl_memonger_in_use) {
@ -68,11 +68,11 @@ class MKLConcatOp final : public MKLOperator<T> {
" has dimension mismatch at axis ",
j);
}
input_size_cache_[i] = Xi.dims();
input_size_cache_[i] = Xi.dims().vec();
output_channels += Xi.dim32(canonical_axis);
input_layouts[i] = Xi.layout();
}
cached_output_dims_ = X0.dims();
cached_output_dims_ = X0.dims().vec();
cached_output_dims_[canonical_axis] = output_channels;
primitive_.Reset(

View File

@ -37,7 +37,7 @@ class MKLConvOp final : public ConvPoolOpBase<MKLContext> {
math::Set<T, CPUContext>(
M, 0.0, cpu_zero_bias.template mutable_data<float>(), &ctx);
zero_bias_.reset(new MKLMemory<T>(std::vector<int64_t>{M}));
zero_bias_.reset(new MKLMemory<T>(at::IntList{M}));
zero_bias_->CopyFrom(cpu_zero_bias);
}
const auto& bias = InputSize() == 2
@ -130,11 +130,11 @@ class MKLConvOp final : public ConvPoolOpBase<MKLContext> {
if (group_ > 1) {
// Explicitly reformat the buffer.
MKLMemory<float> group_filter(
std::vector<int64_t>{int64_t(group_),
int64_t(filter.dim32(0) / group_),
int64_t(filter.dim32(1)),
int64_t(filter.dim32(2)),
int64_t(filter.dim32(3))},
at::IntList{int64_t(group_),
int64_t(filter.dim32(0) / group_),
int64_t(filter.dim32(1)),
int64_t(filter.dim32(2)),
int64_t(filter.dim32(3))},
nullptr,
dnnResourceFilter,
/*share_memory_if_possible=*/true);

View File

@ -47,10 +47,10 @@ class ConvMKLDNNOp final : public ConvPoolOpBase<CPUContext> {
// Pre-allocate Y so we can potentially share memory if applicable.
Y->mutable_data<T>();
if (cached_input_dims_ != X.dims() ||
cached_filter_dims_ != filter.dims()) {
cached_input_dims_ = X.dims();
cached_filter_dims_ = filter.dims();
if (at::IntList(cached_input_dims_) != X.dims() ||
at::IntList(cached_filter_dims_) != filter.dims()) {
cached_input_dims_ = X.dims().vec();
cached_filter_dims_ = filter.dims().vec();
// In order to create an internal layout, let's use convolution as
// primitive.
size_t dimension = 4;

View File

@ -35,9 +35,9 @@ class MKLSumOp final : public MKLOperator<T> {
X0.layout(),
coefficients_.data());
if (Y != &X0) {
Y->Reset(X0.dims(), primitive_, dnnResourceDst);
Y->Reset(X0.dims().vec(), primitive_, dnnResourceDst);
}
buffer_.Reset(X0.dims(), primitive_, dnnResourceDst, true);
buffer_.Reset(X0.dims().vec(), primitive_, dnnResourceDst, true);
}
input_views_.resize(this->InputSize());
for (auto i = 0; i < this->InputSize(); ++i) {

View File

@ -30,7 +30,7 @@ class MKLFullyConnectedOp final : public MKLOperator<T> {
const int N = filter.dim32(0);
CAFFE_ENFORCE(N == bias.dim32(0));
auto Y_shape = X.dims();
auto Y_shape = X.dims().vec();
Y_shape[1] = N;
Y_shape.resize(2);

View File

@ -85,7 +85,7 @@ class PackedFCOp final : public Operator<CPUContext> {
CAFFE_ENFORCE_EQ(packed_matrix->n_, N);
// Do we want to check the other flags as well?
Y_shape_cache_ = X.dims();
Y_shape_cache_ = X.dims().vec();
// This is an invariant of canonical_axis, so we can DCHECK.
DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
Y_shape_cache_.resize(canonical_axis + 1);

View File

@ -27,7 +27,7 @@ static vector<int64_t> GetMKLTensorInfo(
*capacity = tc->size() * sizeof(T);
device->set_device_type(PROTO_MKLDNN);
device->set_cuda_gpu_id(0);
return tc->dims();
return tc->dims().vec();
}
template <typename T>

View File

@ -168,13 +168,24 @@ class C10_EXPORT MKLMemory {
// storage.
template <typename IndexType>
explicit MKLMemory(
const vector<IndexType>& dims,
at::ArrayRef<IndexType> dims,
const dnnPrimitive_t primitive = nullptr,
const dnnResourceType_t type = dnnResourceNumber,
bool share_mem_if_possible = false) {
Reset(dims, primitive, type, share_mem_if_possible);
}
// Initialize an MKLMemory, with the given dimension assuming a C-contiguous
// storage.
template <typename IndexType>
explicit MKLMemory(
const std::vector<IndexType>& dims,
const dnnPrimitive_t primitive = nullptr,
const dnnResourceType_t type = dnnResourceNumber,
bool share_mem_if_possible = false) {
Reset(at::ArrayRef<IndexType>(dims), primitive, type, share_mem_if_possible);
}
// Initialize an MKLMemory with the given size, strides, dnn
// primitive and type.
void Reset(
@ -213,7 +224,18 @@ class C10_EXPORT MKLMemory {
// storage.
template <typename IndexType>
void Reset(
const vector<IndexType>& dims,
const std::vector<IndexType>& dims,
const dnnPrimitive_t primitive = nullptr,
const dnnResourceType_t type = dnnResourceNumber,
bool share_mem_if_possible = false) {
Reset(at::ArrayRef<IndexType>(dims), primitive, dnnResourceNumber, share_mem_if_possible);
}
// Initialize an MKLMemory, with the given dimension assuming a C-contiguous
// storage.
template <typename IndexType>
void Reset(
at::ArrayRef<IndexType> dims,
const dnnPrimitive_t primitive = nullptr,
const dnnResourceType_t type = dnnResourceNumber,
bool share_mem_if_possible = false) {
@ -313,7 +335,7 @@ class C10_EXPORT MKLMemory {
void CopyFrom(const TensorCPU& tensor) {
CAFFE_ENFORCE_EQ(
tensor.dims(),
dims_,
at::IntList(dims_),
"Dims does not match the expected dims of the resource.");
CopyFrom(tensor.template data<T>());
}
@ -321,7 +343,7 @@ class C10_EXPORT MKLMemory {
void CopyFrom(const MKLMemory<T>& other) {
CAFFE_ENFORCE_EQ(
other.dims(),
dims_,
at::IntList(dims_),
"Dims does not match the expected dims of the resource.");
if (share_mem_if_possible_ && dnnLayoutCompare<T>(other.layout_, layout_)) {
@ -456,7 +478,7 @@ class C10_EXPORT MKLMemory {
return buffer_.get();
}
inline const vector<int64_t>& dims() const {
inline at::IntList dims() const {
return dims_;
}

View File

@ -69,5 +69,5 @@ void Caffe2IOSPredictor::run(const Tensor& inData, Tensor& outData, std::string&
}
caffe2::TensorCPU* output = &output_vec.front();
outData.data = output->mutable_data<uint8_t>();
outData.dims = output->dims();
outData.dims = output->dims().vec();
}

View File

@ -311,7 +311,7 @@ utils::ConstTensorView<T> GetSubTensorView(
auto st_idx = ComputeStartIndex(tensor, start_dims);
auto ptr = tensor.data<T>() + st_idx;
auto& input_dims = tensor.dims();
auto input_dims = tensor.dims();
std::vector<int> ret_dims(input_dims.begin() + 1, input_dims.end());
utils::ConstTensorView<T> ret(ptr, ret_dims);

View File

@ -15,7 +15,7 @@ void uniformQuantize2b1b(const TensorCPU& X,
const auto N = X.size_to_dim(X.ndim() - 1);
auto C = X.size() / N;
const auto QC = divRoundUp(C, 8);
auto XQs = X.dims();
auto XQs = X.dims().vec();
XQs[X.ndim() - 1] = QC;
CAFFE_ENFORCE_EQ(XQ.size(), k2b1bXBits);
for (auto i = 0; i < k2b1bXBits; ++i) {
@ -137,7 +137,7 @@ void signQuantize(const TensorCPU& X, TensorCPU* XQ) {
const auto N = X.size_to_dim(X.ndim() - 1);
auto C = X.size() / N;
const auto QC = divRoundUp(C, 8);
auto XQs = X.dims();
auto XQs = X.dims().vec();
XQs[X.ndim() - 1] = QC;
XQ->Resize(XQs);
const float* Xdata = X.data<float>();

View File

@ -98,7 +98,7 @@ void uniformQuantize2b1bNeon(QConvState* state,
const size_t C = X.dim32(X.ndim() - 1);
const size_t N = X.size() / C;
const size_t QC = divRoundUp(C, 8);
auto XQs = X.dims();
auto XQs = X.dims().vec();
XQs[X.ndim() - 1] = QC;
CAFFE_ENFORCE_EQ(XQ.size(), k2b1bXBits);
for (auto i = 0; i < k2b1bXBits; ++i) {

View File

@ -98,7 +98,7 @@ class MPIAllgatherOp final : public Operator<Context> {
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
auto& input = Input(1);
auto* output = Output(0);
vector<int64_t> output_dims = input.dims();
vector<int64_t> output_dims = input.dims().vec();
output_dims[0] *= OperatorBase::Input<MPICommonWorldWrapper>(0).size();
output->Resize(output_dims);
MPI_CHECK(MPI_Allgather(

View File

@ -37,9 +37,9 @@ class BatchMatMulOp final : public Operator<Context> {
auto* Y = Output(0);
auto ndims_A = A.ndim();
auto dims_A = A.dims();
auto dims_A = A.dims().vec();
auto ndims_B = B.ndim();
auto dims_B = B.dims();
auto dims_B = B.dims().vec();
auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) {
std::stringstream ss;

View File

@ -44,7 +44,7 @@ class BatchMatMulOpGPUTest : public testing::Test {
ASSERT_NE(nullptr, Y_blob);
const auto& Y = Y_blob->Get<Tensor>();
Tensor Y_cpu(Y, CPU);
const auto& Y_dims = Y_cpu.dims();
const auto Y_dims = Y_cpu.dims();
ASSERT_EQ(dims.size(), Y_dims.size());
for (std::size_t i = 0; i < dims.size(); ++i) {
ASSERT_EQ(dims[i], Y_dims[i]);

View File

@ -37,7 +37,7 @@ class BatchMatMulOpTest : public testing::Test {
const Blob* Y_blob = ws_.GetBlob("Y");
ASSERT_NE(nullptr, Y_blob);
const auto& Y = Y_blob->Get<TensorCPU>();
const auto& Y_dims = Y.dims();
const auto Y_dims = Y.dims();
ASSERT_EQ(dims.size(), Y_dims.size());
for (std::size_t i = 0; i < dims.size(); ++i) {
ASSERT_EQ(dims[i], Y_dims[i]);

View File

@ -80,7 +80,7 @@ bool BatchDenseToSparseOp<T, Context>::RunOnDevice() {
CAFFE_ENFORCE_EQ(batch_size, dense.dim(0));
dense_last_dim_ = dense.dim(1);
vector<int64_t> output_shape = indices.dims();
vector<int64_t> output_shape = indices.dims().vec();
output->Resize(output_shape);
T* output_data = output->template mutable_data<T>();

View File

@ -138,7 +138,7 @@ bool BBoxTransformOp<float, CPUContext>::RunOnDevice() {
}
}
CAFFE_ENFORCE_EQ(iminfo_in.dims(), (vector<int64_t>{batch_size, 3}));
CAFFE_ENFORCE_EQ(iminfo_in.dims(), (at::IntList{batch_size, 3}));
Eigen::Map<const ERArrXXf> iminfo(
iminfo_in.data<float>(), iminfo_in.dim(0), iminfo_in.dim(1));

View File

@ -76,7 +76,7 @@ class BooleanMaskOp<CUDAContext> final : public Operator<CUDAContext> {
context_.CopyToCPU(1, numOfOutputData, &numOfOutput);
indices_.Resize(numOfOutput);
std::vector<int64_t> dims = src.dims();
std::vector<int64_t> dims = src.dims().vec();
dims[0] = numOfOutput;
dest->Resize(dims);
auto* destData = (uint8_t*)dest->raw_mutable_data(src.meta());

View File

@ -161,7 +161,7 @@ bool SplitOp<Context>::RunOnDevice() {
input_channels,
"Sum of split dimensions do not match: should be ",
input_channels);
vector<int64_t> output_dims(input.dims());
vector<int64_t> output_dims(input.dims().vec());
int before = 1, after = 1;
for (int i = 0; i < canonical_axis; ++i) {
before *= input.dim32(i);
@ -215,7 +215,7 @@ bool SplitByLengthsOp<Context>::RunOnDevice() {
input_channels,
"Sum of split dimensions do not match: should be ",
input_channels);
vector<int64_t> output_dims(input.dims());
vector<int64_t> output_dims(input.dims().vec());
int before = input.size_to_dim(canonical_axis);
int after = input.size_from_dim(canonical_axis + 1);
size_t input_offset = 0;
@ -263,7 +263,7 @@ bool ConcatOp<Context>::RunOnDevice() {
}
int before = 1, after = 1;
vector<int64_t> output_dims(input_zero.dims());
vector<int64_t> output_dims(input_zero.dims().vec());
for (int i = 0; i < input_zero.ndim(); ++i) {
if (i == canonical_axis && !add_axis_) {
continue;

View File

@ -16,8 +16,8 @@ class AlgorithmsCache {
// combination of tensor dimensions & compute data type.
//
TAlgorithm getAlgorithm(
const std::vector<int64_t>& tensorDimensions1,
const std::vector<int64_t>& tensorDimensions2,
at::IntList tensorDimensions1,
at::IntList tensorDimensions2,
int algorithmFlags, // Differentiate between algorithms with different
// parameters in a generic way
std::function<TAlgorithm()> generatingFunc);
@ -28,8 +28,8 @@ class AlgorithmsCache {
template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::getAlgorithm(
const std::vector<int64_t>& tensorDimensions1,
const std::vector<int64_t>& tensorDimensions2,
at::IntList tensorDimensions1,
at::IntList tensorDimensions2,
int algorithmFlags,
std::function<TAlgorithm()> generatingFunc) {
int64_t seed = 0;

View File

@ -580,12 +580,12 @@ bool CudnnConvOp::DoRunWithType() {
if (input_changed || filter_changed) {
VLOG(1) << "Changing the cudnn descriptor configurations.";
if (input_changed) {
cudnn_input_dims_ = X.dims();
cudnn_input_dims_ = X.dims().vec();
SetTensorNdDescriptorWithGroup<T_X>(
X.ndim(), bottom_desc_, N, C, H, W, D);
}
if (filter_changed) {
cudnn_filter_dims_ = filter.dims();
cudnn_filter_dims_ = filter.dims().vec();
if (kernel_.size() == 2) {
#if CUDNN_VERSION_MIN(7, 0, 0)
const int MM = M;
@ -936,12 +936,12 @@ bool CudnnConvGradientOp::DoRunWithType() {
if (input_changed || filter_changed) {
VLOG(1) << "Changing the cudnn descriptor configurations.";
if (input_changed) {
cudnn_input_dims_ = X.dims();
cudnn_input_dims_ = X.dims().vec();
SetTensorNdDescriptorWithGroup<T_X>(
X.ndim(), bottom_desc_, N, C, H, W, D);
}
if (filter_changed) {
cudnn_filter_dims_ = filter.dims();
cudnn_filter_dims_ = filter.dims().vec();
if (kernel_.size() == 2) {
#if CUDNN_VERSION_MIN(7, 0, 0)
const int MM = M;

View File

@ -246,7 +246,7 @@ class ConvPoolOpBase : public Operator<Context> {
// Helper function that is also called from OperatorSchema. Modified
// kernel parameters and output output_dims and channel_first.
static inline void InferOutputSize(
vector<int64_t> input_dims,
at::IntList input_dims,
int /*output_channel*/,
StorageOrder order,
bool global_pooling,

View File

@ -191,7 +191,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
if (input_changed || filter_changed) {
VLOG(1) << "Changing the cudnn descriptor configurations.";
if (input_changed) {
cudnn_input_dims_ = X.dims();
cudnn_input_dims_ = X.dims().vec();
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
bottom_desc_,
GetCudnnTensorFormat(order_),
@ -202,7 +202,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
W));
}
if (filter_changed) {
cudnn_filter_dims_ = filter.dims();
cudnn_filter_dims_ = filter.dims().vec();
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
filter_desc_,
cudnnTypeWrapper<T>::type,
@ -421,7 +421,7 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
if (input_changed || filter_changed) {
VLOG(1) << "Changing the cudnn descriptor configurations.";
if (input_changed) {
cudnn_input_dims_ = X.dims();
cudnn_input_dims_ = X.dims().vec();
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
bottom_desc_,
GetCudnnTensorFormat(order_),
@ -432,7 +432,7 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
W));
}
if (filter_changed) {
cudnn_filter_dims_ = filter.dims();
cudnn_filter_dims_ = filter.dims().vec();
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
filter_desc_,
cudnnTypeWrapper<T>::type,

View File

@ -258,7 +258,7 @@ template <>
bool MakeTwoClassOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
auto shape = X.dims();
auto shape = X.dims().vec();
shape.push_back(2);
int64_t N = X.size();
Y->Resize(shape);
@ -277,7 +277,7 @@ template <>
bool MakeTwoClassGradientOp<float, CPUContext>::RunOnDevice() {
auto& dY = Input(0);
auto* dX = Output(0);
auto shape = dY.dims();
auto shape = dY.dims().vec();
CAFFE_ENFORCE_GE(shape.size(), 1);
CAFFE_ENFORCE_EQ(shape.back(), 2);
shape.pop_back();

View File

@ -114,7 +114,7 @@ template <>
bool MakeTwoClassOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
auto shape = X.dims();
auto shape = X.dims().vec();
shape.push_back(2);
CAFFE_ENFORCE_LT(X.size(), std::numeric_limits<int>::max() / 2);
Y->Resize(shape);
@ -132,7 +132,7 @@ template <>
bool MakeTwoClassGradientOp<float, CUDAContext>::RunOnDevice() {
auto& dY = Input(0);
auto* dX = Output(0);
auto shape = dY.dims();
auto shape = dY.dims().vec();
CAFFE_ENFORCE_GE(shape.size(), 1);
CAFFE_ENFORCE_EQ(shape.back(), 2);
shape.pop_back();

View File

@ -5,7 +5,7 @@ namespace caffe2 {
namespace {
const float* getTensorDataPtr(const Tensor& tensor, int t, int n) {
const auto& dims = tensor.dims();
const auto dims = tensor.dims();
CAFFE_ENFORCE_EQ(dims.size(), 3);
int offset = (t * dims[1] + n) * dims[2];
CAFFE_ENFORCE_LT(offset, tensor.size());
@ -23,7 +23,7 @@ bool CTCBeamSearchDecoderOp<CPUContext>::RunOnDevice() {
// shape: sum over all decoded_length
auto* values = Output(VALUES);
const auto& inputs_dims = inputs.dims();
const auto inputs_dims = inputs.dims();
int32_t max_activation_length = inputs_dims[0];
int32_t batch_size = inputs_dims[1];
int32_t alphabet_size = inputs_dims[2];

View File

@ -5,7 +5,7 @@ namespace caffe2 {
namespace {
const float* getTensorDataPtr(const Tensor& tensor, int t, int n) {
const auto& dims = tensor.dims();
const auto dims = tensor.dims();
CAFFE_ENFORCE_EQ(dims.size(), 3);
int offset = (t * dims[1] + n) * dims[2];
CAFFE_ENFORCE_LT(offset, tensor.size());
@ -23,7 +23,7 @@ bool CTCGreedyDecoderOp<CPUContext>::RunOnDevice() {
// [total_decoded_output]
auto* values = Output(VALUES);
const auto& inputs_dims = inputs.dims();
const auto inputs_dims = inputs.dims();
int32_t max_time_step = inputs_dims[0];
int32_t batch_size = inputs_dims[1];
int32_t num_classes = inputs_dims[2];

View File

@ -156,7 +156,7 @@ void TreeWalker::advance() {
}
std::vector<int64_t> TreeWalker::fieldDim(int fieldId) const {
auto tensorDim = input(fieldId).dims();
auto tensorDim = input(fieldId).dims().vec();
tensorDim[0] = sizes_[lengthIdx(fieldId)];
return tensorDim;
}
@ -427,7 +427,7 @@ class UnPackRecordsOp : public Operator<CPUContext> {
CAFFE_ENFORCE_EQ(numTensors, OutputSize());
for (int i = 0; i < numTensors; ++i) {
outputDims[i] = inputZero->at(i).dims();
outputDims[i] = inputZero->at(i).dims().vec();
outputDims[i][0] = 0;
metas[i] = &inputZero->at(i).meta();
}
@ -441,7 +441,7 @@ class UnPackRecordsOp : public Operator<CPUContext> {
CAFFE_ENFORCE_EQ(numTensors, OutputSize());
for (int i = 0; i < numTensors; ++i) {
const auto& input = Input(i + 1);
outputDims[i] = input.dims();
outputDims[i] = input.dims().vec();
outputDims[i][0] = 0;
metas[i] = &input.meta();
}
@ -508,7 +508,7 @@ class ReadNextBatchOp : public Operator<CPUContext> {
auto offset = offsets[lengthIdx];
auto& in = Input(i + 1);
auto innerSize = in.size_from_dim(1);
outDim = in.dims();
outDim = in.dims().vec();
outDim[0] = size;
auto* out = Output(i);
out->Resize(outDim);
@ -674,7 +674,7 @@ class ReadRandomBatchOp : public Operator<CPUContext> {
auto& offsetsmat = Input(2);
CAFFE_ENFORCE(InputSize() == cursor->it.fields().size() + 3);
auto idxvec = idxblob.template data<int64_t>();
auto& offsetdim = offsetsmat.dims();
auto offsetdim = offsetsmat.dims();
// gather data
std::vector<int64_t> outDim;
int64_t idx;
@ -697,7 +697,7 @@ class ReadRandomBatchOp : public Operator<CPUContext> {
for (int i = 0; i < cursor->it.fields().size(); ++i) {
auto lengthIdx = cursor->it.fields()[i].lengthFieldId + 1;
auto& in = Input(i + 3);
outDim = in.dims();
outDim = in.dims().vec();
outDim.at(0) = 0;
auto idxbegin = idx;
for (int j = 0; j < batchSize_; ++j) {
@ -883,7 +883,7 @@ class ConcatTensorVectorOp final : public Operator<Context> {
auto* tensor = Output(TENSOR);
CAFFE_ENFORCE(!tensorVector->empty());
vector<int64_t> outputDims(tensorVector->at(0).dims());
vector<int64_t> outputDims(tensorVector->at(0).dims().vec());
CAFFE_ENFORCE(outputDims.size() > 0);
for (int i = 1; i < tensorVector->size(); i++) {
// the tensor shapes are the same except for the first dimension

View File

@ -304,8 +304,8 @@ template <typename DType, typename Context>
void DeformConvOpBase<DType, Context>::DeformableIm2col(
const DType* data_im,
const DType* data_offset,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
at::IntList im_shape,
at::IntList col_shape,
DType* data_col) {
CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
CAFFE_ENFORCE_EQ(pad_t(), pad_b());
@ -430,8 +430,8 @@ template <typename DType, typename Context>
void DeformConvOpBase<DType, Context>::DeformableCol2im(
const DType* data_col,
const DType* data_offset,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
at::IntList im_shape,
at::IntList col_shape,
DType* grad_im) {
CAFFE_ENFORCE_EQ(pad_t(), pad_b());
CAFFE_ENFORCE_EQ(pad_l(), pad_r());
@ -577,8 +577,8 @@ void DeformConvOpBase<DType, Context>::DeformableCol2imCoord(
const DType* data_col,
const DType* data_im,
const DType* data_offset,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
at::IntList im_shape,
at::IntList col_shape,
DType* grad_offset) {
CAFFE_ENFORCE_EQ(pad_t(), pad_b());
CAFFE_ENFORCE_EQ(pad_l(), pad_r());

View File

@ -24,21 +24,21 @@ class DeformConvOpBase : public ConvPoolOpBase<Context> {
void DeformableIm2col(
const T* data_im,
const T* data_offset,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
at::IntList im_shape,
at::IntList col_shape,
T* data_col);
void DeformableCol2im(
const T* data_col,
const T* data_offset,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
at::IntList im_shape,
at::IntList col_shape,
T* grad_im);
void DeformableCol2imCoord(
const T* data_col,
const T* data_im,
const T* data_offset,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
at::IntList im_shape,
at::IntList col_shape,
T* grad_offset);
protected:

View File

@ -151,7 +151,7 @@ bool CuDNNDropoutOp::DoRunWithType() {
if (X.dims() != cudnn_input_dims_ && !is_test_) {
CAFFE_ENFORCE(scratch_blob_);
Tensor* states = BlobGetMutableTensor(scratch_blob_, CUDA);
cudnn_input_dims_ = X.dims();
cudnn_input_dims_ = X.dims().vec();
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
data_desc_,
GetCudnnTensorFormat(StorageOrder::NCHW),
@ -244,7 +244,7 @@ bool CuDNNDropoutGradientOp::DoRunWithType() {
}
if (dY.dims() != cudnn_input_dims_) {
cudnn_input_dims_ = dY.dims();
cudnn_input_dims_ = dY.dims().vec();
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
data_desc_,
GetCudnnTensorFormat(StorageOrder::NCHW),

View File

@ -31,7 +31,7 @@ class ExpandDimsOp : public Operator<Context> {
return true;
}
auto newDims = input.dims();
auto newDims = input.dims().vec();
CAFFE_ENFORCE_GE(
input.dims().size() + dims_.size(),
dims_.back() + 1,
@ -85,7 +85,7 @@ class SqueezeOp : public Operator<Context> {
}
static std::vector<int> ComputeDims(
std::vector<int64_t> inputDims,
at::IntList inputDims,
std::vector<int> dims) {
int j = 0;
std::vector<int> newDims;

View File

@ -24,9 +24,9 @@ void batch_matmul_op_cpu_impl(
using Engine = caffe2::DefaultEngine;
auto ndims_A = A.ndim();
auto dims_A = A.dims();
auto dims_A = A.dims().vec();
auto ndims_B = B.ndim();
auto dims_B = B.dims();
auto dims_B = B.dims().vec();
auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) {
std::stringstream ss;

View File

@ -35,7 +35,7 @@ void concat_op_cpu_impl(
}
int before = 1, after = 1;
vector<int64_t> output_dims(inputs[0]->dims());
vector<int64_t> output_dims(inputs[0]->dims().vec());
for (int i = 0; i < inputs[0]->ndim(); ++i) {
if (i == canonical_axis && !add_axis) {
continue;

View File

@ -34,7 +34,7 @@ void expand_dims_op_cpu_impl(
return;
}
auto newDims = input.dims();
auto newDims = input.dims().vec();
CAFFE_ENFORCE_GE(
input.dims().size() + state->dims.size(),
state->dims.back() + 1,

View File

@ -56,7 +56,7 @@ void fc_op_cpu_impl(
CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
CAFFE_ENFORCE(N == b.size(), dimErrorString());
cache->Y_shape_cache_ = X.dims();
cache->Y_shape_cache_ = X.dims().vec();
// This is an invariant of canonical_axis, so we can DCHECK.
DCHECK_LE(canonical_axis + 1, cache->Y_shape_cache_.size());
cache->Y_shape_cache_.resize(canonical_axis + 1);

View File

@ -124,7 +124,7 @@ void uniform_fill_op_cpu_impl(
min = *inputs[1]->template data<float>();
max = *inputs[2]->template data<float>();
if (min > max) {
auto shape = output->dims();
auto shape = output->dims().vec();
shape[0] = 0;
output->Resize(shape);
output->template mutable_data<float>();

View File

@ -25,7 +25,7 @@ void sparse_lengths_sum_op_cpu_impl(
const int64_t M = lengthsInput.dim(0);
const int64_t indices_size = indicesInput.size();
auto shape = dataInput.dims();
auto shape = dataInput.dims().vec();
shape[0] = M;
output->Resize(shape);
T* out_data = output->template mutable_data<T>();

View File

@ -114,7 +114,7 @@ class UniformFillOp final : public FillerOp<Context> {
min = *Input(1).template data<T>();
max = *Input(2).template data<T>();
if (min > max) {
auto shape = output->dims();
auto shape = output->dims().vec();
shape[0] = 0;
output->Resize(shape);
output->template mutable_data<T>();

View File

@ -31,7 +31,7 @@ bool FlexibleTopKOp<T, Context>::RunOnDevice() {
// get flatten shape of input
CAFFE_ENFORCE_GT(input.ndim(), 0);
vector<int64_t> input_dims = input.dims();
vector<int64_t> input_dims = input.dims().vec();
vector<int64_t> linear_shape = {
size_to_dim_(input_dims.size() - 1, input_dims), input_dims.back()};
CAFFE_ENFORCE_EQ(
@ -107,7 +107,7 @@ bool FlexibleTopKGradientOp<T, Context>::RunOnDevice() {
// Resize output tensors to be as orignial_input size and initialized with 0
CAFFE_ENFORCE_GT(original_input.ndim(), 0);
vector<int64_t> original_dims = original_input.dims();
vector<int64_t> original_dims = original_input.dims().vec();
output->Resize(original_dims);
T* output_data = output->template mutable_data<T>();
math::Set<T, Context>(

View File

@ -69,7 +69,7 @@ class FullyConnectedOp final : public Operator<Context> {
CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
CAFFE_ENFORCE(N == b.size(), dimErrorString());
Y_shape_cache_ = X.dims();
Y_shape_cache_ = X.dims().vec();
// This is an invariant of canonical_axis, so we can DCHECK.
DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
Y_shape_cache_.resize(canonical_axis + 1);

View File

@ -34,7 +34,7 @@ bool GatherOp<CUDAContext>::DoRunWithType() {
auto* output = Output(0);
CAFFE_ENFORCE_GE(data.ndim(), 1, "DATA should be at least 1-D");
auto shape = indices.dims();
auto shape = indices.dims().vec();
shape.insert(shape.end(), data.dims().begin() + 1, data.dims().end());
output->Resize(shape);

View File

@ -26,7 +26,7 @@ class GatherOp : public Operator<Context> {
auto* output = Output(0);
CAFFE_ENFORCE_GE(data.ndim(), 1, "DATA should be at least 1-D");
auto shape = indices.dims();
auto shape = indices.dims().vec();
shape.insert(shape.end(), data.dims().begin() + 1, data.dims().end());
output->Resize(shape);

View File

@ -41,7 +41,7 @@ utils::ConstTensorView<T> GetSubTensorView(
auto st_idx = ComputeStartIndex(tensor, start_dims);
auto ptr = tensor.data<T>() + st_idx;
auto& input_dims = tensor.dims();
auto input_dims = tensor.dims();
std::vector<int> ret_dims(input_dims.begin() + 1, input_dims.end());
utils::ConstTensorView<T> ret(ptr, ret_dims);
@ -241,7 +241,7 @@ bool GenerateProposalsOp<CPUContext>::RunOnDevice() {
// bbox_deltas: (num_images, A * box_dim, H, W)
CAFFE_ENFORCE_EQ(
bbox_deltas.dims(),
(vector<int64_t>{num_images, box_dim * A, height, width}));
(at::ArrayRef<int64_t>{num_images, box_dim * A, height, width}));
// im_info_tensor: (num_images, 3), format [height, width, scale; ...]
CAFFE_ENFORCE_EQ(im_info_tensor.dims(), (vector<int64_t>{num_images, 3}));

View File

@ -252,13 +252,13 @@ bool MIOPENConvOp::DoRunWithType() {
if (input_changed || weight_changed) {
VLOG(1) << "Changing MIOpen descriptor configurations.";
if (input_changed) {
mio_input_dims_ = X.dims();
mio_input_dims_ = X.dims().vec();
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
}
if (weight_changed) {
mio_weight_dims_ = Weight.dims();
mio_weight_dims_ = Weight.dims().vec();
MIOPEN_ENFORCE(miopenInitConvolutionDescriptor(
conv_desc_,
mode_,
@ -443,13 +443,13 @@ bool MIOPENConvGradientOp::DoRunWithType() {
if (input_changed || weight_changed) {
VLOG(1) << "Changing MIOpen descriptor configurations.";
if (input_changed) {
mio_input_dims_ = X.dims();
mio_input_dims_ = X.dims().vec();
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
}
if (weight_changed) {
mio_weight_dims_ = Weight.dims();
mio_weight_dims_ = Weight.dims().vec();
MIOPEN_ENFORCE(miopenInitConvolutionDescriptor(
conv_desc_,
mode_,

View File

@ -124,7 +124,7 @@ bool MIOPEN_LRNOP::DoRunWithType() {
// Reshape tensor descriptors if necessary
if (X.dims() != miopen_input_dims_) {
VLOG(1) << "Setting descriptors";
miopen_input_dims_ = X.dims();
miopen_input_dims_ = X.dims().vec();
int C = 1, H = 1, W = 1;
// Normal 4-dimensional tensors for images.
C = X.dim32(1);
@ -173,7 +173,7 @@ bool MIOPENLRNGradientOp::DoRunWithType() {
if (dY.dims() != miopen_input_dims_) {
VLOG(1) << "Setting descriptors";
miopen_input_dims_ = dY.dims();
miopen_input_dims_ = dY.dims().vec();
int C = 1, H = 1, W = 1;
// Normal 4-dimensional tensors for images.
C = dY.dim32(1);

View File

@ -54,7 +54,7 @@ class MIOPENReluOp final : public Operator<HIPContext> {
// See if we need to reshape.
if (X.dims() != miopen_input_dims_) {
VLOG(1) << "Setting descriptors.";
miopen_input_dims_ = X.dims();
miopen_input_dims_ = X.dims().vec();
int C = 1, H = 1, W = 1;
if (X.ndim() == 4) {
// Normal 4-dimensional tensors for images.
@ -144,7 +144,7 @@ class MIOPENReluGradientOp final : public Operator<HIPContext> {
// See if we need to reshape.
if (Y.dims() != miopen_input_dims_) {
VLOG(1) << "Setting descriptors.";
miopen_input_dims_ = Y.dims();
miopen_input_dims_ = Y.dims().vec();
int C = 1, H = 1, W = 1;
if (Y.ndim() == 4) {
// Normal 4-dimensional tensors for images.

View File

@ -51,7 +51,7 @@ class MIOpenSoftmaxOp final : public Operator<HIPContext> {
if (dims_ != X.dims()) {
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
desc_, miopenTypeWrapper<T>::type, N, D, 1, 1));
dims_ = X.dims();
dims_ = X.dims().vec();
}
MIOPEN_ENFORCE(miopenSoftmaxForward(
miopen_wrapper_.inline_miopen_handle(),
@ -110,7 +110,7 @@ class MIOpenSoftmaxGradientOp final : public Operator<HIPContext> {
if (dims_ != Y.dims()) {
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
desc_, miopenTypeWrapper<T>::type, N, D, 1, 1));
dims_ = Y.dims();
dims_ = Y.dims().vec();
}
MIOPEN_ENFORCE(miopenSoftmaxBackward(
miopen_wrapper_.inline_miopen_handle(),

View File

@ -139,7 +139,7 @@ bool MIOpenSpatialBNOp::DoRunWithType() {
// See if we need to reshape.
if (N > 0 && X.dims() != miopen_input_dims_) {
VLOG(1) << "Setting descriptors.";
miopen_input_dims_ = X.dims();
miopen_input_dims_ = X.dims().vec();
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
data_desc_, miopenTypeWrapper<T>::type, N, C, H, W));
@ -273,7 +273,7 @@ bool MIOpenSpatialBNGradientOp::DoRunWithType() {
CAFFE_ENFORCE_EQ(scale.dim32(0), C);
// See if we need to reshape.
if (N > 0 && X.dims() != miopen_input_dims_) {
miopen_input_dims_ = X.dims();
miopen_input_dims_ = X.dims().vec();
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
data_desc_, miopenTypeWrapper<T>::type, N, C, H, W));

View File

@ -19,7 +19,7 @@ bool IntegralImageOp<float, CPUContext>::RunOnDevice() {
auto* Y = Output(0);
CAFFE_ENFORCE_EQ(X.ndim(), 4, "Only supports 4D tensors for the momement");
vector<int64_t> out_shape(X.dims());
vector<int64_t> out_shape(X.dims().vec());
out_shape[2] += 1; // H + 1 output size
out_shape[3] += 1; // W + 1 output size
Y->Resize(out_shape);

View File

@ -124,7 +124,7 @@ bool IntegralImageOp<float, CUDAContext>::RunOnDevice() {
// Input is (N, C, H, W)
// Output is (N, C, H + 1, W + 1)
vector<int64_t> out_shape(X.dims());
vector<int64_t> out_shape(X.dims().vec());
out_shape[2] += 1; // H + 1 output size
out_shape[3] += 1; // W + 1 output size
Y->Resize(out_shape);
@ -172,7 +172,7 @@ bool IntegralImageGradientOp<float, CUDAContext>::RunOnDevice() {
// Row pass reduces shape of dY from (N, C, H + 1, W + 1)
// to (N, C, H + 1, W)
// Col pass reduces shape to (N, C, H, W)
vector<int64_t> row_pass_shape(dY.dims());
vector<int64_t> row_pass_shape(dY.dims().vec());
row_pass_shape[3] -= 1;
row_pass_buffer_.Resize(row_pass_shape);
const int chans = row_pass_buffer_.dim32(1);

View File

@ -60,7 +60,7 @@ class LastNWindowCollectorOp : public Operator<Context> {
}
if (!output_initialized) {
auto dims = input.dims();
auto dims = input.dims().vec();
dims[0] = 0;
output->Resize(dims);
// pass meta to output

View File

@ -106,7 +106,7 @@ bool LayerNormOp<CUDAContext>::DoRunWithType<float>() {
segs.begin(),
std::bind1st(std::multiplies<int>(), right));
seg_indices_.Resize(vector<size_t>{segs.size()});
seg_indices_.Resize(at::IntList{static_cast<int64_t>(segs.size())});
context_.CopyBytesFromCPU(
sizeof(int) * segs.size(),
static_cast<void*>(segs.data()),
@ -261,7 +261,7 @@ bool LayerNormGradientOp<CUDAContext>::DoRunWithType<float>() {
stats_dims.push_back(1);
dmean_.Resize(stats_dims);
dstdev_.Resize(stats_dims);
gscratch_.Resize(std::vector<size_t>{left, right});
gscratch_.Resize(at::IntList{static_cast<int64_t>(left), static_cast<int64_t>(right)});
std::vector<int> segs(left + 1);
std::iota(segs.begin(), segs.end(), 0);

View File

@ -46,7 +46,7 @@ class LengthsPadOp : public Operator<Context> {
CAFFE_ENFORCE_EQ(total_length, data.dim(0));
auto shape = data.dims();
auto shape = data.dims().vec();
shape[0] = lengths_size * target_length_;
output->Resize(shape);

View File

@ -53,7 +53,7 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
const int64_t indices_size = indicesInput.size();
auto* output = Output(0);
auto shape = dataInput.dims();
auto shape = dataInput.dims().vec();
shape[0] = M;
output->Resize(shape);
T* out_data = output->template mutable_data<T>();

View File

@ -57,7 +57,7 @@ class SparseLengths8BitsRowwiseOp : public Operator<Context> {
int64_t dataToReduceSize = indicesInput.dim(0);
const int* lengths = lengthsInput.template data<int>();
vector<int64_t> shape = dataInput.dims();
vector<int64_t> shape = dataInput.dims().vec();
shape[0] = outputSize;
output->Resize(shape);
const float* w = nullptr;

View File

@ -25,7 +25,7 @@ bool LengthsTileOp<CPUContext>::RunOnDevice() {
math::Sum<int32_t, CPUContext>(
lengths_size, lengths_data, &total_length, &cpuContext);
auto shape = data.dims();
auto shape = data.dims().vec();
shape[0] = total_length;
output->Resize(shape);

View File

@ -38,7 +38,7 @@ bool LengthsTileOp<CUDAContext>::RunOnDevice() {
math::Sum<int32_t, CPUContext>(
lengths_size, lengths_data, &total_length, &cpuContext);
auto shape = data.dims();
auto shape = data.dims().vec();
shape[0] = total_length;
output->Resize(shape);

View File

@ -99,7 +99,7 @@ bool CuDNNLRNOp::DoRunWithType() {
// Reshape tensor descriptors if necessary
if (X.dims() != cudnn_input_dims_) {
VLOG(1) << "Setting descriptors";
cudnn_input_dims_ = X.dims();
cudnn_input_dims_ = X.dims().vec();
int C = 1, H = 1, W = 1;
// Normal 4-dimensional tensors for images.
C = X.dim32(1);
@ -155,7 +155,7 @@ bool CuDNNLRNGradientOp::DoRunWithType() {
if (dY.dims() != cudnn_input_dims_) {
VLOG(1) << "Setting descriptors";
cudnn_input_dims_ = dY.dims();
cudnn_input_dims_ = dY.dims().vec();
int C = 1, H = 1, W = 1;
// Normal 4-dimensional tensors for images.
C = dY.dim32(1);

View File

@ -39,7 +39,7 @@ class NumpyTileOp : public Operator<Context> {
// output tensor.
Tensor *src = &buffer, *dst = output;
src->CopyFrom(input);
vector<int64_t> output_dims(input.dims());
vector<int64_t> output_dims(input.dims().vec());
for (size_t i = 0; i < repeats.size(); ++i) {
if (repeats_data[i] == 1) {
continue;

View File

@ -159,13 +159,13 @@ class ONNXWhileOp final : public Operator<Context> {
->template Get<Tensor>();
auto* scan_output_target = Output(i + num_loop_carried_deps);
if (itr == 0) {
auto dims = scan_output.dims();
auto dims = scan_output.dims().vec();
scan_outputs_sizes.push_back(dims);
dims.insert(dims.begin(), 1);
scan_output_target->Resize(dims);
scan_output_target->CopyFrom(scan_output);
} else {
auto dims = scan_output.dims();
auto dims = scan_output.dims().vec();
CAFFE_ENFORCE_EQ(
dims,
scan_outputs_sizes[i],

View File

@ -36,7 +36,7 @@ void BlobToTensorDescriptor(
}
// Set dims
const auto& shape = cpu_tensor.dims();
const auto shape = cpu_tensor.dims();
desc->dimensions = shape.size();
shapes->emplace_back(shape.cbegin(), shape.cend());
desc->shape = shapes->back().data();
@ -76,7 +76,7 @@ template <>
bool OnnxifiOp<float, CPUContext>::RunOnDevice() {
for (unsigned i = 0U; i < InputSize(); ++i) {
const auto& input_tensor = Input(i);
const auto& tensor_dims = input_tensor.dims();
const auto tensor_dims = input_tensor.dims();
auto& tensor_descriptor = input_desc_.at(i);
tensor_descriptor.tag = ONNXIFI_TAG_TENSOR_DESCRIPTOR_V1;
tensor_descriptor.dataType = ONNXIFI_DATATYPE_FLOAT32;

View File

@ -50,7 +50,7 @@ bool PackSegmentsOp<CPUContext>::DoRunWithType2() {
" is equal to the first data dimension ",
data.dim(0));
auto shape = data.dims(); // Shape of output is batch_size x max_len x ...
auto shape = data.dims().vec(); // Shape of output is batch_size x max_len x ...
shape[0] = max_length;
shape.insert(shape.begin(), lengths.size());
output->Resize(shape);
@ -129,7 +129,7 @@ bool UnpackSegmentsOp<CPUContext>::DoRunWithType2() {
int64_t total_l = std::accumulate(l, l + lengths.dim(0), (int64_t)0);
auto shape = data.dims();
auto shape = data.dims().vec();
CAFFE_ENFORCE_EQ(
shape[0], lengths.dim(0), "LENGTH should match DATA in dimension 0");
shape.erase(shape.begin());

View File

@ -208,7 +208,7 @@ bool PackSegmentsOp<CUDAContext>::DoRunWithType2() {
lengths_ptr, num_seq, dev_buffer_, dev_lengths_prefix_sum_, context_);
// create output tensor
auto shape = data.dims(); // Shape of out is batch_size x max_len x ...
auto shape = data.dims().vec(); // Shape of out is batch_size x max_len x ...
shape[0] = max_length;
shape.insert(shape.begin(), lengths.size());
out->Resize(shape);
@ -290,7 +290,7 @@ bool UnpackSegmentsOp<CUDAContext>::DoRunWithType2() {
context_);
// create output tensor
auto shape = data.dims();
auto shape = data.dims().vec();
CAFFE_ENFORCE_EQ(
shape[0], lengths.dim(0), "LENGTH should match DATA in dimension 0");
shape.erase(shape.begin());

View File

@ -41,7 +41,7 @@ class GatherByKeyOp : public Operator<CPUContext> {
const auto& in0Shape = Input(1).dims();
CAFFE_ENFORCE_GE(in0Shape.size(), 1);
vector<int64_t> outShape(keysShape);
vector<int64_t> outShape(keysShape.vec());
outShape.insert(outShape.end(), in0Shape.begin() + 1, in0Shape.end());
CAFFE_ENFORCE_GE(outShape.size(), 1);

View File

@ -220,7 +220,7 @@ class CuDNNPoolOp : public ConvPoolOpBase<CUDAContext> {
if (cudnn_input_dims_ != X.dims()) {
// Dimensions changed; we will need to re-initialize things.
VLOG(1) << "Changing the cudnn descriptor configurations.";
cudnn_input_dims_ = X.dims();
cudnn_input_dims_ = X.dims().vec();
setTensorDescriptor<T>(X.ndim(), order_, N, C, H, W, D, bottom_desc_);
setTensorDescriptor<T>(
Y->ndim(), order_, N, C, H_out, W_out, D_out, top_desc_);
@ -423,7 +423,7 @@ class CuDNNPoolGradientOp : public ConvPoolOpBase<CUDAContext> {
if (cudnn_input_dims_ != X.dims()) {
// Dimensions changed; we will need to re-initialize things.
VLOG(1) << "Changing the cudnn descriptor configurations.";
cudnn_input_dims_ = X.dims();
cudnn_input_dims_ = X.dims().vec();
setTensorDescriptor<T>(X.ndim(), order_, N, C, H, W, D, bottom_desc_);
setTensorDescriptor<T>(
Y.ndim(), order_, N, C, H_out, W_out, D_out, top_desc_);

View File

@ -335,7 +335,7 @@ class BaseReducer {
explicit Meta(bool first = true) : first_dim(first) {}
void computeMeta(const std::vector<int64_t>& dims, int skip_dims) {
void computeMeta(at::IntList dims, int skip_dims) {
first_dim ? block_shape.assign(dims.begin() + skip_dims, dims.end())
: block_shape.assign(dims.begin(), dims.end() - skip_dims);
block_size = first_dim ? size_from_dim_(skip_dims, dims)
@ -344,7 +344,7 @@ class BaseReducer {
void observeInput(int input, const Tensor& value, int skip_dims) {
DCHECK_EQ(0, input);
auto& dims = value.dims();
auto dims = value.dims();
computeMeta(dims, skip_dims);
}
@ -395,7 +395,7 @@ class BaseReducerGradient {
Meta(const Tensor& out_grad, int skip_dims, bool first_dim = true)
: first_dim(first_dim) {
auto& dims = out_grad.dims();
auto dims = out_grad.dims();
first_dim ? block_shape.assign(dims.begin() + skip_dims, dims.end())
: block_shape.assign(dims.begin(), dims.end() - skip_dims);
block_size = first_dim

View File

@ -53,7 +53,7 @@ class RemoveDataBlocksOp final : public Operator<Context> {
indices_size = ind_vec.size();
auto* output = Output(0);
auto shape = data.dims();
auto shape = data.dims().vec();
shape[0] -= indices_size;
output->Resize(shape);
char* out_ptr = (char*)output->raw_mutable_data(data.meta());

View File

@ -44,7 +44,7 @@ class ReservoirSamplingOp final : public Operator<Context> {
if (!output_initialized) {
// IMPORTANT: Force the output to have the right type before reserving,
// so that the output gets the right capacity
auto dims = input.dims();
auto dims = input.dims().vec();
dims[0] = 0;
output->Resize(dims);
output->raw_mutable_data(input.meta());

View File

@ -99,7 +99,7 @@ bool ResizeNearestGradientOp<float, CPUContext>::RunOnDevice() {
const auto& X = Input(1);
auto* dX = Output(0);
const auto& inputDims = dY.dims();
const auto inputDims = dY.dims();
CAFFE_ENFORCE_EQ(4, inputDims.size());
const int batch_size = dY.dim32(0),
num_channels = dY.dim32(1),

View File

@ -75,7 +75,7 @@ bool ResizeNearestOp<float, CUDAContext>::RunOnDevice() {
const auto& X = Input(0);
auto* Y = Output(0);
const auto& inputDims = X.dims();
const auto inputDims = X.dims();
CAFFE_ENFORCE_EQ(4, inputDims.size());
const int batch_size = X.dim32(0), num_channels = X.dim32(1),
input_height = X.dim32(2), input_width = X.dim32(3);
@ -109,7 +109,7 @@ bool ResizeNearestGradientOp<float, CUDAContext>::RunOnDevice() {
const auto& X = Input(1);
auto* dX = Output(0);
const auto& inputDims = dY.dims();
const auto inputDims = dY.dims();
CAFFE_ENFORCE_EQ(4, inputDims.size());
const int batch_size = dY.dim32(0), num_channels = dY.dim32(1),
input_height = dY.dim32(2), input_width = dY.dim32(3);

View File

@ -59,12 +59,12 @@ void ReversePackedSegsOp<CUDAContext>::DoRunWithLengthType() {
CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
auto* output = Output(0);
const auto& shape = data.dims();
const auto shape = data.dims();
output->Resize(shape);
const auto& max_length = data.dims()[0];
const auto& batch_size = data.dims()[1];
const auto& block_size = data.dims()[2];
const auto max_length = data.dims()[0];
const auto batch_size = data.dims()[1];
const auto block_size = data.dims()[2];
CAFFE_ENFORCE(
lengths.dims()[0] == batch_size,
"lenths size should be"

View File

@ -43,12 +43,12 @@ class ReversePackedSegsOp final : public Operator<Context> {
CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
auto* output = Output(0);
const auto& shape = data.dims();
const auto shape = data.dims();
output->Resize(shape);
const auto& max_length = data.dims()[0];
const auto& batch_size = data.dims()[1];
const auto& block_size = data.dims()[2];
const auto max_length = data.dims()[0];
const auto batch_size = data.dims()[1];
const auto block_size = data.dims()[2];
CAFFE_ENFORCE(
lengths.dims()[0] == batch_size,
"lenths size should be"

View File

@ -178,7 +178,7 @@ bool RecurrentOp<T>::RunOnDevice() {
Output(OUTPUT),
Output(HIDDEN_OUTPUT),
Output(CELL_OUTPUT));
cachedInputDims_ = Input(INPUT).dims();
cachedInputDims_ = Input(INPUT).dims().vec();
}
// Validation checks
@ -266,7 +266,7 @@ bool RecurrentGradientOp<T>::RunOnDevice() {
const int seqLength = Input(INPUT).dim32(0);
if (Input(INPUT).dims() != cachedInputDims_) {
initialize(Input(INPUT));
cachedInputDims_ = Input(INPUT).dims();
cachedInputDims_ = Input(INPUT).dims().vec();
}
MIOPEN_ENFORCE(miopenGetRNNTrainingReserveSize(
miopen_wrapper_.inline_miopen_handle(),

View File

@ -76,7 +76,7 @@ void applyOffsetAlias(
auto* dst =
BlobGetMutableTensor(ws->GetBlob(oc.dst), Context::GetDeviceType());
auto timestep = src->size() / src->dim(0);
auto dims = src->dims();
auto dims = src->dims().vec();
const int32_t startDstTimestep =
oc.offset >= 0 ? oc.offset : src->dim(0) + oc.offset;
const int32_t numDstTimesteps = src->dim(0) - startDstTimestep;
@ -905,7 +905,7 @@ class RNNApplyLinkOp : public Operator<Context> {
const int64_t externalTimestepSize = external.size() / external.dim(0);
auto* externalData = external_out->template mutable_data<T>() +
(t + offset_) * externalTimestepSize;
auto internalDims = external_out->dims();
auto internalDims = external_out->dims().vec();
internalDims[0] = window_;
internal_out->Resize(internalDims);

View File

@ -226,7 +226,7 @@ bool RecurrentOp<T>::RunOnDevice() {
Output(OUTPUT),
Output(HIDDEN_OUTPUT),
Output(CELL_OUTPUT));
cachedInputDims_ = Input(INPUT).dims();
cachedInputDims_ = Input(INPUT).dims().vec();
}
// Validation checks
@ -314,7 +314,7 @@ bool RecurrentGradientOp<T>::RunOnDevice() {
const int seqLength = Input(INPUT).dim32(0);
if (Input(INPUT).dims() != cachedInputDims_) {
initialize(Input(INPUT), Output(DROPOUT_STATES));
cachedInputDims_ = Input(INPUT).dims();
cachedInputDims_ = Input(INPUT).dims().vec();
}
CUDNN_ENFORCE(cudnnGetRNNTrainingReserveSize(
cudnn_wrapper_.inline_cudnn_handle(),

View File

@ -72,7 +72,7 @@ class AbstractSortedSegmentRangeOp : public Operator<Context> {
const SIndex* s_ids = segment_ids.template data<SIndex>();
const SIndex K = N > 0 ? s_ids[N - 1] + 1 : 0;
auto shape = dataInput.dims();
auto shape = dataInput.dims().vec();
shape[0] = K;
output->Resize(shape);
@ -142,7 +142,7 @@ class AbstractSortedSegmentRangeGradientOp : public Operator<Context> {
const T* d_in = data_in.template data<T>();
const T* d_out = data_out.template data<T>();
auto shape = segment_grads.dims();
auto shape = segment_grads.dims().vec();
shape[0] = N;
data_grads->Resize(shape);

View File

@ -439,7 +439,7 @@ class CUDASparseLengthsSumOp : public Operator<CUDAContext> {
const int64_t outputSize = lengthsInput.dim(0);
const int len_length = outputSize;
auto shape = dataInput.dims();
auto shape = dataInput.dims().vec();
shape[0] = outputSize;
output->Resize(shape);
T* out_data = output->template mutable_data<T>();
@ -560,7 +560,7 @@ class CUDASparseLengthsMeanOp : public Operator<CUDAContext> {
const int64_t outputSize = lengthsInput.dim(0);
const int len_length = outputSize;
auto shape = dataInput.dims();
auto shape = dataInput.dims().vec();
shape[0] = outputSize;
output->Resize(shape);
T* out_data = output->template mutable_data<T>();
@ -682,7 +682,7 @@ class CUDASparseLengthsMaxOp : public Operator<CUDAContext> {
const int64_t outputSize = lengthsInput.dim(0);
int len_length = outputSize;
auto shape = dataInput.dims();
auto shape = dataInput.dims().vec();
shape[0] = outputSize;
output->Resize(shape);
@ -816,7 +816,7 @@ class CUDASparseLengthsWeightedSumOp : public Operator<CUDAContext> {
const int64_t outputSize = lengthsInput.dim(0);
const int len_length = outputSize;
auto shape = dataInput.dims();
auto shape = dataInput.dims().vec();
shape[0] = outputSize;
output->Resize(shape);
T* out_data = output->template mutable_data<T>();
@ -944,7 +944,7 @@ class CUDAUnsortedSegmentSumOp : public Operator<CUDAContext> {
if (segment_ids.size() == 0 || data.size() == 0) {
// Special handling for empty input
auto dims = data.dims();
auto dims = data.dims().vec();
if (dims.size() > 0) {
dims[0] = 0;
}
@ -993,7 +993,7 @@ class CUDAUnsortedSegmentSumOp : public Operator<CUDAContext> {
sizeof(SIndex), K_tensor_.template data<SIndex>(), &K);
context_.FinishDeviceComputation();
auto dims = data.dims();
auto dims = data.dims().vec();
dims[0] = K + 1;
output->Resize(dims);
@ -1096,7 +1096,7 @@ class SortedSegmentRangeMeanOp : public Operator<Context> {
int M = input.dim32(0);
int N = input.size_from_dim(1);
auto* output = Output(0);
auto dims = input.dims();
auto dims = input.dims().vec();
SIndex K = 0;
context_.CopyBytesToCPU(
sizeof(SIndex),
@ -1307,7 +1307,7 @@ class CUDASparseLengthsSumGradientWithIndicesOp : public Operator<CUDAContext> {
CAFFE_ENFORCE(segmentGradsInput.ndim() > 0);
CAFFE_ENFORCE(len_length == segmentGradsInput.dim(0));
auto shape = segmentGradsInput.dims();
auto shape = segmentGradsInput.dims().vec();
int output_0dim = indicesInput.dim(0);
shape[0] = output_0dim;
dataGradsOutput->Resize(shape);
@ -1386,7 +1386,7 @@ class CUDASparseLengthsMeanGradientWithIndicesOp
CAFFE_ENFORCE(segmentGradsInput.ndim() > 0);
CAFFE_ENFORCE(len_length == segmentGradsInput.dim(0));
auto shape = segmentGradsInput.dims();
auto shape = segmentGradsInput.dims().vec();
int output_0dim = indicesInput.dim(0);
shape[0] = output_0dim;
dataGradsOutput->Resize(shape);
@ -1467,7 +1467,7 @@ class CUDASparseLengthsWeightedSumGradientWithIndicesOp
CAFFE_ENFORCE(segmentGradsInput.ndim() > 0);
CAFFE_ENFORCE(len_length == segmentGradsInput.dim(0));
auto shape = segmentGradsInput.dims();
auto shape = segmentGradsInput.dims().vec();
int output_0dim = indicesInput.dim(0);
shape[0] = output_0dim;
dataGradsOutput->Resize(shape);
@ -1615,7 +1615,7 @@ class CUDALengthsMaxWithMainInputAndForwardOutputGradientOp
auto* prefix_sum_length_data =
inclusive_scan_length_buffer_.template data<int>();
auto shape = dataInput.dims();
auto shape = dataInput.dims().vec();
dataGradsOutput->Resize(shape);
const T* in_data = segmentGradsInput.template data<T>();
@ -1701,7 +1701,7 @@ class CUDASparseLengthsIndicesInGradientWeightedSumWithMainInputGradientOp
CAFFE_ENFORCE(segmentGradsInput.ndim() > 0);
CAFFE_ENFORCE(len_length == segmentGradsInput.dim(0));
auto shape = segmentGradsInput.dims();
auto shape = segmentGradsInput.dims().vec();
int output_0dim = indicesInput.dim(0);
shape[0] = output_0dim;
dataGradsOutput->Resize(shape);

View File

@ -68,7 +68,7 @@ bool RemovePaddingOp<CPUContext>::DoRunWithType() {
auto* out = Output(0);
{
auto out_dims = in.dims();
auto out_dims = in.dims().vec();
out_dims[0] -= pad_width * lengths_size;
out->Resize(std::move(out_dims));
}
@ -196,7 +196,7 @@ bool PadEmptySamplesOp<CPUContext>::RunOnDevice() {
const auto block_size = features.size_from_dim(1);
auto* out_features = Output(1 + k);
auto outDim = features.dims();
auto outDim = features.dims().vec();
outDim.at(0) += needPadding;
out_features->Resize(outDim);
auto dst =

View File

@ -250,7 +250,7 @@ bool RemovePaddingOp<CUDAContext>::DoRunWithType() {
auto* out = Output(0);
{
auto out_dims = in.dims();
auto out_dims = in.dims().vec();
out_dims[0] -= (startPaddingWidth_ + endPaddingWidth_) * lengths_size;
out->Resize(std::move(out_dims));
}

View File

@ -202,7 +202,7 @@ class AddPaddingOp final : public Operator<Context> {
auto* out = Output(0);
{
auto out_dims = in.dims();
auto out_dims = in.dims().vec();
out_dims[0] += (startPaddingWidth_ + endPaddingWidth_) * lengths_size;
out->Resize(std::move(out_dims));
}

View File

@ -38,7 +38,7 @@ class SinusoidPositionEncodingOp : public Operator<Context> {
CAFFE_ENFORCE_EQ(positions.ndim(), 2, "POSITIONS should be a 2-D tensor");
auto shape = positions.dims();
auto shape = positions.dims().vec();
shape.push_back(embedding_size_);
output->Resize(shape);

View File

@ -48,7 +48,7 @@ class CuDNNSoftmaxOp final : public Operator<CUDAContext> {
D,
1,
1));
dims_ = X.dims();
dims_ = X.dims().vec();
}
CUDNN_ENFORCE(cudnnSoftmaxForward(
cudnn_wrapper_.inline_cudnn_handle(),
@ -112,7 +112,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<CUDAContext> {
D,
1,
1));
dims_ = Y.dims();
dims_ = Y.dims().vec();
}
CUDNN_ENFORCE(cudnnSoftmaxBackward(
cudnn_wrapper_.inline_cudnn_handle(),

View File

@ -45,7 +45,7 @@ namespace caffe2 {
const int output_first_dim =
GetOutputFirstDim(sparse_indices_vec, sparse_indices_len);
auto shape = sparse_values.dims();
auto shape = sparse_values.dims().vec();
shape[0] = output_first_dim;
auto* output = Output(0);
output->Resize(shape);

View File

@ -75,7 +75,7 @@ class SparseToDenseOp final : public Operator<Context> {
const int output_first_dim =
GetOutputFirstDim(sparse_indices_vec, sparse_indices_len);
auto shape = sparse_values.dims();
auto shape = sparse_values.dims().vec();
shape[0] = output_first_dim;
auto* output = Output(0);
output->Resize(shape);

View File

@ -72,7 +72,7 @@ class TileOp : public Operator<Context> {
const auto axis = input.canonical_axis_index(axis_);
// reshape output to be input tiled along the axis
vector<int64_t> output_dims(input.dims());
vector<int64_t> output_dims(input.dims().vec());
output_dims[axis_] = output_dims[axis_] * tiles_;
output->Resize(output_dims);
@ -187,7 +187,7 @@ class TileGradientOp : public Operator<Context> {
const auto axis = input.canonical_axis_index(axis_);
// reshape output to be input "untiled" along the axis
vector<int64_t> output_dims(input.dims());
vector<int64_t> output_dims(input.dims().vec());
output_dims[axis_] = output_dims[axis_] / tiles_;
output->Resize(output_dims);

Some files were not shown because too many files have changed in this diff Show More