mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] disable C++ Function under functorch transforms (#103957)
Fixes https://github.com/pytorch/pytorch/issues/102720 Pull Request resolved: https://github.com/pytorch/pytorch/pull/103957 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
ec24f1e4cc
commit
47894bb165
@ -29,6 +29,7 @@ struct TORCH_API FuncTorchTLSBase {
|
||||
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
|
||||
|
||||
virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
|
||||
virtual void checkSupportsCppAutogradFunction() const = 0;
|
||||
virtual void checkSupportsInplaceRequiresGrad() const = 0;
|
||||
virtual void checkSupportsRetainGrad() const = 0;
|
||||
};
|
||||
|
@ -99,6 +99,12 @@ class FuncTorchTLS : public FuncTorchTLSBase {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void checkSupportsCppAutogradFunction() const override {
|
||||
TORCH_CHECK(
|
||||
dynamicLayerStack.empty(),
|
||||
"cannot use C++ torch::autograd::Function with functorch transforms (vmap, grad, vjp, etc)");
|
||||
}
|
||||
|
||||
void checkSupportsInplaceRequiresGrad() const override {
|
||||
TORCH_CHECK(dynamicLayerStack.empty() || allow_inplace_requires_grad_,
|
||||
"You are attempting to call Tensor.requires_grad_() (or perhaps using ",
|
||||
|
23
test/cpp_extensions/identity.cpp
Normal file
23
test/cpp_extensions/identity.cpp
Normal file
@ -0,0 +1,23 @@
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
using namespace torch::autograd;
|
||||
|
||||
class Identity : public Function<Identity> {
|
||||
public:
|
||||
static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input) {
|
||||
return input;
|
||||
}
|
||||
|
||||
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
|
||||
return {grad_outputs[0]};
|
||||
}
|
||||
};
|
||||
|
||||
torch::Tensor identity(torch::Tensor input) {
|
||||
return Identity::apply(input);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("identity", &identity, "identity");
|
||||
}
|
@ -910,6 +910,21 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
for fast_mode in (True, False):
|
||||
gradcheck(torch.ops.my.add, [a, b], eps=1e-2, fast_mode=fast_mode)
|
||||
|
||||
def test_custom_functorch_error(self):
|
||||
# Test that a custom C++ Function raises an error under functorch transforms
|
||||
identity_m = torch.utils.cpp_extension.load(
|
||||
name="identity",
|
||||
sources=["cpp_extensions/identity.cpp"],
|
||||
)
|
||||
|
||||
t = torch.randn(3, requires_grad=True)
|
||||
|
||||
msg = r"cannot use C\+\+ torch::autograd::Function with functorch"
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
torch.func.vmap(identity_m.identity)(t)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
torch.func.grad(identity_m.identity)(t)
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
@ -263,6 +263,14 @@ template <class T>
|
||||
template <typename X, typename... Args>
|
||||
auto Function<T>::apply(Args&&... args)
|
||||
-> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>> {
|
||||
const auto& functorch_tls = at::functorch::functorchTLSAccessor();
|
||||
if (functorch_tls) {
|
||||
// Function support for functorch is handled in Python.
|
||||
// Here we are dealing with a (C++) Function, which is not supported.
|
||||
// Let's raise an error instead of being silently incorrect.
|
||||
functorch_tls->checkSupportsCppAutogradFunction();
|
||||
}
|
||||
|
||||
std::shared_ptr<CppNode<T>> node(new CppNode<T>(), deleteNode);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
variable_list input_vars;
|
||||
|
Reference in New Issue
Block a user