mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes https://github.com/pytorch/pytorch/issues/102720 Pull Request resolved: https://github.com/pytorch/pytorch/pull/103957 Approved by: https://github.com/zou3519
24 lines
531 B
C++
24 lines
531 B
C++
#include <torch/extension.h>
|
|
#include <torch/torch.h>
|
|
|
|
using namespace torch::autograd;
|
|
|
|
class Identity : public Function<Identity> {
|
|
public:
|
|
static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input) {
|
|
return input;
|
|
}
|
|
|
|
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
|
|
return {grad_outputs[0]};
|
|
}
|
|
};
|
|
|
|
torch::Tensor identity(torch::Tensor input) {
|
|
return Identity::apply(input);
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("identity", &identity, "identity");
|
|
}
|