Scoped extension building for C++ backed custom ops tests (#136695)

FIXES #125579 #131103 #133197 #133283 #134738 #135369 #135685

Tests that create C++ extensions can cause flakiness in CI due to library namespace conflict and test ordering. We can build them in temp dirs to ensure isolation.

An alternative is to build these as part of the build process and have build time errors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136695
Approved by: https://github.com/zou3519
This commit is contained in:
Simon Fan
2024-10-25 13:19:47 -07:00
committed by PyTorch MergeBot
parent 10e2840ce3
commit 99608ceed6
5 changed files with 73 additions and 34 deletions

View File

@ -79,6 +79,7 @@ from torch.testing._internal.common_methods_invocations import (
from torch.testing._internal.common_utils import (
freeze_rng_state,
IS_FBCODE,
scoped_load_inline,
set_default_dtype,
skipIfNNModuleInlined,
skipIfWindows,
@ -321,16 +322,17 @@ class MiscTests(torch._inductor.test_case.TestCase):
res_compiled = add_fn(2, 3, torch.tensor(0.0))
self.assertEqual(res, res_compiled)
@scoped_load_inline
@skipIfNNModuleInlined("fails internal CI")
@unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode")
def test_cpp_extension_recommends_custom_ops(self):
def test_cpp_extension_recommends_custom_ops(self, load_inline):
cpp_source = """
#include <torch/extension.h>
at::Tensor foobar(const at::Tensor& x) {
return x.clone();
}
"""
module = torch.utils.cpp_extension.load_inline(
module = load_inline(
name="mylib",
cpp_sources=cpp_source,
functions="foobar",
@ -362,7 +364,7 @@ class MiscTests(torch._inductor.test_case.TestCase):
return x.clone();
}
"""
module2 = torch.utils.cpp_extension.load_inline(
module2 = load_inline(
name="mylib2",
cpp_sources=cpp_source,
functions="baz",

View File

@ -26,7 +26,7 @@ from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import skipIfWindows
from torch.testing._internal.common_utils import scoped_load_inline, skipIfWindows
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU
from torch.testing._internal.logging_utils import logs_to_string
@ -1586,7 +1586,8 @@ main()
f, compiler_fn=compiler_fn_with_op_check, compile_fn=False
)
def test_non_traceable_autograd_cpp_node(self):
@scoped_load_inline
def test_non_traceable_autograd_cpp_node(self, load_inline):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static constexpr bool is_traceable = false;
@ -1613,7 +1614,7 @@ TORCH_LIBRARY(test_non_traceable_autograd_cpp_node, m) {
}
"""
module = torch.utils.cpp_extension.load_inline(
module = load_inline(
name="test_non_traceable_autograd_cpp_node",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
@ -1634,8 +1635,8 @@ TORCH_LIBRARY(test_non_traceable_autograd_cpp_node, m) {
), compiled_autograd.enable(compiler_fn):
fn()
@unittest.skip("Flaky, cache from test ordering affects test. #135369")
def test_autograd_cpp_node(self):
@scoped_load_inline
def test_autograd_cpp_node(self, load_inline):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static constexpr bool is_traceable = true;
@ -1662,7 +1663,7 @@ TORCH_LIBRARY(test_autograd_cpp_node, m) {
}
"""
module = torch.utils.cpp_extension.load_inline(
module = load_inline(
name="test_autograd_cpp_node",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
@ -1682,7 +1683,8 @@ TORCH_LIBRARY(test_autograd_cpp_node, m) {
# compiles for 10 (static) and 100 (dynamic)
self.check_output_and_recompiles(fn, 2)
def test_autograd_cpp_node_id(self):
@scoped_load_inline
def test_autograd_cpp_node_id(self, load_inline):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static constexpr bool is_traceable = true;
@ -1730,7 +1732,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_id, m) {
}
"""
module = torch.utils.cpp_extension.load_inline(
module = load_inline(
name="test_autograd_cpp_node_id",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
@ -1773,7 +1775,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_id, m) {
self.check_output_and_recompiles(different_autograd_fn, 2)
def test_autograd_cpp_node_saved(self):
@scoped_load_inline
def test_autograd_cpp_node_saved(self, load_inline):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static constexpr bool is_traceable = true;
@ -1827,7 +1830,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved, m) {
}
"""
module = torch.utils.cpp_extension.load_inline(
module = load_inline(
name="test_autograd_cpp_node_saved",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
@ -1848,7 +1851,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved, m) {
self.check_output_and_recompiles(fn, 2)
def test_autograd_cpp_node_saved_dynamic(self):
@scoped_load_inline
def test_autograd_cpp_node_saved_dynamic(self, load_inline):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static constexpr bool is_traceable = true;
@ -1884,7 +1888,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_dynamic, m) {
}
"""
module = torch.utils.cpp_extension.load_inline(
module = load_inline(
name="test_autograd_cpp_node_saved_dynamic",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
@ -1904,7 +1908,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_dynamic, m) {
# compiles for 10 (static) and 100 (dynamic)
self.check_output_and_recompiles(fn, 2)
def test_autograd_cpp_node_saved_int(self):
@scoped_load_inline
def test_autograd_cpp_node_saved_int(self, load_inline):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static constexpr bool is_traceable = true;
@ -1943,7 +1948,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_int, m) {
}
"""
module = torch.utils.cpp_extension.load_inline(
module = load_inline(
name="test_autograd_cpp_node_saved_int",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
@ -1962,7 +1967,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_int, m) {
self.check_output_and_recompiles(fn, 1)
def test_autograd_cpp_node_saved_float(self):
@scoped_load_inline
def test_autograd_cpp_node_saved_float(self, load_inline):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static constexpr bool is_traceable = true;
@ -2001,7 +2007,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_float, m) {
}
"""
module = torch.utils.cpp_extension.load_inline(
module = load_inline(
name="test_autograd_cpp_node_saved_float",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
@ -2021,7 +2027,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_float, m) {
# compiled autograd and dynamo both support symfloat, but not backend
self.check_output_and_recompiles(fn, [1, 3])
def test_autograd_cpp_node_data_dependent(self):
@scoped_load_inline
def test_autograd_cpp_node_data_dependent(self, load_inline):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static constexpr bool is_traceable = true;
@ -2092,7 +2099,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_data_dependent, m) {
}
"""
module = torch.utils.cpp_extension.load_inline(
module = load_inline(
name="test_autograd_cpp_node_data_dependent",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
@ -2332,8 +2339,9 @@ main()
# Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@scoped_load_inline
@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self):
def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static constexpr bool is_traceable = true;
@ -2371,7 +2379,7 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
}
"""
module = torch.utils.cpp_extension.load_inline(
module = load_inline(
name="test_cudagraphs_cpu_scalar_used_in_cpp_custom_op",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",

View File

@ -69,6 +69,7 @@ from torch.testing._internal.common_utils import (
IS_WINDOWS,
parametrize,
run_tests,
scoped_load_inline,
set_warn_always_context,
skipIfMps,
skipIfNoLapack,
@ -85,7 +86,6 @@ from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts,
)
from torch.utils.cpp_extension import load_inline
from torch.utils.flop_counter import FlopCounterMode
@ -9854,7 +9854,8 @@ for shape in [(1,), ()]:
out = x * y
out.sum().backward()
def test_multi_grad_all_hooks(self):
@scoped_load_inline
def test_multi_grad_all_hooks(self, load_inline):
t1 = torch.rand(2, requires_grad=True)
t2 = torch.rand(2, requires_grad=True)
t3 = torch.rand(2, requires_grad=True)
@ -9899,19 +9900,19 @@ torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
return CustomOpAutogradFunction::apply(x);
}
TORCH_LIBRARY(test_autograd_cpp_node, m) {
TORCH_LIBRARY(test_multigrad_all_hooks, m) {
m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
}
"""
module = load_inline(
name="test_autograd_cpp_node",
name="test_multigrad_all_hooks",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
verbose=True,
)
t4 = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn(t4)
t4 = torch.ops.test_multigrad_all_hooks.custom_op_backed_by_autograd_fn(t4)
res = [None] * 4
count = [0]

View File

@ -33,6 +33,7 @@ from torch.testing._internal.common_utils import (
IS_WINDOWS,
parametrize,
run_tests,
scoped_load_inline,
skipIfTorchDynamo,
subtest,
TestCase,
@ -2088,7 +2089,8 @@ dynamic shape operator: _torch_testing.numpy_nonzero.default
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"):
torch.library.impl("blah::blah", "somethingsomething")
def test_autograd_function_backed_op(self):
@scoped_load_inline
def test_autograd_function_backed_op(self, load_inline):
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static constexpr bool is_traceable = true;
@ -2110,13 +2112,13 @@ torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) {
return CustomOpAutogradFunction::apply(x);
}
TORCH_LIBRARY(mylib, m) {
TORCH_LIBRARY(test_autograd_function_backed_op, m) {
m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
}
"""
module = torch.utils.cpp_extension.load_inline(
name="mylib",
module = load_inline(
name="test_autograd_function_backed_op",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
verbose=True,
@ -2124,7 +2126,11 @@ TORCH_LIBRARY(mylib, m) {
x = torch.ones(2, 2, requires_grad=True)
temp = x.clone().detach()
out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x)
out = (
torch.ops.test_autograd_function_backed_op.custom_op_backed_by_autograd_fn(
x
)
)
loss = out.sum()
loss.backward()
self.assertEqual(x.grad, temp)

View File

@ -98,6 +98,7 @@ from torch.testing._comparison import not_close_error_metas
from torch.testing._internal.common_dtype import get_all_dtypes
from torch.utils._import_utils import _check_module_exists
import torch.utils._pytree as pytree
from torch.utils import cpp_extension
try:
import pytest
has_pytest = True
@ -5379,7 +5380,7 @@ def remove_cpp_extensions_build_root():
"""
Removes the default root folder under which extensions are built.
"""
default_build_root = torch.utils.cpp_extension.get_default_build_root()
default_build_root = cpp_extension.get_default_build_root()
if os.path.exists(default_build_root):
if IS_WINDOWS:
# rmtree returns permission error: [WinError 5] Access is denied
@ -5387,3 +5388,24 @@ def remove_cpp_extensions_build_root():
subprocess.run(["rm", "-rf", default_build_root], stdout=subprocess.PIPE)
else:
shutil.rmtree(default_build_root, ignore_errors=True)
# Decorator to provide a helper to load inline extensions to a temp directory
def scoped_load_inline(func):
@wraps(func)
def wrapper(*args, **kwargs):
def load_inline(*args, **kwargs):
if IS_WINDOWS:
# TODO(xmfan): even using TemporaryDirectoryName will result in permission error
return cpp_extension.load_inline(*args, **kwargs)
assert "build_directory" not in kwargs
with TemporaryDirectoryName() as temp_dir_name:
if kwargs.get("verbose", False):
print(f'Using temporary extension directory {temp_dir_name}...', file=sys.stderr)
kwargs["build_directory"] = temp_dir_name
return cpp_extension.load_inline(*args, **kwargs)
return func(*args, load_inline=load_inline, **kwargs)
return wrapper