mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Intel GPU] Support RegisterXPU.cpp codegen and compile for the in-tree XPU structured GEMM OPs. (#139025)
[Intel GPU] Support RegisterXPU.cpp codegen and compile for the in-tree XPU structured GEMM ops. Motivation: There are two parts of aten ops for XPU, one is in-tree ops like GEMM related OPs and the other is out-off-tree ops in torch-xpu-ops. For the in-tree part,since Pytorch uses native_functions.yaml registration and is equipped with convenient codegen capabilities, we want to take advantage of these benefits as well. At the same time, since AOT Inductor also uses native_functions.yaml to generate c shim wrappers, we also need to enable this mechanism for XPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139025 Approved by: https://github.com/EikanWang, https://github.com/jansel, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b650c360a
commit
929a647363
@ -348,6 +348,7 @@ endif()
|
||||
if(USE_XPU)
|
||||
list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/xpu)
|
||||
list(APPEND ATen_XPU_SRCS ${xpu_cpp})
|
||||
list(APPEND ATen_XPU_SRCS ${xpu_generated_sources})
|
||||
endif()
|
||||
|
||||
list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||
|
@ -1,9 +1,24 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
||||
namespace at::native::xpu {
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
namespace xpu {
|
||||
|
||||
// result = beta * self + alpha * (mat1 * mat2)
|
||||
Tensor& addmm_out(
|
||||
@ -425,20 +440,35 @@ Tensor& tensordot_out(
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, XPU, m){
|
||||
m.impl("addmm.out", TORCH_FN(addmm_out));
|
||||
m.impl("_addmm_activation.out", TORCH_FN(_addmm_activation_out));
|
||||
m.impl("mm.out", TORCH_FN(mm_out));
|
||||
m.impl("mm", TORCH_FN(mm));
|
||||
m.impl("baddbmm.out", TORCH_FN(baddbmm_out));
|
||||
m.impl("baddbmm_", TORCH_FN(baddbmm_));
|
||||
m.impl("baddbmm", TORCH_FN(baddbmm));
|
||||
m.impl("addbmm.out", TORCH_FN(addbmm_out));
|
||||
m.impl("addbmm_", TORCH_FN(addbmm_));
|
||||
m.impl("addbmm", TORCH_FN(addbmm));
|
||||
m.impl("bmm.out", TORCH_FN(bmm_out));
|
||||
m.impl("bmm", TORCH_FN(bmm));
|
||||
m.impl("addmv.out", TORCH_FN(addmv_out));
|
||||
m.impl("tensordot.out", TORCH_FN(tensordot_out));
|
||||
}
|
||||
} // namespace xpu
|
||||
|
||||
} // namespace at::native::xpu
|
||||
TORCH_IMPL_FUNC(addmm_out_xpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
xpu::addmm_out(self, mat1, mat2, beta, alpha, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(mm_out_xpu)(const Tensor& self, const Tensor& mat2, const Tensor& result) {
|
||||
xpu::mm_out(self, mat2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(bmm_out_xpu)(const Tensor& self, const Tensor& batch2, const Tensor &result) {
|
||||
xpu::bmm_out(self, batch2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addmm_activation_out_xpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu, const Tensor& result) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
xpu::_addmm_activation_out(self, mat1, mat2, beta, alpha, use_gelu, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(baddbmm_out_xpu)(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||
xpu::baddbmm_out(
|
||||
self, batch1, batch2, beta, alpha, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addmv_out_xpu)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||
xpu::addmv_out(self, mat, vec, beta, alpha, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
} // namespace at::native
|
@ -641,6 +641,7 @@
|
||||
CPU: addmv_out_cpu
|
||||
CUDA: addmv_out_cuda
|
||||
MPS: addmv_out_mps
|
||||
XPU: addmv_out_xpu
|
||||
SparseCsrCPU: addmv_out_sparse_compressed
|
||||
SparseCsrCUDA: addmv_out_sparse_compressed_cuda
|
||||
|
||||
@ -1061,6 +1062,7 @@
|
||||
CPU: baddbmm_out_cpu
|
||||
CUDA: baddbmm_out_cuda
|
||||
MPS: baddbmm_out_mps
|
||||
XPU: baddbmm_out_xpu
|
||||
SparseCsrCUDA: baddbmm_out_sparse_csr_cuda
|
||||
|
||||
- func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
@ -1358,6 +1360,7 @@
|
||||
CPU: bmm_out_cpu
|
||||
CUDA: bmm_out_cuda
|
||||
MPS: bmm_out_mps
|
||||
XPU: bmm_out_xpu
|
||||
SparseCPU: bmm_out_sparse_cpu
|
||||
SparseCUDA: bmm_out_sparse_cuda
|
||||
SparseCsrCUDA: bmm_out_sparse_csr_cuda
|
||||
@ -4130,6 +4133,7 @@
|
||||
CPU: mm_out_cpu
|
||||
CUDA: mm_out_cuda
|
||||
MPS: mm_out_mps
|
||||
XPU: mm_out_xpu
|
||||
SparseCPU, SparseCUDA: _sparse_mm_out
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm_out
|
||||
|
||||
@ -6993,6 +6997,7 @@
|
||||
CPU: addmm_out_cpu
|
||||
CUDA: addmm_out_cuda
|
||||
MPS: addmm_out_mps
|
||||
XPU: addmm_out_xpu
|
||||
SparseCPU: addmm_out_sparse_dense_cpu
|
||||
SparseCUDA: addmm_out_sparse_dense_cuda
|
||||
SparseCsrCPU: addmm_out_sparse_compressed_cpu
|
||||
@ -7021,6 +7026,7 @@
|
||||
dispatch:
|
||||
CPU: addmm_activation_out_cpu
|
||||
CUDA: addmm_activation_out_cuda
|
||||
XPU: addmm_activation_out_xpu
|
||||
|
||||
- func: _addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor
|
||||
structured_delegate: _addmm_activation.out
|
||||
@ -8655,18 +8661,18 @@
|
||||
- func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: addbmm_
|
||||
CPU, CUDA, XPU: addbmm_
|
||||
MPS: addbmm_mps_
|
||||
|
||||
- func: addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: addbmm_out
|
||||
CPU, CUDA, XPU: addbmm_out
|
||||
MPS: addbmm_out_mps
|
||||
|
||||
- func: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
||||
variants: method, function
|
||||
dispatch:
|
||||
CPU, CUDA: addbmm
|
||||
CPU, CUDA, XPU: addbmm
|
||||
MPS: addbmm_mps
|
||||
|
||||
- func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)
|
||||
|
@ -16,6 +16,7 @@
|
||||
#if defined(CAFFE2_BUILD_MAIN_LIB) || \
|
||||
defined(TORCH_CUDA_BUILD_MAIN_LIB) || \
|
||||
defined(TORCH_HIP_BUILD_MAIN_LIB) || \
|
||||
defined(TORCH_XPU_BUILD_MAIN_LIB) || \
|
||||
defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \
|
||||
defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB)
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
|
@ -332,6 +332,7 @@ set(TORCH_GENERATED_CODE
|
||||
${GENERATED_H_PYTHON}
|
||||
${GENERATED_TESTING_PYTHON}
|
||||
${GENERATED_CXX_TORCH_CUDA}
|
||||
${GENERATED_CXX_TORCH_XPU}
|
||||
)
|
||||
|
||||
set(GEN_PER_OPERATOR_FLAG)
|
||||
|
@ -94,6 +94,11 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
set(GEN_MPS_FLAG --mps)
|
||||
endif()
|
||||
|
||||
set(GEN_XPU_FLAG)
|
||||
if(USE_XPU)
|
||||
set(GEN_XPU_FLAG --xpu)
|
||||
endif()
|
||||
|
||||
set(CUSTOM_BUILD_FLAGS)
|
||||
if(INTERN_BUILD_MOBILE)
|
||||
if(USE_VULKAN)
|
||||
@ -179,6 +184,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
${GEN_PER_OPERATOR_FLAG}
|
||||
${GEN_ROCM_FLAG}
|
||||
${GEN_MPS_FLAG}
|
||||
${GEN_XPU_FLAG}
|
||||
${CUSTOM_BUILD_FLAGS}
|
||||
)
|
||||
|
||||
@ -217,22 +223,31 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
include("${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake")
|
||||
include("${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake")
|
||||
include("${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake")
|
||||
|
||||
if(USE_XPU)
|
||||
include("${CMAKE_BINARY_DIR}/aten/src/ATen/xpu_generated_${gen_type}.cmake")
|
||||
endif()
|
||||
message(STATUS "${gen_type} outputs: ${gen_outputs}")
|
||||
set(OUTPUT_LIST
|
||||
${generated_${gen_type}}
|
||||
${cuda_generated_${gen_type}}
|
||||
${core_generated_${gen_type}}
|
||||
${cpu_vec_generated_${gen_type}}
|
||||
${ops_generated_${gen_type}}
|
||||
${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake
|
||||
${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake
|
||||
${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake
|
||||
${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake
|
||||
${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake)
|
||||
if(USE_XPU)
|
||||
list(APPEND OUTPUT_LIST
|
||||
${xpu_generated_${gen_type}}
|
||||
${CMAKE_BINARY_DIR}/aten/src/ATen/xpu_generated_${gen_type}.cmake
|
||||
)
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
COMMENT "Generating ATen ${gen_type}"
|
||||
OUTPUT
|
||||
${generated_${gen_type}}
|
||||
${cuda_generated_${gen_type}}
|
||||
${core_generated_${gen_type}}
|
||||
${cpu_vec_generated_${gen_type}}
|
||||
${ops_generated_${gen_type}}
|
||||
${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake
|
||||
${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake
|
||||
${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake
|
||||
${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake
|
||||
${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake
|
||||
OUTPUT ${OUTPUT_LIST}
|
||||
COMMAND ${GEN_COMMAND_${gen_type}}
|
||||
DEPENDS ${all_python} ${${gen_type}_templates}
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml
|
||||
@ -260,6 +275,16 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
target_compile_definitions(ATEN_CUDA_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS)
|
||||
endif()
|
||||
|
||||
if(USE_XPU)
|
||||
add_custom_target(ATEN_XPU_FILES_GEN_TARGET DEPENDS
|
||||
${xpu_generated_headers} ${xpu_generated_sources})
|
||||
add_library(ATEN_XPU_FILES_GEN_LIB INTERFACE)
|
||||
add_dependencies(ATEN_XPU_FILES_GEN_LIB ATEN_XPU_FILES_GEN_TARGET)
|
||||
|
||||
if(USE_PER_OPERATOR_HEADERS)
|
||||
target_compile_definitions(ATEN_XPU_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS)
|
||||
endif()
|
||||
endif()
|
||||
# Handle source files that need to be compiled multiple times for
|
||||
# different vectorization options
|
||||
file(GLOB cpu_kernel_cpp_in "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/cpu/*.cpp" "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/kernels/*.cpp")
|
||||
|
@ -342,7 +342,7 @@ inductor_expected_failures_single_sample["xpu"] = {
|
||||
"cholesky_solve": {f64},
|
||||
"cholesky_inverse": {f64},
|
||||
# could not create a primitive
|
||||
"addbmm": {f16, f32, f64},
|
||||
"addbmm": {f64},
|
||||
"addmm": {f16, f32, f64},
|
||||
"addmv": {f32, f64},
|
||||
# could not create a primitive descriptor for
|
||||
|
@ -53,6 +53,7 @@ from torchgen.model import (
|
||||
BackendMetadata,
|
||||
BaseOperatorName,
|
||||
DEFAULT_KERNEL_NAMESPACE,
|
||||
dispatch_device_map,
|
||||
DispatchKey,
|
||||
FRAGMENT_NAMESPACES,
|
||||
FunctionSchema,
|
||||
@ -143,6 +144,25 @@ _GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
|
||||
_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
|
||||
|
||||
|
||||
def file_manager_from_dispatch_key(
|
||||
dispatch_key: DispatchKey,
|
||||
device_fms: dict[str, FileManager],
|
||||
default_fm: FileManager,
|
||||
) -> FileManager:
|
||||
fm = device_fms.get(
|
||||
next(
|
||||
(
|
||||
device
|
||||
for check, device in dispatch_device_map.items()
|
||||
if check(dispatch_key)
|
||||
),
|
||||
"",
|
||||
),
|
||||
default_fm,
|
||||
)
|
||||
return fm
|
||||
|
||||
|
||||
def parse_native_yaml_struct(
|
||||
es: object,
|
||||
valid_tags: set[str],
|
||||
@ -1716,7 +1736,7 @@ def gen_aggregated_headers(
|
||||
selector: SelectiveBuilder,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
cpu_fm: FileManager,
|
||||
cuda_fm: FileManager,
|
||||
device_fms: dict[str, FileManager],
|
||||
functions_keys: set[DispatchKey],
|
||||
dispatch_keys: Sequence[DispatchKey],
|
||||
rocm: bool,
|
||||
@ -1796,7 +1816,7 @@ def gen_aggregated_headers(
|
||||
)
|
||||
|
||||
for dispatch_key in dispatch_keys:
|
||||
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
|
||||
fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
|
||||
if dispatch_key in functions_keys:
|
||||
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
|
||||
|
||||
@ -1836,7 +1856,7 @@ def gen_per_operator_headers(
|
||||
selector: SelectiveBuilder,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
cpu_fm: FileManager,
|
||||
cuda_fm: FileManager,
|
||||
device_fms: dict[str, FileManager],
|
||||
ops_fm: FileManager,
|
||||
functions_keys: set[DispatchKey],
|
||||
dispatch_keys: Sequence[DispatchKey],
|
||||
@ -1984,7 +2004,7 @@ def gen_per_operator_headers(
|
||||
},
|
||||
)
|
||||
|
||||
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
|
||||
fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
|
||||
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
|
||||
|
||||
fm.write_with_template(
|
||||
@ -2033,7 +2053,7 @@ def gen_headers(
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
core_fm: FileManager,
|
||||
cpu_fm: FileManager,
|
||||
cuda_fm: FileManager,
|
||||
device_fms: dict[str, FileManager],
|
||||
ops_fm: FileManager,
|
||||
dispatch_keys: Sequence[DispatchKey],
|
||||
functions_keys: set[DispatchKey],
|
||||
@ -2048,7 +2068,7 @@ def gen_headers(
|
||||
selector=selector,
|
||||
backend_indices=backend_indices,
|
||||
cpu_fm=cpu_fm,
|
||||
cuda_fm=cuda_fm,
|
||||
device_fms=device_fms,
|
||||
ops_fm=ops_fm,
|
||||
dispatch_keys=dispatch_keys,
|
||||
functions_keys=functions_keys,
|
||||
@ -2063,7 +2083,7 @@ def gen_headers(
|
||||
selector=selector,
|
||||
backend_indices=backend_indices,
|
||||
cpu_fm=cpu_fm,
|
||||
cuda_fm=cuda_fm,
|
||||
device_fms=device_fms,
|
||||
dispatch_keys=dispatch_keys,
|
||||
functions_keys=functions_keys,
|
||||
rocm=rocm,
|
||||
@ -2171,9 +2191,9 @@ def gen_source_files(
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
aoti_fm: FileManager,
|
||||
core_fm: FileManager,
|
||||
cpu_fm: FileManager,
|
||||
cpu_vec_fm: FileManager,
|
||||
cuda_fm: FileManager,
|
||||
cpu_fm: FileManager,
|
||||
device_fms: dict[str, FileManager],
|
||||
dispatch_keys: Sequence[DispatchKey],
|
||||
functions_keys: set[DispatchKey],
|
||||
rocm: bool,
|
||||
@ -2195,8 +2215,7 @@ def gen_source_files(
|
||||
#include <ATen/hip/HIPContext.h>"""
|
||||
|
||||
for dispatch_key in dispatch_keys:
|
||||
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
|
||||
|
||||
fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
|
||||
if per_operator_headers:
|
||||
|
||||
def operator_headers() -> list[str]:
|
||||
@ -2752,6 +2771,12 @@ def main() -> None:
|
||||
action="store_true",
|
||||
help="Generate MPS registration code when set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--xpu",
|
||||
action="store_true",
|
||||
help="Generate XPU registration code when set",
|
||||
)
|
||||
|
||||
# TODO: --op-registration-whitelist will be removed when all call-sites
|
||||
# for gen.py are moved over to using the operator YAML file for mobile
|
||||
# custom build.
|
||||
@ -2833,6 +2858,19 @@ def main() -> None:
|
||||
if DispatchKey.MPS in dispatch_keys:
|
||||
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
|
||||
|
||||
xpu_in_whitelist = (
|
||||
options.backend_whitelist and str(DispatchKey.XPU) in options.backend_whitelist
|
||||
)
|
||||
# Only generate RegisterXPU.cpp when there is "--xpu" with torhgen/gen.py
|
||||
# Before this change, torchgen always generates RegisterXPU.cpp for out-of-tree
|
||||
# torch-xpu-ops native_functions.yaml which use --backend_whitelist=XPU and without "--xpu".
|
||||
# After this change is landed, we will add --xpu in torch-xpu-ops and remove the check of "xpu_in_whitelist".
|
||||
if (not options.xpu) and (not xpu_in_whitelist):
|
||||
ignore_keys.add(DispatchKey.XPU)
|
||||
|
||||
if DispatchKey.XPU in dispatch_keys:
|
||||
del dispatch_keys[dispatch_keys.index(DispatchKey.XPU)]
|
||||
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
|
||||
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
|
||||
native_functions, backend_indices = (
|
||||
@ -2877,6 +2915,9 @@ def main() -> None:
|
||||
cuda_fm = make_file_manager(options=options)
|
||||
ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
|
||||
aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
|
||||
device_fms = {"cuda": cuda_fm}
|
||||
if options.xpu:
|
||||
device_fms["xpu"] = make_file_manager(options=options)
|
||||
|
||||
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
|
||||
# for them; this is the set
|
||||
@ -2892,6 +2933,9 @@ def main() -> None:
|
||||
if options.mps:
|
||||
functions_keys.add(DispatchKey.MPS)
|
||||
|
||||
if options.xpu:
|
||||
functions_keys.add(DispatchKey.XPU)
|
||||
|
||||
if options.backend_whitelist:
|
||||
dispatch_keys = [
|
||||
k
|
||||
@ -2921,9 +2965,9 @@ def main() -> None:
|
||||
backend_indices=backend_indices,
|
||||
aoti_fm=aoti_fm,
|
||||
core_fm=core_fm,
|
||||
cpu_fm=cpu_fm,
|
||||
cpu_vec_fm=cpu_vec_fm,
|
||||
cuda_fm=cuda_fm,
|
||||
cpu_fm=cpu_fm,
|
||||
device_fms=device_fms,
|
||||
dispatch_keys=dispatch_keys,
|
||||
functions_keys=functions_keys,
|
||||
rocm=options.rocm,
|
||||
@ -2944,7 +2988,7 @@ def main() -> None:
|
||||
backend_indices=backend_indices,
|
||||
core_fm=core_fm,
|
||||
cpu_fm=cpu_fm,
|
||||
cuda_fm=cuda_fm,
|
||||
device_fms=device_fms,
|
||||
ops_fm=ops_fm,
|
||||
dispatch_keys=dispatch_keys,
|
||||
functions_keys=functions_keys,
|
||||
@ -2964,9 +3008,8 @@ def main() -> None:
|
||||
(cpu_fm, ""),
|
||||
(cpu_vec_fm, "cpu_vec_"),
|
||||
(core_fm, "core_"),
|
||||
(cuda_fm, "cuda_"),
|
||||
(ops_fm, "ops_"),
|
||||
]:
|
||||
] + [(device_fm, f"{device}_") for device, device_fm in device_fms.items()]:
|
||||
varname = prefix + depfile_stem
|
||||
path = depfile_path.parent / (prefix + depfile_name)
|
||||
fm.write_outputs(varname, str(path))
|
||||
|
@ -346,6 +346,9 @@ def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
|
||||
return dk in UFUNC_DISPATCH_KEYS
|
||||
|
||||
|
||||
dispatch_device_map = {is_cuda_dispatch_key: "cuda", is_xpu_dispatch_key: "xpu"}
|
||||
|
||||
|
||||
# This is oddly named ScalarType and not DType for symmetry with C++
|
||||
class ScalarType(Enum):
|
||||
Byte = auto()
|
||||
|
Reference in New Issue
Block a user