mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow both Variables and Tensors in c10 kernel interface (#20816)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20816 Previously, the c10 dispatcher expected ops to be called with Variables and unwrapped them to Tensors before calling into the kernel. The kernel was expected to return Tensors that were re-wrapped into Variables before passing them on into the system. However, that doesn't work with kernels that call other operators. One recent example was a kernel that returned the result of `torch::ones()` as output. Now, with this diff, the c10 dispatcher still passes Tensors to the kernel and Variables back into the system, but it accepts ops to be called with both Tensors or Variables and kernels are also allowed to return either. After https://github.com/pytorch/pytorch/pull/17072 , we should be able to get rid of the whole wrapping/unwrapping logic. Reviewed By: hl475 Differential Revision: D15453963 fbshipit-source-id: 7602b7f2bc43e8ceb8a8c0e97aafcc53d4c47b6c
This commit is contained in:
committed by
Facebook Github Bot
parent
9ea009fe8b
commit
98928f4d79
@ -11,7 +11,11 @@ at::Tensor unwrap_tensor(at::Tensor&& tensor) {
|
||||
if (tensor.requires_grad()) {
|
||||
throw std::runtime_error("Autograd not yet supported for c10 ops.");
|
||||
}
|
||||
return torch::autograd::Variable(std::move(tensor)).data();
|
||||
if (tensor.is_variable()) {
|
||||
return torch::autograd::Variable(std::move(tensor)).data();
|
||||
} else {
|
||||
return std::move(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
IValue unwrap(IValue&& ivalue) {
|
||||
@ -39,7 +43,11 @@ IValue unwrap(IValue&& ivalue) {
|
||||
}
|
||||
|
||||
at::Tensor wrap_tensor(at::Tensor&& tensor) {
|
||||
return torch::autograd::make_variable(tensor);
|
||||
if (tensor.is_variable()) {
|
||||
return std::move(tensor);
|
||||
} else {
|
||||
return torch::autograd::make_variable(std::move(tensor));
|
||||
}
|
||||
}
|
||||
|
||||
IValue wrap(IValue&& ivalue) {
|
||||
|
Reference in New Issue
Block a user