Files
pytorch/torch/csrc/utils/python_symnode.h
Edward Z. Yang 1ff52225f1 Unify SymIntNode and SymFloatNode into SymNode (#87817)
This refactor was prompted by challenges handling mixed int/float
operations in C++.  A previous version of this patch
added overloads for each permutation of int/float and was unwieldy
https://github.com/pytorch/pytorch/pull/87722/  This PR takes a different
approach.

The general outline of the patch is to combine the C++ types SymIntNode
and SymFloatNode into a single type, SymNode.  This is type erased; we
no longer know statically at C++ if we have an int/float and have to test
it with the is_int()/is_float() virtual methods.  This has a number of
knock on effects.

- We no longer have C++ classes to bind to Python.  Instead, we take an
  entirely new approach to our Python API, where we have a SymInt/SymFloat
  class defined entirely in Python, which hold a SymNode (which corresponds
  to the C++ SymNode).  However, SymNode is not pybind11-bound; instead,
  it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode
  when it goes into C++.  This implies a userland rename.

  In principle, it is also possible for the canonical implementation of SymNode
  to be written in C++, and then bound to Python with pybind11 (we have
  this code, although it is commented out.)  However, I did not implement
  this as we currently have no C++ implementations of SymNode.

  Because we do return SymInt/SymFloat from C++ bindings, the C++ binding
  code needs to know how to find these classes.  Currently, this is done
  just by manually importing torch and getting the attributes.

- Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now
  takes SymInt/SymFloat, rather than SymNode, bringing it in line with how
  __torch_dispatch__ works.

Some miscellaneous improvements:

- SymInt now has a constructor that takes SymNode.  Note that this
  constructor is ambiguous if you pass in a subclass of SymNode,
  so an explicit downcast is necessary.  This means toSymFloat/toSymInt
  are no more.  This is a mild optimization as it means rvalue reference
  works automatically.

- We uniformly use the caster for c10::SymInt/SymFloat, rather than
  going the long way via the SymIntNode/SymFloatNode.

- Removed some unnecessary toSymInt/toSymFloat calls in normalize_*
  functions, pretty sure this doesn't do anything.

- guard_int is now a free function, since to guard on an int you cannot
  assume the method exists.  A function can handle both int and SymInt
  inputs.

- We clean up the magic method definition code for SymInt/SymFloat/SymNode.
  ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets
  plain methods; this is to help avoid confusion between the two types.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817
Approved by: https://github.com/albanD, https://github.com/anjali411
2022-10-27 20:56:02 +00:00

183 lines
5.1 KiB
C++

#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/core/SymNodeImpl.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
TORCH_PYTHON_API py::handle get_symint_class();
TORCH_PYTHON_API py::handle get_symfloat_class();
// NB: These functions must not be called too early, otherwise torch not setup.
// Alternate design is to have torch "register" the object to us
inline bool is_symint(py::handle obj) {
return py::isinstance(obj, get_symint_class());
}
inline bool is_symfloat(py::handle obj) {
return py::isinstance(obj, get_symfloat_class());
}
namespace impl {
// This c10::SymNodeImpl simply backends to a Python object that
// implements the API. The Python object is the source of truth,
// this is just an adapter so C++ calls can get to the object.
class PythonSymNodeImpl : public c10::SymNodeImpl {
public:
PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() {
pyobj_ = std::make_shared<c10::SafePyObject>(
pyobj.release().ptr(), getPyInterpreter());
};
c10::SymNode wrap_int(int64_t num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap_int")(num);
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode wrap_float(double num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap_float")(num);
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
bool bool_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("bool_")().is(py::handle(Py_True));
}
bool is_int() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("is_int")().is(py::handle(Py_True));
}
bool is_float() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("is_float")().is(py::handle(Py_True));
}
int64_t guard_int(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
}
double guard_float(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_float")(file, line).cast<double>();
}
int64_t int_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("int_")().cast<int64_t>();
}
std::string str() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("str")().cast<std::string>();
}
c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) {
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj());
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode dispatch_common_(const char* fname) {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)();
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode add(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode sub(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode mul(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode truediv(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode pow(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode floordiv(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode mod(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode eq(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode gt(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode lt(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode le(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode ge(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode min(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode max(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode ceil() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode floor() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode neg() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode clone() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode sym_int() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode sym_float() override {
return dispatch_common_(__FUNCTION__);
}
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
}
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};
} // namespace impl
} // namespace torch