Merge more symbolic meta kernels and symint changes from branch (#86334)

symintify split_with_sizes, dropout, fused_fake_obs_quant. meta for padding_2d ops

add meta_bernoulli_

meta kernel for at::gather

get pytorch_struct to pass: meta for scatter_add, fix backward

symintify split ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86334
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2022-10-06 13:25:05 -07:00
committed by PyTorch MergeBot
parent 3af0eafea6
commit 08e3999fa4
17 changed files with 324 additions and 65 deletions

View File

@ -315,6 +315,7 @@ class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl {
}
SymIntNode ceil() override;
SymIntNode floor() override;
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
@ -342,6 +343,12 @@ SymIntNode PythonSymFloatNodeImpl::ceil() {
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
SymIntNode PythonSymFloatNodeImpl::floor() {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("floor")();
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
namespace {
using autograd::variable_list;
@ -1573,6 +1580,9 @@ void initJITBindings(PyObject* module) {
.def(
"__ceil__",
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->ceil(); })
.def(
"__floor__",
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->floor(); })
.def(
"get_pyobj",
[](c10::SymFloatNode a) -> py::object {