mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Merge more symbolic meta kernels and symint changes from branch (#86334)
symintify split_with_sizes, dropout, fused_fake_obs_quant. meta for padding_2d ops add meta_bernoulli_ meta kernel for at::gather get pytorch_struct to pass: meta for scatter_add, fix backward symintify split ops Pull Request resolved: https://github.com/pytorch/pytorch/pull/86334 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
3af0eafea6
commit
08e3999fa4
@ -315,6 +315,7 @@ class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl {
|
||||
}
|
||||
|
||||
SymIntNode ceil() override;
|
||||
SymIntNode floor() override;
|
||||
|
||||
py::handle getPyObj() {
|
||||
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
||||
@ -342,6 +343,12 @@ SymIntNode PythonSymFloatNodeImpl::ceil() {
|
||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
||||
}
|
||||
|
||||
SymIntNode PythonSymFloatNodeImpl::floor() {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("floor")();
|
||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using autograd::variable_list;
|
||||
@ -1573,6 +1580,9 @@ void initJITBindings(PyObject* module) {
|
||||
.def(
|
||||
"__ceil__",
|
||||
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->ceil(); })
|
||||
.def(
|
||||
"__floor__",
|
||||
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->floor(); })
|
||||
.def(
|
||||
"get_pyobj",
|
||||
[](c10::SymFloatNode a) -> py::object {
|
||||
|
Reference in New Issue
Block a user