mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
[jit] Remove graph() call from abstract Function interface. (#65967)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65967 Graph is an implementation detail. If user wants to get access to the underlying graph, they should be able to explicitly dynamic cast instead. ghstack-source-id: 141659819 Test Plan: no behavior change. Reviewed By: gmagogsfm Differential Revision: D31326153 fbshipit-source-id: a0e984f57c6013494b92a7095bf5bb660035eb84
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7c48b9ee25
commit
b55a2500d2
@ -114,7 +114,7 @@ FunctionSchema PythonValue::getSchema(
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) {
|
||||
@ -168,7 +168,7 @@ std::string PythonValue::kind() const {
|
||||
|
||||
std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const c10::optional<size_t>& size_hint) {
|
||||
const std::string type_str = typeString(self);
|
||||
std::stringstream ss;
|
||||
@ -179,7 +179,7 @@ std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
const std::string type_str = typeString(self);
|
||||
std::stringstream ss;
|
||||
@ -208,7 +208,7 @@ void PythonValue::checkForAddToConstantsError(std::stringstream& ss) {
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonModuleValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
py::object member = getattr(loc, field);
|
||||
// note: is_constant = true because we consider that global properties
|
||||
@ -220,7 +220,7 @@ std::shared_ptr<SugaredValue> PythonModuleValue::attr(
|
||||
#if !defined(USE_ROCM)
|
||||
std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
// List of all the cuda operators which are supported in JIT
|
||||
const std::unordered_set<std::string> cuda_ops = {
|
||||
@ -259,11 +259,13 @@ std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
|
||||
}
|
||||
#endif
|
||||
|
||||
Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
|
||||
Value* ModuleValue::asValue(const SourceRange& loc, GraphFunction& m) {
|
||||
return self_;
|
||||
}
|
||||
|
||||
SugaredValuePtr ModuleValue::asTupleValue(const SourceRange& loc, Function& m) {
|
||||
SugaredValuePtr ModuleValue::asTupleValue(
|
||||
const SourceRange& loc,
|
||||
GraphFunction& m) {
|
||||
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
|
||||
auto dict = getSugaredDict(loc, m);
|
||||
auto mods = dict->getModules();
|
||||
@ -298,7 +300,7 @@ bool ModuleValue::areAllSubmodulesSubtypeOf(
|
||||
|
||||
SugaredValuePtr ModuleValue::getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint) {
|
||||
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
|
||||
@ -365,7 +367,7 @@ SugaredValuePtr ModuleValue::getitem(
|
||||
|
||||
void checkInterface(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::shared_ptr<ModuleValue>& self,
|
||||
const std::string& field) {
|
||||
if (self->asValue(loc, m)->type()->cast<InterfaceType>()) {
|
||||
@ -377,7 +379,7 @@ void checkInterface(
|
||||
|
||||
void recurseThroughNestedModules(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
std::vector<SugaredValuePtr>& keys,
|
||||
std::vector<SugaredValuePtr>& values,
|
||||
std::shared_ptr<ModuleValue>& self,
|
||||
@ -413,7 +415,7 @@ void recurseThroughNestedModules(
|
||||
|
||||
std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
|
||||
const SourceRange& loc,
|
||||
Function& m) {
|
||||
GraphFunction& m) {
|
||||
std::vector<std::string> paramNames;
|
||||
std::vector<SugaredValuePtr> values;
|
||||
|
||||
@ -441,7 +443,7 @@ std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
|
||||
|
||||
std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
|
||||
const SourceRange& loc,
|
||||
Function& m) {
|
||||
GraphFunction& m) {
|
||||
std::vector<std::string> submoduleNames;
|
||||
const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
|
||||
for (size_t i = 0; i < selfType->numAttributes(); ++i) {
|
||||
@ -472,7 +474,7 @@ std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
|
||||
|
||||
std::shared_ptr<SugaredValue> SugaredDict::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
// Recursive compilation does not maintain module aliasing,
|
||||
// so we do not add uniqueness checks on
|
||||
@ -508,7 +510,7 @@ std::shared_ptr<SugaredValue> SugaredDict::attr(
|
||||
|
||||
std::shared_ptr<SugaredEnumClass> createSugaredEnumClassFromObj(
|
||||
const py::object& obj,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) {
|
||||
auto annotation_type = py::module::import("torch.jit.annotations")
|
||||
.attr("try_ann_to_type")(obj, loc);
|
||||
@ -521,7 +523,7 @@ std::shared_ptr<SugaredEnumClass> createSugaredEnumClassFromObj(
|
||||
// helper function for instantiating a SugaredValue from an IValue
|
||||
std::shared_ptr<SugaredValue> toSugaredValue(
|
||||
const IValue& v,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) {
|
||||
if (v.isTuple()) {
|
||||
auto tp = v.toTuple();
|
||||
@ -540,7 +542,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||
// This method controls how we desugar attribute lookups on ScriptModules
|
||||
std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
// 1. Look inside Module object for the field.
|
||||
const auto& selfType_ = concreteType_->getJitType();
|
||||
@ -661,14 +663,14 @@ std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
|
||||
|
||||
bool ModuleValue::hasAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
return tryGetAttr(loc, m, field) != nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<SugaredValue> ModuleValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) {
|
||||
@ -759,7 +761,7 @@ std::shared_ptr<SugaredValue> ModuleValue::call(
|
||||
// This method controls how we desugar attribute lookups on ScriptModules.
|
||||
std::shared_ptr<SugaredValue> ModuleValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
if (auto attr = tryGetAttr(loc, m, field)) {
|
||||
return attr;
|
||||
@ -788,7 +790,7 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
|
||||
<< " has no attribute '" << field << "' " << hint;
|
||||
}
|
||||
|
||||
SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
|
||||
SugaredValuePtr ModuleValue::iter(const SourceRange& loc, GraphFunction& m) {
|
||||
const auto iterableModuleKind = concreteType_->getIterableModuleKind();
|
||||
if (iterableModuleKind == IterableModuleKind::NONE) {
|
||||
throw ErrorReport(loc)
|
||||
@ -807,7 +809,7 @@ SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonClassValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
// Resolve values from the Python object first (e.g. for static methods on
|
||||
// this type, resolve them as functions)
|
||||
@ -824,7 +826,7 @@ std::shared_ptr<SugaredValue> PythonClassValue::attr(
|
||||
|
||||
bool PythonClassValue::hasAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
try {
|
||||
py::getattr(py_type_, field.c_str());
|
||||
@ -836,7 +838,7 @@ bool PythonClassValue::hasAttr(
|
||||
|
||||
void ModuleValue::setAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field,
|
||||
Value* newValue) {
|
||||
// Forward to SimpleValue::setAttr
|
||||
@ -846,7 +848,7 @@ void ModuleValue::setAttr(
|
||||
|
||||
std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) {
|
||||
@ -888,7 +890,7 @@ std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonExceptionValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t /*n_binders*/) {
|
||||
@ -984,7 +986,7 @@ bool isEnumClass(py::object obj) {
|
||||
|
||||
std::shared_ptr<SugaredValue> createSimpleEnumValue(
|
||||
const py::object& obj,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) {
|
||||
auto enum_class = obj.attr("__class__");
|
||||
auto enum_type =
|
||||
@ -996,7 +998,7 @@ std::shared_ptr<SugaredValue> createSimpleEnumValue(
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonSliceClass::call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t /*n_binders*/) {
|
||||
@ -1046,7 +1048,7 @@ std::shared_ptr<SugaredValue> PythonSliceClass::call(
|
||||
|
||||
std::shared_ptr<SugaredValue> toSugaredValue(
|
||||
py::object obj,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc,
|
||||
bool is_constant) {
|
||||
// directly create SimpleValues when possible, because they are first-class
|
||||
|
Reference in New Issue
Block a user