[jit] Remove graph() call from abstract Function interface. (#65967)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65967

Graph is an implementation detail. If user wants to get access to the
underlying graph, they should be able to explicitly dynamic cast instead.
ghstack-source-id: 141659819

Test Plan: no behavior change.

Reviewed By: gmagogsfm

Differential Revision: D31326153

fbshipit-source-id: a0e984f57c6013494b92a7095bf5bb660035eb84
This commit is contained in:
Zhengxu Chen
2021-10-27 11:52:48 -07:00
committed by Facebook GitHub Bot
parent 7c48b9ee25
commit b55a2500d2
43 changed files with 324 additions and 261 deletions

View File

@ -114,7 +114,7 @@ FunctionSchema PythonValue::getSchema(
std::shared_ptr<SugaredValue> PythonValue::call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
@ -168,7 +168,7 @@ std::string PythonValue::kind() const {
std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const c10::optional<size_t>& size_hint) {
const std::string type_str = typeString(self);
std::stringstream ss;
@ -179,7 +179,7 @@ std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
std::shared_ptr<SugaredValue> PythonValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
const std::string type_str = typeString(self);
std::stringstream ss;
@ -208,7 +208,7 @@ void PythonValue::checkForAddToConstantsError(std::stringstream& ss) {
std::shared_ptr<SugaredValue> PythonModuleValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
py::object member = getattr(loc, field);
// note: is_constant = true because we consider that global properties
@ -220,7 +220,7 @@ std::shared_ptr<SugaredValue> PythonModuleValue::attr(
#if !defined(USE_ROCM)
std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
// List of all the cuda operators which are supported in JIT
const std::unordered_set<std::string> cuda_ops = {
@ -259,11 +259,13 @@ std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
}
#endif
Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
Value* ModuleValue::asValue(const SourceRange& loc, GraphFunction& m) {
return self_;
}
SugaredValuePtr ModuleValue::asTupleValue(const SourceRange& loc, Function& m) {
SugaredValuePtr ModuleValue::asTupleValue(
const SourceRange& loc,
GraphFunction& m) {
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
auto dict = getSugaredDict(loc, m);
auto mods = dict->getModules();
@ -298,7 +300,7 @@ bool ModuleValue::areAllSubmodulesSubtypeOf(
SugaredValuePtr ModuleValue::getitem(
const SourceRange& loc,
Function& m,
GraphFunction& m,
Value* idx,
TypePtr type_hint) {
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
@ -365,7 +367,7 @@ SugaredValuePtr ModuleValue::getitem(
void checkInterface(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::shared_ptr<ModuleValue>& self,
const std::string& field) {
if (self->asValue(loc, m)->type()->cast<InterfaceType>()) {
@ -377,7 +379,7 @@ void checkInterface(
void recurseThroughNestedModules(
const SourceRange& loc,
Function& m,
GraphFunction& m,
std::vector<SugaredValuePtr>& keys,
std::vector<SugaredValuePtr>& values,
std::shared_ptr<ModuleValue>& self,
@ -413,7 +415,7 @@ void recurseThroughNestedModules(
std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
const SourceRange& loc,
Function& m) {
GraphFunction& m) {
std::vector<std::string> paramNames;
std::vector<SugaredValuePtr> values;
@ -441,7 +443,7 @@ std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
const SourceRange& loc,
Function& m) {
GraphFunction& m) {
std::vector<std::string> submoduleNames;
const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
for (size_t i = 0; i < selfType->numAttributes(); ++i) {
@ -472,7 +474,7 @@ std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
std::shared_ptr<SugaredValue> SugaredDict::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
// Recursive compilation does not maintain module aliasing,
// so we do not add uniqueness checks on
@ -508,7 +510,7 @@ std::shared_ptr<SugaredValue> SugaredDict::attr(
std::shared_ptr<SugaredEnumClass> createSugaredEnumClassFromObj(
const py::object& obj,
Function& m,
GraphFunction& m,
const SourceRange& loc) {
auto annotation_type = py::module::import("torch.jit.annotations")
.attr("try_ann_to_type")(obj, loc);
@ -521,7 +523,7 @@ std::shared_ptr<SugaredEnumClass> createSugaredEnumClassFromObj(
// helper function for instantiating a SugaredValue from an IValue
std::shared_ptr<SugaredValue> toSugaredValue(
const IValue& v,
Function& m,
GraphFunction& m,
const SourceRange& loc) {
if (v.isTuple()) {
auto tp = v.toTuple();
@ -540,7 +542,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
// This method controls how we desugar attribute lookups on ScriptModules
std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
// 1. Look inside Module object for the field.
const auto& selfType_ = concreteType_->getJitType();
@ -661,14 +663,14 @@ std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
bool ModuleValue::hasAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
return tryGetAttr(loc, m, field) != nullptr;
}
std::shared_ptr<SugaredValue> ModuleValue::call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
@ -759,7 +761,7 @@ std::shared_ptr<SugaredValue> ModuleValue::call(
// This method controls how we desugar attribute lookups on ScriptModules.
std::shared_ptr<SugaredValue> ModuleValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
if (auto attr = tryGetAttr(loc, m, field)) {
return attr;
@ -788,7 +790,7 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
<< " has no attribute '" << field << "' " << hint;
}
SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
SugaredValuePtr ModuleValue::iter(const SourceRange& loc, GraphFunction& m) {
const auto iterableModuleKind = concreteType_->getIterableModuleKind();
if (iterableModuleKind == IterableModuleKind::NONE) {
throw ErrorReport(loc)
@ -807,7 +809,7 @@ SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
std::shared_ptr<SugaredValue> PythonClassValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
// Resolve values from the Python object first (e.g. for static methods on
// this type, resolve them as functions)
@ -824,7 +826,7 @@ std::shared_ptr<SugaredValue> PythonClassValue::attr(
bool PythonClassValue::hasAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
try {
py::getattr(py_type_, field.c_str());
@ -836,7 +838,7 @@ bool PythonClassValue::hasAttr(
void ModuleValue::setAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field,
Value* newValue) {
// Forward to SimpleValue::setAttr
@ -846,7 +848,7 @@ void ModuleValue::setAttr(
std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
@ -888,7 +890,7 @@ std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
std::shared_ptr<SugaredValue> PythonExceptionValue::call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t /*n_binders*/) {
@ -984,7 +986,7 @@ bool isEnumClass(py::object obj) {
std::shared_ptr<SugaredValue> createSimpleEnumValue(
const py::object& obj,
Function& m,
GraphFunction& m,
const SourceRange& loc) {
auto enum_class = obj.attr("__class__");
auto enum_type =
@ -996,7 +998,7 @@ std::shared_ptr<SugaredValue> createSimpleEnumValue(
std::shared_ptr<SugaredValue> PythonSliceClass::call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t /*n_binders*/) {
@ -1046,7 +1048,7 @@ std::shared_ptr<SugaredValue> PythonSliceClass::call(
std::shared_ptr<SugaredValue> toSugaredValue(
py::object obj,
Function& m,
GraphFunction& m,
const SourceRange& loc,
bool is_constant) {
// directly create SimpleValues when possible, because they are first-class