mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] made tensors const references - fixes pytorch/functorch#38
This commit is contained in:
@ -23,6 +23,7 @@ class PythonTensor(object):
|
||||
return self.value
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs={}):
|
||||
import pdb; pdb.set_trace()
|
||||
namespace, func_name = func.split("::")
|
||||
func = getattr(getattr(torch.ops, namespace), func_name)
|
||||
outs = kwargs['val']
|
||||
|
@ -25,22 +25,22 @@ void PythonTensorImpl::set_storage_offset(int64_t storage_offset) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for PythonTensorImpl");
|
||||
}
|
||||
|
||||
bool isPythonTensor(at::Tensor tensor) {
|
||||
bool isPythonTensor(const at::Tensor& tensor) {
|
||||
return tensor.unsafeGetTensorImpl()->key_set().has(
|
||||
c10::DispatchKey::FuncTorchPython);
|
||||
}
|
||||
PythonTensorImpl* getPythonImpl(at::Tensor tensor) {
|
||||
PythonTensorImpl* getPythonImpl(const at::Tensor& tensor) {
|
||||
return static_cast<PythonTensorImpl*>(tensor.unsafeGetTensorImpl());
|
||||
}
|
||||
|
||||
at::Tensor addPythonKey(const py::object& tensor) {
|
||||
return at::detail::make_tensor<PythonTensorImpl>(tensor);
|
||||
}
|
||||
bool hasPythonKey(at::Tensor tensor) {
|
||||
bool hasPythonKey(const at::Tensor& tensor) {
|
||||
return isPythonTensor(tensor);
|
||||
}
|
||||
|
||||
py::object removePythonKey(at::Tensor tensor) {
|
||||
py::object removePythonKey(const at::Tensor& tensor) {
|
||||
assert(isPythonTensor(tensor));
|
||||
return getPythonImpl(tensor)->value_;
|
||||
}
|
||||
|
@ -50,10 +50,10 @@ struct TORCH_API PythonTensorImpl : public c10::TensorImpl {
|
||||
py::object value_;
|
||||
};
|
||||
|
||||
PythonTensorImpl* getPythonImpl(at::Tensor tensor);
|
||||
PythonTensorImpl* getPythonImpl(const at::Tensor& tensor);
|
||||
|
||||
at::Tensor addPythonKey(const py::object& tensor);
|
||||
bool hasPythonKey(at::Tensor tensor);
|
||||
bool hasPythonKey(const at::Tensor& tensor);
|
||||
|
||||
py::object removePythonKey(at::Tensor tensor);
|
||||
py::object removePythonKey(const at::Tensor& tensor);
|
||||
}}
|
||||
|
Reference in New Issue
Block a user