mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use TORCH_EXTENSION_NAME macro to avoid mismatched module/extension name (#5277)
* Warn users about mismatched module/extension name * Define TORCH_EXTENSION_NAME macro
This commit is contained in:
committed by
Soumith Chintala
parent
5c93ca258b
commit
22fe542b8e
@ -1,11 +1,20 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include "doubler.h"
|
||||
|
||||
using namespace at;
|
||||
|
||||
Tensor exp_add(Tensor x, Tensor y);
|
||||
|
||||
Tensor tanh_add(Tensor x, Tensor y) {
|
||||
return x.tanh() + y.tanh();
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(jit_extension, m) {
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("tanh_add", &tanh_add, "tanh(x) + tanh(y)");
|
||||
m.def("exp_add", &exp_add, "exp(x) + exp(y)");
|
||||
py::class_<Doubler>(m, "Doubler")
|
||||
.def(py::init<int, int>())
|
||||
.def("forward", &Doubler::forward)
|
||||
.def("get", &Doubler::get);
|
||||
}
|
||||
|
Reference in New Issue
Block a user