mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Reland symint_numel (#84281)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/84281 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
d09486ab23
commit
eda217ab67
@ -266,6 +266,9 @@ c10::Layout concrete_layout_fn(
|
||||
c10::SymIntArrayRef concrete_sym_strides_fn(
|
||||
const c10::impl::PyInterpreter*,
|
||||
const c10::TensorImpl* self);
|
||||
c10::SymInt concrete_sym_numel_fn(
|
||||
const c10::impl::PyInterpreter*,
|
||||
const c10::TensorImpl* self);
|
||||
template <const char*, typename... Ts>
|
||||
void concrete_trace_cuda(const c10::impl::PyInterpreter*, Ts...);
|
||||
static constexpr char trace_cuda_event_creation_fn_name[] =
|
||||
@ -298,6 +301,7 @@ class PyInterpreterHolder {
|
||||
&concrete_sizes_fn,
|
||||
&concrete_sym_sizes_fn,
|
||||
&concrete_layout_fn,
|
||||
&concrete_sym_numel_fn,
|
||||
&concrete_sym_strides_fn,
|
||||
c10::impl::GPUTraceFunctionWrapper(
|
||||
&concrete_trace_cuda<trace_cuda_event_creation_fn_name>,
|
||||
@ -2497,6 +2501,33 @@ c10::Layout concrete_layout_fn(
|
||||
return toLayout(out.ptr());
|
||||
}
|
||||
|
||||
c10::SymInt concrete_sym_numel_fn(
|
||||
const c10::impl::PyInterpreter*,
|
||||
const c10::TensorImpl* self) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
auto out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"sym_numel",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
.attr("sym_numel")
|
||||
.attr("default")
|
||||
.ptr(),
|
||||
"torch.ops.aten");
|
||||
|
||||
if (out == Py_None) {
|
||||
TORCH_CHECK(
|
||||
!self->has_symbolic_sizes_strides(),
|
||||
"Cannot call numel on a tensor with symbolic shapes/strides");
|
||||
return self->sym_numel_default();
|
||||
}
|
||||
return torch::is_symint_node(out)
|
||||
? out.cast<c10::SymIntNodeImpl*>()->toSymInt()
|
||||
: c10::SymInt{py::cast<int64_t>(out)};
|
||||
}
|
||||
|
||||
template <const char* func_name, typename... Ts>
|
||||
void concrete_trace_cuda(const c10::impl::PyInterpreter*, Ts... args) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
|
Reference in New Issue
Block a user