mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Redo how custom/python_custom methods on TensorImpl work (#84796)
This reverts commit 591b75bf98b92acd4f3d0a1dc934198afeaa6fc1. Manual revert of https://github.com/pytorch/pytorch/pull/84641 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/84796 Approved by: https://github.com/izaitsevfb
This commit is contained in:
committed by
PyTorch MergeBot
parent
96e4bd9500
commit
ca3b2bfbe3
@ -245,7 +245,6 @@ struct ConcretePyInterpreterVTable final
|
||||
c10::Layout layout(const TensorImpl* self) const override;
|
||||
c10::SymInt sym_numel(const TensorImpl* self) const override;
|
||||
c10::SymIntArrayRef sym_strides(const TensorImpl* self) const override;
|
||||
c10::SymInt sym_storage_offset(const TensorImpl* self) const override;
|
||||
|
||||
void trace_gpu_event_creation(uintptr_t event) const override {
|
||||
concrete_trace_cuda<trace_cuda_event_creation_fn_name>(event);
|
||||
@ -716,14 +715,14 @@ static PyObject* THPVariable_make_subclass(
|
||||
data.set_requires_grad(r.toBool(2));
|
||||
const auto sizes_strides_policy = r.stringViewOptional(3);
|
||||
if (sizes_strides_policy.has_value()) {
|
||||
data.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
|
||||
data.unsafeGetTensorImpl()->set_sizes_strides_policy(
|
||||
parseSizesStridesPolicyArgument(*sizes_strides_policy));
|
||||
}
|
||||
if (r.toBool(4)) {
|
||||
data.unsafeGetTensorImpl()->set_python_custom_device(true);
|
||||
data.unsafeGetTensorImpl()->set_custom_device(true);
|
||||
}
|
||||
if (r.toBool(5)) {
|
||||
data.unsafeGetTensorImpl()->set_python_custom_layout(true);
|
||||
data.unsafeGetTensorImpl()->set_custom_layout(true);
|
||||
}
|
||||
if (!r.isNone(6)) {
|
||||
data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6));
|
||||
@ -805,7 +804,7 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
||||
|
||||
const auto sizes_strides_policy = r.stringViewOptional(10);
|
||||
if (sizes_strides_policy.has_value()) {
|
||||
tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_strides_policy(
|
||||
parseSizesStridesPolicyArgument(*sizes_strides_policy));
|
||||
}
|
||||
} else {
|
||||
@ -820,12 +819,17 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
||||
|
||||
auto sym_sizes = r.symintlist(1);
|
||||
auto sym_strides = r.symintlist(2);
|
||||
auto sym_storage_offset = r.toSymIntOptional(3);
|
||||
|
||||
TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
|
||||
|
||||
tensor_impl->set_sizes_and_strides(
|
||||
sym_sizes, sym_strides, sym_storage_offset.value_or(0));
|
||||
// TODO: this should probably be sym_sizes, sym_strides AND offset
|
||||
tensor_impl->set_sym_sizes_and_strides(sym_sizes, sym_strides);
|
||||
|
||||
// TODO: this may need to be symbolic as well
|
||||
auto storage_offset = r.toInt64Optional(3);
|
||||
if (storage_offset) {
|
||||
tensor_impl->set_storage_offset(*storage_offset);
|
||||
}
|
||||
|
||||
const auto sizes_strides_policy = r.stringViewOptional(10);
|
||||
if (sizes_strides_policy.has_value()) {
|
||||
@ -838,10 +842,10 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
||||
tensor.set_requires_grad(r.toBool(9));
|
||||
|
||||
if (r.toBool(11)) {
|
||||
tensor.unsafeGetTensorImpl()->set_python_custom_device(true);
|
||||
tensor.unsafeGetTensorImpl()->set_custom_device(true);
|
||||
}
|
||||
if (r.toBool(12)) {
|
||||
tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
|
||||
tensor.unsafeGetTensorImpl()->set_custom_layout(true);
|
||||
}
|
||||
|
||||
return THPVariable_NewWithVar(
|
||||
@ -2538,29 +2542,6 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel(
|
||||
: c10::SymInt{py::cast<int64_t>(out)};
|
||||
}
|
||||
|
||||
c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
|
||||
const c10::TensorImpl* self) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
auto out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"sym_storage_offset",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
.attr("sym_storage_offset")
|
||||
.attr("default")
|
||||
.ptr(),
|
||||
"torch.ops.aten");
|
||||
|
||||
if (out == Py_None) {
|
||||
return self->sym_storage_offset_default();
|
||||
}
|
||||
return torch::is_symint_node(out)
|
||||
? out.cast<c10::SymIntNodeImpl*>()->toSymInt()
|
||||
: c10::SymInt{py::cast<int64_t>(out)};
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
|
||||
const c10::TensorImpl* self) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
|
Reference in New Issue
Block a user