Files
pytorch/torch/csrc/jit/script/python_sugared_value.cpp
Wanchao Liang e870b11ae6 Revert D15275731: Remote SourceLocation
Differential Revision:
D15275731

Original commit changeset: f4da178c3137

fbshipit-source-id: 830b79735eb2dadc4795b5aae407826bf20ef121
2019-05-09 13:07:11 -07:00

463 lines
16 KiB
C++

#include <torch/csrc/jit/script/python_sugared_value.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/Layout.h>
#include <torch/csrc/jit/script/module_python.h>
#include <torch/csrc/jit/script/schema_matching.h>
#include <memory>
#include <sstream>
#include <string>
#include <tuple>
#include <vector>
namespace torch {
namespace jit {
namespace script {
std::string typeString(py::handle h) {
return py::str(h.get_type().attr("__name__"));
}
std::shared_ptr<Function> as_function(const py::object& obj) {
if (py::isinstance<Function>(obj)) {
return py::cast<std::shared_ptr<Function>>(obj);
}
return nullptr;
}
FunctionSchema PythonValue::getSchema(
const size_t n_args,
const size_t n_binders) {
auto annotations = py::module::import("torch.jit.annotations");
auto signature = annotations.attr("get_signature")(self);
std::vector<Argument> args, rets;
// We may mutate this if we can determine the number of args from Python
// introspection.
size_t actual_n_args = n_args;
if (!signature.is_none()) {
std::vector<TypePtr> arg_types;
TypePtr ret_type;
std::tie(arg_types, ret_type) =
py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
args.reserve(arg_types.size());
size_t idx = 0; // Fake argument names by putting in the index
for (auto& arg_type : arg_types) {
args.push_back(
Argument(std::to_string(idx++), std::move(arg_type), {}, {}, false));
}
rets.push_back(Argument("0", std::move(ret_type), {}, {}, false));
} else {
// Create a default signature using what information we have
// First see if we can introspect the number of function parameters
// irrespective of the presence of explicit type annotations
auto num_params = annotations.attr("get_num_params")(self);
if (!num_params.is_none()) {
// Return a signature with the correct number of params according to the
// Python function. The error handling in call() will catch any mismatch
// later.
actual_n_args = py::cast<size_t>(num_params);
}
// Construct the default signature: all arguments and returns will be
// DynamicType
args.reserve(actual_n_args);
for (size_t i = 0; i < actual_n_args; ++i) {
args.push_back(
Argument(std::to_string(i), TensorType::get(), {}, {}, false));
}
TypePtr ret_type = TensorType::get();
if (n_binders == 0) {
ret_type = NoneType::get();
} else if (n_binders > 1) {
std::vector<TypePtr> tuple_values(n_binders, ret_type);
ret_type = TupleType::create(std::move(tuple_values));
}
rets.push_back(Argument("0", ret_type, {}, {}, false));
}
return FunctionSchema("", "", std::move(args), std::move(rets));
}
std::shared_ptr<SugaredValue> PythonValue::call(
const SourceRange& loc,
Function& m,
at::ArrayRef<NamedValue> inputs_,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {
auto inputs = toValues(*m.graph(), inputs_);
auto schema = getSchema(inputs.size(), n_binders);
std::stringstream failure_messages;
c10::optional<MatchedSchema> matched_schema = tryMatchSchema(
schema,
loc,
*m.graph(),
c10::nullopt,
inputs_,
attributes,
failure_messages,
/*conv_tensor_to_num*/ true);
if (!matched_schema)
throw ErrorReport(loc) << failure_messages.str();
// Release the function object so we can wrap it in a PythonOp
py::object func = self;
std::string cconv(inputs.size(), 'd');
Node* new_node = m.graph()->insertNode(
m.graph()->createPythonOp(THPObjectPtr(func.release().ptr()), cconv, {}));
// Mark if function is ignored on export
if (py::cast<bool>(
py::module::import("torch.jit").attr("_try_get_ignored_op")(self))) {
auto python_op = static_cast<PythonOp*>(new_node);
python_op->ignore_on_export = true;
}
new_node->setSourceLocation(std::make_shared<SourceRange>(loc));
for (auto& i : matched_schema->inputs)
new_node->addInput(i);
Value* output =
new_node->addOutput()->setType(matched_schema->return_types.at(0));
return std::make_shared<SimpleValue>(output);
}
std::string PythonValue::kind() const {
std::stringstream ss;
ss << "python value of type '" << typeString(self) << "'";
return ss.str();
}
std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
const SourceRange& loc,
Function& m,
const c10::optional<size_t>& size_hint) {
const std::string type_str = typeString(self);
std::stringstream ss;
ss << kind() << " cannot be used as a tuple";
checkForAddToConstantsError(ss);
throw ErrorReport(loc) << ss.str();
}
std::shared_ptr<SugaredValue> PythonValue::attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
const std::string type_str = typeString(self);
std::stringstream ss;
ss << "attribute lookup is not defined on " << kind();
checkForAddToConstantsError(ss);
throw ErrorReport(loc) << ss.str();
}
py::object PythonValue::getattr(
const SourceRange& loc,
const std::string& name) {
try {
return py::getattr(self, name.c_str());
} catch (py::error_already_set& e) {
throw ErrorReport(loc) << "object has no attribute " << name;
}
}
void PythonValue::checkForAddToConstantsError(std::stringstream& ss) {
auto nn = py::module::import("torch.nn");
if (py::isinstance(self, nn.attr("ModuleList")) ||
py::isinstance(self, nn.attr("Sequential"))) {
ss << ". Did you forget to add it to __constants__? ";
}
}
std::shared_ptr<SugaredValue> PythonModuleValue::attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
py::object member = getattr(loc, field);
// note: is_constant = true because we consider that global properties
// on modules like math.pi or torch.float to be constants
// eventhough it is possible, though rare, for someone to mutate them
return toSugaredValue(member, m, loc, /*is_constant=*/true);
}
std::vector<std::shared_ptr<SugaredValue>> ConstantPythonTupleValue::asTuple(
const SourceRange& loc,
Function& m,
const c10::optional<size_t>& size_hint) {
py::tuple tup = self;
std::vector<std::shared_ptr<SugaredValue>> result;
result.reserve(tup.size());
for (py::handle t : tup) {
py::object obj = py::reinterpret_borrow<py::object>(t);
result.push_back(toSugaredValue(obj, m, loc, true));
}
return result;
}
Value* ConstantPythonTupleValue::asValue(const SourceRange& loc, Function& m) {
std::vector<Value*> values;
for (const auto& sugared_item : asTuple(loc, m)) {
values.push_back(sugared_item->asValue(loc, m));
}
auto node = m.graph()->createTuple(values);
return m.graph()->insertNode(node)->output();
}
std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {
std::stringstream err;
std::vector<NamedValue> new_inputs = inputs.vec();
new_inputs.insert(new_inputs.begin(), module_);
for (const std::string& method_name : method_names_) {
auto cls = module_->type()->expect<ClassType>();
std::shared_ptr<Function> fn = cls->getMethod(method_name);
auto match = tryMatchSchema(
fn->getSchema(),
loc,
*caller.graph().get(),
c10::nullopt,
new_inputs,
attributes,
err,
true);
if (match) {
return MethodValue(module_, fn)
.call(loc, caller, inputs, attributes, n_binders);
}
}
throw ErrorReport(loc) << "Could not find any matching overloads\n"
<< err.str();
}
std::shared_ptr<SugaredValue> ModuleValue::attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
// workaround to make self.training work
// it adds a buffer 'training' to the model if one doesn't exist
// and then loads that parameter, casting it to bool
if (field == "training") {
Slot* v = module_->find_buffer(field);
if (!v) {
bool training = py::cast<bool>(py::getattr(py_module_, "training"));
auto t =
autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
module_->register_buffer("training", std::move(t));
v = module_->find_buffer(field);
}
Value* the_tensor = m.graph()->insertGetAttr(self_, "training");
Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor});
return std::make_shared<SimpleValue>(the_bool);
}
if (std::shared_ptr<Module> v = module_->find_module(field)) {
return std::make_shared<ModuleValue>(
m.graph()->insertGetAttr(self_, field),
v,
py_module_.attr(field.c_str()));
} else if (auto kind = module_->kind_of(field)) {
// methods, parameters, attributes, and buffers are all first class
return SimpleValue(self_).attr(loc, m, field);
}
// This can also be a call to a non-script module, or a plain
// python method. If so return this as a python value.
py::object overloads =
py_module_.attr("_overloads").attr("get")(field, py::none());
if (!overloads.is_none()) {
return std::make_shared<OverloadedMethodValue>(
self_, py::cast<std::vector<std::string>>(overloads));
}
if (py::object attr = py::getattr(py_module_, field.c_str(), py::none())) {
if (py::isinstance<py::function>(attr) &&
py::hasattr(attr, "_parameter_names_fn")) {
// Fetch the names of the parameters in the list so they're in the
// right order
auto fn_self = py::getattr(attr, "__self__");
auto param_names = py::getattr(attr, "_parameter_names_fn")(fn_self);
Graph& g = *m.graph();
// Add all module parameters as inputs to the graph
std::vector<Value*> params;
for (auto name : param_names) {
params.emplace_back(g.insertGetAttr(self_, py::str(name)));
}
auto list = g.insertNode(g.createTuple(params))->output();
return std::make_shared<ConstantParameterList>(list);
}
if (py::isinstance<py::function>(attr) ||
py::isinstance(attr, py::module::import("torch.nn").attr("Module")) ||
py_module_.attr("_constants_set").contains(field.c_str())) {
return toSugaredValue(attr, m, loc, true);
} else {
std::string hint = "did you forget to add it __constants__?";
if (py::isinstance(attr, py::module::import("torch").attr("Tensor"))) {
hint = "Tensors must be added to a module as a buffer or parameter";
}
throw ErrorReport(loc)
<< "attribute '" << field << "' of type '" << typeString(attr)
<< "' is not usable in a script method (" << hint << ")";
}
}
throw ErrorReport(loc) << "module has no attribute '" << field << "'";
}
std::vector<std::shared_ptr<SugaredValue>> ModuleValue::asTuple(
const SourceRange& loc,
Function& m,
const c10::optional<size_t>& size_hint) {
if (!py::isinstance(
py_module_, py::module::import("torch.jit").attr("_ConstModuleList")))
return SugaredValue::asTuple(loc, m, size_hint);
std::vector<std::shared_ptr<SugaredValue>> result;
for (py::handle py_submodule : py_module_) {
py::object obj = py::reinterpret_borrow<py::object>(py_submodule);
if (auto sub_module = as_module(obj)) {
Value* module_v = m.graph()->insertGetAttr(self_, sub_module->name());
result.emplace_back(
std::make_shared<ModuleValue>(module_v, sub_module, obj));
} else {
result.push_back(toSugaredValue(
obj,
m,
loc,
/*is_constant =*/false));
}
}
return result;
}
std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {
c10::optional<bool> result;
Graph& graph = *(caller.graph());
auto index = py::cast<size_t>(dispatched_fn_["index"]);
auto arg_name = py::str(dispatched_fn_["arg_name"]);
if (index < inputs.size()) {
// Dispatch flag is in arg list
result = constant_as<bool>(inputs.at(index).value(graph));
} else if (auto i = findInputWithName(arg_name, attributes)) {
// Dispatch flag is in kwargs
result = constant_as<bool>(attributes[*i].value(graph));
} else {
// Didn't find dispatch flag, so use default value
result = py::cast<bool>(dispatched_fn_["default"]);
}
if (!result) {
throw ErrorReport(loc) << "value for boolean dispatch was not constant";
}
std::shared_ptr<SugaredValue> value;
if (*result) {
value = toSugaredValue(dispatched_fn_["if_true"], caller, loc);
} else {
value = toSugaredValue(dispatched_fn_["if_false"], caller, loc);
}
return value->call(loc, caller, inputs, attributes, n_binders);
}
std::shared_ptr<SugaredValue> toSugaredValue(
py::object obj,
Function& m,
SourceRange loc,
bool is_constant) {
// directly create SimpleValues when possible, because they are first-class
// and can be re-assigned. Otherwise, this would be invalid:
// f = python_constant
// while ...
// f = f + 1
auto& g = *m.graph();
if (is_constant) {
if (py::isinstance<py::bool_>(obj)) {
return toSimple(g.insertConstant(py::cast<bool>(obj), nullptr, loc));
} else if (py::isinstance<py::int_>(obj)) {
return toSimple(g.insertConstant(py::cast<int64_t>(obj), nullptr, loc));
} else if (py::isinstance<py::float_>(obj)) {
return toSimple(g.insertConstant(py::cast<double>(obj), nullptr, loc));
} else if (py::isinstance<py::str>(obj)) {
return toSimple(
g.insertConstant(py::cast<std::string>(obj), nullptr, loc));
} else if (obj.is(py::none())) {
return toSimple(g.insertConstant(IValue(), nullptr, loc));
} else if (THPDevice_Check(obj.ptr())) {
auto device = reinterpret_cast<THPDevice*>(obj.ptr());
return toSimple(g.insertConstant(device->device));
} else if (THPLayout_Check(obj.ptr())) {
auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
const auto v = static_cast<int64_t>(layout->layout);
return toSimple(g.insertConstant(v, nullptr, loc));
} else if (THPDtype_Check(obj.ptr())) {
auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
const auto v = static_cast<int64_t>(dtype->scalar_type);
return toSimple(g.insertConstant(v, nullptr, loc));
} else if (py::isinstance<py::tuple>(obj)) {
return std::make_shared<ConstantPythonTupleValue>(obj);
}
}
auto weak_obj =
py::module::import("torch.jit").attr("_try_get_weak_module")(obj);
if (!weak_obj.is_none()) {
obj = weak_obj;
}
if (auto callee = as_function(obj)) {
return std::make_shared<FunctionValue>(callee);
} else if (py::isinstance<py::module>(obj)) {
return std::make_shared<PythonModuleValue>(obj);
} else if (obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr()) {
return std::make_shared<ForkValue>();
} else if (
obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) {
return std::make_shared<AnnotateValue>();
} else if (auto callee = as_module(obj)) {
throw ErrorReport(loc) << "Cannot call a ScriptModule that is not"
<< " a submodule of the caller";
}
py::object builtin_name =
py::module::import("torch.jit").attr("_find_builtin")(obj);
if (!builtin_name.is_none()) {
return std::make_shared<BuiltinFunction>(
Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
}
if (py::isinstance<py::function>(obj)) {
auto compiled_fn =
py::module::import("torch.jit").attr("_try_compile_weak_script")(obj);
if (auto callee = as_function(compiled_fn)) {
return std::make_shared<FunctionValue>(callee);
}
}
py::object dispatched_fn =
py::module::import("torch.jit").attr("_try_get_dispatched_fn")(obj);
if (!dispatched_fn.is_none()) {
return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
}
py::bool_ isClass = py::module::import("inspect").attr("isclass")(obj);
if (py::cast<bool>(isClass)) {
py::str qualifiedName =
py::module::import("torch.jit").attr("_qualified_name")(obj);
auto& pyCu = CompilationUnit::_get_python_cu();
if (auto classType = pyCu.get_class(c10::QualifiedName(qualifiedName))) {
return std::make_shared<ClassValue>(classType);
}
}
return std::make_shared<PythonValue>(obj);
}
} // namespace script
} // namespace jit
} // namespace torch