mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AOTI] Refine the C shim autogen mechanism (#125589)
Summary: Based on the discussions in https://github.com/pytorch/pytorch/pull/120513. Instead of auto-generate C shim fallback ops for thousands of ops, we maintain a list of fallback ops based on torch/_inductor/lowering.py, and only generate C shim functions for those ops. At the torchgen time, we will re-generate C shim files and compare the header file contents against the existing C shim headers. If there is any change, the compilation will fail with prompt on how to proceed. This makes sure the ABI-compatible C shim layer is small enough to maintain in the long run. Differential Revision: [D57004046](https://our.internmc.facebook.com/intern/diff/D57004046) Pull Request resolved: https://github.com/pytorch/pytorch/pull/125589 Approved by: https://github.com/frank-wei, https://github.com/chenyang78, https://github.com/albanD, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0bde9c08ef
commit
ed48ea9997
2
.gitignore
vendored
2
.gitignore
vendored
@ -87,7 +87,7 @@ torch/csrc/api/include/torch/version.h
|
||||
torch/csrc/cudnn/cuDNN.cpp
|
||||
torch/csrc/generated
|
||||
torch/csrc/generic/TensorMethods.cpp
|
||||
torch/csrc/inductor/aoti_torch/generated/*
|
||||
torch/csrc/inductor/aoti_torch/generated/*.cpp
|
||||
torch/csrc/jit/generated/*
|
||||
torch/csrc/jit/fuser/config.h
|
||||
torch/csrc/nn/THCUNN.cpp
|
||||
|
@ -78,6 +78,7 @@ exclude_patterns = [
|
||||
'aten/src/ATen/native/vulkan/api/vk_mem_alloc.h',
|
||||
'c10/util/strong_type.h',
|
||||
'**/fb/**',
|
||||
'torch/csrc/inductor/aoti_torch/generated/**',
|
||||
'torch/csrc/jit/serialization/mobile_bytecode_generated.h',
|
||||
'torch/csrc/utils/pythoncapi_compat.h',
|
||||
'aten/src/ATen/dlpack.h',
|
||||
|
@ -71,11 +71,9 @@ test_failures_cpp_wrapper = {
|
||||
|
||||
if config.abi_compatible:
|
||||
xfail_list = [
|
||||
"test_add_complex_cpu",
|
||||
"test_bernoulli1_cpu", # cpp fallback op naming issue
|
||||
"test_conv2d_binary_inplace_fusion_failed_cpu",
|
||||
"test_conv2d_binary_inplace_fusion_pass_cpu",
|
||||
"test_cumsum_cpu",
|
||||
"test_dynamic_qlinear_cpu",
|
||||
"test_dynamic_qlinear_qat_cpu",
|
||||
"test_lstm_packed_change_input_sizes_cpu",
|
||||
|
@ -97,7 +97,6 @@ if TEST_WITH_ROCM:
|
||||
|
||||
if config.abi_compatible:
|
||||
xfail_list = [
|
||||
"test_add_complex_cuda",
|
||||
"test_bernoulli1_cuda", # cpp fallback op naming issue
|
||||
"test_conv_backward_cuda",
|
||||
"test_profiler_mark_wrapper_call_cuda",
|
||||
|
112
torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h
Normal file
112
torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h
Normal file
@ -0,0 +1,112 @@
|
||||
|
||||
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_forward_only(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_per_sample_weights_backward(AtenTensorHandle grad, AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, AtenTensorHandle offset2bag, int64_t mode, int64_t padding_idx, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_addmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_addmv(AtenTensorHandle self, AtenTensorHandle mat, AtenTensorHandle vec, double beta, double alpha, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_angle(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cholesky_inverse(AtenTensorHandle self, int32_t upper, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cholesky_solve(AtenTensorHandle self, AtenTensorHandle input2, int32_t upper, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_convolution(AtenTensorHandle input, AtenTensorHandle weight, AtenTensorHandle* bias, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_convolution_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle weight, const int64_t** bias_sizes, int64_t bias_sizes_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cummax(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_gcd(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_geqrf(AtenTensorHandle self, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_grid_sampler_2d_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle grid, int64_t interpolation_mode, int64_t padding_mode, int32_t align_corners, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_index_reduce(AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle source, const char* reduce, int32_t include_self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_lu_unpack(AtenTensorHandle LU_data, AtenTensorHandle LU_pivots, int32_t unpack_data, int32_t unpack_pivots, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_masked_scatter(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle source, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_masked_scatter_backward(AtenTensorHandle grad_output, AtenTensorHandle mask, const int64_t* sizes, int64_t sizes_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_pool2d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_pool2d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_pool3d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_pool3d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_unpool2d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_unpool3d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_median(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mode(AtenTensorHandle self, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_trilinear3d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_d, double* scales_h, double* scales_w, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_view_as_complex(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_view_as_real(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
119
torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h
Normal file
119
torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h
Normal file
@ -0,0 +1,119 @@
|
||||
|
||||
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__adaptive_avg_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__adaptive_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__adaptive_avg_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__adaptive_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cudnn_rnn(AtenTensorHandle input, const AtenTensorHandle* weight, int64_t weight_len_, int64_t weight_stride0, AtenTensorHandle* weight_buf, AtenTensorHandle hx, AtenTensorHandle* cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, int32_t batch_first, double dropout, int32_t train, int32_t bidirectional, const int64_t* batch_sizes, int64_t batch_sizes_len_, AtenTensorHandle* dropout_state, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle out, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, AtenTensorHandle logsumexp, double dropout_p, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, int64_t custom_mask_type, int32_t bias_requires_grad, double* scale, int64_t* num_splits_key, int64_t* window_size, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t* max_seqlen_q, int64_t* max_seqlen_k, double dropout_p, int64_t custom_mask_type, int32_t compute_log_sumexp, double* scale, AtenTensorHandle* causal_diagonal, AtenTensorHandle* seqlen_k, int64_t* window_size, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_forward_only(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_per_sample_weights_backward(AtenTensorHandle grad, AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, AtenTensorHandle offset2bag, int64_t mode, int64_t padding_idx, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, int32_t compute_log_sumexp, double dropout_p, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double dropout_p, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle* bias, int32_t* out_dtype, AtenTensorHandle* scale_a, AtenTensorHandle* scale_b, AtenTensorHandle* scale_result, int32_t use_fast_accum, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__thnn_fused_lstm_cell(AtenTensorHandle input_gates, AtenTensorHandle hidden_gates, AtenTensorHandle cx, AtenTensorHandle* input_bias, AtenTensorHandle* hidden_bias, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_addmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_addmv(AtenTensorHandle self, AtenTensorHandle mat, AtenTensorHandle vec, double beta, double alpha, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_angle(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cholesky_inverse(AtenTensorHandle self, int32_t upper, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cholesky_solve(AtenTensorHandle self, AtenTensorHandle input2, int32_t upper, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_convolution(AtenTensorHandle input, AtenTensorHandle weight, AtenTensorHandle* bias, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_convolution_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle weight, const int64_t** bias_sizes, int64_t bias_sizes_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cummax(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_gcd(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_geqrf(AtenTensorHandle self, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_grid_sampler_2d_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle grid, int64_t interpolation_mode, int64_t padding_mode, int32_t align_corners, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_index_reduce(AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle source, const char* reduce, int32_t include_self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_lu_unpack(AtenTensorHandle LU_data, AtenTensorHandle LU_pivots, int32_t unpack_data, int32_t unpack_pivots, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_masked_scatter(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle source, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_masked_scatter_backward(AtenTensorHandle grad_output, AtenTensorHandle mask, const int64_t* sizes, int64_t sizes_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_pool2d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_pool2d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_pool3d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_pool3d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_unpool2d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_unpool3d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_median(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mode(AtenTensorHandle self, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_trilinear3d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_d, double* scales_h, double* scales_w, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_view_as_complex(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_view_as_real(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
0
torchgen/aoti/__init__.py
Normal file
0
torchgen/aoti/__init__.py
Normal file
129
torchgen/aoti/fallback_ops.py
Normal file
129
torchgen/aoti/fallback_ops.py
Normal file
@ -0,0 +1,129 @@
|
||||
# This list is based on the fallback ops from torch/_inductor/lowering.py
|
||||
# If you add a new op to the list, remember to run `python torchgen/gen.py --update-aoti-c-shim`
|
||||
# to update C shim files.
|
||||
|
||||
inductor_fallback_ops = {
|
||||
"aten._adaptive_avg_pool2d_backward",
|
||||
"aten._adaptive_avg_pool2d",
|
||||
"aten._adaptive_avg_pool3d",
|
||||
"aten._adaptive_avg_pool3d_backward",
|
||||
"aten.adaptive_max_pool2d_backward",
|
||||
"aten.adaptive_max_pool2d",
|
||||
"aten.adaptive_max_pool3d",
|
||||
"aten.adaptive_max_pool3d_backward",
|
||||
"aten.addbmm",
|
||||
"aten._addmm_activation",
|
||||
"aten.addmm.out",
|
||||
"aten.addmv",
|
||||
"aten.angle",
|
||||
"aten.avg_pool2d_backward",
|
||||
"aten.avg_pool2d",
|
||||
"aten.avg_pool3d_backward",
|
||||
"aten.avg_pool3d",
|
||||
"aten.bmm.out",
|
||||
"aten.bucketize.Tensor",
|
||||
"aten.cat",
|
||||
"aten._cdist_backward",
|
||||
"aten._cdist_forward",
|
||||
"aten.cholesky_inverse",
|
||||
"aten.cholesky_solve",
|
||||
"aten.convolution_backward",
|
||||
"aten._cudnn_rnn",
|
||||
"aten._cudnn_rnn_backward",
|
||||
"aten.convolution",
|
||||
"aten.cummax",
|
||||
"aten.cummin",
|
||||
"aten.cumprod",
|
||||
"aten.cumsum",
|
||||
"aten._efficient_attention_backward",
|
||||
"aten._efficient_attention_forward",
|
||||
"aten._efficientzerotensor",
|
||||
"aten._embedding_bag",
|
||||
"aten._embedding_bag_dense_backward",
|
||||
"aten._embedding_bag_forward_only",
|
||||
"aten._embedding_bag_per_sample_weights_backward",
|
||||
"aten.exponential",
|
||||
"aten._fft_c2c",
|
||||
"aten._fft_r2c",
|
||||
"aten._flash_attention_backward",
|
||||
"aten._flash_attention_forward",
|
||||
"aten.fractional_max_pool2d_backward",
|
||||
"aten.fractional_max_pool2d",
|
||||
"aten.fractional_max_pool3d",
|
||||
"aten.fractional_max_pool3d_backward",
|
||||
"aten._fused_moving_avg_obs_fq_helper",
|
||||
"aten._fused_moving_avg_obs_fq_helper_functional",
|
||||
"aten.gcd",
|
||||
"aten.geqrf",
|
||||
"aten.grid_sampler_2d_backward",
|
||||
"aten.histc",
|
||||
"aten.histogram.bin_ct",
|
||||
"aten._histogramdd_bin_edges",
|
||||
"aten._histogramdd_from_bin_cts",
|
||||
"aten.index_reduce",
|
||||
"aten.index.Tensor",
|
||||
"aten.kthvalue",
|
||||
"aten.logcumsumexp",
|
||||
"aten.lu_unpack",
|
||||
"aten.masked_scatter",
|
||||
"aten.masked_scatter_backward",
|
||||
"aten.max_pool2d_with_indices_backward",
|
||||
"aten.max_pool2d_with_indices",
|
||||
"aten.max_pool3d_with_indices",
|
||||
"aten.max_pool3d_with_indices_backward",
|
||||
"aten.max_unpool2d",
|
||||
"aten.max_unpool3d",
|
||||
"aten.median",
|
||||
"aten.mm.out",
|
||||
"aten.mode",
|
||||
"aten.mul.Scalar",
|
||||
"aten.nanmedian",
|
||||
"aten.nonzero",
|
||||
"aten.ormqr",
|
||||
"aten._pdist_backward",
|
||||
"aten._pdist_forward",
|
||||
"aten.pow.Scalar",
|
||||
"aten.pow.Tensor_Scalar",
|
||||
"aten.pow.Tensor_Tensor",
|
||||
"aten.rand",
|
||||
"aten.rand.generator",
|
||||
"aten.randint",
|
||||
"aten.randn",
|
||||
"aten.randn.generator",
|
||||
"aten.randperm",
|
||||
"aten.repeat_interleave.Tensor",
|
||||
"aten.replication_pad1d_backward",
|
||||
"aten.replication_pad2d_backward",
|
||||
"aten.resize_",
|
||||
"aten.resize_as_",
|
||||
"aten._scaled_dot_product_efficient_attention_backward",
|
||||
"aten._scaled_dot_product_efficient_attention",
|
||||
"aten._scaled_dot_product_flash_attention_backward",
|
||||
"aten._scaled_dot_product_flash_attention",
|
||||
"aten._scaled_dot_product_flash_attention_for_cpu_backward",
|
||||
"aten._scaled_dot_product_flash_attention_for_cpu",
|
||||
"aten._scaled_mm",
|
||||
"aten.scatter_reduce.two_out",
|
||||
"aten.scatter.src_out",
|
||||
"aten.searchsorted",
|
||||
"aten._segment_reduce_backward",
|
||||
"aten.segment_reduce",
|
||||
"aten.soft_margin_loss_backward",
|
||||
"aten.sort",
|
||||
"aten.sort.stable",
|
||||
"aten._sparse_coo_tensor_with_dims_and_tensors",
|
||||
"aten._thnn_fused_lstm_cell",
|
||||
"aten.topk",
|
||||
"aten._to_sparse",
|
||||
"aten.to_sparse",
|
||||
"aten.triangular_solve",
|
||||
"aten._trilinear",
|
||||
"aten.uniform",
|
||||
"aten.upsample_bicubic2d_backward",
|
||||
"aten.upsample_linear1d_backward",
|
||||
"aten.upsample_trilinear3d_backward",
|
||||
"aten.view_as_complex",
|
||||
"aten.view_as_real",
|
||||
"aten.view.dtype",
|
||||
"aten.zeros.names",
|
||||
}
|
105
torchgen/gen.py
105
torchgen/gen.py
@ -3,6 +3,7 @@ import functools
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
@ -27,6 +28,7 @@ import torchgen.api.native as native
|
||||
import torchgen.api.structured as structured
|
||||
import torchgen.dest as dest
|
||||
|
||||
from torchgen.aoti.fallback_ops import inductor_fallback_ops
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
@ -48,6 +50,7 @@ from torchgen.gen_aoti_c_shim import (
|
||||
gen_aoti_c_shim,
|
||||
gen_static_dispatch_backend_call_signature,
|
||||
get_backend_index_for_aoti,
|
||||
get_fallback_op_name,
|
||||
)
|
||||
from torchgen.gen_functionalization_type import (
|
||||
gen_functionalization_definition,
|
||||
@ -2190,6 +2193,7 @@ def gen_source_files(
|
||||
force_schema_registration: bool,
|
||||
per_operator_headers: bool,
|
||||
skip_dispatcher_op_registration: bool,
|
||||
update_aoti_c_shim: bool,
|
||||
) -> None:
|
||||
extra_cuda_headers = """\
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -2349,53 +2353,99 @@ def gen_source_files(
|
||||
else:
|
||||
raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
|
||||
|
||||
structured_func_group_dict = {
|
||||
f"{func_group.functional.namespace}.{func_group.functional.func.name}": func_group
|
||||
for func_group in structured_native_functions
|
||||
}
|
||||
if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
|
||||
fallbacks = dict()
|
||||
for func in native_functions:
|
||||
op_name = get_fallback_op_name(func)
|
||||
if op_name in inductor_fallback_ops:
|
||||
fallbacks[op_name] = (
|
||||
func,
|
||||
structured_func_group_dict.get(
|
||||
f"{func.namespace}.{func.func.name.name}", None
|
||||
),
|
||||
)
|
||||
fallback_native_functions = tuple(
|
||||
value for _, value in sorted(fallbacks.items())
|
||||
)
|
||||
|
||||
def get_header(
|
||||
f: NativeFunction,
|
||||
func: NativeFunction,
|
||||
func_group: Optional[NativeFunctionsGroup],
|
||||
) -> Optional[str]:
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
f, dispatch_key, backend_indices
|
||||
func, func_group, dispatch_key, backend_indices
|
||||
)
|
||||
return (
|
||||
None
|
||||
if backend_index is None
|
||||
else f"#include <ATen/ops/{f.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
|
||||
else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
|
||||
)
|
||||
|
||||
def headers_for_aoti() -> str:
|
||||
headers = []
|
||||
for g in grouped_native_functions:
|
||||
if isinstance(g, NativeFunctionsGroup):
|
||||
for f in g.functions():
|
||||
# some variants are registered in the backend, but some are registered as CompositeExplicitAutograd
|
||||
header = get_header(f)
|
||||
if header is not None:
|
||||
headers.append(header)
|
||||
else:
|
||||
header = get_header(g)
|
||||
if header is not None:
|
||||
headers.append(header)
|
||||
for func, func_group in fallback_native_functions:
|
||||
header = get_header(func, func_group)
|
||||
if header is not None:
|
||||
headers.append(header)
|
||||
return "\n".join(sorted(set(headers)))
|
||||
|
||||
extra_headers = (
|
||||
extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
|
||||
)
|
||||
|
||||
aoti_fm.write(
|
||||
f"c_shim_{dispatch_key.lower()}.h",
|
||||
lambda: gen_aoti_c_shim(
|
||||
native_functions,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=True,
|
||||
includes="",
|
||||
),
|
||||
# header files were checked in for ABI-compatiblilty checking
|
||||
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
|
||||
new_header = gen_aoti_c_shim(
|
||||
fallback_native_functions,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=True,
|
||||
includes="",
|
||||
)
|
||||
if update_aoti_c_shim:
|
||||
aoti_fm.write(
|
||||
header_file_name,
|
||||
lambda: new_header,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
with open(
|
||||
os.path.join(aoti_fm.install_dir, header_file_name)
|
||||
) as old_file:
|
||||
old_header = old_file.read()
|
||||
assert (
|
||||
old_header == new_header
|
||||
), """
|
||||
|
||||
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
|
||||
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
|
||||
Only in a limited number of situations, this is allowed:
|
||||
|
||||
1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
|
||||
If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing
|
||||
C shim header files.
|
||||
|
||||
2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
|
||||
change in the AOTInductor land. In this case, you need to keep a manual copy of that existing
|
||||
fallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version
|
||||
number of that fallback op in the newly generated C shim files, and update the cpp wrapper
|
||||
codegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance.
|
||||
|
||||
"""
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
|
||||
)
|
||||
|
||||
# cpp files are always generated on-the-fly
|
||||
aoti_fm.write(
|
||||
f"c_shim_{dispatch_key.lower()}.cpp",
|
||||
lambda: gen_aoti_c_shim(
|
||||
native_functions,
|
||||
fallback_native_functions,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=False,
|
||||
@ -2780,6 +2830,12 @@ def main() -> None:
|
||||
default=["headers", "sources", "declarations_yaml"],
|
||||
help="Generate only a subset of files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-aoti-c-shim",
|
||||
action="store_true",
|
||||
help="Update AOTInductor C shim after changing torchgen/aoti/fallback_ops.py. "
|
||||
"WARNING: Do not use this unless you are sure what you are doing!!!",
|
||||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
|
||||
@ -2898,6 +2954,7 @@ def main() -> None:
|
||||
force_schema_registration=options.force_schema_registration,
|
||||
per_operator_headers=options.per_operator_headers,
|
||||
skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
|
||||
update_aoti_c_shim=options.update_aoti_c_shim,
|
||||
)
|
||||
|
||||
if "headers" in options.generate:
|
||||
|
@ -15,18 +15,12 @@ from torchgen.model import (
|
||||
FunctionSchema,
|
||||
ListType,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
OptionalType,
|
||||
Type,
|
||||
)
|
||||
from torchgen.utils import mapMaybe
|
||||
|
||||
|
||||
def returns_are_all_tensor(schema: FunctionSchema) -> bool:
|
||||
return len(schema.returns) != 0 and all(
|
||||
ret.type.is_tensor_like() for ret in schema.returns
|
||||
)
|
||||
|
||||
|
||||
base_type_to_c_type = {
|
||||
BaseTy.Tensor: "AtenTensorHandle",
|
||||
BaseTy.bool: "int32_t", # Use int to pass bool
|
||||
@ -301,52 +295,63 @@ def gen_static_dispatch_backend_call(
|
||||
f: NativeFunction,
|
||||
backend_index: BackendIndex,
|
||||
) -> str:
|
||||
assert backend_index.has_kernel(f)
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
|
||||
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
|
||||
|
||||
|
||||
def get_backend_index_for_aoti(
|
||||
f: NativeFunction,
|
||||
func: NativeFunction,
|
||||
func_group: Optional[NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
) -> Optional[BackendIndex]:
|
||||
if "pointwise" in f.tags:
|
||||
# TODO: No need to generate C shim for Inductor lowered ops.
|
||||
# Only skip pointwise kernels for now, and we can add more tags later.
|
||||
return None
|
||||
|
||||
backend_index = None
|
||||
if backend_indices[dispatch_key].has_kernel(f):
|
||||
if backend_indices[dispatch_key].has_kernel(func) or (
|
||||
func.structured_delegate is not None
|
||||
and func_group is not None
|
||||
and backend_indices[dispatch_key].has_kernel(func_group)
|
||||
):
|
||||
backend_index = backend_indices[dispatch_key]
|
||||
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(f):
|
||||
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
|
||||
# We need to create C shim wrappers for CompositeExplicitAutograd kernels
|
||||
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
|
||||
elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
|
||||
f
|
||||
func
|
||||
):
|
||||
# We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
|
||||
backend_index = backend_indices[
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional
|
||||
]
|
||||
|
||||
return backend_index
|
||||
|
||||
|
||||
def get_fallback_op_name(func: NativeFunction) -> str:
|
||||
return (
|
||||
f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}"
|
||||
if func.func.name.overload_name
|
||||
else f"{func.namespace}.{func.func.name.name}"
|
||||
)
|
||||
|
||||
|
||||
def gen_c_shim(
|
||||
f: NativeFunction,
|
||||
func: NativeFunction,
|
||||
func_group: Optional[NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
) -> Optional[str]:
|
||||
backend_index = get_backend_index_for_aoti(f, dispatch_key, backend_indices)
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
func, func_group, dispatch_key, backend_indices
|
||||
)
|
||||
if backend_index is None:
|
||||
return None
|
||||
|
||||
schema = f.func
|
||||
schema = func.func
|
||||
device = dispatch_key.lower()
|
||||
backend_call = gen_static_dispatch_backend_call(
|
||||
f,
|
||||
func,
|
||||
backend_index,
|
||||
)
|
||||
|
||||
@ -366,34 +371,56 @@ def gen_c_shim(
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShimGenerator:
|
||||
func_group_mapping: Dict[str, Optional[NativeFunctionsGroup]]
|
||||
dispatch_key: DispatchKey
|
||||
backend_indices: Dict[DispatchKey, BackendIndex]
|
||||
header: bool # True to generate .h and False to generate .cpp
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
result = gen_c_shim(f, self.dispatch_key, self.backend_indices, self.header)
|
||||
def __call__(
|
||||
self,
|
||||
func: NativeFunction,
|
||||
) -> Optional[str]:
|
||||
result = gen_c_shim(
|
||||
func,
|
||||
self.func_group_mapping.get(get_fallback_op_name(func), None),
|
||||
self.dispatch_key,
|
||||
self.backend_indices,
|
||||
self.header,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def gen_aoti_c_shim(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
native_functions: Sequence[Tuple[NativeFunction, Optional[NativeFunctionsGroup]]],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
includes: str = "",
|
||||
) -> str:
|
||||
func_group_mapping = {
|
||||
get_fallback_op_name(func): func_group for func, func_group in native_functions
|
||||
}
|
||||
body = "\n".join(
|
||||
list(
|
||||
mapMaybe(
|
||||
ShimGenerator(dispatch_key, backend_indices, header),
|
||||
native_functions,
|
||||
ShimGenerator(
|
||||
func_group_mapping, dispatch_key, backend_indices, header
|
||||
),
|
||||
[func for func, _ in native_functions],
|
||||
)
|
||||
)
|
||||
)
|
||||
device = dispatch_key.lower()
|
||||
|
||||
warning = (
|
||||
"// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND."
|
||||
)
|
||||
|
||||
if header:
|
||||
return f"""
|
||||
{warning}
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
@ -407,13 +434,14 @@ extern "C" {{
|
||||
#ifdef __cplusplus
|
||||
}} // extern "C"
|
||||
#endif
|
||||
|
||||
"""
|
||||
|
||||
else:
|
||||
device = dispatch_key.lower()
|
||||
return f"""
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
{warning}
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/{str(dispatch_key)}Functions.h>
|
||||
@ -425,6 +453,4 @@ extern "C" {{
|
||||
|
||||
using namespace torch::aot_inductor;
|
||||
|
||||
{body}
|
||||
|
||||
"""
|
||||
{body}"""
|
||||
|
Reference in New Issue
Block a user