Reland symint_numel (#84281)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84281
Approved by: https://github.com/ezyang
This commit is contained in:
Nikolay Korovaiko
2022-08-30 21:53:34 +00:00
committed by PyTorch MergeBot
parent d09486ab23
commit eda217ab67
16 changed files with 118 additions and 23 deletions

View File

@ -266,6 +266,9 @@ c10::Layout concrete_layout_fn(
c10::SymIntArrayRef concrete_sym_strides_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
c10::SymInt concrete_sym_numel_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
template <const char*, typename... Ts>
void concrete_trace_cuda(const c10::impl::PyInterpreter*, Ts...);
static constexpr char trace_cuda_event_creation_fn_name[] =
@ -298,6 +301,7 @@ class PyInterpreterHolder {
&concrete_sizes_fn,
&concrete_sym_sizes_fn,
&concrete_layout_fn,
&concrete_sym_numel_fn,
&concrete_sym_strides_fn,
c10::impl::GPUTraceFunctionWrapper(
&concrete_trace_cuda<trace_cuda_event_creation_fn_name>,
@ -2497,6 +2501,33 @@ c10::Layout concrete_layout_fn(
return toLayout(out.ptr());
}
c10::SymInt concrete_sym_numel_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"sym_numel",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_numel")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out == Py_None) {
TORCH_CHECK(
!self->has_symbolic_sizes_strides(),
"Cannot call numel on a tensor with symbolic shapes/strides");
return self->sym_numel_default();
}
return torch::is_symint_node(out)
? out.cast<c10::SymIntNodeImpl*>()->toSymInt()
: c10::SymInt{py::cast<int64_t>(out)};
}
template <const char* func_name, typename... Ts>
void concrete_trace_cuda(const c10::impl::PyInterpreter*, Ts... args) {
pybind11::gil_scoped_acquire gil;