mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Intel GPU] Support RegisterXPU.cpp codegen and compile for the in-tree XPU structured GEMM OPs. (#139025)
[Intel GPU] Support RegisterXPU.cpp codegen and compile for the in-tree XPU structured GEMM ops. Motivation: There are two parts of aten ops for XPU, one is in-tree ops like GEMM related OPs and the other is out-off-tree ops in torch-xpu-ops. For the in-tree part,since Pytorch uses native_functions.yaml registration and is equipped with convenient codegen capabilities, we want to take advantage of these benefits as well. At the same time, since AOT Inductor also uses native_functions.yaml to generate c shim wrappers, we also need to enable this mechanism for XPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139025 Approved by: https://github.com/EikanWang, https://github.com/jansel, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b650c360a
commit
929a647363
@ -348,6 +348,7 @@ endif()
|
|||||||
if(USE_XPU)
|
if(USE_XPU)
|
||||||
list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/xpu)
|
list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/xpu)
|
||||||
list(APPEND ATen_XPU_SRCS ${xpu_cpp})
|
list(APPEND ATen_XPU_SRCS ${xpu_cpp})
|
||||||
|
list(APPEND ATen_XPU_SRCS ${xpu_generated_sources})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||||
|
@ -1,9 +1,24 @@
|
|||||||
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
#include <ATen/WrapDimUtilsMulti.h>
|
#include <ATen/WrapDimUtilsMulti.h>
|
||||||
#include <ATen/native/Resize.h>
|
#include <ATen/native/Resize.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||||
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
|
|
||||||
namespace at::native::xpu {
|
#include <ATen/Functions.h>
|
||||||
|
#include <ATen/NativeFunctions.h>
|
||||||
|
#else
|
||||||
|
#include <ATen/ops/_addmm_activation_native.h>
|
||||||
|
#include <ATen/ops/addmm_native.h>
|
||||||
|
#include <ATen/ops/addmv_native.h>
|
||||||
|
#include <ATen/ops/baddbmm_native.h>
|
||||||
|
#include <ATen/ops/bmm_native.h>
|
||||||
|
#include <ATen/ops/empty.h>
|
||||||
|
#include <ATen/ops/mm_native.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace at::native {
|
||||||
|
namespace xpu {
|
||||||
|
|
||||||
// result = beta * self + alpha * (mat1 * mat2)
|
// result = beta * self + alpha * (mat1 * mat2)
|
||||||
Tensor& addmm_out(
|
Tensor& addmm_out(
|
||||||
@ -425,20 +440,35 @@ Tensor& tensordot_out(
|
|||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(aten, XPU, m){
|
TORCH_LIBRARY_IMPL(aten, XPU, m){
|
||||||
m.impl("addmm.out", TORCH_FN(addmm_out));
|
|
||||||
m.impl("_addmm_activation.out", TORCH_FN(_addmm_activation_out));
|
|
||||||
m.impl("mm.out", TORCH_FN(mm_out));
|
|
||||||
m.impl("mm", TORCH_FN(mm));
|
|
||||||
m.impl("baddbmm.out", TORCH_FN(baddbmm_out));
|
|
||||||
m.impl("baddbmm_", TORCH_FN(baddbmm_));
|
|
||||||
m.impl("baddbmm", TORCH_FN(baddbmm));
|
|
||||||
m.impl("addbmm.out", TORCH_FN(addbmm_out));
|
|
||||||
m.impl("addbmm_", TORCH_FN(addbmm_));
|
|
||||||
m.impl("addbmm", TORCH_FN(addbmm));
|
|
||||||
m.impl("bmm.out", TORCH_FN(bmm_out));
|
|
||||||
m.impl("bmm", TORCH_FN(bmm));
|
|
||||||
m.impl("addmv.out", TORCH_FN(addmv_out));
|
|
||||||
m.impl("tensordot.out", TORCH_FN(tensordot_out));
|
m.impl("tensordot.out", TORCH_FN(tensordot_out));
|
||||||
}
|
}
|
||||||
|
} // namespace xpu
|
||||||
|
|
||||||
} // namespace at::native::xpu
|
TORCH_IMPL_FUNC(addmm_out_xpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||||
|
xpu::addmm_out(self, mat1, mat2, beta, alpha, const_cast<Tensor&>(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_IMPL_FUNC(mm_out_xpu)(const Tensor& self, const Tensor& mat2, const Tensor& result) {
|
||||||
|
xpu::mm_out(self, mat2, const_cast<Tensor&>(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_IMPL_FUNC(bmm_out_xpu)(const Tensor& self, const Tensor& batch2, const Tensor &result) {
|
||||||
|
xpu::bmm_out(self, batch2, const_cast<Tensor&>(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_IMPL_FUNC(addmm_activation_out_xpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu, const Tensor& result) {
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||||
|
xpu::_addmm_activation_out(self, mat1, mat2, beta, alpha, use_gelu, const_cast<Tensor&>(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_IMPL_FUNC(baddbmm_out_xpu)(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||||
|
xpu::baddbmm_out(
|
||||||
|
self, batch1, batch2, beta, alpha, const_cast<Tensor&>(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_IMPL_FUNC(addmv_out_xpu)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||||
|
xpu::addmv_out(self, mat, vec, beta, alpha, const_cast<Tensor&>(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace at::native
|
@ -641,6 +641,7 @@
|
|||||||
CPU: addmv_out_cpu
|
CPU: addmv_out_cpu
|
||||||
CUDA: addmv_out_cuda
|
CUDA: addmv_out_cuda
|
||||||
MPS: addmv_out_mps
|
MPS: addmv_out_mps
|
||||||
|
XPU: addmv_out_xpu
|
||||||
SparseCsrCPU: addmv_out_sparse_compressed
|
SparseCsrCPU: addmv_out_sparse_compressed
|
||||||
SparseCsrCUDA: addmv_out_sparse_compressed_cuda
|
SparseCsrCUDA: addmv_out_sparse_compressed_cuda
|
||||||
|
|
||||||
@ -1061,6 +1062,7 @@
|
|||||||
CPU: baddbmm_out_cpu
|
CPU: baddbmm_out_cpu
|
||||||
CUDA: baddbmm_out_cuda
|
CUDA: baddbmm_out_cuda
|
||||||
MPS: baddbmm_out_mps
|
MPS: baddbmm_out_mps
|
||||||
|
XPU: baddbmm_out_xpu
|
||||||
SparseCsrCUDA: baddbmm_out_sparse_csr_cuda
|
SparseCsrCUDA: baddbmm_out_sparse_csr_cuda
|
||||||
|
|
||||||
- func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
- func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||||
@ -1358,6 +1360,7 @@
|
|||||||
CPU: bmm_out_cpu
|
CPU: bmm_out_cpu
|
||||||
CUDA: bmm_out_cuda
|
CUDA: bmm_out_cuda
|
||||||
MPS: bmm_out_mps
|
MPS: bmm_out_mps
|
||||||
|
XPU: bmm_out_xpu
|
||||||
SparseCPU: bmm_out_sparse_cpu
|
SparseCPU: bmm_out_sparse_cpu
|
||||||
SparseCUDA: bmm_out_sparse_cuda
|
SparseCUDA: bmm_out_sparse_cuda
|
||||||
SparseCsrCUDA: bmm_out_sparse_csr_cuda
|
SparseCsrCUDA: bmm_out_sparse_csr_cuda
|
||||||
@ -4130,6 +4133,7 @@
|
|||||||
CPU: mm_out_cpu
|
CPU: mm_out_cpu
|
||||||
CUDA: mm_out_cuda
|
CUDA: mm_out_cuda
|
||||||
MPS: mm_out_mps
|
MPS: mm_out_mps
|
||||||
|
XPU: mm_out_xpu
|
||||||
SparseCPU, SparseCUDA: _sparse_mm_out
|
SparseCPU, SparseCUDA: _sparse_mm_out
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm_out
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm_out
|
||||||
|
|
||||||
@ -6993,6 +6997,7 @@
|
|||||||
CPU: addmm_out_cpu
|
CPU: addmm_out_cpu
|
||||||
CUDA: addmm_out_cuda
|
CUDA: addmm_out_cuda
|
||||||
MPS: addmm_out_mps
|
MPS: addmm_out_mps
|
||||||
|
XPU: addmm_out_xpu
|
||||||
SparseCPU: addmm_out_sparse_dense_cpu
|
SparseCPU: addmm_out_sparse_dense_cpu
|
||||||
SparseCUDA: addmm_out_sparse_dense_cuda
|
SparseCUDA: addmm_out_sparse_dense_cuda
|
||||||
SparseCsrCPU: addmm_out_sparse_compressed_cpu
|
SparseCsrCPU: addmm_out_sparse_compressed_cpu
|
||||||
@ -7021,6 +7026,7 @@
|
|||||||
dispatch:
|
dispatch:
|
||||||
CPU: addmm_activation_out_cpu
|
CPU: addmm_activation_out_cpu
|
||||||
CUDA: addmm_activation_out_cuda
|
CUDA: addmm_activation_out_cuda
|
||||||
|
XPU: addmm_activation_out_xpu
|
||||||
|
|
||||||
- func: _addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor
|
- func: _addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor
|
||||||
structured_delegate: _addmm_activation.out
|
structured_delegate: _addmm_activation.out
|
||||||
@ -8655,18 +8661,18 @@
|
|||||||
- func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
|
- func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
|
||||||
variants: method
|
variants: method
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: addbmm_
|
CPU, CUDA, XPU: addbmm_
|
||||||
MPS: addbmm_mps_
|
MPS: addbmm_mps_
|
||||||
|
|
||||||
- func: addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
- func: addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: addbmm_out
|
CPU, CUDA, XPU: addbmm_out
|
||||||
MPS: addbmm_out_mps
|
MPS: addbmm_out_mps
|
||||||
|
|
||||||
- func: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
- func: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
||||||
variants: method, function
|
variants: method, function
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: addbmm
|
CPU, CUDA, XPU: addbmm
|
||||||
MPS: addbmm_mps
|
MPS: addbmm_mps
|
||||||
|
|
||||||
- func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)
|
- func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
#if defined(CAFFE2_BUILD_MAIN_LIB) || \
|
#if defined(CAFFE2_BUILD_MAIN_LIB) || \
|
||||||
defined(TORCH_CUDA_BUILD_MAIN_LIB) || \
|
defined(TORCH_CUDA_BUILD_MAIN_LIB) || \
|
||||||
defined(TORCH_HIP_BUILD_MAIN_LIB) || \
|
defined(TORCH_HIP_BUILD_MAIN_LIB) || \
|
||||||
|
defined(TORCH_XPU_BUILD_MAIN_LIB) || \
|
||||||
defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \
|
defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \
|
||||||
defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB)
|
defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB)
|
||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
|
@ -332,6 +332,7 @@ set(TORCH_GENERATED_CODE
|
|||||||
${GENERATED_H_PYTHON}
|
${GENERATED_H_PYTHON}
|
||||||
${GENERATED_TESTING_PYTHON}
|
${GENERATED_TESTING_PYTHON}
|
||||||
${GENERATED_CXX_TORCH_CUDA}
|
${GENERATED_CXX_TORCH_CUDA}
|
||||||
|
${GENERATED_CXX_TORCH_XPU}
|
||||||
)
|
)
|
||||||
|
|
||||||
set(GEN_PER_OPERATOR_FLAG)
|
set(GEN_PER_OPERATOR_FLAG)
|
||||||
|
@ -94,6 +94,11 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||||||
set(GEN_MPS_FLAG --mps)
|
set(GEN_MPS_FLAG --mps)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
set(GEN_XPU_FLAG)
|
||||||
|
if(USE_XPU)
|
||||||
|
set(GEN_XPU_FLAG --xpu)
|
||||||
|
endif()
|
||||||
|
|
||||||
set(CUSTOM_BUILD_FLAGS)
|
set(CUSTOM_BUILD_FLAGS)
|
||||||
if(INTERN_BUILD_MOBILE)
|
if(INTERN_BUILD_MOBILE)
|
||||||
if(USE_VULKAN)
|
if(USE_VULKAN)
|
||||||
@ -179,6 +184,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||||||
${GEN_PER_OPERATOR_FLAG}
|
${GEN_PER_OPERATOR_FLAG}
|
||||||
${GEN_ROCM_FLAG}
|
${GEN_ROCM_FLAG}
|
||||||
${GEN_MPS_FLAG}
|
${GEN_MPS_FLAG}
|
||||||
|
${GEN_XPU_FLAG}
|
||||||
${CUSTOM_BUILD_FLAGS}
|
${CUSTOM_BUILD_FLAGS}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -217,22 +223,31 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||||||
include("${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake")
|
include("${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake")
|
||||||
include("${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake")
|
include("${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake")
|
||||||
include("${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake")
|
include("${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake")
|
||||||
|
if(USE_XPU)
|
||||||
|
include("${CMAKE_BINARY_DIR}/aten/src/ATen/xpu_generated_${gen_type}.cmake")
|
||||||
|
endif()
|
||||||
message(STATUS "${gen_type} outputs: ${gen_outputs}")
|
message(STATUS "${gen_type} outputs: ${gen_outputs}")
|
||||||
|
set(OUTPUT_LIST
|
||||||
|
${generated_${gen_type}}
|
||||||
|
${cuda_generated_${gen_type}}
|
||||||
|
${core_generated_${gen_type}}
|
||||||
|
${cpu_vec_generated_${gen_type}}
|
||||||
|
${ops_generated_${gen_type}}
|
||||||
|
${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake
|
||||||
|
${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake
|
||||||
|
${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake
|
||||||
|
${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake
|
||||||
|
${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake)
|
||||||
|
if(USE_XPU)
|
||||||
|
list(APPEND OUTPUT_LIST
|
||||||
|
${xpu_generated_${gen_type}}
|
||||||
|
${CMAKE_BINARY_DIR}/aten/src/ATen/xpu_generated_${gen_type}.cmake
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
COMMENT "Generating ATen ${gen_type}"
|
COMMENT "Generating ATen ${gen_type}"
|
||||||
OUTPUT
|
OUTPUT ${OUTPUT_LIST}
|
||||||
${generated_${gen_type}}
|
|
||||||
${cuda_generated_${gen_type}}
|
|
||||||
${core_generated_${gen_type}}
|
|
||||||
${cpu_vec_generated_${gen_type}}
|
|
||||||
${ops_generated_${gen_type}}
|
|
||||||
${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake
|
|
||||||
${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake
|
|
||||||
${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake
|
|
||||||
${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake
|
|
||||||
${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake
|
|
||||||
COMMAND ${GEN_COMMAND_${gen_type}}
|
COMMAND ${GEN_COMMAND_${gen_type}}
|
||||||
DEPENDS ${all_python} ${${gen_type}_templates}
|
DEPENDS ${all_python} ${${gen_type}_templates}
|
||||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml
|
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml
|
||||||
@ -260,6 +275,16 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||||||
target_compile_definitions(ATEN_CUDA_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS)
|
target_compile_definitions(ATEN_CUDA_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(USE_XPU)
|
||||||
|
add_custom_target(ATEN_XPU_FILES_GEN_TARGET DEPENDS
|
||||||
|
${xpu_generated_headers} ${xpu_generated_sources})
|
||||||
|
add_library(ATEN_XPU_FILES_GEN_LIB INTERFACE)
|
||||||
|
add_dependencies(ATEN_XPU_FILES_GEN_LIB ATEN_XPU_FILES_GEN_TARGET)
|
||||||
|
|
||||||
|
if(USE_PER_OPERATOR_HEADERS)
|
||||||
|
target_compile_definitions(ATEN_XPU_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
# Handle source files that need to be compiled multiple times for
|
# Handle source files that need to be compiled multiple times for
|
||||||
# different vectorization options
|
# different vectorization options
|
||||||
file(GLOB cpu_kernel_cpp_in "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/cpu/*.cpp" "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/kernels/*.cpp")
|
file(GLOB cpu_kernel_cpp_in "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/cpu/*.cpp" "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/kernels/*.cpp")
|
||||||
|
@ -342,7 +342,7 @@ inductor_expected_failures_single_sample["xpu"] = {
|
|||||||
"cholesky_solve": {f64},
|
"cholesky_solve": {f64},
|
||||||
"cholesky_inverse": {f64},
|
"cholesky_inverse": {f64},
|
||||||
# could not create a primitive
|
# could not create a primitive
|
||||||
"addbmm": {f16, f32, f64},
|
"addbmm": {f64},
|
||||||
"addmm": {f16, f32, f64},
|
"addmm": {f16, f32, f64},
|
||||||
"addmv": {f32, f64},
|
"addmv": {f32, f64},
|
||||||
# could not create a primitive descriptor for
|
# could not create a primitive descriptor for
|
||||||
|
@ -53,6 +53,7 @@ from torchgen.model import (
|
|||||||
BackendMetadata,
|
BackendMetadata,
|
||||||
BaseOperatorName,
|
BaseOperatorName,
|
||||||
DEFAULT_KERNEL_NAMESPACE,
|
DEFAULT_KERNEL_NAMESPACE,
|
||||||
|
dispatch_device_map,
|
||||||
DispatchKey,
|
DispatchKey,
|
||||||
FRAGMENT_NAMESPACES,
|
FRAGMENT_NAMESPACES,
|
||||||
FunctionSchema,
|
FunctionSchema,
|
||||||
@ -143,6 +144,25 @@ _GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
|
|||||||
_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
|
_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def file_manager_from_dispatch_key(
|
||||||
|
dispatch_key: DispatchKey,
|
||||||
|
device_fms: dict[str, FileManager],
|
||||||
|
default_fm: FileManager,
|
||||||
|
) -> FileManager:
|
||||||
|
fm = device_fms.get(
|
||||||
|
next(
|
||||||
|
(
|
||||||
|
device
|
||||||
|
for check, device in dispatch_device_map.items()
|
||||||
|
if check(dispatch_key)
|
||||||
|
),
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
default_fm,
|
||||||
|
)
|
||||||
|
return fm
|
||||||
|
|
||||||
|
|
||||||
def parse_native_yaml_struct(
|
def parse_native_yaml_struct(
|
||||||
es: object,
|
es: object,
|
||||||
valid_tags: set[str],
|
valid_tags: set[str],
|
||||||
@ -1716,7 +1736,7 @@ def gen_aggregated_headers(
|
|||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
backend_indices: dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
cpu_fm: FileManager,
|
cpu_fm: FileManager,
|
||||||
cuda_fm: FileManager,
|
device_fms: dict[str, FileManager],
|
||||||
functions_keys: set[DispatchKey],
|
functions_keys: set[DispatchKey],
|
||||||
dispatch_keys: Sequence[DispatchKey],
|
dispatch_keys: Sequence[DispatchKey],
|
||||||
rocm: bool,
|
rocm: bool,
|
||||||
@ -1796,7 +1816,7 @@ def gen_aggregated_headers(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for dispatch_key in dispatch_keys:
|
for dispatch_key in dispatch_keys:
|
||||||
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
|
fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
|
||||||
if dispatch_key in functions_keys:
|
if dispatch_key in functions_keys:
|
||||||
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
|
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
|
||||||
|
|
||||||
@ -1836,7 +1856,7 @@ def gen_per_operator_headers(
|
|||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
backend_indices: dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
cpu_fm: FileManager,
|
cpu_fm: FileManager,
|
||||||
cuda_fm: FileManager,
|
device_fms: dict[str, FileManager],
|
||||||
ops_fm: FileManager,
|
ops_fm: FileManager,
|
||||||
functions_keys: set[DispatchKey],
|
functions_keys: set[DispatchKey],
|
||||||
dispatch_keys: Sequence[DispatchKey],
|
dispatch_keys: Sequence[DispatchKey],
|
||||||
@ -1984,7 +2004,7 @@ def gen_per_operator_headers(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
|
fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
|
||||||
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
|
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
|
||||||
|
|
||||||
fm.write_with_template(
|
fm.write_with_template(
|
||||||
@ -2033,7 +2053,7 @@ def gen_headers(
|
|||||||
backend_indices: dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
core_fm: FileManager,
|
core_fm: FileManager,
|
||||||
cpu_fm: FileManager,
|
cpu_fm: FileManager,
|
||||||
cuda_fm: FileManager,
|
device_fms: dict[str, FileManager],
|
||||||
ops_fm: FileManager,
|
ops_fm: FileManager,
|
||||||
dispatch_keys: Sequence[DispatchKey],
|
dispatch_keys: Sequence[DispatchKey],
|
||||||
functions_keys: set[DispatchKey],
|
functions_keys: set[DispatchKey],
|
||||||
@ -2048,7 +2068,7 @@ def gen_headers(
|
|||||||
selector=selector,
|
selector=selector,
|
||||||
backend_indices=backend_indices,
|
backend_indices=backend_indices,
|
||||||
cpu_fm=cpu_fm,
|
cpu_fm=cpu_fm,
|
||||||
cuda_fm=cuda_fm,
|
device_fms=device_fms,
|
||||||
ops_fm=ops_fm,
|
ops_fm=ops_fm,
|
||||||
dispatch_keys=dispatch_keys,
|
dispatch_keys=dispatch_keys,
|
||||||
functions_keys=functions_keys,
|
functions_keys=functions_keys,
|
||||||
@ -2063,7 +2083,7 @@ def gen_headers(
|
|||||||
selector=selector,
|
selector=selector,
|
||||||
backend_indices=backend_indices,
|
backend_indices=backend_indices,
|
||||||
cpu_fm=cpu_fm,
|
cpu_fm=cpu_fm,
|
||||||
cuda_fm=cuda_fm,
|
device_fms=device_fms,
|
||||||
dispatch_keys=dispatch_keys,
|
dispatch_keys=dispatch_keys,
|
||||||
functions_keys=functions_keys,
|
functions_keys=functions_keys,
|
||||||
rocm=rocm,
|
rocm=rocm,
|
||||||
@ -2171,9 +2191,9 @@ def gen_source_files(
|
|||||||
backend_indices: dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
aoti_fm: FileManager,
|
aoti_fm: FileManager,
|
||||||
core_fm: FileManager,
|
core_fm: FileManager,
|
||||||
cpu_fm: FileManager,
|
|
||||||
cpu_vec_fm: FileManager,
|
cpu_vec_fm: FileManager,
|
||||||
cuda_fm: FileManager,
|
cpu_fm: FileManager,
|
||||||
|
device_fms: dict[str, FileManager],
|
||||||
dispatch_keys: Sequence[DispatchKey],
|
dispatch_keys: Sequence[DispatchKey],
|
||||||
functions_keys: set[DispatchKey],
|
functions_keys: set[DispatchKey],
|
||||||
rocm: bool,
|
rocm: bool,
|
||||||
@ -2195,8 +2215,7 @@ def gen_source_files(
|
|||||||
#include <ATen/hip/HIPContext.h>"""
|
#include <ATen/hip/HIPContext.h>"""
|
||||||
|
|
||||||
for dispatch_key in dispatch_keys:
|
for dispatch_key in dispatch_keys:
|
||||||
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
|
fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
|
||||||
|
|
||||||
if per_operator_headers:
|
if per_operator_headers:
|
||||||
|
|
||||||
def operator_headers() -> list[str]:
|
def operator_headers() -> list[str]:
|
||||||
@ -2752,6 +2771,12 @@ def main() -> None:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Generate MPS registration code when set",
|
help="Generate MPS registration code when set",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--xpu",
|
||||||
|
action="store_true",
|
||||||
|
help="Generate XPU registration code when set",
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: --op-registration-whitelist will be removed when all call-sites
|
# TODO: --op-registration-whitelist will be removed when all call-sites
|
||||||
# for gen.py are moved over to using the operator YAML file for mobile
|
# for gen.py are moved over to using the operator YAML file for mobile
|
||||||
# custom build.
|
# custom build.
|
||||||
@ -2833,6 +2858,19 @@ def main() -> None:
|
|||||||
if DispatchKey.MPS in dispatch_keys:
|
if DispatchKey.MPS in dispatch_keys:
|
||||||
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
|
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
|
||||||
|
|
||||||
|
xpu_in_whitelist = (
|
||||||
|
options.backend_whitelist and str(DispatchKey.XPU) in options.backend_whitelist
|
||||||
|
)
|
||||||
|
# Only generate RegisterXPU.cpp when there is "--xpu" with torhgen/gen.py
|
||||||
|
# Before this change, torchgen always generates RegisterXPU.cpp for out-of-tree
|
||||||
|
# torch-xpu-ops native_functions.yaml which use --backend_whitelist=XPU and without "--xpu".
|
||||||
|
# After this change is landed, we will add --xpu in torch-xpu-ops and remove the check of "xpu_in_whitelist".
|
||||||
|
if (not options.xpu) and (not xpu_in_whitelist):
|
||||||
|
ignore_keys.add(DispatchKey.XPU)
|
||||||
|
|
||||||
|
if DispatchKey.XPU in dispatch_keys:
|
||||||
|
del dispatch_keys[dispatch_keys.index(DispatchKey.XPU)]
|
||||||
|
|
||||||
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
|
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
|
||||||
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
|
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
|
||||||
native_functions, backend_indices = (
|
native_functions, backend_indices = (
|
||||||
@ -2877,6 +2915,9 @@ def main() -> None:
|
|||||||
cuda_fm = make_file_manager(options=options)
|
cuda_fm = make_file_manager(options=options)
|
||||||
ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
|
ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
|
||||||
aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
|
aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
|
||||||
|
device_fms = {"cuda": cuda_fm}
|
||||||
|
if options.xpu:
|
||||||
|
device_fms["xpu"] = make_file_manager(options=options)
|
||||||
|
|
||||||
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
|
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
|
||||||
# for them; this is the set
|
# for them; this is the set
|
||||||
@ -2892,6 +2933,9 @@ def main() -> None:
|
|||||||
if options.mps:
|
if options.mps:
|
||||||
functions_keys.add(DispatchKey.MPS)
|
functions_keys.add(DispatchKey.MPS)
|
||||||
|
|
||||||
|
if options.xpu:
|
||||||
|
functions_keys.add(DispatchKey.XPU)
|
||||||
|
|
||||||
if options.backend_whitelist:
|
if options.backend_whitelist:
|
||||||
dispatch_keys = [
|
dispatch_keys = [
|
||||||
k
|
k
|
||||||
@ -2921,9 +2965,9 @@ def main() -> None:
|
|||||||
backend_indices=backend_indices,
|
backend_indices=backend_indices,
|
||||||
aoti_fm=aoti_fm,
|
aoti_fm=aoti_fm,
|
||||||
core_fm=core_fm,
|
core_fm=core_fm,
|
||||||
cpu_fm=cpu_fm,
|
|
||||||
cpu_vec_fm=cpu_vec_fm,
|
cpu_vec_fm=cpu_vec_fm,
|
||||||
cuda_fm=cuda_fm,
|
cpu_fm=cpu_fm,
|
||||||
|
device_fms=device_fms,
|
||||||
dispatch_keys=dispatch_keys,
|
dispatch_keys=dispatch_keys,
|
||||||
functions_keys=functions_keys,
|
functions_keys=functions_keys,
|
||||||
rocm=options.rocm,
|
rocm=options.rocm,
|
||||||
@ -2944,7 +2988,7 @@ def main() -> None:
|
|||||||
backend_indices=backend_indices,
|
backend_indices=backend_indices,
|
||||||
core_fm=core_fm,
|
core_fm=core_fm,
|
||||||
cpu_fm=cpu_fm,
|
cpu_fm=cpu_fm,
|
||||||
cuda_fm=cuda_fm,
|
device_fms=device_fms,
|
||||||
ops_fm=ops_fm,
|
ops_fm=ops_fm,
|
||||||
dispatch_keys=dispatch_keys,
|
dispatch_keys=dispatch_keys,
|
||||||
functions_keys=functions_keys,
|
functions_keys=functions_keys,
|
||||||
@ -2964,9 +3008,8 @@ def main() -> None:
|
|||||||
(cpu_fm, ""),
|
(cpu_fm, ""),
|
||||||
(cpu_vec_fm, "cpu_vec_"),
|
(cpu_vec_fm, "cpu_vec_"),
|
||||||
(core_fm, "core_"),
|
(core_fm, "core_"),
|
||||||
(cuda_fm, "cuda_"),
|
|
||||||
(ops_fm, "ops_"),
|
(ops_fm, "ops_"),
|
||||||
]:
|
] + [(device_fm, f"{device}_") for device, device_fm in device_fms.items()]:
|
||||||
varname = prefix + depfile_stem
|
varname = prefix + depfile_stem
|
||||||
path = depfile_path.parent / (prefix + depfile_name)
|
path = depfile_path.parent / (prefix + depfile_name)
|
||||||
fm.write_outputs(varname, str(path))
|
fm.write_outputs(varname, str(path))
|
||||||
|
@ -346,6 +346,9 @@ def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
|
|||||||
return dk in UFUNC_DISPATCH_KEYS
|
return dk in UFUNC_DISPATCH_KEYS
|
||||||
|
|
||||||
|
|
||||||
|
dispatch_device_map = {is_cuda_dispatch_key: "cuda", is_xpu_dispatch_key: "xpu"}
|
||||||
|
|
||||||
|
|
||||||
# This is oddly named ScalarType and not DType for symmetry with C++
|
# This is oddly named ScalarType and not DType for symmetry with C++
|
||||||
class ScalarType(Enum):
|
class ScalarType(Enum):
|
||||||
Byte = auto()
|
Byte = auto()
|
||||||
|
Reference in New Issue
Block a user