Revert "Redo how custom/python_custom methods on TensorImpl work (#84796)

This reverts commit 591b75bf98b92acd4f3d0a1dc934198afeaa6fc1.

Manual revert of https://github.com/pytorch/pytorch/pull/84641

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84796
Approved by: https://github.com/izaitsevfb
This commit is contained in:
Eli Uriegas
2022-09-10 00:18:13 +00:00
committed by PyTorch MergeBot
parent 96e4bd9500
commit ca3b2bfbe3
27 changed files with 288 additions and 470 deletions

View File

@ -245,7 +245,6 @@ struct ConcretePyInterpreterVTable final
c10::Layout layout(const TensorImpl* self) const override;
c10::SymInt sym_numel(const TensorImpl* self) const override;
c10::SymIntArrayRef sym_strides(const TensorImpl* self) const override;
c10::SymInt sym_storage_offset(const TensorImpl* self) const override;
void trace_gpu_event_creation(uintptr_t event) const override {
concrete_trace_cuda<trace_cuda_event_creation_fn_name>(event);
@ -716,14 +715,14 @@ static PyObject* THPVariable_make_subclass(
data.set_requires_grad(r.toBool(2));
const auto sizes_strides_policy = r.stringViewOptional(3);
if (sizes_strides_policy.has_value()) {
data.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
data.unsafeGetTensorImpl()->set_sizes_strides_policy(
parseSizesStridesPolicyArgument(*sizes_strides_policy));
}
if (r.toBool(4)) {
data.unsafeGetTensorImpl()->set_python_custom_device(true);
data.unsafeGetTensorImpl()->set_custom_device(true);
}
if (r.toBool(5)) {
data.unsafeGetTensorImpl()->set_python_custom_layout(true);
data.unsafeGetTensorImpl()->set_custom_layout(true);
}
if (!r.isNone(6)) {
data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6));
@ -805,7 +804,7 @@ static PyObject* THPVariable_make_wrapper_subclass(
const auto sizes_strides_policy = r.stringViewOptional(10);
if (sizes_strides_policy.has_value()) {
tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
tensor.unsafeGetTensorImpl()->set_sizes_strides_policy(
parseSizesStridesPolicyArgument(*sizes_strides_policy));
}
} else {
@ -820,12 +819,17 @@ static PyObject* THPVariable_make_wrapper_subclass(
auto sym_sizes = r.symintlist(1);
auto sym_strides = r.symintlist(2);
auto sym_storage_offset = r.toSymIntOptional(3);
TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
tensor_impl->set_sizes_and_strides(
sym_sizes, sym_strides, sym_storage_offset.value_or(0));
// TODO: this should probably be sym_sizes, sym_strides AND offset
tensor_impl->set_sym_sizes_and_strides(sym_sizes, sym_strides);
// TODO: this may need to be symbolic as well
auto storage_offset = r.toInt64Optional(3);
if (storage_offset) {
tensor_impl->set_storage_offset(*storage_offset);
}
const auto sizes_strides_policy = r.stringViewOptional(10);
if (sizes_strides_policy.has_value()) {
@ -838,10 +842,10 @@ static PyObject* THPVariable_make_wrapper_subclass(
tensor.set_requires_grad(r.toBool(9));
if (r.toBool(11)) {
tensor.unsafeGetTensorImpl()->set_python_custom_device(true);
tensor.unsafeGetTensorImpl()->set_custom_device(true);
}
if (r.toBool(12)) {
tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
tensor.unsafeGetTensorImpl()->set_custom_layout(true);
}
return THPVariable_NewWithVar(
@ -2538,29 +2542,6 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel(
: c10::SymInt{py::cast<int64_t>(out)};
}
c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"sym_storage_offset",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_storage_offset")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out == Py_None) {
return self->sym_storage_offset_default();
}
return torch::is_symint_node(out)
? out.cast<c10::SymIntNodeImpl*>()->toSymInt()
: c10::SymInt{py::cast<int64_t>(out)};
}
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;