[functorch] disable C++ Function under functorch transforms (#103957)

Fixes https://github.com/pytorch/pytorch/issues/102720

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103957
Approved by: https://github.com/zou3519
This commit is contained in:
kshitij12345
2023-06-23 11:01:40 +00:00
committed by PyTorch MergeBot
parent ec24f1e4cc
commit 47894bb165
5 changed files with 53 additions and 0 deletions

View File

@ -29,6 +29,7 @@ struct TORCH_API FuncTorchTLSBase {
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0; virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0; virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
virtual void checkSupportsCppAutogradFunction() const = 0;
virtual void checkSupportsInplaceRequiresGrad() const = 0; virtual void checkSupportsInplaceRequiresGrad() const = 0;
virtual void checkSupportsRetainGrad() const = 0; virtual void checkSupportsRetainGrad() const = 0;
}; };

View File

@ -99,6 +99,12 @@ class FuncTorchTLS : public FuncTorchTLSBase {
return 0; return 0;
} }
void checkSupportsCppAutogradFunction() const override {
TORCH_CHECK(
dynamicLayerStack.empty(),
"cannot use C++ torch::autograd::Function with functorch transforms (vmap, grad, vjp, etc)");
}
void checkSupportsInplaceRequiresGrad() const override { void checkSupportsInplaceRequiresGrad() const override {
TORCH_CHECK(dynamicLayerStack.empty() || allow_inplace_requires_grad_, TORCH_CHECK(dynamicLayerStack.empty() || allow_inplace_requires_grad_,
"You are attempting to call Tensor.requires_grad_() (or perhaps using ", "You are attempting to call Tensor.requires_grad_() (or perhaps using ",

View File

@ -0,0 +1,23 @@
#include <torch/extension.h>
#include <torch/torch.h>
using namespace torch::autograd;
class Identity : public Function<Identity> {
public:
static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input) {
return input;
}
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
return {grad_outputs[0]};
}
};
torch::Tensor identity(torch::Tensor input) {
return Identity::apply(input);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("identity", &identity, "identity");
}

View File

@ -910,6 +910,21 @@ class TestCppExtensionJIT(common.TestCase):
for fast_mode in (True, False): for fast_mode in (True, False):
gradcheck(torch.ops.my.add, [a, b], eps=1e-2, fast_mode=fast_mode) gradcheck(torch.ops.my.add, [a, b], eps=1e-2, fast_mode=fast_mode)
def test_custom_functorch_error(self):
# Test that a custom C++ Function raises an error under functorch transforms
identity_m = torch.utils.cpp_extension.load(
name="identity",
sources=["cpp_extensions/identity.cpp"],
)
t = torch.randn(3, requires_grad=True)
msg = r"cannot use C\+\+ torch::autograd::Function with functorch"
with self.assertRaisesRegex(RuntimeError, msg):
torch.func.vmap(identity_m.identity)(t)
with self.assertRaisesRegex(RuntimeError, msg):
torch.func.grad(identity_m.identity)(t)
if __name__ == "__main__": if __name__ == "__main__":
common.run_tests() common.run_tests()

View File

@ -263,6 +263,14 @@ template <class T>
template <typename X, typename... Args> template <typename X, typename... Args>
auto Function<T>::apply(Args&&... args) auto Function<T>::apply(Args&&... args)
-> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>> { -> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>> {
const auto& functorch_tls = at::functorch::functorchTLSAccessor();
if (functorch_tls) {
// Function support for functorch is handled in Python.
// Here we are dealing with a (C++) Function, which is not supported.
// Let's raise an error instead of being silently incorrect.
functorch_tls->checkSupportsCppAutogradFunction();
}
std::shared_ptr<CppNode<T>> node(new CppNode<T>(), deleteNode); std::shared_ptr<CppNode<T>> node(new CppNode<T>(), deleteNode);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables) // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
variable_list input_vars; variable_list input_vars;