Revert "[aoti] Add MPS runner and shim (#153964)"

This reverts commit 918ae5d36188f419a47f3b1315f9fb373035ed66.

Reverted https://github.com/pytorch/pytorch/pull/153964 on behalf of https://github.com/angelayi due to broke frl build ([comment](https://github.com/pytorch/pytorch/pull/153964#issuecomment-2901876832))
This commit is contained in:
PyTorch MergeBot
2025-05-22 16:35:59 +00:00
parent 47a01f3efb
commit a82c8891d5
15 changed files with 1 additions and 275 deletions

View File

@ -275,11 +275,6 @@ if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER)
"${TORCH_SRC_DIR}/csrc/lazy/generated/RegisterLazy.cpp"
)
endif()
if(USE_MPS)
list(APPEND GENERATED_CXX_TORCH
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_mps.cpp"
)
endif()
endif()
set(GENERATED_H_TORCH
@ -708,8 +703,6 @@ list(APPEND Caffe2_CPU_SRCS ${TORCH_SRCS})
if(USE_MPS)
list(APPEND Caffe2_CPU_SRCS ${Caffe2_MPS_SRCS})
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/shim_mps.cpp)
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_runner/model_container_runner_mps.cpp)
if(CAN_COMPILE_METAL)
file(TOUCH ${CMAKE_BINARY_DIR}/aten/src/ATen/metallib_dummy.cpp)
list(APPEND Caffe2_CPU_SRCS ${CMAKE_BINARY_DIR}/aten/src/ATen/metallib_dummy.cpp)

View File

@ -19,7 +19,6 @@ def alloc_tensor_by_stealing_from_void_ptr(
class AOTIModelContainerRunnerCpu: ...
class AOTIModelContainerRunnerCuda: ...
class AOTIModelContainerRunnerXpu: ...
class AOTIModelContainerRunnerMps: ...
# Defined in torch/csrc/inductor/aoti_package/pybind.cpp
class AOTIModelPackageLoader: ...

View File

@ -1,4 +0,0 @@
#pragma once
#include <torch/csrc/inductor/aoti_include/common.h>
#include <torch/csrc/inductor/cpp_wrapper/device_internal/mps.h>

View File

@ -1,39 +0,0 @@
#if defined(__APPLE__)
#include <torch/csrc/inductor/aoti_runner/model_container_runner_mps.h>
namespace torch::inductor {
AOTIModelContainerRunnerMps::AOTIModelContainerRunnerMps(
const std::string& model_so_path,
size_t num_models,
bool run_single_threaded)
: AOTIModelContainerRunner(
model_so_path,
num_models,
"mps",
"",
run_single_threaded) {}
AOTIModelContainerRunnerMps::~AOTIModelContainerRunnerMps() = default;
namespace {
std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_mps(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& cubin_dir,
const bool run_single_threaded) {
if (device_str != "mps") {
throw std::runtime_error("Incorrect device passed to aoti_runner_mps");
}
return std::make_unique<AOTIModelContainerRunnerMps>(
model_so_path, num_models, run_single_threaded);
}
} // namespace
static RegisterAOTIModelRunner register_mps_runner(
"mps",
&create_aoti_runner_mps);
} // namespace torch::inductor
#endif

View File

@ -1,18 +0,0 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#pragma once
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
namespace torch::inductor {
class TORCH_API AOTIModelContainerRunnerMps : public AOTIModelContainerRunner {
public:
AOTIModelContainerRunnerMps(
const std::string& model_so_path,
size_t num_models = 1,
const bool run_single_threaded = false);
~AOTIModelContainerRunnerMps() override;
};
} // namespace torch::inductor
#endif

View File

@ -5,9 +5,6 @@
#ifdef USE_XPU
#include <torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h>
#endif
#ifdef __APPLE__
#include <torch/csrc/inductor/aoti_runner/model_container_runner_mps.h>
#endif
#include <torch/csrc/inductor/aoti_runner/pybind.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
@ -133,41 +130,6 @@ void initAOTIRunnerBindings(PyObject* module) {
"free_inactive_constant_buffer",
&AOTIModelContainerRunnerXpu::free_inactive_constant_buffer);
#endif
#ifdef __APPLE__
py::class_<AOTIModelContainerRunnerMps>(m, "AOTIModelContainerRunnerMps")
.def(py::init<const std::string&, int>())
.def(
"run",
&AOTIModelContainerRunnerMps::run,
py::arg("inputs"),
py::arg("stream_handle") = nullptr)
.def("get_call_spec", &AOTIModelContainerRunnerMps::get_call_spec)
.def(
"get_constant_names_to_original_fqns",
&AOTIModelContainerRunnerMps::getConstantNamesToOriginalFQNs)
.def(
"get_constant_names_to_dtypes",
&AOTIModelContainerRunnerMps::getConstantNamesToDtypes)
.def(
"extract_constants_map",
&AOTIModelContainerRunnerMps::extract_constants_map)
.def(
"update_constant_buffer",
static_cast<void (AOTIModelContainerRunnerMps::*)(
std::unordered_map<std::string, at::Tensor>&, bool, bool, bool)>(
&AOTIModelContainerRunnerMps::update_constant_buffer),
py::arg("tensor_map"),
py::arg("use_inactive"),
py::arg("validate_full_updates"),
py::arg("user_managed") = false)
.def(
"swap_constant_buffer",
&AOTIModelContainerRunnerMps::swap_constant_buffer)
.def(
"free_inactive_constant_buffer",
&AOTIModelContainerRunnerMps::free_inactive_constant_buffer);
#endif
m.def(

View File

@ -100,7 +100,7 @@ inline void parse_device_str(
const std::string& device_str,
int32_t& device_type,
int32_t& device_idx) {
std::regex re("(cpu|cuda|xpu|mps)(:([0-9]+))?");
std::regex re("(cpu|cuda|xpu)(:([0-9]+))?");
std::smatch sm;
bool matched = std::regex_match(device_str, sm, re);
AOTI_RUNTIME_CHECK(matched, "Invalid device: " + device_str);
@ -112,10 +112,6 @@ inline void parse_device_str(
#ifdef USE_XPU
} else if (sm[1].str() == "xpu") {
device_type = aoti_torch_device_type_xpu();
#endif
#ifdef __APPLE__
} else if (sm[1].str() == "mps") {
device_type = aoti_torch_device_type_mps();
#endif
} else {
AOTI_RUNTIME_CHECK(false, "Invalid device: " + device_str);

View File

@ -106,7 +106,6 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu();
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda();
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_meta();
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_xpu();
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_mps();
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_privateuse1();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2();

View File

@ -1,22 +0,0 @@
#ifndef AOTI_TORCH_SHIM_MPS
#define AOTI_TORCH_SHIM_MPS
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#ifdef __cplusplus
extern "C" {
#endif
struct AOTIMetalKernelFunctionOpaque;
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*;
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg(
AOTIMetalKernelFunctionHandle func,
unsigned idx,
AtenTensorHandle tensor);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // AOTI_TORCH_SHIM_MPS

View File

@ -1,111 +0,0 @@
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#ifdef __cplusplus
extern "C" {
#endif
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__adaptive_avg_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__adaptive_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_add_Tensor(AtenTensorHandle self, AtenTensorHandle other, double alpha, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_addmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_addmv(AtenTensorHandle self, AtenTensorHandle mat, AtenTensorHandle vec, double beta, double alpha, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_angle(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_baddbmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cholesky_solve(AtenTensorHandle self, AtenTensorHandle input2, int32_t upper, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_convolution(AtenTensorHandle input, AtenTensorHandle weight, AtenTensorHandle* bias, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_convolution_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle weight, const int64_t** bias_sizes, int64_t bias_sizes_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cummax(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_index_Tensor(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_lu_unpack(AtenTensorHandle LU_data, AtenTensorHandle LU_pivots, int32_t unpack_data, int32_t unpack_pivots, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_masked_scatter(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle source, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_masked_scatter_backward(AtenTensorHandle grad_output, AtenTensorHandle mask, const int64_t* sizes, int64_t sizes_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_masked_select(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_median(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_reshape(AtenTensorHandle self, const int64_t* shape, int64_t shape_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_view_as_complex(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_view_as_real(AtenTensorHandle self, AtenTensorHandle* ret0);
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -129,7 +129,6 @@ AOTI_TORCH_DEVICE_TYPE_IMPL(cpu, CPU)
AOTI_TORCH_DEVICE_TYPE_IMPL(cuda, CUDA)
AOTI_TORCH_DEVICE_TYPE_IMPL(meta, Meta)
AOTI_TORCH_DEVICE_TYPE_IMPL(xpu, XPU)
AOTI_TORCH_DEVICE_TYPE_IMPL(mps, MPS)
AOTI_TORCH_DEVICE_TYPE_IMPL(privateuse1, PrivateUse1)
#undef AOTI_TORCH_DEVICE_TYPE_IMPL

View File

@ -1,19 +0,0 @@
#include <ATen/native/mps/MetalShaderLibrary.h>
#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
using namespace torch::aot_inductor;
AOTITorchError aoti_torch_mps_set_arg(
AOTIMetalKernelFunctionHandle handle,
unsigned idx,
AtenTensorHandle tensor) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto t = tensor_handle_to_tensor_pointer(tensor);
if (t == nullptr) {
throw std::runtime_error("Tensor is null.");
}
auto func = reinterpret_cast<at::native::mps::MetalKernelFunction*>(handle);
func->setArg(idx, *t);
});
}

View File

@ -1,4 +0,0 @@
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h>

View File

@ -1,4 +0,0 @@
#pragma once
#include <torch/csrc/inductor/cpp_wrapper/common.h>
#include <torch/csrc/inductor/cpp_wrapper/device_internal/mps.h>

View File

@ -2987,7 +2987,6 @@ def main() -> None:
if options.mps:
functions_keys.add(DispatchKey.MPS)
aoti_backends.add(DispatchKey.MPS)
if options.xpu:
functions_keys.add(DispatchKey.XPU)