From 47894bb16594fc4bd6045d739fba6e63bdf793a8 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 23 Jun 2023 11:01:40 +0000 Subject: [PATCH] [functorch] disable C++ Function under functorch transforms (#103957) Fixes https://github.com/pytorch/pytorch/issues/102720 Pull Request resolved: https://github.com/pytorch/pytorch/pull/103957 Approved by: https://github.com/zou3519 --- aten/src/ATen/FuncTorchTLS.h | 1 + aten/src/ATen/functorch/DynamicLayer.cpp | 6 ++++++ test/cpp_extensions/identity.cpp | 23 +++++++++++++++++++++++ test/test_cpp_extensions_jit.py | 15 +++++++++++++++ torch/csrc/autograd/custom_function.h | 8 ++++++++ 5 files changed, 53 insertions(+) create mode 100644 test/cpp_extensions/identity.cpp diff --git a/aten/src/ATen/FuncTorchTLS.h b/aten/src/ATen/FuncTorchTLS.h index b8fde728fad2..3f33709e89ba 100644 --- a/aten/src/ATen/FuncTorchTLS.h +++ b/aten/src/ATen/FuncTorchTLS.h @@ -29,6 +29,7 @@ struct TORCH_API FuncTorchTLSBase { virtual std::unique_ptr deepcopy() const = 0; virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0; + virtual void checkSupportsCppAutogradFunction() const = 0; virtual void checkSupportsInplaceRequiresGrad() const = 0; virtual void checkSupportsRetainGrad() const = 0; }; diff --git a/aten/src/ATen/functorch/DynamicLayer.cpp b/aten/src/ATen/functorch/DynamicLayer.cpp index e2782e7945fa..2d271a613340 100644 --- a/aten/src/ATen/functorch/DynamicLayer.cpp +++ b/aten/src/ATen/functorch/DynamicLayer.cpp @@ -99,6 +99,12 @@ class FuncTorchTLS : public FuncTorchTLSBase { return 0; } + void checkSupportsCppAutogradFunction() const override { + TORCH_CHECK( + dynamicLayerStack.empty(), + "cannot use C++ torch::autograd::Function with functorch transforms (vmap, grad, vjp, etc)"); + } + void checkSupportsInplaceRequiresGrad() const override { TORCH_CHECK(dynamicLayerStack.empty() || allow_inplace_requires_grad_, "You are attempting to call Tensor.requires_grad_() (or perhaps using ", diff --git a/test/cpp_extensions/identity.cpp b/test/cpp_extensions/identity.cpp new file mode 100644 index 000000000000..ebde67762e7b --- /dev/null +++ b/test/cpp_extensions/identity.cpp @@ -0,0 +1,23 @@ +#include +#include + +using namespace torch::autograd; + +class Identity : public Function { + public: + static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input) { + return input; + } + + static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) { + return {grad_outputs[0]}; + } +}; + +torch::Tensor identity(torch::Tensor input) { + return Identity::apply(input); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("identity", &identity, "identity"); +} diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index 7436182f97bc..65ff29924fc4 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -910,6 +910,21 @@ class TestCppExtensionJIT(common.TestCase): for fast_mode in (True, False): gradcheck(torch.ops.my.add, [a, b], eps=1e-2, fast_mode=fast_mode) + def test_custom_functorch_error(self): + # Test that a custom C++ Function raises an error under functorch transforms + identity_m = torch.utils.cpp_extension.load( + name="identity", + sources=["cpp_extensions/identity.cpp"], + ) + + t = torch.randn(3, requires_grad=True) + + msg = r"cannot use C\+\+ torch::autograd::Function with functorch" + with self.assertRaisesRegex(RuntimeError, msg): + torch.func.vmap(identity_m.identity)(t) + + with self.assertRaisesRegex(RuntimeError, msg): + torch.func.grad(identity_m.identity)(t) if __name__ == "__main__": common.run_tests() diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 98a53b6f4ebc..8d49f7cfacae 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -263,6 +263,14 @@ template template auto Function::apply(Args&&... args) -> std::enable_if_t::value, forward_t> { + const auto& functorch_tls = at::functorch::functorchTLSAccessor(); + if (functorch_tls) { + // Function support for functorch is handled in Python. + // Here we are dealing with a (C++) Function, which is not supported. + // Let's raise an error instead of being silently incorrect. + functorch_tls->checkSupportsCppAutogradFunction(); + } + std::shared_ptr> node(new CppNode(), deleteNode); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) variable_list input_vars;