[Intel GPU] Support RegisterXPU.cpp codegen and compile for the in-tree XPU structured GEMM OPs. (#139025)

[Intel GPU] Support RegisterXPU.cpp codegen and compile for the in-tree XPU structured GEMM ops.

Motivation: There are two parts of aten ops for XPU, one is in-tree ops like GEMM related OPs and the other is out-off-tree ops in torch-xpu-ops. For the in-tree part,since Pytorch uses native_functions.yaml registration and is equipped with convenient codegen capabilities, we want to take advantage of these benefits as well.
At the same time, since AOT Inductor also uses native_functions.yaml to generate c shim wrappers, we also need to enable this mechanism for XPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139025
Approved by: https://github.com/EikanWang, https://github.com/jansel, https://github.com/desertfire
This commit is contained in:
xinan.lin
2024-11-08 18:04:34 -08:00
committed by PyTorch MergeBot
parent 0b650c360a
commit 929a647363
9 changed files with 157 additions and 47 deletions

View File

@ -348,6 +348,7 @@ endif()
if(USE_XPU) if(USE_XPU)
list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/xpu) list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/xpu)
list(APPEND ATen_XPU_SRCS ${xpu_cpp}) list(APPEND ATen_XPU_SRCS ${xpu_cpp})
list(APPEND ATen_XPU_SRCS ${xpu_generated_sources})
endif() endif()
list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..) list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..)

View File

@ -1,9 +1,24 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/WrapDimUtilsMulti.h> #include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/Resize.h> #include <ATen/native/Resize.h>
#include <torch/library.h> #include <torch/library.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h> #include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#ifndef AT_PER_OPERATOR_HEADERS
namespace at::native::xpu { #include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/mm_native.h>
#endif
namespace at::native {
namespace xpu {
// result = beta * self + alpha * (mat1 * mat2) // result = beta * self + alpha * (mat1 * mat2)
Tensor& addmm_out( Tensor& addmm_out(
@ -425,20 +440,35 @@ Tensor& tensordot_out(
} }
TORCH_LIBRARY_IMPL(aten, XPU, m){ TORCH_LIBRARY_IMPL(aten, XPU, m){
m.impl("addmm.out", TORCH_FN(addmm_out));
m.impl("_addmm_activation.out", TORCH_FN(_addmm_activation_out));
m.impl("mm.out", TORCH_FN(mm_out));
m.impl("mm", TORCH_FN(mm));
m.impl("baddbmm.out", TORCH_FN(baddbmm_out));
m.impl("baddbmm_", TORCH_FN(baddbmm_));
m.impl("baddbmm", TORCH_FN(baddbmm));
m.impl("addbmm.out", TORCH_FN(addbmm_out));
m.impl("addbmm_", TORCH_FN(addbmm_));
m.impl("addbmm", TORCH_FN(addbmm));
m.impl("bmm.out", TORCH_FN(bmm_out));
m.impl("bmm", TORCH_FN(bmm));
m.impl("addmv.out", TORCH_FN(addmv_out));
m.impl("tensordot.out", TORCH_FN(tensordot_out)); m.impl("tensordot.out", TORCH_FN(tensordot_out));
} }
} // namespace xpu
} // namespace at::native::xpu TORCH_IMPL_FUNC(addmm_out_xpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
xpu::addmm_out(self, mat1, mat2, beta, alpha, const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(mm_out_xpu)(const Tensor& self, const Tensor& mat2, const Tensor& result) {
xpu::mm_out(self, mat2, const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(bmm_out_xpu)(const Tensor& self, const Tensor& batch2, const Tensor &result) {
xpu::bmm_out(self, batch2, const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(addmm_activation_out_xpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu, const Tensor& result) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
xpu::_addmm_activation_out(self, mat1, mat2, beta, alpha, use_gelu, const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(baddbmm_out_xpu)(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
xpu::baddbmm_out(
self, batch1, batch2, beta, alpha, const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(addmv_out_xpu)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
xpu::addmv_out(self, mat, vec, beta, alpha, const_cast<Tensor&>(result));
}
} // namespace at::native

View File

@ -641,6 +641,7 @@
CPU: addmv_out_cpu CPU: addmv_out_cpu
CUDA: addmv_out_cuda CUDA: addmv_out_cuda
MPS: addmv_out_mps MPS: addmv_out_mps
XPU: addmv_out_xpu
SparseCsrCPU: addmv_out_sparse_compressed SparseCsrCPU: addmv_out_sparse_compressed
SparseCsrCUDA: addmv_out_sparse_compressed_cuda SparseCsrCUDA: addmv_out_sparse_compressed_cuda
@ -1061,6 +1062,7 @@
CPU: baddbmm_out_cpu CPU: baddbmm_out_cpu
CUDA: baddbmm_out_cuda CUDA: baddbmm_out_cuda
MPS: baddbmm_out_mps MPS: baddbmm_out_mps
XPU: baddbmm_out_xpu
SparseCsrCUDA: baddbmm_out_sparse_csr_cuda SparseCsrCUDA: baddbmm_out_sparse_csr_cuda
- func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
@ -1358,6 +1360,7 @@
CPU: bmm_out_cpu CPU: bmm_out_cpu
CUDA: bmm_out_cuda CUDA: bmm_out_cuda
MPS: bmm_out_mps MPS: bmm_out_mps
XPU: bmm_out_xpu
SparseCPU: bmm_out_sparse_cpu SparseCPU: bmm_out_sparse_cpu
SparseCUDA: bmm_out_sparse_cuda SparseCUDA: bmm_out_sparse_cuda
SparseCsrCUDA: bmm_out_sparse_csr_cuda SparseCsrCUDA: bmm_out_sparse_csr_cuda
@ -4130,6 +4133,7 @@
CPU: mm_out_cpu CPU: mm_out_cpu
CUDA: mm_out_cuda CUDA: mm_out_cuda
MPS: mm_out_mps MPS: mm_out_mps
XPU: mm_out_xpu
SparseCPU, SparseCUDA: _sparse_mm_out SparseCPU, SparseCUDA: _sparse_mm_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm_out
@ -6993,6 +6997,7 @@
CPU: addmm_out_cpu CPU: addmm_out_cpu
CUDA: addmm_out_cuda CUDA: addmm_out_cuda
MPS: addmm_out_mps MPS: addmm_out_mps
XPU: addmm_out_xpu
SparseCPU: addmm_out_sparse_dense_cpu SparseCPU: addmm_out_sparse_dense_cpu
SparseCUDA: addmm_out_sparse_dense_cuda SparseCUDA: addmm_out_sparse_dense_cuda
SparseCsrCPU: addmm_out_sparse_compressed_cpu SparseCsrCPU: addmm_out_sparse_compressed_cpu
@ -7021,6 +7026,7 @@
dispatch: dispatch:
CPU: addmm_activation_out_cpu CPU: addmm_activation_out_cpu
CUDA: addmm_activation_out_cuda CUDA: addmm_activation_out_cuda
XPU: addmm_activation_out_xpu
- func: _addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor - func: _addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor
structured_delegate: _addmm_activation.out structured_delegate: _addmm_activation.out
@ -8655,18 +8661,18 @@
- func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) - func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
variants: method variants: method
dispatch: dispatch:
CPU, CUDA: addbmm_ CPU, CUDA, XPU: addbmm_
MPS: addbmm_mps_ MPS: addbmm_mps_
- func: addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - func: addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
dispatch: dispatch:
CPU, CUDA: addbmm_out CPU, CUDA, XPU: addbmm_out
MPS: addbmm_out_mps MPS: addbmm_out_mps
- func: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - func: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: method, function variants: method, function
dispatch: dispatch:
CPU, CUDA: addbmm CPU, CUDA, XPU: addbmm
MPS: addbmm_mps MPS: addbmm_mps
- func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) - func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)

View File

@ -16,6 +16,7 @@
#if defined(CAFFE2_BUILD_MAIN_LIB) || \ #if defined(CAFFE2_BUILD_MAIN_LIB) || \
defined(TORCH_CUDA_BUILD_MAIN_LIB) || \ defined(TORCH_CUDA_BUILD_MAIN_LIB) || \
defined(TORCH_HIP_BUILD_MAIN_LIB) || \ defined(TORCH_HIP_BUILD_MAIN_LIB) || \
defined(TORCH_XPU_BUILD_MAIN_LIB) || \
defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \ defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \
defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB) defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB)
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #define TORCH_ASSERT_ONLY_METHOD_OPERATORS

View File

@ -332,6 +332,7 @@ set(TORCH_GENERATED_CODE
${GENERATED_H_PYTHON} ${GENERATED_H_PYTHON}
${GENERATED_TESTING_PYTHON} ${GENERATED_TESTING_PYTHON}
${GENERATED_CXX_TORCH_CUDA} ${GENERATED_CXX_TORCH_CUDA}
${GENERATED_CXX_TORCH_XPU}
) )
set(GEN_PER_OPERATOR_FLAG) set(GEN_PER_OPERATOR_FLAG)

View File

@ -94,6 +94,11 @@ if(INTERN_BUILD_ATEN_OPS)
set(GEN_MPS_FLAG --mps) set(GEN_MPS_FLAG --mps)
endif() endif()
set(GEN_XPU_FLAG)
if(USE_XPU)
set(GEN_XPU_FLAG --xpu)
endif()
set(CUSTOM_BUILD_FLAGS) set(CUSTOM_BUILD_FLAGS)
if(INTERN_BUILD_MOBILE) if(INTERN_BUILD_MOBILE)
if(USE_VULKAN) if(USE_VULKAN)
@ -179,6 +184,7 @@ if(INTERN_BUILD_ATEN_OPS)
${GEN_PER_OPERATOR_FLAG} ${GEN_PER_OPERATOR_FLAG}
${GEN_ROCM_FLAG} ${GEN_ROCM_FLAG}
${GEN_MPS_FLAG} ${GEN_MPS_FLAG}
${GEN_XPU_FLAG}
${CUSTOM_BUILD_FLAGS} ${CUSTOM_BUILD_FLAGS}
) )
@ -217,22 +223,31 @@ if(INTERN_BUILD_ATEN_OPS)
include("${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake")
include("${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake")
include("${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake") include("${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake")
if(USE_XPU)
include("${CMAKE_BINARY_DIR}/aten/src/ATen/xpu_generated_${gen_type}.cmake")
endif()
message(STATUS "${gen_type} outputs: ${gen_outputs}") message(STATUS "${gen_type} outputs: ${gen_outputs}")
set(OUTPUT_LIST
${generated_${gen_type}}
${cuda_generated_${gen_type}}
${core_generated_${gen_type}}
${cpu_vec_generated_${gen_type}}
${ops_generated_${gen_type}}
${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake)
if(USE_XPU)
list(APPEND OUTPUT_LIST
${xpu_generated_${gen_type}}
${CMAKE_BINARY_DIR}/aten/src/ATen/xpu_generated_${gen_type}.cmake
)
endif()
add_custom_command( add_custom_command(
COMMENT "Generating ATen ${gen_type}" COMMENT "Generating ATen ${gen_type}"
OUTPUT OUTPUT ${OUTPUT_LIST}
${generated_${gen_type}}
${cuda_generated_${gen_type}}
${core_generated_${gen_type}}
${cpu_vec_generated_${gen_type}}
${ops_generated_${gen_type}}
${CMAKE_BINARY_DIR}/aten/src/ATen/generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/ops_generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/core_generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/cpu_vec_generated_${gen_type}.cmake
${CMAKE_BINARY_DIR}/aten/src/ATen/cuda_generated_${gen_type}.cmake
COMMAND ${GEN_COMMAND_${gen_type}} COMMAND ${GEN_COMMAND_${gen_type}}
DEPENDS ${all_python} ${${gen_type}_templates} DEPENDS ${all_python} ${${gen_type}_templates}
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml
@ -260,6 +275,16 @@ if(INTERN_BUILD_ATEN_OPS)
target_compile_definitions(ATEN_CUDA_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS) target_compile_definitions(ATEN_CUDA_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS)
endif() endif()
if(USE_XPU)
add_custom_target(ATEN_XPU_FILES_GEN_TARGET DEPENDS
${xpu_generated_headers} ${xpu_generated_sources})
add_library(ATEN_XPU_FILES_GEN_LIB INTERFACE)
add_dependencies(ATEN_XPU_FILES_GEN_LIB ATEN_XPU_FILES_GEN_TARGET)
if(USE_PER_OPERATOR_HEADERS)
target_compile_definitions(ATEN_XPU_FILES_GEN_LIB INTERFACE AT_PER_OPERATOR_HEADERS)
endif()
endif()
# Handle source files that need to be compiled multiple times for # Handle source files that need to be compiled multiple times for
# different vectorization options # different vectorization options
file(GLOB cpu_kernel_cpp_in "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/cpu/*.cpp" "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/kernels/*.cpp") file(GLOB cpu_kernel_cpp_in "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/cpu/*.cpp" "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/kernels/*.cpp")

View File

@ -342,7 +342,7 @@ inductor_expected_failures_single_sample["xpu"] = {
"cholesky_solve": {f64}, "cholesky_solve": {f64},
"cholesky_inverse": {f64}, "cholesky_inverse": {f64},
# could not create a primitive # could not create a primitive
"addbmm": {f16, f32, f64}, "addbmm": {f64},
"addmm": {f16, f32, f64}, "addmm": {f16, f32, f64},
"addmv": {f32, f64}, "addmv": {f32, f64},
# could not create a primitive descriptor for # could not create a primitive descriptor for

View File

@ -53,6 +53,7 @@ from torchgen.model import (
BackendMetadata, BackendMetadata,
BaseOperatorName, BaseOperatorName,
DEFAULT_KERNEL_NAMESPACE, DEFAULT_KERNEL_NAMESPACE,
dispatch_device_map,
DispatchKey, DispatchKey,
FRAGMENT_NAMESPACES, FRAGMENT_NAMESPACES,
FunctionSchema, FunctionSchema,
@ -143,6 +144,25 @@ _GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {} _GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
def file_manager_from_dispatch_key(
dispatch_key: DispatchKey,
device_fms: dict[str, FileManager],
default_fm: FileManager,
) -> FileManager:
fm = device_fms.get(
next(
(
device
for check, device in dispatch_device_map.items()
if check(dispatch_key)
),
"",
),
default_fm,
)
return fm
def parse_native_yaml_struct( def parse_native_yaml_struct(
es: object, es: object,
valid_tags: set[str], valid_tags: set[str],
@ -1716,7 +1736,7 @@ def gen_aggregated_headers(
selector: SelectiveBuilder, selector: SelectiveBuilder,
backend_indices: dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
cpu_fm: FileManager, cpu_fm: FileManager,
cuda_fm: FileManager, device_fms: dict[str, FileManager],
functions_keys: set[DispatchKey], functions_keys: set[DispatchKey],
dispatch_keys: Sequence[DispatchKey], dispatch_keys: Sequence[DispatchKey],
rocm: bool, rocm: bool,
@ -1796,7 +1816,7 @@ def gen_aggregated_headers(
) )
for dispatch_key in dispatch_keys: for dispatch_key in dispatch_keys:
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
if dispatch_key in functions_keys: if dispatch_key in functions_keys:
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>" inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
@ -1836,7 +1856,7 @@ def gen_per_operator_headers(
selector: SelectiveBuilder, selector: SelectiveBuilder,
backend_indices: dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
cpu_fm: FileManager, cpu_fm: FileManager,
cuda_fm: FileManager, device_fms: dict[str, FileManager],
ops_fm: FileManager, ops_fm: FileManager,
functions_keys: set[DispatchKey], functions_keys: set[DispatchKey],
dispatch_keys: Sequence[DispatchKey], dispatch_keys: Sequence[DispatchKey],
@ -1984,7 +2004,7 @@ def gen_per_operator_headers(
}, },
) )
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>" inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
fm.write_with_template( fm.write_with_template(
@ -2033,7 +2053,7 @@ def gen_headers(
backend_indices: dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
core_fm: FileManager, core_fm: FileManager,
cpu_fm: FileManager, cpu_fm: FileManager,
cuda_fm: FileManager, device_fms: dict[str, FileManager],
ops_fm: FileManager, ops_fm: FileManager,
dispatch_keys: Sequence[DispatchKey], dispatch_keys: Sequence[DispatchKey],
functions_keys: set[DispatchKey], functions_keys: set[DispatchKey],
@ -2048,7 +2068,7 @@ def gen_headers(
selector=selector, selector=selector,
backend_indices=backend_indices, backend_indices=backend_indices,
cpu_fm=cpu_fm, cpu_fm=cpu_fm,
cuda_fm=cuda_fm, device_fms=device_fms,
ops_fm=ops_fm, ops_fm=ops_fm,
dispatch_keys=dispatch_keys, dispatch_keys=dispatch_keys,
functions_keys=functions_keys, functions_keys=functions_keys,
@ -2063,7 +2083,7 @@ def gen_headers(
selector=selector, selector=selector,
backend_indices=backend_indices, backend_indices=backend_indices,
cpu_fm=cpu_fm, cpu_fm=cpu_fm,
cuda_fm=cuda_fm, device_fms=device_fms,
dispatch_keys=dispatch_keys, dispatch_keys=dispatch_keys,
functions_keys=functions_keys, functions_keys=functions_keys,
rocm=rocm, rocm=rocm,
@ -2171,9 +2191,9 @@ def gen_source_files(
backend_indices: dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
aoti_fm: FileManager, aoti_fm: FileManager,
core_fm: FileManager, core_fm: FileManager,
cpu_fm: FileManager,
cpu_vec_fm: FileManager, cpu_vec_fm: FileManager,
cuda_fm: FileManager, cpu_fm: FileManager,
device_fms: dict[str, FileManager],
dispatch_keys: Sequence[DispatchKey], dispatch_keys: Sequence[DispatchKey],
functions_keys: set[DispatchKey], functions_keys: set[DispatchKey],
rocm: bool, rocm: bool,
@ -2195,8 +2215,7 @@ def gen_source_files(
#include <ATen/hip/HIPContext.h>""" #include <ATen/hip/HIPContext.h>"""
for dispatch_key in dispatch_keys: for dispatch_key in dispatch_keys:
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
if per_operator_headers: if per_operator_headers:
def operator_headers() -> list[str]: def operator_headers() -> list[str]:
@ -2752,6 +2771,12 @@ def main() -> None:
action="store_true", action="store_true",
help="Generate MPS registration code when set", help="Generate MPS registration code when set",
) )
parser.add_argument(
"--xpu",
action="store_true",
help="Generate XPU registration code when set",
)
# TODO: --op-registration-whitelist will be removed when all call-sites # TODO: --op-registration-whitelist will be removed when all call-sites
# for gen.py are moved over to using the operator YAML file for mobile # for gen.py are moved over to using the operator YAML file for mobile
# custom build. # custom build.
@ -2833,6 +2858,19 @@ def main() -> None:
if DispatchKey.MPS in dispatch_keys: if DispatchKey.MPS in dispatch_keys:
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)] del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
xpu_in_whitelist = (
options.backend_whitelist and str(DispatchKey.XPU) in options.backend_whitelist
)
# Only generate RegisterXPU.cpp when there is "--xpu" with torhgen/gen.py
# Before this change, torchgen always generates RegisterXPU.cpp for out-of-tree
# torch-xpu-ops native_functions.yaml which use --backend_whitelist=XPU and without "--xpu".
# After this change is landed, we will add --xpu in torch-xpu-ops and remove the check of "xpu_in_whitelist".
if (not options.xpu) and (not xpu_in_whitelist):
ignore_keys.add(DispatchKey.XPU)
if DispatchKey.XPU in dispatch_keys:
del dispatch_keys[dispatch_keys.index(DispatchKey.XPU)]
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys) parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path] valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
native_functions, backend_indices = ( native_functions, backend_indices = (
@ -2877,6 +2915,9 @@ def main() -> None:
cuda_fm = make_file_manager(options=options) cuda_fm = make_file_manager(options=options)
ops_fm = make_file_manager(options=options, install_dir=ops_install_dir) ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir) aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
device_fms = {"cuda": cuda_fm}
if options.xpu:
device_fms["xpu"] = make_file_manager(options=options)
# Only a limited set of dispatch keys get CPUFunctions.h headers generated # Only a limited set of dispatch keys get CPUFunctions.h headers generated
# for them; this is the set # for them; this is the set
@ -2892,6 +2933,9 @@ def main() -> None:
if options.mps: if options.mps:
functions_keys.add(DispatchKey.MPS) functions_keys.add(DispatchKey.MPS)
if options.xpu:
functions_keys.add(DispatchKey.XPU)
if options.backend_whitelist: if options.backend_whitelist:
dispatch_keys = [ dispatch_keys = [
k k
@ -2921,9 +2965,9 @@ def main() -> None:
backend_indices=backend_indices, backend_indices=backend_indices,
aoti_fm=aoti_fm, aoti_fm=aoti_fm,
core_fm=core_fm, core_fm=core_fm,
cpu_fm=cpu_fm,
cpu_vec_fm=cpu_vec_fm, cpu_vec_fm=cpu_vec_fm,
cuda_fm=cuda_fm, cpu_fm=cpu_fm,
device_fms=device_fms,
dispatch_keys=dispatch_keys, dispatch_keys=dispatch_keys,
functions_keys=functions_keys, functions_keys=functions_keys,
rocm=options.rocm, rocm=options.rocm,
@ -2944,7 +2988,7 @@ def main() -> None:
backend_indices=backend_indices, backend_indices=backend_indices,
core_fm=core_fm, core_fm=core_fm,
cpu_fm=cpu_fm, cpu_fm=cpu_fm,
cuda_fm=cuda_fm, device_fms=device_fms,
ops_fm=ops_fm, ops_fm=ops_fm,
dispatch_keys=dispatch_keys, dispatch_keys=dispatch_keys,
functions_keys=functions_keys, functions_keys=functions_keys,
@ -2964,9 +3008,8 @@ def main() -> None:
(cpu_fm, ""), (cpu_fm, ""),
(cpu_vec_fm, "cpu_vec_"), (cpu_vec_fm, "cpu_vec_"),
(core_fm, "core_"), (core_fm, "core_"),
(cuda_fm, "cuda_"),
(ops_fm, "ops_"), (ops_fm, "ops_"),
]: ] + [(device_fm, f"{device}_") for device, device_fm in device_fms.items()]:
varname = prefix + depfile_stem varname = prefix + depfile_stem
path = depfile_path.parent / (prefix + depfile_name) path = depfile_path.parent / (prefix + depfile_name)
fm.write_outputs(varname, str(path)) fm.write_outputs(varname, str(path))

View File

@ -346,6 +346,9 @@ def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
return dk in UFUNC_DISPATCH_KEYS return dk in UFUNC_DISPATCH_KEYS
dispatch_device_map = {is_cuda_dispatch_key: "cuda", is_xpu_dispatch_key: "xpu"}
# This is oddly named ScalarType and not DType for symmetry with C++ # This is oddly named ScalarType and not DType for symmetry with C++
class ScalarType(Enum): class ScalarType(Enum):
Byte = auto() Byte = auto()