mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12180 I had to fix a lot of call sites, because a lot of places assume that you can actually get a const vector&, and if the internal representation of sizes in a tensor is NOT a vector, it's not possible to fulfill this API contract. Framework changes: - I deleted TensorImpl::dims(); caffe2::Tensor::dims() just forwards to sizes() now. - De-templatized SetDims; now it is an explicit list of ArrayRef and variadic overloads. This makes implicit conversions work again, so I don't need to explicitly list the std::vector cases too. - As a knock-on effect, this causes Reset() to accept at::IntList as well as const std::vector<int64_t>& - Edited variadic overloads of SetDims to all forward to the underlying arbitrary-dim implementation, reducing code duplication. (It's probably marginally less efficient in the new world.) - Replace Tensor constructor accepting const std::vector<int64_t>& with at::IntList - Make MKLTensor accept ArrayRef along with vector in constructor and Reset (unfortunately, no implicit conversions here, since it's templated on index type.) - There are a few other places, like cudnn, where I changed functions that previously took const std::vector<int64_t>& to take at::IntList instead. Classification of call site changes: - 'const std::vector<int64_t>& x_dims = x.dims()' ==> 'at::IntList x_dims = x.dims()' - 'std::vector<int64_t> x_dims = x.dims()' ==> 'std::vector<int64_t> x_dims = x.dims().vec()' (we need a copy!) Usually this is because we're about to mutably modify the vector to compute some new dimension. However, it also very commonly occurs in the form: 'x_dims_ = x.dims()' because we frequently cache sizes in operators. - Instead of constructing std::vector<int64_t>{blah, blah}, construct an at::IntList directly ArrayRef changes: - cbegin()/cend() iterators, they operate the same aas begin()/end() because everything on ArrayRef is const. - Moved operator<< into ArrayRef.h, so that it's always available when working with ArrayRef. I also templated it, so it now works on an ArrayRef of any type. - Add operator== overload for ArrayRef, and also add variants to permit comparison of ArrayRef with std::vector, a very common operation. (The non-templated version of operator== can get these automatically via implicit conversion, but with templates C++ refuses to do any explicit conversions.) I'm planning to audit all dims() call sites to make sure they don't expect 'auto x = t.dims()' to give you an x whose lifetime can validly outlive the tensor. I opted not to do a dims() to sizes() rename, because dims() also matches the protobufs accessor. Bad news! Reviewed By: jerryzh168 Differential Revision: D10111759 fbshipit-source-id: a2a81dc4b92c22ad4b3b8ef4077a7e97b6479452