#pragma once #include #include #include #include #include #include #include #include namespace torch::stable { // We expect this to be the stable version of the empty_like op that takes in // no kwargs (device, dtype, layout, memory_format). We will add kwargs // support in the future. inline torch::stable::Tensor empty_like(const torch::stable::Tensor& self) { const auto num_args = 6; std::array stack{ from(self), from(std::nullopt), from(std::nullopt), from(std::nullopt), from(std::nullopt), from(std::nullopt)}; TORCH_ERROR_CODE_CHECK( aoti_torch_call_dispatcher("aten::empty_like", "", stack.data())); return to(stack[0]); } // We expect this to be the stable version of the fill_.Scalar op // with identical semantics to the existing fill_.Scalar op. // A subtle nuance is that `value` is typed as a double, but it is // actually a Scalar. This is because Scalar.h is currently not // header-only. inline torch::stable::Tensor fill_( const torch::stable::Tensor& self, double value) { TORCH_ERROR_CODE_CHECK(aoti_torch_aten_fill__Scalar(self.get(), value)); return self; } // We expect this to be the stable version of the narrow.default op. // narrow takes in a SymInt for start and length, but these are typed as // int64_t as SymInt is not yet header-only. inline torch::stable::Tensor narrow( torch::stable::Tensor& self, int64_t dim, int64_t start, int64_t length) { AtenTensorHandle ret0 = nullptr; TORCH_ERROR_CODE_CHECK( aoti_torch_aten_narrow(self.get(), dim, start, length, &ret0)); return torch::stable::Tensor(ret0); } // We expect this to be a stable version of the new_empty op that takes in // only dtype information. inline torch::stable::Tensor new_empty( const torch::stable::Tensor& self, torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype = std::nullopt) { int32_t device_type; TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); int32_t device_index; TORCH_ERROR_CODE_CHECK( aoti_torch_get_device_index(self.get(), &device_index)); int32_t target_dtype; if (dtype.has_value()) { target_dtype = to(from(dtype.value())); } else { TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype)); } int32_t layout; TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout)); AtenTensorHandle ret0; TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty( self.get(), size.data(), static_cast(size.size()), &target_dtype, &layout, &device_type, device_index, nullptr, // pin_memory (nullptr for default) &ret0)); return torch::stable::Tensor(ret0); } // We expect this to be a stable version of the new_zeros op that takes in // only dtype information. inline torch::stable::Tensor new_zeros( const torch::stable::Tensor& self, torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype = std::nullopt) { int32_t device_type; TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); int32_t device_index; TORCH_ERROR_CODE_CHECK( aoti_torch_get_device_index(self.get(), &device_index)); int32_t target_dtype; if (dtype.has_value()) { target_dtype = to(from(dtype.value())); } else { TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype)); } int32_t layout; TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout)); AtenTensorHandle ath; TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_zeros( self.get(), size.data(), static_cast(size.size()), &target_dtype, &layout, &device_type, device_index, nullptr, // pin_memory (nullptr for default) &ath)); return torch::stable::Tensor(ath); } // We expect this to be the stable version of the pad.default op. // pad.default takes in a SymInt[] as the pad argument however pad is typed as // torch::headeronly::IntHeaderOnlyArrayRef as SymInt is not yet header-only. inline torch::stable::Tensor pad( const torch::stable::Tensor& self, torch::headeronly::IntHeaderOnlyArrayRef pad, const std::string& mode = "constant", double value = 0.0) { AtenTensorHandle ret0 = nullptr; TORCH_ERROR_CODE_CHECK(aoti_torch_aten_pad( self.get(), pad.data(), pad.size(), mode.c_str(), &value, &ret0)); return torch::stable::Tensor(ret0); } // We expect the following two functions to be stable versions of the // amax.default op with identical semantics to the existing amax.default op. If // `keepdim` is true, the result will have the same number of dimensions as // `self`, with the specified dimension having size 1. Otherwise, the result // will have one fewer dimension than `self`, with the specified dimension // removed. // This function is an overload to compute the maximum value along each slice of // `self` along a single dimension `dim`. inline torch::stable::Tensor amax( const torch::stable::Tensor& self, int64_t dim, bool keepdim = false) { AtenTensorHandle ret = nullptr; TORCH_ERROR_CODE_CHECK( aoti_torch_aten_amax(self.get(), &dim, 1, keepdim, &ret)); return torch::stable::Tensor(ret); } // This function is an overload to compute the maximum value along each slice of // `self` reducing over all the dimensions in the vector `dims`. The // amax.default op takes in a SymInt[] as the dims argument, however dims is // typed as use IntHeaderOnlyArrayRef here because SymInt is not yet header-only inline torch::stable::Tensor amax( const torch::stable::Tensor& self, torch::headeronly::IntHeaderOnlyArrayRef dims, bool keepdim = false) { AtenTensorHandle ret = nullptr; TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax( self.get(), dims.data(), static_cast(dims.size()), keepdim, &ret)); return torch::stable::Tensor(ret); } // We expect this to be the stable version of the transpose op with identical // semantics to the existing transpose.int op. inline torch::stable::Tensor transpose( const torch::stable::Tensor& self, int64_t dim0, int64_t dim1) { const auto num_args = 3; std::array stack{from(self), from(dim0), from(dim1)}; TORCH_ERROR_CODE_CHECK( aoti_torch_call_dispatcher("aten::transpose", "int", stack.data())); return to(stack[0]); } // We expect this to be the stable version of the zero_ op with identical // semantics to the existing zero_ op (except that it will not be called as // a tensor method but only as a function i.e. zero_(t) not t.zero_()). inline torch::stable::Tensor zero_(torch::stable::Tensor& self) { const auto num_args = 1; std::array stack{from(self)}; TORCH_ERROR_CODE_CHECK( aoti_torch_call_dispatcher("aten::zero_", "", stack.data())); return to(stack[0]); } // We expect this to be the stable version of the copy_ op with // identical semantics to the existing copy_ op. inline torch::stable::Tensor copy_( torch::stable::Tensor& self, const torch::stable::Tensor& src, std::optional non_blocking = std::nullopt) { const auto num_args = 3; std::array stack{ from(self), from(src), from(non_blocking.value_or(false))}; TORCH_ERROR_CODE_CHECK( aoti_torch_call_dispatcher("aten::copy_", "", stack.data())); return to(stack[0]); } // We expect this to be the stable version of the clone op. We will // add optional memory_format kwarg support in the future. inline torch::stable::Tensor clone(const torch::stable::Tensor& self) { const auto num_args = 2; std::array stack{from(self), from(std::nullopt)}; TORCH_ERROR_CODE_CHECK( aoti_torch_call_dispatcher("aten::clone", "", stack.data())); return to(stack[0]); } } // namespace torch::stable