Supports `torch.utils.cpp_extension.load_inline` on Windows with ROCm.
Tested on Windows with gfx1201.
Note that it currently only works when CC and CXX are set to `clang-cl`. This is also needed when building extensions via. `setuptools` due to linker errors when using `cl` directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162577
Approved by: https://github.com/ezyang
Users can specify the following to get a libtorch_free `.so`.
"aot_inductor.use_libtorch": False,
The following config is only used for torchnative (see https://github.com/meta-pytorch/torchnative/pull/110). It's not intended to be used by executorch. The reason we need it for torchnative is because a lot of the symbol definitions in torchnative repo is only in header files.
"aot_inductor.libtorch_free_header": "/data/users/shangdiy/torchnative/standalone,/data/users/shangdiy/torchnative/" (or their custom headers)
The main motivating use case is for executorch to produce a libtorch free `.so`.
TODO for follow-up PR: this flag should be consolidated with the `compile_standalone` flag.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162655
Approved by: https://github.com/angelayi
With the new change we only log the warning if we're running non distributed code or if we're in rank 0. Unit testing that certain messages get printed on certain ranks only feels kinda jank so test plan is below instead
Test plan
```python
# torchrun --nproc_per_node=2 demo_fix.py
import os
import logging
logging.getLogger('torch.utils.cpp_extension').setLevel(logging.DEBUG)
import torch
if 'RANK' in os.environ:
torch.distributed.init_process_group('nccl')
from torch.utils.cpp_extension import _get_cuda_arch_flags
_get_cuda_arch_flags()
print(f"Rank {os.environ.get('RANK', '0')} done")
```
Logs showing how how `TORCH_CUDA_ARCH_LIST`only shows up once if we explicitly set the the logging level to `logging.DEBUG`. It also improves the debug message to explain what the actual behavior will be
```
(source) [marksaroufim@devgpu005]~% torchrun --nproc_per_node=2 demo_fix.py
W0911 18:30:16.594000 1315439 /home/marksaroufim/pytorch/torch/distributed/run.py:814]
W0911 18:30:16.594000 1315439 /home/marksaroufim/pytorch/torch/distributed/run.py:814] *****************************************
W0911 18:30:16.594000 1315439 /home/marksaroufim/pytorch/torch/distributed/run.py:814] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0911 18:30:16.594000 1315439 /home/marksaroufim/pytorch/torch/distributed/run.py:814] *****************************************
[rank0]:V0911 18:30:18.921000 1316753 pytorch/torch/utils/cpp_extension.py:2444] TORCH_CUDA_ARCH_LIST is not set, using TORCH_CUDA_ARCH_LIST='10.0+PTX' for visible GPU architectures. Set os.environ['TORCH_CUDA_ARCH_LIST'] to override.
Rank 0 done
Rank 1 done
```
But if we just use the default and comment out `logging.getLogger('torch.utils.cpp_extension').setLevel(logging.DEBUG)`
Then we get
```
(source) [marksaroufim@devgpu005]~% torchrun --nproc_per_node=2 demo_fix.py
W0911 18:14:33.926000 690759 /home/marksaroufim/pytorch/torch/distributed/run.py:814]
W0911 18:14:33.926000 690759 /home/marksaroufim/pytorch/torch/distributed/run.py:814] *****************************************
W0911 18:14:33.926000 690759 /home/marksaroufim/pytorch/torch/distributed/run.py:814] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0911 18:14:33.926000 690759 /home/marksaroufim/pytorch/torch/distributed/run.py:814] *****************************************
Rank 0 done
Rank 1 done
(source) [marksaroufim@devgpu005]~%
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162764
Approved by: https://github.com/ezyang, https://github.com/zou3519
This is a similar change to https://github.com/pytorch/pytorch/pull/153986, this time adding flags to the hipcc command under `cpp_extension.py`.
The `-Wno-ignored-attributes` flag in particular avoids about 200MB of warning spam when building torchvision, like these:
```
In file included from D:\b\vision_main\torchvision\csrc\ops\hip\deform_conv2d_kernel.hip:72:
In file included from D:\projects\TheRock\external-builds\pytorch\.venv\Lib\site-packages\torch\include\ATen/ATen.h:13:
In file included from D:\projects\TheRock\external-builds\pytorch\.venv\Lib\site-packages\torch\include\ATen/Functions.h:386:
In file included from D:\projects\TheRock\external-builds\pytorch\.venv\Lib\site-packages\torch\include\ATen/ops/_sparse_softmax.h:21:
D:\projects\TheRock\external-builds\pytorch\.venv\Lib\site-packages\torch\include\ATen/ops/_sparse_softmax_ops.h:18:8: warning: __declspec attribute 'dllimport' is not supported [-Wignored-attributes]
18 | struct TORCH_API _sparse_softmax_int {
| ^~~~~~~~~
D:\projects\TheRock\external-builds\pytorch\.venv\Lib\site-packages\torch\include\torch/headeronly/macros/Export.h💯19: note: expanded from macro 'TORCH_API'
100 | #define TORCH_API C10_IMPORT
| ^~~~~~~~~~
D:\projects\TheRock\external-builds\pytorch\.venv\Lib\site-packages\torch\include\torch/headeronly/macros/Export.h:53:31: note: expanded from macro 'C10_IMPORT'
53 | #define C10_IMPORT __declspec(dllimport)
| ^~~~~~~~~
```
The `-fms-extensions` flag just seems beneficial to include: https://clang.llvm.org/docs/MSVCCompatibility.html.
See also this downstream issue where these changes were tested: https://github.com/ROCm/TheRock/issues/910.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159790
Approved by: https://github.com/jeffdaily
With the legacy driver (nvgpu) used for CUDA 12.9, Thor was operating with SM 10.1.
This changes to SM 11.0 when the newer driver model (OpenRM), which is intended for CUDA 13.0, is introduced.
Thor 10.1 --> 11.0
Spark 12.1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156176
Approved by: https://github.com/ezyang
Previeusly, if users want to let pytorch determine the cuda arch when jit loading cuda extensions, they should left environment variable `TORCH_CUDA_ARCH_LIST` empty, but which will raise an warning. This commit add an option to set `TORCH_CUDA_ARCH_LIST=native`, to tell pytorch users want to use native cuda arch intentionally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156923
Approved by: https://github.com/ezyang
When CC and CXX compiler is set to clang, and clang was compiled with libc++, compilation of torchvision fails with:
```
File "/usr/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 585, in build_extensions
compiler_name, compiler_version = self._check_abi()
^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 1034, in _check_abi
_, version = get_compiler_abi_compatibility_and_version(compiler)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 449, in get_compiler_abi_compatibility_and_version
if tuple(map(int, version)) >= minimum_required_version:
^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: invalid literal for int() with base 10: '7+libcxx'
```
Compiler identification is a valid semantic version:
```
$ clang -dumpfullversion -dumpversion
20.1.7+libcxx
```
After adjusting parser of version, clang is able to compile extensions successfully.
Fixes#157665
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157666
Approved by: https://github.com/msaroufim
Fixes https://github.com/pytorch/pytorch/issues/156815
As far as testing goes
* I tried to use cuobjdump but that was kinda goofy bccd9393a5 the problem was that the name of the cubin will have a single gencode always
* Another idea was to read stderr and check that the right amount of gencodes is there 0beadc01b3 this helped a lot to convince me locally that this test works, the test passed on my dev gpu but was failing in CI and I suspect it's because of a bad interaction with subprocesses
* Last approach was to have a simpler unit test to check which flags get added by default, this is not as comprehensive as the previous ideas but it works and is fast so will opt for this since I'm convinced testing is working per my own experiments and customers
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156850
Approved by: https://github.com/malfet
Commit fixes AOT compilation in sycl cpp extension which got accidentally dropped on aca2c99a652 (fallback to JIT compilation had happened). Commit also fixes override logic for default sycl targets allowing flexibility to specify targets externally. Further, commit extends test coverage to cover such a case and fixes issue in the test where consequent tests executed same (first) compiled extension due to name conflicts.
Fixes: #156249
Fixes: aca2c99a652 ("xpu: get xpu arch flags at runtime in cpp_extensions (#152192)")
CC: @pengxin99, @guangyey
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156364
Approved by: https://github.com/ezyang
Today `cpp_extensions` makes heavy use of printing to stderr, this makes our life harder in KernelBot where we typically rely on stderr to only surface real errors but instead today cpp_extensions leverages stderr for updates that could be qualified as INFO, WARNING, ERROR
Now instead we'll recommend users of our cpp extension system to do something like
```python
import logging
cpp_ext_logger = logging.getLogger("torch.utils.cpp_extension")
cpp_ext_logger.setLevel(logging.WARNING)
```
While this dramatically reduces log spew, it can be viewed as a BC breaking change if people were relying on certain strings being present in stdout or stderr
Considering different teams might want to silence errors differently, this PR proposes replacing all `print()` statements with `logging` statements with the same heuristics that the python logging module recommends
1. DEBUG: For things like detailed compilation steps or reading filepaths - by default gets logged on stdout
2. INFO: Build progress - by default gets logged on stdout
3. WARNING: Surfacing issues that might cause bad performance or slow compilation times - by default gets logged on stdout
4. ERROR: Problems that prevent proper functioning - by default gets logged on stdout
Note that warnings.warn is a different library and is not hooked up to the python logging module by default
So the goal of this PR is to make it possible for teams to set the logging that is most appropriate to them. One annoying thing is logger throws ruff errors if you try to use it in conjunction with f strings or .format so have to use old school %s
An unrelated improvement I'd be happy to push to a seperate PR is adding support for "native" in `TORCH_CUDA_ARCH_LIST` which would just pick the ARCH for the current device
An example of what's in stderr today
```
Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py311_cu124/grayscale/build.ninja...
/usr/local/lib/python3.11/site-packages/torch/utils/cpp_extension.py:2059: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
Building extension module grayscale...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module grayscale...
/usr/local/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:679: UserWarning: Graph break due to unsupported builtin grayscale.PyCapsule.grayscale. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
torch._dynamo.utils.warn_once(msg)
```
Whereas after this PR users can do
`python benchmark_load_inline.py > >(tee stdout.txt) 2> >(tee stderr.txt >&2)`
```python
import os
import sys
from pathlib import Path
import shutil
import tempfile
import torch
from torch.utils.cpp_extension import load_inline
import logging
cpp_ext_logger = logging.getLogger("torch.utils.cpp_extension")
cpp_ext_logger.setLevel(logging.WARNING)
os.environ["TORCH_CUDA_ARCH_LIST"] = "native"
cpp_code = """
torch::Tensor to_gray(torch::Tensor input);
"""
cuda_kernel_code = """
torch::Tensor to_gray(torch::Tensor input) {
auto output = torch::epty({input.size(0), input.size(1)}, input.options());
return output ;
}
"""
# Avoid caching results
with tempfile.TemporaryDirectory() as build_dir:
cuda_module = load_inline(
name="to_gray_cuda",
cpp_sources=cpp_code,
cuda_sources=cuda_kernel_code,
functions=["to_gray"],
with_cuda=True,
verbose=True,
extra_cflags=["-std=c++17"], # "-ftime-report", "-H"],
extra_cuda_cflags=["-arch=sm_89"],
build_directory=build_dir,
)
```
## New logs
### On failure
Which gives a much more reasonable stdout
```
[1/3] /usr/local/cuda-12.8/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -DTORCH_EXTENSION_NAME=to_gray_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1016\" -isystem /home/marksaroufim/pytorch/torch/include -isystem /home/marksaroufim/pytorch/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-12.8/include -isystem /usr/local/cuda/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/nv/include/python3.10 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -arch=sm_89 -std=c++17 -c /tmp/tmpbg_xzv0r/cuda.cu -o cuda.cuda.o
FAILED: cuda.cuda.o
/usr/local/cuda-12.8/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -DTORCH_EXTENSION_NAME=to_gray_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1016\" -isystem /home/marksaroufim/pytorch/torch/include -isystem /home/marksaroufim/pytorch/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-12.8/include -isystem /usr/local/cuda/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/nv/include/python3.10 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -arch=sm_89 -std=c++17 -c /tmp/tmpbg_xzv0r/cuda.cu -o cuda.cuda.o
/tmp/tmpbg_xzv0r/cuda.cu(6): error: namespace "torch" has no member "epty"
auto output = torch::epty({input.size(0), input.size(1)}, input.options());
^
1 error detected in the compilation of "/tmp/tmpbg_xzv0r/cuda.cu".
[2/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=to_gray_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1016\" -isystem /home/marksaroufim/pytorch/torch/include -isystem /home/marksaroufim/pytorch/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-12.8/include -isystem /usr/local/cuda/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/nv/include/python3.10 -fPIC -std=c++17 -std=c++17 -c /tmp/tmpbg_xzv0r/main.cpp -o main.o
ninja: build stopped: subcommand failed.
```
And stderr
```
Traceback (most recent call last):
File "/home/marksaroufim/pytorch/torch/utils/cpp_extension.py", line 2874, in _run_ninja_build
subprocess.run(
File "/home/marksaroufim/.conda/envs/nv/lib/python3.10/subprocess.py", line 526, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/marksaroufim/load_inline_slow/benchmark_load_inline.py", line 30, in <module>
cuda_module = load_inline(
File "/home/marksaroufim/pytorch/torch/utils/cpp_extension.py", line 2261, in load_inline
return _jit_compile(
File "/home/marksaroufim/pytorch/torch/utils/cpp_extension.py", line 2367, in _jit_compile
_write_ninja_file_and_build_library(
File "/home/marksaroufim/pytorch/torch/utils/cpp_extension.py", line 2528, in _write_ninja_file_and_build_library
_run_ninja_build(
File "/home/marksaroufim/pytorch/torch/utils/cpp_extension.py", line 2892, in _run_ninja_build
raise RuntimeError(message) from e
RuntimeError: Error building extension 'to_gray_cuda'
```
### On success
stdout
```
[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=to_gray_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1016\" -isystem /home/marksaroufim/pytorch/torch/include -isystem /home/marksaroufim/pytorch/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-12.8/include -isystem /usr/local/cuda/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/nv/include/python3.10 -fPIC -std=c++17 -std=c++17 -c /tmp/tmpxv_ovlrf/main.cpp -o main.o
[2/3] /usr/local/cuda-12.8/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -DTORCH_EXTENSION_NAME=to_gray_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1016\" -isystem /home/marksaroufim/pytorch/torch/include -isystem /home/marksaroufim/pytorch/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-12.8/include -isystem /usr/local/cuda/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/nv/include/python3.10 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -arch=sm_89 -std=c++17 -c /tmp/tmpxv_ovlrf/cuda.cu -o cuda.cuda.o
[3/3] c++ main.o cuda.cuda.o -shared -L/home/marksaroufim/pytorch/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda-12.8/lib64 -lcudart -o to_gray_cuda.so
```
And an empty stderr as expected
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152260
Approved by: https://github.com/albanD
In this approach, we are catching any lane within a wave that is doing fastatomics to the same destination address and computing the sum on the CU. This is leading to 3x improvement in scatter_add performance and 2x improvement in index_select.
scatter_add performance on MI300x:
dtype|Baseline (before optimizations)|opportunistic fastatomics
-------|----------------------------------|----------------------------------
f32|1.389425039|0.430447996
fp16|2.195472956|0.779729486
bf16|2.194051027|0.784599513
Using the following reproducer
```
import torch
import triton
def main():
dtype = torch.float32
dim = 1305301
a = torch.rand(100, device="cuda", dtype=dtype)
index = torch.randint(0, 100, (dim,), device="cuda")
src = torch.rand(dim, device="cuda", dtype=dtype)
print("=" * 20)
print(
triton.testing.do_bench(
lambda: a.scatter_add(0, index, src),
return_mode="median",
)
)
print("=" * 20)
if __name__ == "__main__":
main()
```
co-authored by: @amd-hhashemi
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146264
Approved by: https://github.com/jeffdaily, https://github.com/mxz297
Co-authored-by: Hashem Hashemi <hashem.hashemi@amd.com>
Here are the following modifications made to cpp_extension.py- 1) Changed compiler flag to use --version.
2) Added a feature to convert alpha-numeric string to numeric string for the version string returned by compiler. This was the source of error as the parser was failing on parsing alpha-numeric version string.
Build with following pytorch extensions- Apex, TorchVision, TorchAudio & DeepSpeed.
Unit tested with following pytorch extensions- Apex, TorchVision.
(cherry picked from commit c873aeac35851a7d5000eb7f24561d3f56c2ffbd)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150451
Approved by: https://github.com/jeffdaily
Fixes:
- https://github.com/pytorch/pytorch/issues/101138
**Description**
The PR enhances error handling in `_check_cuda_version` by verifying the existence of the `nvcc` executable before invoking `subprocess.check_output`. If `nvcc` is missing, a `FileNotFoundError` is raised with a clear message, guiding users to check their CUDA installation and path configuration.
**Testing**
Manually tested with and without `nvcc` present in the expected path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148671
Approved by: https://github.com/malfet
In the kernelBot leaderboard we support people competing with custom cuda extensions via `load_inline()`, however even on toy kernels this can result in cold starts of up to 90s - this feature is primarily responsible for us having to double our timeout values
I performed an investigation here https://github.com/msaroufim/load_inline_slow and the primary cause was that torch/extension.h and torch/types.h add in about 5,000 header files https://github.com/msaroufim/load_inline_slow/blob/main/header-analysis
So we introduce a mode `no_implicit_headers` which forces users to be explicit about exactly what they want to add. There's a proper test meant to be used in a CLI and a pytest test that's not terribly helpful
Then there's still an open question around what's the most minimal example implementation we can provide. For the baseline kernel we're showing here, it takes about 1 min to compile
1. There's using TensorBase.h (finicky to get right but can get compilation times down to 7s)
2. Just using Tensor.h (down to 15s)
3. Using Shim.h (did not try yet since the syntax is verbose relative to cuda)
This is my take so far https://gist.github.com/msaroufim/079a8d08ffebd0f91a1c2247eb0ce9e0 for a minimal implementation at 15s but @malfet has a simpler one at only 5s
There's more things I'd like to try moving forward like nvrtc and fancier compilation flags. Typical advice around using precompiled headers does not apply to us because we are mostly interested in cold starts where we tear down the machine after running a kernel
Also in a future PR I'd like to fix issue I've noticed with load_inline
1. It needs a force recompilation mode, I was using this quite a bit myself
2. The cache does not take into account changes in environment so the best way to force a recompilation is to change some string in the file
3. Instead of relying on pybind, can we use TORCH_LIBRARY instead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149480
Approved by: https://github.com/malfet
- Updated HIP flags for Windows (removed non Windows flags on Windows case, added runtime library)
- Set hipcc call for Windows case
- Removed CUDA flags (not used in ROCm) on Windows
- Updated Windows compiler (added case when using ROCm on Windows)
- Fixed path issue in hipify_python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147382
Approved by: https://github.com/jeffdaily
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Fixes https://github.com/ROCm/hip/issues/3764.
Fixes and improvements to CUDA->HIP flag conversion for CPP extensions
- Log flag conversion for debugging purposes.
- Fix cases where it should not touch the -I flags or cases where CUDA appears more than once by replacing only the first instance.
- Fix case where nvcc key may not exist
- Fix case where hipify should ignore flag values and only touch the flag itself
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149245
Approved by: https://github.com/jeffdaily
Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>
This commit just aligns description of `py_limited_api` feature in SyclExtension with CPP/CUDA. We've missed this change on doing SyclExtension due to parallel work on the changes. For CPP/CUDA change was done in 515e55e6927ad5f57ec222d7779712630341acf3.
CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147988
Approved by: https://github.com/janeyx99, https://github.com/guangyey
This patch adds support for sycl kernels build via `torch.utils.cpp_extension.load`, `torch.utils.cpp_extension.load_inline` and (new) `class SyclExtension` APIs. Files having `.sycl` extension are considered to have sycl kernels and are compiled with `icpx` (dpc++ sycl compiler from Intel). Files with other extensions, `.cpp`, `.cu`, are handled as before. API supports building sycl along with other file types into single extension.
Note that `.sycl` file extension is a PyTorch convention for files containing sycl code which I propose to adopt. We did follow up with compiler team to introduce such file extension in the compiler, but they are opposed to this. At the same time discussion around sycl file extension and adding sycl language support into such tools as cmake is ongoing. Eventually cmake also considers to introduce some file extension convention for sycl. I hope we can further influence cmake and compiler communities to broader adopt `.sycl` file extension.
By default SYCL kernels are compiled for all Intel GPU devices for which pytorch native aten SYCL kernels are compiled. At the moment `pvc,xe-lpg`. This behavior can be overridden by setting `TORCH_XPU_ARCH_LIST` environment variables to the comma separated list of desired devices to compile for.
Fixes: #132944
CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132945
Approved by: https://github.com/albanD, https://github.com/guangyey, https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This patch adds support for sycl kernels build via `torch.utils.cpp_extension.load`, `torch.utils.cpp_extension.load_inline` and (new) `class SyclExtension` APIs. Files having `.sycl` extension are considered to have sycl kernels and are compiled with `icpx` (dpc++ sycl compiler from Intel). Files with other extensions, `.cpp`, `.cu`, are handled as before. API supports building sycl along with other file types into single extension.
Note that `.sycl` file extension is a PyTorch convention for files containing sycl code which I propose to adopt. We did follow up with compiler team to introduce such file extension in the compiler, but they are opposed to this. At the same time discussion around sycl file extension and adding sycl language support into such tools as cmake is ongoing. Eventually cmake also considers to introduce some file extension convention for sycl. I hope we can further influence cmake and compiler communities to broader adopt `.sycl` file extension.
By default SYCL kernels are compiled for all Intel GPU devices for which pytorch native aten SYCL kernels are compiled. At the moment `pvc,xe-lpg`. This behavior can be overridden by setting `TORCH_XPU_ARCH_LIST` environment variables to the comma separated list of desired devices to compile for.
Fixes: #132944
CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132945
Approved by: https://github.com/albanD, https://github.com/guangyey
This could be BC breaking, because there was a period of time when we use py_limited_api=True but don't enforce the flag, and now that we will start enforcing the flag, people's custom extensions may fail to build.
This is strictly still better behavior, as it is sketchy to claim CPython agnosticism without the flag, but calling this out as potential people yelling at us. Ways to mitigate this risk + reasons this may not be too big a deal:
- People haven't known about py_limited_api for extensions much due to lack of docs from python so usage is low right now
- My current tutorial is in store to make new users of py_limited_api pass this flag, so it'd be a noop for them.
Test plan:
* Locally i'm confident as I tried rebuilding ao with this change and it reliably failed (cuz importing torch/extension.h is a nono)
* Unit test wise, the normal python_agnostic one I added should work
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145764
Approved by: https://github.com/ezyang, https://github.com/zou3519, https://github.com/albanD
While working on conda-forge integration, I needed to look at the way the include paths are calculated, and noticed an avoidable duplication between `torch/utils/cpp_extension.py` and `torch/_inductor/cpp_builder.py`. The latter already imports the former anyway, so simply reuse the same function.
Furthermore, remove long-obsolete include-paths. AFAICT, the `/TH` headers have not existed since pytorch 1.11.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145480
Approved by: https://github.com/ezyang