Reland 2 of Merge more symbolic meta kernels and symint changes from branch (#86334) (#86488)

symintify split_with_sizes, dropout, fused_fake_obs_quant. meta for padding_2d ops

add meta_bernoulli_

meta kernel for at::gather

get pytorch_struct to pass: meta for scatter_add, fix backward

symintify split ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86488
Approved by: https://github.com/ezyang
This commit is contained in:
albanD
2022-10-10 08:44:51 -04:00
committed by PyTorch MergeBot
parent 55663b7f81
commit 978b46d7c9
17 changed files with 324 additions and 65 deletions

View File

@ -315,6 +315,7 @@ class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl {
}
SymIntNode ceil() override;
SymIntNode floor() override;
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
@ -342,6 +343,12 @@ SymIntNode PythonSymFloatNodeImpl::ceil() {
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
SymIntNode PythonSymFloatNodeImpl::floor() {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("floor")();
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
namespace {
using autograd::variable_list;
@ -1573,6 +1580,9 @@ void initJITBindings(PyObject* module) {
.def(
"__ceil__",
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->ceil(); })
.def(
"__floor__",
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->floor(); })
.def(
"get_pyobj",
[](c10::SymFloatNode a) -> py::object {