mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[NVFUSER] refactor nvfuser build (#89621)
This PR is the first step towards refactors the build for nvfuser in order to have the coegen being a standalone library. Contents inside this PR: 1. nvfuser code base has been moved to `./nvfuser`, from `./torch/csrc/jit/codegen/cuda/`, except for registration code for integration (interface.h/interface.cpp) 2. splits the build system so nvfuser is generating its own `.so` files. Currently there are: - `libnvfuser_codegen.so`, which contains the integration, codegen and runtime system of nvfuser - `nvfuser.so`, which is nvfuser's python API via pybind. Python frontend is now exposed via `nvfuser._C.XXX` instead of `torch._C._nvfuser` 3. nvfuser cpp tests is currently being compiled into `nvfuser_tests` 4. cmake is refactored so that: - nvfuser now has its own `CMakeLists.txt`, which is under `torch/csrc/jit/codegen/cuda/`. - nvfuser backend code is not compiled inside `libtorch_cuda_xxx` any more - nvfuser is added as a subdirectory under `./CMakeLists.txt` at the very end after torch is built. - since nvfuser has dependency on torch, the registration of nvfuser at runtime is done via dlopen (`at::DynamicLibrary`). This avoids circular dependency in cmake, which will be a nightmare to handle. For details, look at `torch/csrc/jit/codegen/cuda/interface.cpp::LoadingNvfuserLibrary` Future work that's scoped in following PR: - Currently since nvfuser codegen has dependency on torch, we need to refactor that out so we can move nvfuser into a submodule and not rely on dlopen to load the library. @malfet - Since we moved nvfuser into a cmake build, we effectively disabled bazel build for nvfuser. This could impact internal workload at Meta, so we need to put support back. cc'ing @vors Pull Request resolved: https://github.com/pytorch/pytorch/pull/89621 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
0a57a20c02
commit
c11b301bcd
@ -473,6 +473,7 @@ test_libtorch() {
|
||||
ln -sf "$TORCH_LIB_DIR"/libshm* "$TORCH_BIN_DIR"
|
||||
ln -sf "$TORCH_LIB_DIR"/libtorch* "$TORCH_BIN_DIR"
|
||||
ln -sf "$TORCH_LIB_DIR"/libtbb* "$TORCH_BIN_DIR"
|
||||
ln -sf "$TORCH_LIB_DIR"/libnvfuser* "$TORCH_BIN_DIR"
|
||||
|
||||
# Start background download
|
||||
python tools/download_mnist.py --quiet -d test/cpp/api/mnist &
|
||||
@ -490,6 +491,7 @@ test_libtorch() {
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
|
||||
"$TORCH_BIN_DIR"/test_jit --gtest_output=xml:$TEST_REPORTS_DIR/test_jit.xml
|
||||
"$TORCH_BIN_DIR"/nvfuser_tests --gtest_output=xml:$TEST_REPORTS_DIR/nvfuser_tests.xml
|
||||
else
|
||||
"$TORCH_BIN_DIR"/test_jit --gtest_filter='-*CUDA' --gtest_output=xml:$TEST_REPORTS_DIR/test_jit.xml
|
||||
fi
|
||||
|
@ -138,7 +138,12 @@ python -c "import os, glob; os.system('python -mpip install --no-index --no-deps
|
||||
if "%BUILD_ENVIRONMENT%"=="" (
|
||||
echo NOTE: To run `import torch`, please make sure to activate the conda environment by running `call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3` in Command Prompt before running Git Bash.
|
||||
) else (
|
||||
7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torchgen %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\functorch && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
||||
if "%USE_CUDA%"=="1" (
|
||||
7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torchgen %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\functorch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\nvfuser && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
||||
) else (
|
||||
7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torchgen %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\functorch && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
||||
)
|
||||
|
||||
if errorlevel 1 exit /b
|
||||
if not errorlevel 0 exit /b
|
||||
|
||||
|
20
BUILD.bazel
20
BUILD.bazel
@ -1573,25 +1573,7 @@ cc_library(
|
||||
)
|
||||
|
||||
# torch
|
||||
py_binary(
|
||||
name = "stringify_file",
|
||||
srcs = ["torch/csrc/jit/codegen/cuda/tools/stringify_file.py"],
|
||||
)
|
||||
|
||||
generated_nvfuser_hdrs = ["generated_" + hdr for hdr in libtorch_nvfuser_generated_headers]
|
||||
|
||||
[
|
||||
genrule(
|
||||
name = name,
|
||||
srcs = [src],
|
||||
outs = ["nvfuser_resources/{}".format(hdr)],
|
||||
cmd = "$(location :stringify_file) -i $< -o $@",
|
||||
tools = [":stringify_file"],
|
||||
)
|
||||
for name, src, hdr in zip(generated_nvfuser_hdrs, libtorch_nvfuser_runtime_sources, libtorch_nvfuser_generated_headers)
|
||||
]
|
||||
|
||||
torch_cuda_headers = glob(["torch/csrc/cuda/*.h"]) + generated_nvfuser_hdrs
|
||||
torch_cuda_headers = glob(["torch/csrc/cuda/*.h"])
|
||||
|
||||
cc_library(
|
||||
name = "torch_headers",
|
||||
|
@ -183,6 +183,9 @@ option(USE_TSAN "Use Thread Sanitizer" OFF)
|
||||
option(USE_CUDA "Use CUDA" ON)
|
||||
cmake_dependent_option(
|
||||
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
|
||||
cmake_dependent_option(
|
||||
BUILD_NVFUSER "Build NVFUSER" ON
|
||||
"USE_CUDA OR USE_ROCM" OFF)
|
||||
option(USE_FAST_NVCC "Use parallel NVCC build" OFF)
|
||||
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
|
||||
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
|
||||
@ -1156,6 +1159,14 @@ if(BUILD_JNI)
|
||||
add_subdirectory(android/pytorch_android)
|
||||
endif()
|
||||
|
||||
if(NOT USE_CUDA AND NOT USE_ROCM)
|
||||
set(BUILD_NVFUSER OFF CACHE BOOL "BUILD nvfuser" FORCE)
|
||||
endif()
|
||||
|
||||
if(BUILD_NVFUSER)
|
||||
add_subdirectory(third_party/nvfuser)
|
||||
endif()
|
||||
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
|
||||
|
@ -921,7 +921,7 @@ void codegenOutputQuery(
|
||||
|
||||
// TODO: another copy paste from jit, refactor so it's usable from both
|
||||
// TODO: try making the CUcontext thread local to see if that improves performance - why is this slow?
|
||||
void __inline__ initializeCudaContext() {
|
||||
void initializeCudaContext() {
|
||||
// lazily construct context if non-existing yet;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
CUcontext pctx = nullptr;
|
||||
@ -1656,5 +1656,4 @@ void launch_jitted_pwise_function(
|
||||
nullptr));
|
||||
}
|
||||
|
||||
|
||||
} // at::cuda::jit
|
||||
|
@ -198,4 +198,6 @@ inline std::string typeName(ScalarType t) {
|
||||
}
|
||||
#undef TYPE_NAME_CASE
|
||||
|
||||
TORCH_CUDA_CPP_API void initializeCudaContext();
|
||||
|
||||
}}} // namespace at::cuda::jit
|
||||
|
@ -18,34 +18,34 @@ GENERATED_LAZY_TS_CPP = [
|
||||
|
||||
# NVFuser runtime library
|
||||
libtorch_nvfuser_runtime_sources = [
|
||||
"torch/csrc/jit/codegen/cuda/runtime/array.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/array_rocm.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/bf16_support.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/bf16_support_rocm.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/block_reduction.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/block_sync_default.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/block_sync_default_rocm.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/broadcast.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/fp16_support.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/fused_reduction.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/fused_welford_helper.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/fused_welford_impl.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/grid_broadcast.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/helpers.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/index_utils.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/memory.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/swizzle.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/tensor.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/tuple.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/type_traits.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/warp.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/warp_rocm.cu",
|
||||
"torch/csrc/jit/codegen/cuda/runtime/welford.cu",
|
||||
"third_party/nvfuser/runtime/array.cu",
|
||||
"third_party/nvfuser/runtime/array_rocm.cu",
|
||||
"third_party/nvfuser/runtime/bf16_support.cu",
|
||||
"third_party/nvfuser/runtime/bf16_support_rocm.cu",
|
||||
"third_party/nvfuser/runtime/block_reduction.cu",
|
||||
"third_party/nvfuser/runtime/block_sync_atomic.cu",
|
||||
"third_party/nvfuser/runtime/block_sync_default.cu",
|
||||
"third_party/nvfuser/runtime/block_sync_default_rocm.cu",
|
||||
"third_party/nvfuser/runtime/broadcast.cu",
|
||||
"third_party/nvfuser/runtime/fp16_support.cu",
|
||||
"third_party/nvfuser/runtime/fused_reduction.cu",
|
||||
"third_party/nvfuser/runtime/fused_welford_helper.cu",
|
||||
"third_party/nvfuser/runtime/fused_welford_impl.cu",
|
||||
"third_party/nvfuser/runtime/grid_broadcast.cu",
|
||||
"third_party/nvfuser/runtime/grid_reduction.cu",
|
||||
"third_party/nvfuser/runtime/grid_sync.cu",
|
||||
"third_party/nvfuser/runtime/helpers.cu",
|
||||
"third_party/nvfuser/runtime/index_utils.cu",
|
||||
"third_party/nvfuser/runtime/memory.cu",
|
||||
"third_party/nvfuser/runtime/random_numbers.cu",
|
||||
"third_party/nvfuser/runtime/swizzle.cu",
|
||||
"third_party/nvfuser/runtime/tensor.cu",
|
||||
"third_party/nvfuser/runtime/tensorcore.cu",
|
||||
"third_party/nvfuser/runtime/tuple.cu",
|
||||
"third_party/nvfuser/runtime/type_traits.cu",
|
||||
"third_party/nvfuser/runtime/warp.cu",
|
||||
"third_party/nvfuser/runtime/warp_rocm.cu",
|
||||
"third_party/nvfuser/runtime/welford.cu",
|
||||
"aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh",
|
||||
"aten/src/ATen/cuda/detail/UnpackRaw.cuh",
|
||||
]
|
||||
@ -677,107 +677,6 @@ libtorch_cuda_core_sources = [
|
||||
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
|
||||
"torch/csrc/profiler/stubs/cuda.cpp",
|
||||
"torch/csrc/autograd/functions/comm.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/arith.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/compute_at.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/inlining.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/compute_at_map.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/codegen.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/contiguity.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/dispatch.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/expr_evaluator.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/executor.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/executor_launch_params.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/evaluator_common.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/executor_utils.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/fusion.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/graph_fuser.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/grouped_reduction.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/index_compute.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_index_compute.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/instrumentation.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/ir_builder.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/ir_cloner.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/ir_container.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/ir_graphviz.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/ir_nodes.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/ir_iostream.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/ir_utils.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/iter_visitor.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/kernel.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/kernel_cache.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/kernel_ir.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_allocation.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_fused_reduction.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_index.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_instrument.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_loops.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_predicate.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_replace_size.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_shift.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_sync_information.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_unroll.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_utils.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_validation.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower2device.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/manager.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/mutator.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/non_divisible_split.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/ops/alias.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/ops/composite.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/ops/normalization.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/parallel_type_bitmap.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/parser.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/partial_split_map.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/partition.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/predicate_compute.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/register_interface.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/root_domain_map.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/registry.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/utils.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/type_inference.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/type_promotion.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/tensor_view.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/transform_iter.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/transform_replay.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/transform_rfactor.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/transform_view.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/type.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/utils.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/mma_type.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp",
|
||||
"torch/csrc/jit/passes/frozen_conv_add_relu_fusion_cuda.cpp",
|
||||
"torch/csrc/jit/tensorexpr/cuda_codegen.cpp",
|
||||
"torch/csrc/jit/runtime/register_cuda_ops.cpp",
|
||||
@ -923,7 +822,6 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/dynamo/init.cpp",
|
||||
"torch/csrc/functorch/init.cpp",
|
||||
"torch/csrc/jit/backends/backend_init.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp",
|
||||
"torch/csrc/jit/python/init.cpp",
|
||||
"torch/csrc/jit/passes/onnx.cpp",
|
||||
"torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp",
|
||||
|
@ -657,6 +657,7 @@ if(USE_CUDA)
|
||||
PROPERTIES COMPILE_DEFINITIONS "NVRTC_SHORTHASH=${CUDA_NVRTC_SHORTHASH}"
|
||||
)
|
||||
set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp PROPERTIES COMPILE_FLAGS "-DUSE_CUDA=1")
|
||||
set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/interface.cpp PROPERTIES COMPILE_FLAGS "-DUSE_CUDA=1")
|
||||
endif()
|
||||
|
||||
if(BUILD_ONEDNN_GRAPH)
|
||||
@ -978,10 +979,6 @@ elseif(USE_CUDA)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(USE_CUDA OR USE_ROCM)
|
||||
include(${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/nvfuser.cmake)
|
||||
endif()
|
||||
|
||||
if(NOT MSVC AND USE_XNNPACK)
|
||||
TARGET_LINK_LIBRARIES(torch_cpu PRIVATE fxdiv)
|
||||
endif()
|
||||
|
@ -120,6 +120,7 @@ function(caffe2_print_configuration_summary)
|
||||
if(${USE_ROCM})
|
||||
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
|
||||
endif()
|
||||
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
|
||||
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
|
||||
message(STATUS " USE_FBGEMM : ${USE_FBGEMM}")
|
||||
message(STATUS " USE_FAKELOWP : ${USE_FAKELOWP}")
|
||||
|
1
nvfuser/__init__.py
Normal file
1
nvfuser/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import _C
|
29
setup.py
29
setup.py
@ -547,6 +547,11 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
else:
|
||||
report('-- Not using ITT')
|
||||
|
||||
if cmake_cache_vars['BUILD_NVFUSER']:
|
||||
report('-- Building nvfuser')
|
||||
else:
|
||||
report('-- Not Building nvfuser')
|
||||
|
||||
# Do not use clang to compile extensions if `-fstack-clash-protection` is defined
|
||||
# in system CFLAGS
|
||||
c_flags = str(os.getenv('CFLAGS', ''))
|
||||
@ -636,6 +641,22 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
os.makedirs(dst_dir)
|
||||
self.copy_file(src, dst)
|
||||
|
||||
# Copy nvfuser extension
|
||||
for i, ext in enumerate(self.extensions):
|
||||
if ext.name != "nvfuser._C":
|
||||
continue
|
||||
fullname = self.get_ext_fullname(ext.name)
|
||||
filename = self.get_ext_filename(fullname)
|
||||
fileext = os.path.splitext(filename)[1]
|
||||
src = os.path.join(os.path.dirname(filename), "nvfuser" + fileext)
|
||||
dst = os.path.join(os.path.realpath(self.build_lib), filename)
|
||||
if os.path.exists(src):
|
||||
report("Copying {} from {} to {}".format(ext.name, src, dst))
|
||||
dst_dir = os.path.dirname(dst)
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
self.copy_file(src, dst)
|
||||
|
||||
setuptools.command.build_ext.build_ext.build_extensions(self)
|
||||
|
||||
|
||||
@ -894,6 +915,8 @@ def configure_extension_build():
|
||||
excludes.extend(['caffe2', 'caffe2.*'])
|
||||
if not cmake_cache_vars['BUILD_FUNCTORCH']:
|
||||
excludes.extend(['functorch', 'functorch.*'])
|
||||
if not cmake_cache_vars['BUILD_NVFUSER']:
|
||||
excludes.extend(['nvfuser', 'nvfuser.*'])
|
||||
packages = find_packages(exclude=excludes)
|
||||
C = Extension("torch._C",
|
||||
libraries=main_libraries,
|
||||
@ -940,6 +963,12 @@ def configure_extension_build():
|
||||
name=str('functorch._C'),
|
||||
sources=[]),
|
||||
)
|
||||
if cmake_cache_vars['BUILD_NVFUSER']:
|
||||
extensions.append(
|
||||
Extension(
|
||||
name=str('nvfuser._C'),
|
||||
sources=[]),
|
||||
)
|
||||
|
||||
cmdclass = {
|
||||
'bdist_wheel': wheel_concatenate,
|
||||
|
@ -95,23 +95,6 @@ set(JIT_TEST_SRCS
|
||||
${JIT_TEST_ROOT}/test_flatbuffer.cpp
|
||||
)
|
||||
|
||||
if(USE_CUDA)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu1.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu2.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu3.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_view.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_rng.cu)
|
||||
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_utils.cpp)
|
||||
endif()
|
||||
|
||||
add_executable(test_jit
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${JIT_TEST_SRCS}
|
||||
|
@ -322,6 +322,24 @@ ALLOW_LIST = [
|
||||
("aten::_fused_sdp_choice", datetime.date(2023, 3, 15)),
|
||||
("aten::_flash_attention_forward", datetime.date(2023, 3, 15)),
|
||||
("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)),
|
||||
("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)),
|
||||
("prim::CudaFusionGuard", datetime.date(2023, 2, 1)),
|
||||
("prim::CudaFusionGroup", datetime.date(2023, 2, 1)),
|
||||
("prim::CudaFusionViewGuard", datetime.date(2023, 2, 1)),
|
||||
("prim::CudaFusionSizeEq", datetime.date(2023, 2, 1)),
|
||||
("prim::transpose_copy.int", datetime.date(2023, 2, 1)),
|
||||
("prim::expand_as_copy", datetime.date(2023, 2, 1)),
|
||||
("prim::squeeze_copy", datetime.date(2023, 2, 1)),
|
||||
("prim::squeeze_copy.dim", datetime.date(2023, 2, 1)),
|
||||
("prim::unsqueeze_copy", datetime.date(2023, 2, 1)),
|
||||
("prim::expand_copy", datetime.date(2023, 2, 1)),
|
||||
("prim::flatten_copy", datetime.date(2023, 2, 1)),
|
||||
("prim::add_optional", datetime.date(2023, 2, 1)),
|
||||
("prim::reshape_copy", datetime.date(2023, 2, 1)),
|
||||
("prim::permute_copy", datetime.date(2023, 2, 1)),
|
||||
("prim::infer_unsqueeze_size", datetime.date(2023, 2, 1)),
|
||||
("prim::t_copy", datetime.date(2023, 2, 1)),
|
||||
("prim::view_copy", datetime.date(2023, 2, 1)),
|
||||
]
|
||||
|
||||
ALLOW_LIST_COMPILED = [
|
||||
|
@ -9,9 +9,11 @@ from torch.testing._internal.jit_utils import RUN_CUDA
|
||||
import torch._refs as refs
|
||||
import torch._prims as prims
|
||||
|
||||
# Will only create the _nvfuser module if CUDA is available
|
||||
if hasattr(torch._C, "_nvfuser"):
|
||||
from torch._C._nvfuser import Fusion, FusionCache, FusionDefinition, DataType
|
||||
# Will only create the nvfuser module if CUDA is available
|
||||
try:
|
||||
from nvfuser._C import Fusion, FusionCache, FusionDefinition, DataType
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM
|
||||
|
||||
|
@ -145,7 +145,7 @@ class TestPrims(TestCase):
|
||||
# This test is to ensure that when the nvfuser implementation exists it is used
|
||||
# Assuming one-to-one mapping between prims and nvfuser implementations
|
||||
# This test is not intended to test the correctness of the nvfuser implementation
|
||||
from torch._C._nvfuser import FusionDefinition as fd
|
||||
from nvfuser._C import FusionDefinition as fd
|
||||
|
||||
prim_nvfuser_ops = set(torch._prims.__all__).intersection(dir(fd.ops))
|
||||
ops_without_nvfuser_impl = {
|
||||
|
323
third_party/nvfuser/CMakeLists.txt
vendored
Normal file
323
third_party/nvfuser/CMakeLists.txt
vendored
Normal file
@ -0,0 +1,323 @@
|
||||
if(NOT BUILD_NVFUSER)
|
||||
return()
|
||||
endif()
|
||||
|
||||
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
|
||||
project(nvfuser)
|
||||
|
||||
if(NOT USE_ROCM)
|
||||
set(TORCHLIB_FLAVOR torch_cuda)
|
||||
else()
|
||||
set(TORCHLIB_FLAVOR torch_hip)
|
||||
endif()
|
||||
|
||||
# --- project
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/nvfuser")
|
||||
|
||||
set(NVFUSER_ROOT ${PROJECT_SOURCE_DIR})
|
||||
set(NVFUSER_SRCS_DIR "${NVFUSER_ROOT}/csrc")
|
||||
set(TORCH_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../..")
|
||||
set(TORCH_INSTALL_LIB_DIR ${TORCH_ROOT}/torch/lib)
|
||||
|
||||
# --- build nvfuser_codegen library
|
||||
|
||||
set(NVFUSER_SRCS)
|
||||
set(NVFUSER_CODEGEN ${PROJECT_NAME}_codegen)
|
||||
list(APPEND NVFUSER_SRCS
|
||||
${NVFUSER_SRCS_DIR}/arith.cpp
|
||||
${NVFUSER_SRCS_DIR}/compute_at.cpp
|
||||
${NVFUSER_SRCS_DIR}/inlining.cpp
|
||||
${NVFUSER_SRCS_DIR}/compute_at_map.cpp
|
||||
${NVFUSER_SRCS_DIR}/codegen.cpp
|
||||
${NVFUSER_SRCS_DIR}/contiguity.cpp
|
||||
${NVFUSER_SRCS_DIR}/dispatch.cpp
|
||||
${NVFUSER_SRCS_DIR}/expr_evaluator.cpp
|
||||
${NVFUSER_SRCS_DIR}/kernel_expr_evaluator.cpp
|
||||
${NVFUSER_SRCS_DIR}/executor.cpp
|
||||
${NVFUSER_SRCS_DIR}/executor_kernel_arg.cpp
|
||||
${NVFUSER_SRCS_DIR}/executor_launch_params.cpp
|
||||
${NVFUSER_SRCS_DIR}/evaluator_common.cpp
|
||||
${NVFUSER_SRCS_DIR}/executor_utils.cpp
|
||||
${NVFUSER_SRCS_DIR}/fusion.cpp
|
||||
${NVFUSER_SRCS_DIR}/graph_fuser.cpp
|
||||
${NVFUSER_SRCS_DIR}/grouped_reduction.cpp
|
||||
${NVFUSER_SRCS_DIR}/index_compute.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_index_compute.cpp
|
||||
${NVFUSER_SRCS_DIR}/instrumentation.cpp
|
||||
${NVFUSER_SRCS_DIR}/ir_base_nodes.cpp
|
||||
${NVFUSER_SRCS_DIR}/ir_builder.cpp
|
||||
${NVFUSER_SRCS_DIR}/ir_cloner.cpp
|
||||
${NVFUSER_SRCS_DIR}/ir_container.cpp
|
||||
${NVFUSER_SRCS_DIR}/ir_graphviz.cpp
|
||||
${NVFUSER_SRCS_DIR}/ir_nodes.cpp
|
||||
${NVFUSER_SRCS_DIR}/ir_iostream.cpp
|
||||
${NVFUSER_SRCS_DIR}/ir_utils.cpp
|
||||
${NVFUSER_SRCS_DIR}/iter_visitor.cpp
|
||||
${NVFUSER_SRCS_DIR}/kernel.cpp
|
||||
${NVFUSER_SRCS_DIR}/kernel_cache.cpp
|
||||
${NVFUSER_SRCS_DIR}/kernel_ir.cpp
|
||||
${NVFUSER_SRCS_DIR}/kernel_ir_dispatch.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_alias_memory.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_allocation.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_double_buffer.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_divisible_split.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_expr_sort.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_fused_reduction.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_fusion_simplifier.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_index.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_index_hoist.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_insert_syncs.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_instrument.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_loops.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_magic_zero.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_misaligned_vectorization.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_predicate.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_predicate_elimination.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_replace_size.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_shift.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_sync_information.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_thread_predicate.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_trivial_broadcast.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_trivial_reductions.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_unroll.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_utils.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_validation.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_warp_reduce.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower2device.cpp
|
||||
${NVFUSER_SRCS_DIR}/lower_bank_conflict.cpp
|
||||
${NVFUSER_SRCS_DIR}/manager.cpp
|
||||
${NVFUSER_SRCS_DIR}/maxinfo_propagator.cpp
|
||||
${NVFUSER_SRCS_DIR}/mutator.cpp
|
||||
${NVFUSER_SRCS_DIR}/non_divisible_split.cpp
|
||||
${NVFUSER_SRCS_DIR}/ops/alias.cpp
|
||||
${NVFUSER_SRCS_DIR}/ops/composite.cpp
|
||||
${NVFUSER_SRCS_DIR}/ops/normalization.cpp
|
||||
${NVFUSER_SRCS_DIR}/parallel_dimension_map.cpp
|
||||
${NVFUSER_SRCS_DIR}/parallel_type_bitmap.cpp
|
||||
${NVFUSER_SRCS_DIR}/parser.cpp
|
||||
${NVFUSER_SRCS_DIR}/partial_split_map.cpp
|
||||
${NVFUSER_SRCS_DIR}/partition.cpp
|
||||
${NVFUSER_SRCS_DIR}/predicate_compute.cpp
|
||||
${NVFUSER_SRCS_DIR}/python_frontend/fusion_cache.cpp
|
||||
${NVFUSER_SRCS_DIR}/python_frontend/fusion_definition.cpp
|
||||
${NVFUSER_SRCS_DIR}/python_frontend/fusion_interface.cpp
|
||||
${NVFUSER_SRCS_DIR}/register_interface.cpp
|
||||
${NVFUSER_SRCS_DIR}/root_domain_map.cpp
|
||||
${NVFUSER_SRCS_DIR}/scheduler/pointwise.cpp
|
||||
${NVFUSER_SRCS_DIR}/scheduler/pointwise_utils.cpp
|
||||
${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp
|
||||
${NVFUSER_SRCS_DIR}/scheduler/normalization.cpp
|
||||
${NVFUSER_SRCS_DIR}/scheduler/reduction.cpp
|
||||
${NVFUSER_SRCS_DIR}/scheduler/matmul.cpp
|
||||
${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp
|
||||
${NVFUSER_SRCS_DIR}/scheduler/registry.cpp
|
||||
${NVFUSER_SRCS_DIR}/scheduler/utils.cpp
|
||||
${NVFUSER_SRCS_DIR}/scheduler/vectorize_helper.cpp
|
||||
${NVFUSER_SRCS_DIR}/type_inference.cpp
|
||||
${NVFUSER_SRCS_DIR}/type_promotion.cpp
|
||||
${NVFUSER_SRCS_DIR}/fusion_segmenter.cpp
|
||||
${NVFUSER_SRCS_DIR}/tensor_view.cpp
|
||||
${NVFUSER_SRCS_DIR}/transform_iter.cpp
|
||||
${NVFUSER_SRCS_DIR}/transform_replay.cpp
|
||||
${NVFUSER_SRCS_DIR}/transform_rfactor.cpp
|
||||
${NVFUSER_SRCS_DIR}/transform_view.cpp
|
||||
${NVFUSER_SRCS_DIR}/type.cpp
|
||||
${NVFUSER_SRCS_DIR}/utils.cpp
|
||||
${NVFUSER_SRCS_DIR}/mma_type.cpp
|
||||
${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp
|
||||
)
|
||||
|
||||
add_library(${NVFUSER_CODEGEN} SHARED ${NVFUSER_SRCS})
|
||||
|
||||
if(NOT USE_ROCM)
|
||||
target_compile_options(${NVFUSER_CODEGEN} PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB")
|
||||
# NB: This must be target_compile_definitions, not target_compile_options,
|
||||
# as the latter is not respected by nvcc
|
||||
target_compile_definitions(${NVFUSER_CODEGEN} PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB")
|
||||
else()
|
||||
target_compile_options(${NVFUSER_CODEGEN} PRIVATE "-DTORCH_HIP_BUILD_MAIN_LIB")
|
||||
target_compile_definitions(${NVFUSER_CODEGEN} PRIVATE "-DTORCH_HIP_BUILD_MAIN_LIB")
|
||||
target_compile_definitions(${NVFUSER_CODEGEN} PRIVATE
|
||||
USE_ROCM
|
||||
__HIP_PLATFORM_HCC__
|
||||
)
|
||||
endif()
|
||||
|
||||
target_link_libraries(${NVFUSER_CODEGEN} PRIVATE torch ${TORCHLIB_FLAVOR})
|
||||
if(NOT USE_ROCM)
|
||||
target_link_libraries(${NVFUSER_CODEGEN} PRIVATE ${CUDA_NVRTC_LIB} torch::nvtoolsext)
|
||||
target_include_directories(${NVFUSER_CODEGEN} PRIVATE ${CUDA_INCLUDE_DIRS})
|
||||
else()
|
||||
target_link_libraries(${NVFUSER_CODEGEN} PRIVATE ${ROCM_HIPRTC_LIB})
|
||||
target_include_directories(${NVFUSER_CODEGEN} PRIVATE ${Caffe2_HIP_INCLUDE})
|
||||
endif()
|
||||
if(NOT MSVC)
|
||||
target_compile_options(${NVFUSER_CODEGEN} PRIVATE -Wno-unused-variable)
|
||||
endif()
|
||||
target_include_directories(${NVFUSER_CODEGEN}
|
||||
PUBLIC $<BUILD_INTERFACE:${NVFUSER_SRCS_DIR}>)
|
||||
set_property(TARGET ${NVFUSER_CODEGEN} PROPERTY CXX_STANDARD 17)
|
||||
install(TARGETS ${NVFUSER_CODEGEN} EXPORT NvfuserTargets DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
|
||||
# --- build nvfuser_python library
|
||||
|
||||
if(BUILD_PYTHON)
|
||||
set(NVFUSER "${PROJECT_NAME}")
|
||||
#find_package(pybind11 REQUIRED)
|
||||
|
||||
set(NVFUSER_PYTHON_SRCS)
|
||||
list(APPEND NVFUSER_PYTHON_SRCS
|
||||
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp
|
||||
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp
|
||||
)
|
||||
|
||||
add_library(${NVFUSER} MODULE ${NVFUSER_PYTHON_SRCS})
|
||||
if(NOT USE_ROCM)
|
||||
target_compile_options(${NVFUSER} PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB")
|
||||
# NB: This must be target_compile_definitions, not target_compile_options,
|
||||
# as the latter is not respected by nvcc
|
||||
target_compile_definitions(${NVFUSER} PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB")
|
||||
target_link_libraries(${NVFUSER} PRIVATE torch::nvtoolsext)
|
||||
else()
|
||||
target_compile_options(${NVFUSER} PRIVATE "-DTORCH_HIP_BUILD_MAIN_LIB")
|
||||
target_compile_definitions(${NVFUSER} PRIVATE "-DTORCH_HIP_BUILD_MAIN_LIB")
|
||||
target_compile_definitions(${NVFUSER} PRIVATE
|
||||
USE_ROCM
|
||||
__HIP_PLATFORM_HCC__
|
||||
)
|
||||
target_include_directories(${NVFUSER_CODEGEN} PRIVATE ${Caffe2_HIP_INCLUDE})
|
||||
endif()
|
||||
|
||||
target_link_libraries(${NVFUSER} PRIVATE ${NVFUSER_CODEGEN})
|
||||
target_link_libraries(${NVFUSER} PRIVATE torch torch_python ${TORCHLIB_FLAVOR})
|
||||
target_link_libraries(${NVFUSER} PRIVATE pybind::pybind11)
|
||||
target_include_directories(${NVFUSER} PRIVATE ${TORCH_ROOT})
|
||||
target_compile_definitions(${NVFUSER} PRIVATE EXTENSION_NAME=_C)
|
||||
target_compile_options(${NVFUSER} PRIVATE ${TORCH_PYTHON_COMPILE_OPTIONS})
|
||||
|
||||
# avoid using Python3_add_library, copied from functorch
|
||||
set_target_properties(${NVFUSER} PROPERTIES PREFIX "" DEBUG_POSTFIX "")
|
||||
if(NOT MSVC)
|
||||
target_compile_options(${NVFUSER} PRIVATE -Wno-unused-variable)
|
||||
set_target_properties(${NVFUSER} PROPERTIES SUFFIX ".so")
|
||||
else()
|
||||
set_target_properties(${NVFUSER} PROPERTIES SUFFIX ".pyd")
|
||||
endif()
|
||||
|
||||
set_target_properties(${NVFUSER} PROPERTIES LIBRARY_OUTPUT_DIRECTORY
|
||||
${CMAKE_BINARY_DIR}/nvfuser)
|
||||
set_target_properties(${NVFUSER} PROPERTIES INSTALL_RPATH "${_rpath_portable_origin}/../torch/lib")
|
||||
|
||||
if(TORCH_PYTHON_LINK_FLAGS AND NOT TORCH_PYTHON_LINK_FLAGS STREQUAL "")
|
||||
message(STATUS "somehow this is happening")
|
||||
set_target_properties(${NVFUSER} PROPERTIES LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS})
|
||||
endif()
|
||||
install(TARGETS ${NVFUSER} EXPORT NvfuserTargets DESTINATION ${TORCH_ROOT}/nvfuser/)
|
||||
endif()
|
||||
|
||||
# --- generate runtime files
|
||||
|
||||
# The list of NVFUSER runtime files
|
||||
list(APPEND NVFUSER_RUNTIME_FILES
|
||||
${NVFUSER_ROOT}/runtime/array.cu
|
||||
${NVFUSER_ROOT}/runtime/block_reduction.cu
|
||||
${NVFUSER_ROOT}/runtime/block_sync_atomic.cu
|
||||
${NVFUSER_ROOT}/runtime/block_sync_default.cu
|
||||
${NVFUSER_ROOT}/runtime/broadcast.cu
|
||||
${NVFUSER_ROOT}/runtime/fp16_support.cu
|
||||
${NVFUSER_ROOT}/runtime/fused_reduction.cu
|
||||
${NVFUSER_ROOT}/runtime/fused_welford_helper.cu
|
||||
${NVFUSER_ROOT}/runtime/fused_welford_impl.cu
|
||||
${NVFUSER_ROOT}/runtime/bf16_support.cu
|
||||
${NVFUSER_ROOT}/runtime/grid_broadcast.cu
|
||||
${NVFUSER_ROOT}/runtime/grid_reduction.cu
|
||||
${NVFUSER_ROOT}/runtime/grid_sync.cu
|
||||
${NVFUSER_ROOT}/runtime/helpers.cu
|
||||
${NVFUSER_ROOT}/runtime/index_utils.cu
|
||||
${NVFUSER_ROOT}/runtime/random_numbers.cu
|
||||
${NVFUSER_ROOT}/runtime/swizzle.cu
|
||||
${NVFUSER_ROOT}/runtime/tensor.cu
|
||||
${NVFUSER_ROOT}/runtime/tuple.cu
|
||||
${NVFUSER_ROOT}/runtime/type_traits.cu
|
||||
${NVFUSER_ROOT}/runtime/welford.cu
|
||||
${NVFUSER_ROOT}/runtime/warp.cu
|
||||
${NVFUSER_ROOT}/runtime/tensorcore.cu
|
||||
${NVFUSER_ROOT}/runtime/memory.cu
|
||||
${TORCH_ROOT}/aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh
|
||||
${TORCH_ROOT}/aten/src/ATen/cuda/detail/UnpackRaw.cuh
|
||||
)
|
||||
|
||||
if(USE_ROCM)
|
||||
list(APPEND NVFUSER_RUNTIME_FILES
|
||||
${NVFUSER_ROOT}/runtime/array_rocm.cu
|
||||
${NVFUSER_ROOT}/runtime/bf16_support_rocm.cu
|
||||
${NVFUSER_ROOT}/runtime/block_sync_default_rocm.cu
|
||||
${NVFUSER_ROOT}/runtime/warp_rocm.cu
|
||||
)
|
||||
endif()
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/include/nvfuser_resources")
|
||||
|
||||
# "stringify" NVFUSER runtime sources
|
||||
# (generate C++ header files embedding the original input as a string literal)
|
||||
set(NVFUSER_STRINGIFY_TOOL "${NVFUSER_ROOT}/tools/stringify_file.py")
|
||||
foreach(src ${NVFUSER_RUNTIME_FILES})
|
||||
get_filename_component(filename ${src} NAME_WE)
|
||||
set(dst "${CMAKE_BINARY_DIR}/include/nvfuser_resources/${filename}.h")
|
||||
add_custom_command(
|
||||
COMMENT "Stringify NVFUSER runtime source file"
|
||||
OUTPUT ${dst}
|
||||
DEPENDS ${src} "${NVFUSER_STRINGIFY_TOOL}"
|
||||
COMMAND ${PYTHON_EXECUTABLE} ${NVFUSER_STRINGIFY_TOOL} -i ${src} -o ${dst}
|
||||
)
|
||||
add_custom_target(nvfuser_rt_${filename} DEPENDS ${dst})
|
||||
add_dependencies(${NVFUSER_CODEGEN} nvfuser_rt_${filename})
|
||||
|
||||
# also generate the resource headers during the configuration step
|
||||
# (so tools like clang-tidy can run w/o requiring a real build)
|
||||
execute_process(COMMAND
|
||||
${PYTHON_EXECUTABLE} ${NVFUSER_STRINGIFY_TOOL} -i ${src} -o ${dst})
|
||||
endforeach()
|
||||
|
||||
target_include_directories(${NVFUSER_CODEGEN} PRIVATE "${CMAKE_BINARY_DIR}/include")
|
||||
|
||||
# -- build tests
|
||||
|
||||
if(USE_CUDA)
|
||||
set(NVFUSER_TESTS "${PROJECT_NAME}_tests")
|
||||
set(JIT_TEST_SRCS)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_SRCS_DIR}/python_frontend/test/test_nvfuser_fusion_definition.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_SRCS_DIR}/python_frontend/test/test_nvfuser_fusion_cache.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_SRCS_DIR}/python_frontend/test/test_nvfuser_fusion_record.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu1.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu2.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu3.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_tensor_factories.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_fused_reduction.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_shift.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_tensorcore.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_view.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_transpose.cpp)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_rng.cu)
|
||||
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_utils.cpp)
|
||||
|
||||
add_executable(${NVFUSER_TESTS}
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${TORCH_ROOT}/test/cpp/jit/test_utils.cpp
|
||||
${JIT_TEST_SRCS})
|
||||
|
||||
target_compile_definitions(${NVFUSER_TESTS} PRIVATE USE_GTEST)
|
||||
if(NOT USE_ROCM)
|
||||
target_compile_definitions(${NVFUSER_TESTS} PRIVATE USE_CUDA)
|
||||
else()
|
||||
target_compile_definitions(${NVFUSER_TESTS} PRIVATE USE_ROCM)
|
||||
endif()
|
||||
target_include_directories(${NVFUSER_TESTS} PRIVATE "${NVFUSER_ROOT}" "${TORCH_ROOT}/torch/csrc/api/include/")
|
||||
target_link_libraries(${NVFUSER_TESTS} PRIVATE ${NVFUSER_CODEGEN} torch ${TORCHLIB_FLAVOR} gtest_main gmock_main)
|
||||
if(NOT MSVC)
|
||||
target_compile_options(${NVFUSER_TESTS} PRIVATE -Wno-unused-variable)
|
||||
endif()
|
||||
|
||||
install(TARGETS ${NVFUSER_TESTS} DESTINATION bin)
|
||||
endif()
|
@ -1,15 +1,15 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
||||
#include <arith.h>
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type_promotion.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_builder.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <ir_utils.h>
|
||||
#include <type.h>
|
||||
#include <type_promotion.h>
|
||||
#include <cfloat>
|
||||
|
||||
namespace torch {
|
||||
@ -2171,7 +2171,7 @@ TensorView* gather(
|
||||
return out_tv;
|
||||
}
|
||||
|
||||
TORCH_CUDA_CU_API TensorView* viewAsScalar(TensorView* inp) {
|
||||
TensorView* viewAsScalar(TensorView* inp) {
|
||||
auto inp_type = inp->getDataType().value();
|
||||
TORCH_CHECK(
|
||||
isVectorType(inp_type),
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type_promotion.h>
|
||||
#include <ir_interface_nodes.h>
|
||||
#include <type.h>
|
||||
#include <type_promotion.h>
|
||||
|
||||
class Val;
|
||||
|
@ -1,12 +1,12 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/codegen.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
#include <codegen.h>
|
||||
#include <expr_evaluator.h>
|
||||
#include <instrumentation.h>
|
||||
#include <kernel_expr_evaluator.h>
|
||||
#include <kernel_ir.h>
|
||||
#include <kernel_ir_dispatch.h>
|
||||
#include <scheduler/mma_utils.h>
|
||||
#include <type.h>
|
||||
#include <utils.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
||||
#include <kernel.h>
|
||||
|
||||
#include <string>
|
||||
|
@ -1,11 +1,11 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/compute_at.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
|
||||
#include <compute_at.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <ir_utils.h>
|
||||
#include <lower_utils.h>
|
||||
#include <root_domain_map.h>
|
||||
#include <transform_iter.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
|
@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/inlining.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
|
||||
#include <inlining.h>
|
||||
#include <root_domain_map.h>
|
||||
#include <transform_replay.h>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/Exception.h>
|
@ -1,10 +1,10 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
|
||||
#include <compute_at_map.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
|
||||
#include <disjoint_set.h>
|
||||
#include <ir_utils.h>
|
||||
#include <lower2device.h>
|
||||
#include <root_domain_map.h>
|
||||
#include <transform_iter.h>
|
||||
|
||||
#include <tuple>
|
||||
|
||||
@ -431,7 +431,7 @@ void IterDomainGraph::build(Fusion* fusion) {
|
||||
// might not be a compute at leaf domain of `p_tv`, but it actually
|
||||
// has an equivalent compute at leaf domain. For that case, we map
|
||||
// the equivalent compute at leaf domain.
|
||||
for (int i = 0; i < p_tv->getComputeAtPosition(); i++) {
|
||||
for (unsigned int i = 0; i < p_tv->getComputeAtPosition(); i++) {
|
||||
auto id = p_tv->axis(i);
|
||||
if (permissive_disjoint_sets.permissiveAreMapped(p_id, id)) {
|
||||
loop_nodes_.mapEntries(c_id, id);
|
@ -1,9 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
|
||||
#include <disjoint_set.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <kernel_ir.h>
|
||||
#include <lower_trivial_reductions.h>
|
||||
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
@ -1,8 +1,8 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <ir_utils.h>
|
||||
#include <iter_visitor.h>
|
||||
#include <lower2device.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/contiguity.h>
|
||||
#include <contiguity.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -135,7 +135,7 @@ void OrderedIdInformation::handle(Merge* merge) {
|
||||
// Update maps
|
||||
// Find the position inner would have to have to be considered ordered
|
||||
auto pos_after_outer = outer_pos + 1;
|
||||
for (; pos_after_outer < active_ids_.size(); pos_after_outer++) {
|
||||
for (; pos_after_outer < int64_t(active_ids_.size()); pos_after_outer++) {
|
||||
if (active_ids_[pos_after_outer] == nullptr) {
|
||||
// Can't be considered ordered after a nullptr
|
||||
break;
|
@ -2,11 +2,11 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h>
|
||||
#include <compute_at_map.h>
|
||||
#include <disjoint_set.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <lower_shift.h>
|
||||
#include <lower_trivial_broadcast.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -9,7 +9,7 @@
|
||||
#include <vector>
|
||||
|
||||
// For printing of the set when using a Statement as the type for the set
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
|
||||
#include <ir_base_nodes.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -1,8 +1,8 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <type.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
||||
#include <dispatch.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -3,7 +3,7 @@
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
#include <utils.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 94 KiB After Width: | Height: | Size: 94 KiB |
@ -1,10 +1,10 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <expr_evaluator.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_utils.h>
|
||||
#include <kernel_expr_evaluator.h>
|
||||
#include <lower2device.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/evaluator_common.h>
|
||||
#include <evaluator_common.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -1,10 +1,10 @@
|
||||
#pragma once
|
||||
#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_kernel_arg.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <dynamic_type.h>
|
||||
#include <executor_kernel_arg.h>
|
||||
#include <executor_launch_params.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <lower2device.h>
|
||||
|
||||
#include <c10/core/DeviceType.h>
|
||||
|
@ -1,21 +1,22 @@
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/executor.h>
|
||||
#include <executor.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/codegen.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_kernel_arg.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_bank_conflict.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
#include <codegen.h>
|
||||
#include <executor_kernel_arg.h>
|
||||
#include <executor_utils.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_utils.h>
|
||||
#include <iter_visitor.h>
|
||||
#include <kernel_ir.h>
|
||||
#include <lower_bank_conflict.h>
|
||||
#include <utils.h>
|
||||
|
||||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/llvm_jit_strings.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
@ -877,7 +878,7 @@ KernelArgumentHolder FusionExecutor::inferOutputSizes(
|
||||
executor_entry = &executor_entry_lookup_[*opt_code];
|
||||
}
|
||||
|
||||
executor_utils::initializeCudaContext();
|
||||
at::cuda::jit::initializeCudaContext();
|
||||
TORCH_INTERNAL_ASSERT(lowered_);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
@ -975,7 +976,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
|
||||
|
||||
c10::DeviceGuard dg(options_.device);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
executor_utils::initializeCudaContext();
|
||||
at::cuda::jit::initializeCudaContext();
|
||||
TORCH_INTERNAL_ASSERT(lowered_);
|
||||
launch_params_ = LaunchParams();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
@ -1258,7 +1259,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
|
||||
|
||||
if (execute_kernel_) {
|
||||
if (maybe_available_dynamic_smem_.has_value() &&
|
||||
launch_params_.smem() > maybe_available_dynamic_smem_.value()) {
|
||||
size_t(launch_params_.smem()) > maybe_available_dynamic_smem_.value()) {
|
||||
#ifndef USE_ROCM
|
||||
// Increase limit of dynamic shared memory if needed.
|
||||
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuFuncSetAttribute(
|
@ -1,13 +1,13 @@
|
||||
#pragma once
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
#include <executor_launch_params.h>
|
||||
#include <executor_utils.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_cloner.h>
|
||||
#include <ir_printer.h>
|
||||
#include <kernel_expr_evaluator.h>
|
||||
#include <lower2device.h>
|
||||
#include <utils.h>
|
||||
|
||||
#include <c10/core/DeviceType.h>
|
||||
|
||||
@ -261,7 +261,7 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
|
||||
// See:
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-7-x
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-8-x
|
||||
const int max_static_smem_ = 48 << 10;
|
||||
const uint64_t max_static_smem_ = 48 << 10;
|
||||
int warp_size_ = 0;
|
||||
executor_utils::NvrtcFunction compiled_kernel_;
|
||||
|
@ -1,9 +1,9 @@
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
// Extract size and strides
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
|
||||
#include <kernel_cache.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_kernel_arg.h>
|
||||
#include <executor_kernel_arg.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -3,7 +3,7 @@
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <type.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <array>
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
|
||||
#include <executor_launch_params.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <type.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -1,15 +1,16 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <ATen/native/cuda/jit_utils.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/contiguity.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <contiguity.h>
|
||||
#include <executor_utils.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
|
||||
@ -926,18 +927,6 @@ ExpressionEvaluator bindFusionInputs(
|
||||
return expr_eval;
|
||||
}
|
||||
|
||||
void initializeCudaContext() {
|
||||
// lazily construct context if non-existing yet;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
CUcontext pctx = nullptr;
|
||||
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
std::unique_lock<std::mutex> cudaFreeMutexLock(
|
||||
*(c10::cuda::getFreeMutex()));
|
||||
C10_CUDA_CHECK(cudaFree(nullptr));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Dump PTX or CUBIN to a file
|
||||
@ -979,7 +968,7 @@ std::pair<NvrtcFunction, std::string> nvrtcCompile(
|
||||
"NVFuser Compile: arch check disabled, should not compile any kernel");
|
||||
}
|
||||
|
||||
initializeCudaContext();
|
||||
at::cuda::jit::initializeCudaContext();
|
||||
|
||||
std::stringstream ptxas_log;
|
||||
|
@ -9,13 +9,13 @@
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_kernel_arg.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <executor_kernel_arg.h>
|
||||
#include <expr_evaluator.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <kernel.h>
|
||||
#include <kernel_expr_evaluator.h>
|
||||
#include <lower2device.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -54,8 +54,6 @@ struct NvrtcFunction {
|
||||
CUfunction function = CUfunction();
|
||||
};
|
||||
|
||||
void initializeCudaContext();
|
||||
|
||||
// Returns executable function and the ptxas log from compilation
|
||||
std::pair<NvrtcFunction, std::string> nvrtcCompile(
|
||||
const std::string& code,
|
@ -1,10 +1,10 @@
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/evaluator_common.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <evaluator_common.h>
|
||||
#include <expr_evaluator.h>
|
||||
#include <fusion.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_iostream.h>
|
||||
|
||||
#include <iostream>
|
||||
|
@ -1,9 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
|
||||
#include <dynamic_type.h>
|
||||
#include <ir_interface_nodes.h>
|
||||
#include <iter_visitor.h>
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
|
@ -1,17 +1,17 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/codegen.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_bank_conflict.h>
|
||||
#include <arith.h>
|
||||
#include <codegen.h>
|
||||
#include <disjoint_set.h>
|
||||
#include <fusion.h>
|
||||
#include <fusion_segmenter.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_cloner.h>
|
||||
#include <ir_printer.h>
|
||||
#include <ir_utils.h>
|
||||
#include <iter_visitor.h>
|
||||
#include <kernel.h>
|
||||
#include <lower2device.h>
|
||||
#include <lower_bank_conflict.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -4,9 +4,9 @@
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_container.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
|
||||
#include <ir_base_nodes.h>
|
||||
#include <ir_container.h>
|
||||
#include <iter_visitor.h>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
@ -1,13 +1,13 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_graphviz.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/debug_utils.h>
|
||||
#include <arith.h>
|
||||
#include <fusion.h>
|
||||
#include <fusion_segmenter.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_cloner.h>
|
||||
#include <ir_graphviz.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <ir_utils.h>
|
||||
#include <scheduler/debug_utils.h>
|
||||
|
||||
#include <sstream>
|
||||
|
@ -1,11 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_base_nodes.h>
|
||||
#include <kernel_cache.h>
|
||||
#include <scheduler/all_schedulers.h>
|
||||
#include <scheduler/registry.h>
|
||||
#include <utils.h>
|
||||
|
||||
#include <deque>
|
||||
#include <list>
|
@ -2,12 +2,11 @@
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/interface.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/parser.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/partition.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_view.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
#include <instrumentation.h>
|
||||
#include <parser.h>
|
||||
#include <partition.h>
|
||||
#include <transform_view.h>
|
||||
#include <utils.h>
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
@ -1,9 +1,9 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
|
||||
#include <ir_builder.h>
|
||||
#include <ir_utils.h>
|
||||
#include <root_domain_map.h>
|
||||
#include <transform_iter.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/grouped_reduction.h>
|
||||
#include <grouped_reduction.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <ir_all_nodes.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -1,26 +1,26 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/index_compute.h>
|
||||
#include <index_compute.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/contiguity.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_double_buffer.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_index_compute.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
|
||||
#include <arith.h>
|
||||
#include <contiguity.h>
|
||||
#include <expr_evaluator.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <ir_utils.h>
|
||||
#include <kernel_expr_evaluator.h>
|
||||
#include <lower2device.h>
|
||||
#include <lower_double_buffer.h>
|
||||
#include <lower_index_compute.h>
|
||||
#include <lower_magic_zero.h>
|
||||
#include <lower_shift.h>
|
||||
#include <lower_unroll.h>
|
||||
#include <lower_utils.h>
|
||||
#include <lower_validation.h>
|
||||
#include <root_domain_map.h>
|
||||
#include <transform_iter.h>
|
||||
#include <transform_replay.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
||||
#include <iter_visitor.h>
|
||||
#include <root_domain_map.h>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
@ -1,7 +1,7 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/inlining.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
|
||||
#include <inlining.h>
|
||||
#include <ir_utils.h>
|
||||
#include <root_domain_map.h>
|
||||
#include <transform_iter.h>
|
||||
|
||||
#include <utility>
|
||||
|
||||
@ -210,7 +210,7 @@ FindMappedPositions::FindMappedPositions(
|
||||
reference_pos += int64_t(reference->nDims()) + 1;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
reference_pos >= 0 && reference_pos <= reference->nDims(),
|
||||
reference_pos >= 0 && reference_pos <= int64_t(reference->nDims()),
|
||||
"Invalid axis received ",
|
||||
reference_pos,
|
||||
" but should be > -",
|
@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/maxinfo_propagator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
|
||||
#include <ir_interface_nodes.h>
|
||||
#include <maxinfo_propagator.h>
|
||||
#include <transform_replay.h>
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
@ -1,4 +1,4 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <instrumentation.h>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
#include <utils.h>
|
||||
|
||||
#include <nvToolsExt.h>
|
||||
|
8
third_party/nvfuser/csrc/ir_all_nodes.h
vendored
Normal file
8
third_party/nvfuser/csrc/ir_all_nodes.h
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <ir_base_nodes.h>
|
||||
#include <ir_interface_nodes.h>
|
||||
#include <ir_internal_nodes.h>
|
||||
|
||||
// TODO: remove this once the Kernel IR split is complete
|
||||
#include <kernel_ir.h>
|
@ -1,14 +1,14 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/mutator.h>
|
||||
#include <dispatch.h>
|
||||
#include <expr_evaluator.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_builder.h>
|
||||
#include <ir_cloner.h>
|
||||
#include <ir_printer.h>
|
||||
#include <kernel.h>
|
||||
#include <kernel_ir.h>
|
||||
#include <kernel_ir_dispatch.h>
|
||||
#include <mutator.h>
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
@ -5,8 +5,8 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
#include <type.h>
|
||||
#include <utils.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
@ -1,7 +1,7 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_builder.h>
|
||||
#include <ir_cloner.h>
|
||||
#include <kernel.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_container.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_container.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -1,8 +1,8 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
|
||||
#include <ir_cloner.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_builder.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <dispatch.h>
|
||||
#include <ir_builder.h>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
@ -1,7 +1,7 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_container.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_builder.h>
|
||||
#include <ir_cloner.h>
|
||||
#include <ir_container.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
#include <ir_base_nodes.h>
|
||||
#include <utils.h>
|
||||
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
@ -1,9 +1,9 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_graphviz.h>
|
||||
#include <ir_graphviz.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_builder.h>
|
||||
#include <type.h>
|
||||
|
||||
#include <fstream>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
||||
#include <dispatch.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
@ -2,10 +2,10 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_internal_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/mma_type.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_base_nodes.h>
|
||||
#include <ir_internal_nodes.h>
|
||||
#include <mma_type.h>
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/mma_type.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_base_nodes.h>
|
||||
#include <mma_type.h>
|
||||
#include <parallel_type_bitmap.h>
|
||||
|
||||
//! Nodes in here should generally not be used by users. They should be behind
|
||||
//! the scenes and users shouldn't have to be aware of what they do to use the
|
@ -1,12 +1,12 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <ir_printer.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
||||
#include <fusion.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_utils.h>
|
||||
#include <kernel.h>
|
||||
#include <lower_utils.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
||||
#include <dispatch.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
|
@ -1,16 +1,16 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_view.h>
|
||||
#include <arith.h>
|
||||
#include <disjoint_set.h>
|
||||
#include <ir_cloner.h>
|
||||
#include <ir_interface_nodes.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <ir_utils.h>
|
||||
#include <kernel.h>
|
||||
#include <kernel_ir.h>
|
||||
#include <lower2device.h>
|
||||
#include <root_domain_map.h>
|
||||
#include <transform_iter.h>
|
||||
#include <transform_rfactor.h>
|
||||
#include <transform_view.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
@ -2560,17 +2560,19 @@ TensorDomain* TensorDomain::flatten(int64_t start_dim, int64_t end_dim) {
|
||||
end_dim += inp_domain.size();
|
||||
}
|
||||
TORCH_CHECK(
|
||||
start_dim >= 0 && start_dim < inp_domain.size(),
|
||||
start_dim >= 0 && start_dim < int64_t(inp_domain.size()),
|
||||
"Invalid start_dim ",
|
||||
start_dim);
|
||||
TORCH_CHECK(
|
||||
end_dim >= 0 && end_dim < inp_domain.size(), "Invalid end_dim ", end_dim);
|
||||
end_dim >= 0 && end_dim < int64_t(inp_domain.size()),
|
||||
"Invalid end_dim ",
|
||||
end_dim);
|
||||
TORCH_CHECK(start_dim <= end_dim, "start_dim must be <= end_dim");
|
||||
|
||||
std::vector<IterDomain*> new_root_domain;
|
||||
new_root_domain.reserve(inp_domain.size());
|
||||
for (auto i : c10::irange(inp_domain.size())) {
|
||||
bool is_rfactor_dim = i >= start_dim && i <= end_dim;
|
||||
bool is_rfactor_dim = i >= size_t(start_dim) && i <= size_t(end_dim);
|
||||
auto inp_id = inp_domain[i];
|
||||
auto out_id = IterDomainBuilder(inp_id)
|
||||
.is_rfactor_domain(is_rfactor_dim)
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <iter_visitor.h>
|
||||
|
||||
#include <iostream>
|
||||
|
@ -1,9 +1,9 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
||||
#include <arith.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_builder.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <ir_utils.h>
|
||||
#include <lower_utils.h>
|
||||
|
||||
#include <set>
|
||||
|
||||
@ -569,7 +569,7 @@ std::vector<T*> uniqueEntries(const std::vector<T*>& tv_deuqe) {
|
||||
} // namespace
|
||||
|
||||
// Return immediate producers of val
|
||||
TORCH_CUDA_CU_API std::vector<Val*> producerValsOf(Val* val) {
|
||||
std::vector<Val*> producerValsOf(Val* val) {
|
||||
if (val->definition() == nullptr) {
|
||||
return {};
|
||||
}
|
||||
@ -578,7 +578,7 @@ TORCH_CUDA_CU_API std::vector<Val*> producerValsOf(Val* val) {
|
||||
}
|
||||
|
||||
// Return immediate consumers of val
|
||||
TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(Val* val) {
|
||||
std::vector<Val*> consumerValsOf(Val* val) {
|
||||
std::vector<Val*> consumer_vals;
|
||||
for (auto use_expr : val->uses()) {
|
||||
auto outputs = use_expr->outputs();
|
||||
@ -588,7 +588,7 @@ TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(Val* val) {
|
||||
}
|
||||
|
||||
// Return immediate siblings of val
|
||||
TORCH_CUDA_CU_API std::vector<Val*> siblingValsOf(Val* val) {
|
||||
std::vector<Val*> siblingValsOf(Val* val) {
|
||||
std::vector<Val*> sibling_vals;
|
||||
auto def = val->definition();
|
||||
if (def != nullptr) {
|
||||
@ -604,8 +604,7 @@ TORCH_CUDA_CU_API std::vector<Val*> siblingValsOf(Val* val) {
|
||||
}
|
||||
|
||||
// Return immediate producers of val
|
||||
TORCH_CUDA_CU_API std::vector<Val*> producerValsOf(
|
||||
const std::vector<Val*>& vals) {
|
||||
std::vector<Val*> producerValsOf(const std::vector<Val*>& vals) {
|
||||
std::vector<Val*> all_producer_vals;
|
||||
for (auto val : vals) {
|
||||
auto producer_vals = producerValsOf(val);
|
||||
@ -617,8 +616,7 @@ TORCH_CUDA_CU_API std::vector<Val*> producerValsOf(
|
||||
}
|
||||
|
||||
// Return immediate consumers of val
|
||||
TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(
|
||||
const std::vector<Val*>& vals) {
|
||||
std::vector<Val*> consumerValsOf(const std::vector<Val*>& vals) {
|
||||
std::vector<Val*> all_consumer_vals;
|
||||
for (auto val : vals) {
|
||||
auto consumer_vals = consumerValsOf(val);
|
||||
@ -641,7 +639,7 @@ std::vector<TensorView*> consumerTvsOf(TensorView* tv) {
|
||||
return {consumer_tvs.begin(), consumer_tvs.end()};
|
||||
}
|
||||
|
||||
TORCH_CUDA_CU_API std::vector<TensorView*> siblingTvsOf(TensorView* tv) {
|
||||
std::vector<TensorView*> siblingTvsOf(TensorView* tv) {
|
||||
auto sibling_vals = siblingValsOf(tv);
|
||||
auto sibling_tvs = ir_utils::filterByType<TensorView>(sibling_vals);
|
||||
return {sibling_tvs.begin(), sibling_tvs.end()};
|
||||
@ -879,7 +877,7 @@ bool isReductionTvOp(const Expr* expr) {
|
||||
return ir_utils::isTvOp(expr) && isReductionOp(expr);
|
||||
}
|
||||
|
||||
TORCH_CUDA_CU_API std::vector<ViewOp*> getViewOps(Fusion* fusion) {
|
||||
std::vector<ViewOp*> getViewOps(Fusion* fusion) {
|
||||
auto all_exprs = fusion->exprs();
|
||||
|
||||
auto all_view_ops = ir_utils::filterByType<ViewOp>(all_exprs);
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <type.h>
|
||||
|
||||
#include <iterator>
|
||||
#include <unordered_map>
|
@ -1,10 +1,10 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
|
||||
#include <iter_visitor.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <ir_utils.h>
|
||||
#include <type.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <dispatch.h>
|
||||
#include <type.h>
|
||||
|
||||
#include <deque>
|
||||
#include <unordered_set>
|
@ -1,9 +1,9 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <kernel.h>
|
||||
#include <kernel_expr_evaluator.h>
|
||||
#include <kernel_ir_dispatch.h>
|
||||
#include <lower2device.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
@ -2,14 +2,14 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_sync_information.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_warp_reduce.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/vectorization_info.h>
|
||||
#include <fusion.h>
|
||||
#include <ir_base_nodes.h>
|
||||
#include <ir_builder.h>
|
||||
#include <lower_sync_information.h>
|
||||
#include <lower_warp_reduce.h>
|
||||
#include <parallel_dimension_map.h>
|
||||
#include <utils.h>
|
||||
#include <vectorization_info.h>
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
@ -1,10 +1,10 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
|
||||
#include <kernel_cache.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/parser.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/debug_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_utils.h>
|
||||
#include <parser.h>
|
||||
#include <scheduler/debug_utils.h>
|
||||
#include <scheduler/registry.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
|
||||
@ -209,7 +209,7 @@ std::vector<at::Tensor> FusionExecutorCache::runFusionWithInputs(
|
||||
// permute output tensor returned by kernel execution. See Part_3 in Note [
|
||||
// Permutation support in nvfuser ]
|
||||
for (const auto& pair : fusion_->getPermutationOutputMap()) {
|
||||
if (pair.first < outputs.size()) {
|
||||
if (size_t(pair.first) < outputs.size()) {
|
||||
outputs[pair.first] = outputs[pair.first].permute(pair.second);
|
||||
}
|
||||
}
|
@ -1,11 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/evaluator_common.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/executor.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
|
||||
#include <evaluator_common.h>
|
||||
#include <executor.h>
|
||||
#include <fusion.h>
|
||||
#include <fusion_segmenter.h>
|
||||
#include <scheduler/all_schedulers.h>
|
||||
#include <scheduler/registry.h>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/ArrayRef.h>
|
@ -1,6 +1,6 @@
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
||||
#include <instrumentation.h>
|
||||
#include <kernel_expr_evaluator.h>
|
||||
|
||||
#include <iostream>
|
||||
|
@ -3,10 +3,10 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/evaluator_common.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <dispatch.h>
|
||||
#include <dynamic_type.h>
|
||||
#include <evaluator_common.h>
|
||||
#include <kernel_ir.h>
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
|
@ -1,10 +1,10 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <ir_builder.h>
|
||||
#include <kernel.h>
|
||||
#include <kernel_expr_evaluator.h>
|
||||
#include <kernel_ir.h>
|
||||
#include <lower2device.h>
|
||||
#include <lower_utils.h>
|
||||
#include <type.h>
|
||||
|
||||
#include <iostream>
|
||||
|
@ -1,10 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <ir_base_nodes.h>
|
||||
#include <parallel_type_bitmap.h>
|
||||
#include <type.h>
|
||||
#include <utils.h>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/Optional.h>
|
@ -1,5 +1,5 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
|
||||
#include <kernel_ir.h>
|
||||
#include <kernel_ir_dispatch.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
||||
#include <dispatch.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
@ -1,31 +1,31 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <lower2device.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_alias_memory.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_divisible_split.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_double_buffer.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_expr_sort.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_index.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_insert_syncs.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_instrument.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_loops.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_replace_size.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_warp_reduce.h>
|
||||
#include <expr_evaluator.h>
|
||||
#include <fusion.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <ir_utils.h>
|
||||
#include <lower_alias_memory.h>
|
||||
#include <lower_allocation.h>
|
||||
#include <lower_divisible_split.h>
|
||||
#include <lower_double_buffer.h>
|
||||
#include <lower_expr_sort.h>
|
||||
#include <lower_fusion_simplifier.h>
|
||||
#include <lower_index.h>
|
||||
#include <lower_insert_syncs.h>
|
||||
#include <lower_instrument.h>
|
||||
#include <lower_loops.h>
|
||||
#include <lower_magic_zero.h>
|
||||
#include <lower_misaligned_vectorization.h>
|
||||
#include <lower_predicate.h>
|
||||
#include <lower_replace_size.h>
|
||||
#include <lower_shift.h>
|
||||
#include <lower_trivial_reductions.h>
|
||||
#include <lower_unroll.h>
|
||||
#include <lower_utils.h>
|
||||
#include <lower_validation.h>
|
||||
#include <lower_warp_reduce.h>
|
||||
|
||||
#include <list>
|
||||
#include <unordered_map>
|
@ -2,27 +2,27 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_double_buffer.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_fused_reduction.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_index_hoist.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_predicate_elimination.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_sync_information.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_warp_reduce.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/non_divisible_split.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/partial_split_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/vectorization_info.h>
|
||||
#include <compute_at_map.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <kernel.h>
|
||||
#include <kernel_ir.h>
|
||||
#include <lower_allocation.h>
|
||||
#include <lower_double_buffer.h>
|
||||
#include <lower_fused_reduction.h>
|
||||
#include <lower_index_hoist.h>
|
||||
#include <lower_predicate.h>
|
||||
#include <lower_predicate_elimination.h>
|
||||
#include <lower_shift.h>
|
||||
#include <lower_sync_information.h>
|
||||
#include <lower_thread_predicate.h>
|
||||
#include <lower_trivial_broadcast.h>
|
||||
#include <lower_trivial_reductions.h>
|
||||
#include <lower_warp_reduce.h>
|
||||
#include <non_divisible_split.h>
|
||||
#include <parallel_dimension_map.h>
|
||||
#include <partial_split_map.h>
|
||||
#include <root_domain_map.h>
|
||||
#include <vectorization_info.h>
|
||||
|
||||
#include <memory>
|
||||
#include <ostream>
|
@ -1,12 +1,12 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_alias_memory.h>
|
||||
#include <lower_alias_memory.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <ir_utils.h>
|
||||
#include <kernel_expr_evaluator.h>
|
||||
#include <kernel_ir.h>
|
||||
#include <lower2device.h>
|
||||
#include <lower_utils.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <dispatch.h>
|
||||
#include <ir_all_nodes.h>
|
||||
|
||||
#include <vector>
|
||||
|
@ -1,10 +1,10 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
|
||||
#include <instrumentation.h>
|
||||
#include <ir_iostream.h>
|
||||
#include <kernel_expr_evaluator.h>
|
||||
#include <kernel_ir.h>
|
||||
#include <kernel_ir_dispatch.h>
|
||||
#include <lower2device.h>
|
||||
#include <lower_allocation.h>
|
||||
|
||||
#include <unordered_set>
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <ir_all_nodes.h>
|
||||
#include <kernel_ir.h>
|
||||
|
||||
#include <vector>
|
||||
|
@ -1,10 +1,10 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_bank_conflict.h>
|
||||
#include <lower_bank_conflict.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/type.h>
|
||||
#include <dynamic_type.h>
|
||||
#include <expr_evaluator.h>
|
||||
#include <kernel_ir.h>
|
||||
#include <kernel_ir_dispatch.h>
|
||||
#include <type.h>
|
||||
|
||||
#include <unordered_set>
|
||||
|
@ -1,9 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
||||
#include <dynamic_type.h>
|
||||
#include <executor_launch_params.h>
|
||||
#include <ir_base_nodes.h>
|
||||
#include <kernel.h>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
@ -1,8 +1,8 @@
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_divisible_split.h>
|
||||
#include <lower_divisible_split.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <disjoint_set.h>
|
||||
#include <ir_utils.h>
|
||||
|
||||
#include <unordered_set>
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user