[functorch] made tensors const references - fixes pytorch/functorch#38

This commit is contained in:
Horace He
2021-05-21 02:32:48 -07:00
committed by Jon Janzen
parent ea85c20c35
commit 0e3a2b2d5c
3 changed files with 8 additions and 7 deletions

View File

@ -23,6 +23,7 @@ class PythonTensor(object):
return self.value
def __torch_function__(self, func, types, args=(), kwargs={}):
import pdb; pdb.set_trace()
namespace, func_name = func.split("::")
func = getattr(getattr(torch.ops, namespace), func_name)
outs = kwargs['val']

View File

@ -25,22 +25,22 @@ void PythonTensorImpl::set_storage_offset(int64_t storage_offset) {
TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for PythonTensorImpl");
}
bool isPythonTensor(at::Tensor tensor) {
bool isPythonTensor(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(
c10::DispatchKey::FuncTorchPython);
}
PythonTensorImpl* getPythonImpl(at::Tensor tensor) {
PythonTensorImpl* getPythonImpl(const at::Tensor& tensor) {
return static_cast<PythonTensorImpl*>(tensor.unsafeGetTensorImpl());
}
at::Tensor addPythonKey(const py::object& tensor) {
return at::detail::make_tensor<PythonTensorImpl>(tensor);
}
bool hasPythonKey(at::Tensor tensor) {
bool hasPythonKey(const at::Tensor& tensor) {
return isPythonTensor(tensor);
}
py::object removePythonKey(at::Tensor tensor) {
py::object removePythonKey(const at::Tensor& tensor) {
assert(isPythonTensor(tensor));
return getPythonImpl(tensor)->value_;
}

View File

@ -50,10 +50,10 @@ struct TORCH_API PythonTensorImpl : public c10::TensorImpl {
py::object value_;
};
PythonTensorImpl* getPythonImpl(at::Tensor tensor);
PythonTensorImpl* getPythonImpl(const at::Tensor& tensor);
at::Tensor addPythonKey(const py::object& tensor);
bool hasPythonKey(at::Tensor tensor);
bool hasPythonKey(const at::Tensor& tensor);
py::object removePythonKey(at::Tensor tensor);
py::object removePythonKey(const at::Tensor& tensor);
}}