mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Scoped extension building for C++ backed custom ops tests (#136695)
FIXES #125579 #131103 #133197 #133283 #134738 #135369 #135685 Tests that create C++ extensions can cause flakiness in CI due to library namespace conflict and test ordering. We can build them in temp dirs to ensure isolation. An alternative is to build these as part of the build process and have build time errors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136695 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
10e2840ce3
commit
99608ceed6
@ -79,6 +79,7 @@ from torch.testing._internal.common_methods_invocations import (
|
||||
from torch.testing._internal.common_utils import (
|
||||
freeze_rng_state,
|
||||
IS_FBCODE,
|
||||
scoped_load_inline,
|
||||
set_default_dtype,
|
||||
skipIfNNModuleInlined,
|
||||
skipIfWindows,
|
||||
@ -321,16 +322,17 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
res_compiled = add_fn(2, 3, torch.tensor(0.0))
|
||||
self.assertEqual(res, res_compiled)
|
||||
|
||||
@scoped_load_inline
|
||||
@skipIfNNModuleInlined("fails internal CI")
|
||||
@unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode")
|
||||
def test_cpp_extension_recommends_custom_ops(self):
|
||||
def test_cpp_extension_recommends_custom_ops(self, load_inline):
|
||||
cpp_source = """
|
||||
#include <torch/extension.h>
|
||||
at::Tensor foobar(const at::Tensor& x) {
|
||||
return x.clone();
|
||||
}
|
||||
"""
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
module = load_inline(
|
||||
name="mylib",
|
||||
cpp_sources=cpp_source,
|
||||
functions="foobar",
|
||||
@ -362,7 +364,7 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
return x.clone();
|
||||
}
|
||||
"""
|
||||
module2 = torch.utils.cpp_extension.load_inline(
|
||||
module2 = load_inline(
|
||||
name="mylib2",
|
||||
cpp_sources=cpp_source,
|
||||
functions="baz",
|
||||
|
@ -26,7 +26,7 @@ from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import skipIfWindows
|
||||
from torch.testing._internal.common_utils import scoped_load_inline, skipIfWindows
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU
|
||||
from torch.testing._internal.logging_utils import logs_to_string
|
||||
|
||||
@ -1586,7 +1586,8 @@ main()
|
||||
f, compiler_fn=compiler_fn_with_op_check, compile_fn=False
|
||||
)
|
||||
|
||||
def test_non_traceable_autograd_cpp_node(self):
|
||||
@scoped_load_inline
|
||||
def test_non_traceable_autograd_cpp_node(self, load_inline):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = false;
|
||||
@ -1613,7 +1614,7 @@ TORCH_LIBRARY(test_non_traceable_autograd_cpp_node, m) {
|
||||
}
|
||||
"""
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
module = load_inline(
|
||||
name="test_non_traceable_autograd_cpp_node",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
@ -1634,8 +1635,8 @@ TORCH_LIBRARY(test_non_traceable_autograd_cpp_node, m) {
|
||||
), compiled_autograd.enable(compiler_fn):
|
||||
fn()
|
||||
|
||||
@unittest.skip("Flaky, cache from test ordering affects test. #135369")
|
||||
def test_autograd_cpp_node(self):
|
||||
@scoped_load_inline
|
||||
def test_autograd_cpp_node(self, load_inline):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = true;
|
||||
@ -1662,7 +1663,7 @@ TORCH_LIBRARY(test_autograd_cpp_node, m) {
|
||||
}
|
||||
"""
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
module = load_inline(
|
||||
name="test_autograd_cpp_node",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
@ -1682,7 +1683,8 @@ TORCH_LIBRARY(test_autograd_cpp_node, m) {
|
||||
# compiles for 10 (static) and 100 (dynamic)
|
||||
self.check_output_and_recompiles(fn, 2)
|
||||
|
||||
def test_autograd_cpp_node_id(self):
|
||||
@scoped_load_inline
|
||||
def test_autograd_cpp_node_id(self, load_inline):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = true;
|
||||
@ -1730,7 +1732,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_id, m) {
|
||||
}
|
||||
"""
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
module = load_inline(
|
||||
name="test_autograd_cpp_node_id",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
@ -1773,7 +1775,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_id, m) {
|
||||
|
||||
self.check_output_and_recompiles(different_autograd_fn, 2)
|
||||
|
||||
def test_autograd_cpp_node_saved(self):
|
||||
@scoped_load_inline
|
||||
def test_autograd_cpp_node_saved(self, load_inline):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = true;
|
||||
@ -1827,7 +1830,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved, m) {
|
||||
}
|
||||
"""
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
module = load_inline(
|
||||
name="test_autograd_cpp_node_saved",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
@ -1848,7 +1851,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved, m) {
|
||||
|
||||
self.check_output_and_recompiles(fn, 2)
|
||||
|
||||
def test_autograd_cpp_node_saved_dynamic(self):
|
||||
@scoped_load_inline
|
||||
def test_autograd_cpp_node_saved_dynamic(self, load_inline):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = true;
|
||||
@ -1884,7 +1888,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_dynamic, m) {
|
||||
}
|
||||
"""
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
module = load_inline(
|
||||
name="test_autograd_cpp_node_saved_dynamic",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
@ -1904,7 +1908,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_dynamic, m) {
|
||||
# compiles for 10 (static) and 100 (dynamic)
|
||||
self.check_output_and_recompiles(fn, 2)
|
||||
|
||||
def test_autograd_cpp_node_saved_int(self):
|
||||
@scoped_load_inline
|
||||
def test_autograd_cpp_node_saved_int(self, load_inline):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = true;
|
||||
@ -1943,7 +1948,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_int, m) {
|
||||
}
|
||||
"""
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
module = load_inline(
|
||||
name="test_autograd_cpp_node_saved_int",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
@ -1962,7 +1967,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_int, m) {
|
||||
|
||||
self.check_output_and_recompiles(fn, 1)
|
||||
|
||||
def test_autograd_cpp_node_saved_float(self):
|
||||
@scoped_load_inline
|
||||
def test_autograd_cpp_node_saved_float(self, load_inline):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = true;
|
||||
@ -2001,7 +2007,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_float, m) {
|
||||
}
|
||||
"""
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
module = load_inline(
|
||||
name="test_autograd_cpp_node_saved_float",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
@ -2021,7 +2027,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_float, m) {
|
||||
# compiled autograd and dynamo both support symfloat, but not backend
|
||||
self.check_output_and_recompiles(fn, [1, 3])
|
||||
|
||||
def test_autograd_cpp_node_data_dependent(self):
|
||||
@scoped_load_inline
|
||||
def test_autograd_cpp_node_data_dependent(self, load_inline):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = true;
|
||||
@ -2092,7 +2099,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_data_dependent, m) {
|
||||
}
|
||||
"""
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
module = load_inline(
|
||||
name="test_autograd_cpp_node_data_dependent",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
@ -2332,8 +2339,9 @@ main()
|
||||
# Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
|
||||
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
||||
|
||||
@scoped_load_inline
|
||||
@unittest.skipIf(not HAS_CUDA, "requires cuda")
|
||||
def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self):
|
||||
def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = true;
|
||||
@ -2371,7 +2379,7 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
|
||||
}
|
||||
"""
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
module = load_inline(
|
||||
name="test_cudagraphs_cpu_scalar_used_in_cpp_custom_op",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
|
@ -69,6 +69,7 @@ from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
run_tests,
|
||||
scoped_load_inline,
|
||||
set_warn_always_context,
|
||||
skipIfMps,
|
||||
skipIfNoLapack,
|
||||
@ -85,7 +86,6 @@ from torch.utils.checkpoint import (
|
||||
CheckpointPolicy,
|
||||
create_selective_checkpoint_contexts,
|
||||
)
|
||||
from torch.utils.cpp_extension import load_inline
|
||||
from torch.utils.flop_counter import FlopCounterMode
|
||||
|
||||
|
||||
@ -9854,7 +9854,8 @@ for shape in [(1,), ()]:
|
||||
out = x * y
|
||||
out.sum().backward()
|
||||
|
||||
def test_multi_grad_all_hooks(self):
|
||||
@scoped_load_inline
|
||||
def test_multi_grad_all_hooks(self, load_inline):
|
||||
t1 = torch.rand(2, requires_grad=True)
|
||||
t2 = torch.rand(2, requires_grad=True)
|
||||
t3 = torch.rand(2, requires_grad=True)
|
||||
@ -9899,19 +9900,19 @@ torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
|
||||
return CustomOpAutogradFunction::apply(x);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(test_autograd_cpp_node, m) {
|
||||
TORCH_LIBRARY(test_multigrad_all_hooks, m) {
|
||||
m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
|
||||
}
|
||||
"""
|
||||
|
||||
module = load_inline(
|
||||
name="test_autograd_cpp_node",
|
||||
name="test_multigrad_all_hooks",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
t4 = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn(t4)
|
||||
t4 = torch.ops.test_multigrad_all_hooks.custom_op_backed_by_autograd_fn(t4)
|
||||
|
||||
res = [None] * 4
|
||||
count = [0]
|
||||
|
@ -33,6 +33,7 @@ from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
run_tests,
|
||||
scoped_load_inline,
|
||||
skipIfTorchDynamo,
|
||||
subtest,
|
||||
TestCase,
|
||||
@ -2088,7 +2089,8 @@ dynamic shape operator: _torch_testing.numpy_nonzero.default
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"):
|
||||
torch.library.impl("blah::blah", "somethingsomething")
|
||||
|
||||
def test_autograd_function_backed_op(self):
|
||||
@scoped_load_inline
|
||||
def test_autograd_function_backed_op(self, load_inline):
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static constexpr bool is_traceable = true;
|
||||
@ -2110,13 +2112,13 @@ torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) {
|
||||
return CustomOpAutogradFunction::apply(x);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(mylib, m) {
|
||||
TORCH_LIBRARY(test_autograd_function_backed_op, m) {
|
||||
m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
|
||||
}
|
||||
"""
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
name="mylib",
|
||||
module = load_inline(
|
||||
name="test_autograd_function_backed_op",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
verbose=True,
|
||||
@ -2124,7 +2126,11 @@ TORCH_LIBRARY(mylib, m) {
|
||||
|
||||
x = torch.ones(2, 2, requires_grad=True)
|
||||
temp = x.clone().detach()
|
||||
out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x)
|
||||
out = (
|
||||
torch.ops.test_autograd_function_backed_op.custom_op_backed_by_autograd_fn(
|
||||
x
|
||||
)
|
||||
)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
self.assertEqual(x.grad, temp)
|
||||
|
@ -98,6 +98,7 @@ from torch.testing._comparison import not_close_error_metas
|
||||
from torch.testing._internal.common_dtype import get_all_dtypes
|
||||
from torch.utils._import_utils import _check_module_exists
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.utils import cpp_extension
|
||||
try:
|
||||
import pytest
|
||||
has_pytest = True
|
||||
@ -5379,7 +5380,7 @@ def remove_cpp_extensions_build_root():
|
||||
"""
|
||||
Removes the default root folder under which extensions are built.
|
||||
"""
|
||||
default_build_root = torch.utils.cpp_extension.get_default_build_root()
|
||||
default_build_root = cpp_extension.get_default_build_root()
|
||||
if os.path.exists(default_build_root):
|
||||
if IS_WINDOWS:
|
||||
# rmtree returns permission error: [WinError 5] Access is denied
|
||||
@ -5387,3 +5388,24 @@ def remove_cpp_extensions_build_root():
|
||||
subprocess.run(["rm", "-rf", default_build_root], stdout=subprocess.PIPE)
|
||||
else:
|
||||
shutil.rmtree(default_build_root, ignore_errors=True)
|
||||
|
||||
# Decorator to provide a helper to load inline extensions to a temp directory
|
||||
def scoped_load_inline(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
def load_inline(*args, **kwargs):
|
||||
if IS_WINDOWS:
|
||||
# TODO(xmfan): even using TemporaryDirectoryName will result in permission error
|
||||
return cpp_extension.load_inline(*args, **kwargs)
|
||||
|
||||
assert "build_directory" not in kwargs
|
||||
with TemporaryDirectoryName() as temp_dir_name:
|
||||
if kwargs.get("verbose", False):
|
||||
print(f'Using temporary extension directory {temp_dir_name}...', file=sys.stderr)
|
||||
kwargs["build_directory"] = temp_dir_name
|
||||
return cpp_extension.load_inline(*args, **kwargs)
|
||||
|
||||
return func(*args, load_inline=load_inline, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
Reference in New Issue
Block a user