Enable CPP/CUDAExtension with py_limited_api for python agnosticism (#138088)

Getting tested with ao, but now there is a real test i added.

## What does this PR do?

We want to allow custom PyTorch extensions to be able to build one wheel for multiple Python versions, in other words, achieve python agnosticism. It turns out that there is such a way that setuptools/Python provides already! Namely, if the user promises to use only the Python limited API in their extension, they can pass in `py_limited_api` to their Extension class and to the bdist_wheel command (with a min python version) in order to build 1 wheel that will suffice across multiple Python versions.

Sounds lovely! Why don't people do that already with PyTorch? Well 2 things. This workflow is hardly documented (even searching for python agnostic specifically does not reveal many answers) so I'd expect that people simply don't know about it. But even if they did, _PyTorch_ custom Extensions would still not work because we always link torch_python, which does not abide by py_limited_api rules.

So this is where this PR comes in! We respect when the user specifies py_limited_api and skip linking torch_python under that condition, allowing users to enroll in the provided functionality I just described.

## How do I know this PR works?

I manually tested my silly little ultra_norm locally (with `import python_agnostic`) and wrote a test case for the extension showing that
- torch_python doesn't show up in the ldd tree
- no Py- symbols show up
It may be a little confusing that our test case is actually python-free (more clean than python-agnostic) but it is sufficient (and not necessary) towards showing that this change works.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138088
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Jane Xu
2024-12-11 06:55:47 -08:00
committed by PyTorch MergeBot
parent fb02b40d27
commit be27dbf2b8
8 changed files with 233 additions and 18 deletions

View File

@ -48,7 +48,7 @@ if __name__ == "__main__":
name=PACKAGE_NAME,
version=version,
author="PyTorch Core Team",
description="Example for PyTorch out of tree regitration",
description="Example for PyTorch out of tree registration",
packages=find_packages(exclude=("test",)),
package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so"]},
install_requires=[

View File

@ -0,0 +1,26 @@
from pathlib import Path
import torch
so_files = list(Path(__file__).parent.glob("_C*.so"))
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])
from . import ops
# ----------------------------------------------------------------------------- #
# We've reached the end of what is normal in __init__ files.
# The following is used to assert the ultra_norm op is properly loaded and
# calculates correct results upon import of this extension.
inputs = [
torch.tensor([1.0, 2.0, 3.0], device="cuda"),
torch.tensor([-4.0, -5.0, -6.0], device="cuda"),
]
assert torch.equal(
ops.ultra_norm(inputs),
torch.norm(torch.tensor([1.0, 2.0, 3.0, -4.0, -5.0, -6.0], device="cuda")),
)

View File

@ -0,0 +1,19 @@
#include <ATen/ops/_foreach_norm_native.h>
#include <ATen/ops/cat_cuda_dispatch.h>
#include <ATen/ops/norm_cuda_dispatch.h>
#include <ATen/ops/unsqueeze.h>
#include <torch/extension.h>
at::Tensor ultra_norm(at::TensorList inputs) {
auto res = at::native::foreach_tensor_norm_cuda(inputs);
std::vector<at::Tensor> unsqueezed;
for (const auto& scalar_tensor : res) {
unsqueezed.push_back(at::unsqueeze(scalar_tensor, 0));
}
auto stacked = at::cuda::cat(unsqueezed);
return at::cuda::norm(stacked, 2, at::IntArrayRef{}, false);
}
TORCH_LIBRARY_IMPL(python_agnostic, CUDA, m) {
m.impl("python_agnostic::ultra_norm", &ultra_norm);
}

View File

@ -0,0 +1,26 @@
from typing import List
import torch
from torch import Tensor
lib = torch.library._scoped_library("python_agnostic", "FRAGMENT")
lib.define("ultra_norm(Tensor[] inputs) -> Tensor")
def ultra_norm(inputs: List[Tensor]) -> Tensor:
"""
Computes the ultra-L2-norm of a list of tensors via computing the norm of norms.
Assumes:
- inputs should not be empty
- all tensors in inputs should be on the same device and have the same dtype
Args:
inputs: list of torch.tensors
Returns:
Scalar torch.tensor of shape ()
"""
return torch.ops.python_agnostic.ultra_norm.default(inputs)

View File

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import distutils.command.clean
import shutil
from pathlib import Path
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
ROOT_DIR = Path(__file__).parent
CSRC_DIR = ROOT_DIR / "python_agnostic" / "csrc"
class clean(distutils.command.clean.clean):
def run(self):
# Run default behavior first
distutils.command.clean.clean.run(self)
# Remove extension
for path in (ROOT_DIR / "python_agnostic").glob("**/*.so"):
path.unlink()
# Remove build and dist and egg-info directories
dirs = [
ROOT_DIR / "build",
ROOT_DIR / "dist",
ROOT_DIR / "python_agnostic.egg-info",
]
for path in dirs:
if path.exists():
shutil.rmtree(str(path), ignore_errors=True)
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always"],
}
sources = list(CSRC_DIR.glob("**/*.cu"))
return [
CUDAExtension(
"python_agnostic._C",
sources=sorted(str(s) for s in sources),
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=[],
)
]
setup(
name="python_agnostic",
version="0.0",
author="PyTorch Core Team",
description="Example of python agnostic extension",
ext_modules=get_extension(),
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean,
},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)

View File

@ -1031,18 +1031,23 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
# Build the test cpp extensions modules
shell_env = os.environ.copy()
shell_env["USE_NINJA"] = str(1 if use_ninja else 0)
cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=shell_env)
install_cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
wheel_cmd = [sys.executable, "setup.py", "bdist_wheel"]
return_code = shell(install_cmd, cwd=cpp_extensions_test_dir, env=shell_env)
if return_code != 0:
return return_code
if sys.platform != "win32":
return_code = shell(
cmd,
cwd=os.path.join(cpp_extensions_test_dir, "no_python_abi_suffix_test"),
env=shell_env,
)
if return_code != 0:
return return_code
exts_to_build = [(install_cmd, "no_python_abi_suffix_test")]
if TEST_CUDA:
exts_to_build.append((wheel_cmd, "python_agnostic_extension"))
for cmd, extension_dir in exts_to_build:
return_code = shell(
cmd,
cwd=os.path.join(cpp_extensions_test_dir, extension_dir),
env=shell_env,
)
if return_code != 0:
return return_code
# "install" the test modules and run tests
python_path = os.environ.get("PYTHONPATH", "")

View File

@ -2,8 +2,11 @@
import os
import re
import subprocess
import sys
import unittest
from itertools import repeat
from pathlib import Path
from typing import get_args, get_origin, Union
import torch
@ -13,6 +16,7 @@ import torch.utils.cpp_extension
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import (
IS_WINDOWS,
shell,
skipIfTorchDynamo,
xfailIfTorchDynamo,
)
@ -164,6 +168,48 @@ class TestCppExtensionAOT(common.TestCase):
test = cuda_dlink.add(a, b)
self.assertEqual(test, ref)
@unittest.skipIf(not TEST_CUDA, "python_agnostic is a CUDA extension + needs CUDA")
@unittest.skipIf(not common.IS_LINUX, "test requires linux tools ldd and nm")
def test_python_agnostic(self):
# For this test, run_test.py will call `python setup.py bdist_wheel` in the
# cpp_extensions/python_agnostic_extension folder, where the extension and
# setup calls specify py_limited_api to `True`. To approximate that the
# extension is indeed python agnostic, we test
# a. The extension wheel name contains "cp39-abi3", meaning the wheel
# should be runnable for any Python 3 version after and including 3.9
# b. The produced shared library does not have libtorch_python.so as a
# dependency from the output of "ldd _C.so"
# c. The .so does not need any python related symbols. We approximate
# this by running "nm -u _C.so" and grepping that nothing starts with "Py"
dist_root = os.path.join("cpp_extensions", "python_agnostic_extension", "dist")
matches = list(Path(dist_root).glob("*.whl"))
self.assertEqual(len(matches), 1, msg=str(matches))
whl_file = matches[0]
self.assertRegex(str(whl_file), r".*python_agnostic-0\.0-cp39-abi3-.*\.whl")
build_root = os.path.join(
"cpp_extensions", "python_agnostic_extension", "build"
)
matches = list(Path(build_root).glob("**/*.so"))
self.assertEqual(len(matches), 1, msg=str(matches))
so_file = matches[0]
lddtree = subprocess.check_output(["ldd", so_file]).decode("utf-8")
self.assertFalse("torch_python" in lddtree)
missing_symbols = subprocess.check_output(["nm", "-u", so_file]).decode("utf-8")
self.assertFalse("Py" in missing_symbols)
# finally, clean up the folder
cmd = [sys.executable, "setup.py", "clean"]
return_code = shell(
cmd,
cwd=os.path.join("cpp_extensions", "python_agnostic_extension"),
env=os.environ.copy(),
)
if return_code != 0:
return return_code
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestPybindTypeCasters(common.TestCase):

View File

@ -812,11 +812,11 @@ class BuildExtension(build_ext):
output_dir = os.path.abspath(output_dir)
# Note [Absolute include_dirs]
# Convert relative path in self.compiler.include_dirs to absolute path if any,
# For ninja build, the build location is not local, the build happens
# in a in script created build folder, relative path lost their correctness.
# Convert relative path in self.compiler.include_dirs to absolute path if any.
# For ninja build, the build location is not local, but instead, the build happens
# in a script-created build folder. Thus, relative paths lose their correctness.
# To be consistent with jit extension, we allow user to enter relative include_dirs
# in setuptools.setup, and we convert the relative path to absolute path here
# in setuptools.setup, and we convert the relative path to absolute path here.
convert_to_absolute_paths_inplace(self.compiler.include_dirs)
_, objects, extra_postargs, pp_opts, _ = \
@ -964,6 +964,15 @@ def CppExtension(name, sources, *args, **kwargs):
constructor. Full list arguments can be found at
https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
.. note::
The PyTorch python API (as provided in libtorch_python) cannot be built
with the flag ``py_limited_api=True``. When this flag is passed, it is
the user's responsibility in their library to not use APIs from
libtorch_python (in particular pytorch/python bindings) and to only use
APIs from libtorch (aten objects, operators and the dispatcher). For
example, to give access to custom ops from python, the library should
register the ops through the dispatcher.
Example:
>>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
@ -994,7 +1003,9 @@ def CppExtension(name, sources, *args, **kwargs):
libraries.append('c10')
libraries.append('torch')
libraries.append('torch_cpu')
libraries.append('torch_python')
if not kwargs.get('py_limited_api', False):
# torch_python uses more than the python limited api
libraries.append('torch_python')
if IS_WINDOWS and platform.machine().lower() != "arm64":
libraries.append("sleef")
@ -1017,6 +1028,15 @@ def CUDAExtension(name, sources, *args, **kwargs):
constructor. Full list arguments can be found at
https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
.. note::
The PyTorch python API (as provided in libtorch_python) cannot be built
with the flag ``py_limited_api=True``. When this flag is passed, it is
the user's responsibility in their library to not use APIs from
libtorch_python (in particular pytorch/python bindings) and to only use
APIs from libtorch (aten objects, operators and the dispatcher). For
example, to give access to custom ops from python, the library should
register the ops through the dispatcher.
Example:
>>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
@ -1041,7 +1061,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
By default the extension will be compiled to run on all archs of the cards visible during the
building process of the extension, plus PTX. If down the road a new card is installed the
extension may need to be recompiled. If a visible card has a compute capability (CC) that's
newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch
newer than the newest version for which your nvcc can build fully-compiled binaries, PyTorch
will make nvcc fall back to building kernels with the newest version of PTX your nvcc does
support (see below for details on PTX).
@ -1085,7 +1105,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
An exception to this rule is "dynamic parallelism" (nested kernel launches) which is not used a lot anymore.
`Relocatable device code` is less optimized so it needs to be used only on object files that need it.
Using `-dlto` (Device Link Time Optimization) at the device code compilation step and `dlink` step
help reduce the protentional perf degradation of `-rdc`.
helps reduce the protentional perf degradation of `-rdc`.
Note that it needs to be used at both steps to be useful.
If you have `rdc` objects you need to have an extra `-dlink` (device linking) step before the CPU symbol linking step.
@ -1114,7 +1134,9 @@ def CUDAExtension(name, sources, *args, **kwargs):
libraries.append('c10')
libraries.append('torch')
libraries.append('torch_cpu')
libraries.append('torch_python')
if not kwargs.get('py_limited_api', False):
# torch_python uses more than the python limited api
libraries.append('torch_python')
if IS_HIP_EXTENSION:
libraries.append('amdhip64')
libraries.append('c10_hip')
@ -1381,6 +1403,10 @@ def _get_pybind11_abi_build_flags():
# that can cause a hard to debug segfaults.
# For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties
# captured during PyTorch native library compilation in torch/csrc/Module.cpp
#
# Note that these flags don't have side effects even if the PyTorch extension does not
# require nor use pybind, so we do not do anything differently for them in the py_limited_api
# case.
abi_cflags = []
for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: