diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index eda688ad6e1d..784dd2927fba 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -753,8 +753,89 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_ return newTensor._coalesced_(self.is_coalesced()); } +Tensor& narrow_copy_dense_out( + Tensor& output, const Tensor& self, int64_t dim, int64_t start, int64_t length +) { + if (self.is_cuda()) { + return output.copy_(self.narrow(dim, start, length)); + } + TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(self.dtype() == output.dtype()); + + Tensor self_contig = self.contiguous(); + const auto self_sizes = self_contig.sizes(); + + // wrap dim if negative and do bound check + if (dim < 0) { + dim = at::maybe_wrap_dim(dim, self_sizes.size()); + } else { + TORCH_CHECK(dim < self_sizes.size()); + } + + // wrap start and do bound check + const auto cur_size = self_sizes[dim]; + if (start != cur_size && start < 0) { // start being the end is valid, but + // not a valid dim specification. + start = at::maybe_wrap_dim(start, cur_size); + } + TORCH_CHECK( + length >= 0 && start <= cur_size - length, + "start (", + start, + ") + length (", + length, + ") exceeds dimension size (", + cur_size, + ")."); + + // resize output + auto output_sizes = self_sizes.vec(); + output_sizes[dim] = length; + at::native::resize_(output, output_sizes); + + const int64_t unit = c10::size_from_dim_(dim + 1, self_sizes); + const int64_t num_blocks = c10::size_to_dim_(dim, self_sizes); + + const auto itemsize = self_contig.dtype().itemsize(); + size_t src_nbytes = itemsize * self_contig.numel(); + size_t dst_nbytes = itemsize * output.numel(); + + size_t src_block_size = unit * self_sizes[dim]; + size_t dst_block_size = unit * length; + + if (num_blocks == 0 || dst_block_size == 0) { + return output; + } + + char* src_bytes = static_cast(self_contig.data_ptr()); + char* dst_bytes = static_cast(output.data_ptr()); + + size_t src_block_size_bytes = itemsize * src_block_size; + size_t dst_block_size_bytes = itemsize * dst_block_size; + size_t src_offset = unit * start; + + char* src_offset_bytes = src_bytes + itemsize * src_offset; + char* dst_offset_bytes = dst_bytes; + + for (size_t i = 0; i < num_blocks; ++i) { + char* local_src_offset_bytes = src_offset_bytes + i * src_block_size_bytes; + char* local_dst_offset_bytes = dst_offset_bytes + i * dst_block_size_bytes; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + static_cast(local_src_offset_bytes + dst_block_size_bytes) <= + static_cast(src_bytes + src_nbytes)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + static_cast(local_dst_offset_bytes + dst_block_size_bytes) <= + static_cast(dst_bytes + dst_nbytes)); + + memcpy( + local_dst_offset_bytes, local_src_offset_bytes, dst_block_size_bytes); + } + return output; +} + Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length){ - return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); + auto output = at::empty_like(self); + return narrow_copy_dense_out(output, self, dim, start, length); } Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 715fdccc9691..8885e06e9ef6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3065,11 +3065,15 @@ - func: narrow_copy(Tensor self, int dim, int start, int length) -> Tensor use_c10_dispatcher: full - variants: method + variants: function, method dispatch: CPU, CUDA: narrow_copy_dense SparseCPU, SparseCUDA: narrow_copy_sparse +- func: narrow_copy.out(Tensor self, int dim, int start, int length, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: narrow_copy_dense_out + - func: narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a) use_c10_dispatcher: full variants: function, method diff --git a/aten/src/ATen/test/native_test.cpp b/aten/src/ATen/test/native_test.cpp index b32a0b081042..4c53fd6b6620 100644 --- a/aten/src/ATen/test/native_test.cpp +++ b/aten/src/ATen/test/native_test.cpp @@ -64,6 +64,16 @@ void TestStack(TensorOptions T, Tensor& t) { } } +void TestNarrow(TensorOptions T, Tensor& t) { + auto x = rand({5, 8, 3}); + for (int64_t dim = 0; dim < 3; ++dim) { + const int64_t start = 1, length = 2; + auto y_ref = x.narrow(dim, start, length); + auto y_test = at::native::narrow_copy_dense(x, dim, start, length); + ASSERT_EQUAL(y_ref, y_test); + } +} + // size / stride void TestSize(TensorOptions T, Tensor& t) { auto scalar = randn({}, T); @@ -199,6 +209,7 @@ void test(TensorOptions T, TensorOptions AccT) { TestSplit(T, t); TestChunk(T, t); TestStack(T, t); + TestNarrow(T, t); TestSize(T, t); TestMatmul(T, t, AccT); TestStandardGammaGrad(T, t); diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 11fb5dae2d6c..4d38b8b0a97d 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -31,7 +31,6 @@ bool canRunNatively(Node* n) { // In alphabetical order const static std::unordered_set native_nodes{ "aten::flatten", - "aten::narrow", "aten::reshape", "aten::slice", "aten::transpose", @@ -303,6 +302,29 @@ REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator { }; }); +// The out variant takes precedence over native +REGISTER_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator { + return [](const ProcessedNode* p_node, std::vector& reg) { + auto self = p_node->Input(0, reg).toTensor(); // self + auto dim = p_node->Input(1, reg).toInt(); // dim + int64_t start = 0; + if (p_node->Input(2, reg).isScalar()) { + start = p_node->Input(2, reg).toInt(); + } else { + auto t = p_node->Input(2, reg).toTensor(); + start = t.item(); + } + auto length = p_node->Input(3, reg).toInt(); // length + + if (p_node->Output(0, reg).isNone()) { + p_node->Output(0, reg) = create_empty_from(self); + } + auto output = p_node->Output(0, reg).toTensor(); + output.resize_({0}); + at::native::narrow_copy_dense_out(output, self, dim, start, length); + }; +}); + std::function&)> getOutOfPlaceOperation(Node* n) { auto op_name = n->kind().toQualString();