mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] kill script namespace (#34515)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34515 Once upon a time we thought this was necessary. In reality it is not, so removing it. For backcompat, our public interface (defined in `api/`) still has typedefs to the old `script::` names. There was only one collision: `Pass` as a `Stmt` and `Pass` as a graph transform. I renamed one of them. Test Plan: Imported from OSS Differential Revision: D20353503 Pulled By: suo fbshipit-source-id: 48bb911ce75120a8c9e0c6fb65262ef775dfba93
This commit is contained in:
committed by
Facebook GitHub Bot
parent
cf8b728255
commit
c235be42dd
@ -13,7 +13,7 @@ int main(int argc, char* argv[]) {
|
||||
std::ifstream ifs(input_file_path);
|
||||
std::stringstream buffer;
|
||||
buffer << ifs.rdbuf();
|
||||
torch::jit::script::Module m("TestModule");
|
||||
torch::jit::Module m("TestModule");
|
||||
|
||||
m.define(buffer.str());
|
||||
m.save(output_file_path);
|
||||
|
@ -58,7 +58,7 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
|
||||
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
private:
|
||||
friend HybridBase;
|
||||
torch::jit::script::Module module_;
|
||||
torch::jit::Module module_;
|
||||
|
||||
public:
|
||||
constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";
|
||||
|
@ -5,7 +5,7 @@ package org.pytorch;
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
|
||||
/** Java wrapper for torch::jit::script::Module. */
|
||||
/** Java wrapper for torch::jit::Module. */
|
||||
public class Module {
|
||||
|
||||
private INativePeer mNativePeer;
|
||||
@ -14,7 +14,7 @@ public class Module {
|
||||
* Loads a serialized TorchScript module from the specified path on the disk.
|
||||
*
|
||||
* @param modelPath path to file that contains the serialized TorchScript module.
|
||||
* @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module.
|
||||
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
|
||||
*/
|
||||
public static Module load(final String modelPath) {
|
||||
if (!NativeLoader.isInitialized()) {
|
||||
@ -49,7 +49,7 @@ public class Module {
|
||||
}
|
||||
|
||||
/**
|
||||
* Explicitly destroys the native torch::jit::script::Module. Calling this method is not required,
|
||||
* Explicitly destroys the native torch::jit::Module. Calling this method is not required,
|
||||
* as the native object will be destroyed when this object is garbage-collected. However, the
|
||||
* timing of garbage collection is not guaranteed, so proactively calling {@code destroy} can free
|
||||
* memory more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
|
||||
|
@ -22,7 +22,7 @@ TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
|
||||
|
||||
// A Function is a pure Graph with no implicit `self` object bound.
|
||||
// It contains schema information, and the executor that manages the
|
||||
// execution of the function. script::Method is a wrapper around a
|
||||
// execution of the function. Method is a wrapper around a
|
||||
// underlying Function that also provides a `self` object.
|
||||
struct TORCH_API Function {
|
||||
virtual bool isGraphFunction() const = 0;
|
||||
|
@ -378,7 +378,7 @@ std::vector<std::pair<IValue, IValue>> iterationOrder(const c10::Dict<IValue, IV
|
||||
}
|
||||
|
||||
StrongTypePtr::StrongTypePtr(
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> cu,
|
||||
std::shared_ptr<torch::jit::CompilationUnit> cu,
|
||||
std::shared_ptr<Type> type) {
|
||||
cu_ = std::move(cu);
|
||||
type_ = type;
|
||||
|
@ -11,10 +11,8 @@ namespace jit {
|
||||
class CustomClassHolder : public c10::intrusive_ptr_target {};
|
||||
|
||||
struct Function;
|
||||
namespace script {
|
||||
struct CompilationUnit;
|
||||
struct Module;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
namespace c10 {
|
||||
@ -356,7 +354,7 @@ struct CAFFE2_API IValue final {
|
||||
c10::intrusive_ptr<ivalue::Object> toObject() const & ;
|
||||
const ivalue::Object& toObjectRef() const;
|
||||
|
||||
torch::jit::script::Module toModule() const;
|
||||
torch::jit::Module toModule() const;
|
||||
bool isModule() const;
|
||||
|
||||
// PyObject
|
||||
@ -692,10 +690,10 @@ private:
|
||||
// guaranteed to stay alive as long as we hold this object.
|
||||
struct TORCH_API StrongTypePtr {
|
||||
StrongTypePtr(
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> cu,
|
||||
std::shared_ptr<torch::jit::CompilationUnit> cu,
|
||||
std::shared_ptr<Type> type);
|
||||
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> cu_;
|
||||
std::shared_ptr<torch::jit::CompilationUnit> cu_;
|
||||
std::shared_ptr<Type> type_;
|
||||
};
|
||||
|
||||
|
@ -15,9 +15,7 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
struct Function;
|
||||
namespace script {
|
||||
struct CompilationUnit;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
namespace c10 {
|
||||
@ -406,7 +404,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
|
||||
}
|
||||
std::shared_ptr<ClassType> type() const;
|
||||
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> compilation_unit() {
|
||||
std::shared_ptr<torch::jit::CompilationUnit> compilation_unit() {
|
||||
return type_.cu_;
|
||||
}
|
||||
|
||||
@ -868,7 +866,7 @@ IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
|
||||
auto res = getCustomClassType<inputType>();
|
||||
auto retObject = ivalue::Object::create(
|
||||
StrongTypePtr(
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit>(),
|
||||
std::shared_ptr<torch::jit::CompilationUnit>(),
|
||||
std::move(res)),
|
||||
1);
|
||||
auto objPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(std::move(x));
|
||||
|
@ -17,9 +17,7 @@
|
||||
struct ClassType;
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
struct CompilationUnit;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -1491,7 +1489,7 @@ CAFFE2_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type);
|
||||
|
||||
struct ClassType;
|
||||
using ClassTypePtr = std::shared_ptr<ClassType>;
|
||||
using ::torch::jit::script::CompilationUnit;
|
||||
using ::torch::jit::CompilationUnit;
|
||||
|
||||
// This represents a class in TorchScript.
|
||||
struct CAFFE2_API ClassType : public NamedType {
|
||||
@ -1801,7 +1799,7 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
|
||||
struct InterfaceType;
|
||||
using InterfaceTypePtr = std::shared_ptr<InterfaceType>;
|
||||
using ::torch::jit::script::CompilationUnit;
|
||||
using ::torch::jit::CompilationUnit;
|
||||
|
||||
// Interfaces are a list of abstract methods that a class might meet.
|
||||
// If a class provides those methods, it implicitly meets the interface.
|
||||
|
@ -24,7 +24,7 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
void dump_opnames(const script::Module& m, std::unordered_set<std::string>& opnames) {
|
||||
void dump_opnames(const Module& m, std::unordered_set<std::string>& opnames) {
|
||||
auto methods = m.get_methods();
|
||||
for (const auto& method : methods) {
|
||||
const auto& func = method.function();
|
||||
|
@ -7,8 +7,8 @@
|
||||
namespace caffe2 {
|
||||
|
||||
using torch::jit::IValue;
|
||||
using torch::jit::script::Method;
|
||||
using torch::jit::script::Module;
|
||||
using torch::jit::Method;
|
||||
using torch::jit::Module;
|
||||
|
||||
namespace {
|
||||
class ScriptModuleSerializer : public BlobSerializerBase {
|
||||
@ -31,7 +31,7 @@ class ScriptModuleSerializer : public BlobSerializerBase {
|
||||
// the more efficient serialization version (if we ever get to that point)
|
||||
BlobProto blob_proto;
|
||||
blob_proto.set_name(name);
|
||||
blob_proto.set_type("torch::jit::script::Module");
|
||||
blob_proto.set_type("torch::jit::Module");
|
||||
blob_proto.set_content(ss.str());
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
@ -134,7 +134,7 @@ REGISTER_BLOB_SERIALIZER(
|
||||
// NB: the first argument to REGISTER_BLOB_DESERIALIZER macro doesn't really
|
||||
// need to be a real type, it just get converted to string
|
||||
REGISTER_BLOB_DESERIALIZER(
|
||||
torch::jit::script::Module,
|
||||
torch::jit::Module,
|
||||
ScriptModuleDeserializer);
|
||||
|
||||
OPERATOR_SCHEMA(ScriptModule)
|
||||
|
@ -94,12 +94,12 @@ REGISTER_BLOB_FETCHER((TypeMeta::Id<string>()), StringFetcher);
|
||||
class ScriptModuleFetcher : public BlobFetcherBase {
|
||||
public:
|
||||
pybind11::object Fetch(const Blob& blob) override {
|
||||
return py::cast(*blob.Get<std::unique_ptr<torch::jit::script::Module>>());
|
||||
return py::cast(*blob.Get<std::unique_ptr<torch::jit::Module>>());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_BLOB_FETCHER(
|
||||
(TypeMeta::Id<std::unique_ptr<torch::jit::script::Module>>()),
|
||||
(TypeMeta::Id<std::unique_ptr<torch::jit::Module>>()),
|
||||
caffe2::python::ScriptModuleFetcher);
|
||||
#endif
|
||||
|
||||
@ -247,9 +247,9 @@ bool feedBlob(
|
||||
return true;
|
||||
}
|
||||
#ifdef FBCODE_CAFFE2
|
||||
if (auto module = torch::jit::script::as_module(arg)) {
|
||||
blob->GetMutable<std::unique_ptr<torch::jit::script::Module>>()->reset(
|
||||
new torch::jit::script::Module(*module));
|
||||
if (auto module = torch::jit::as_module(arg)) {
|
||||
blob->GetMutable<std::unique_ptr<torch::jit::Module>>()->reset(
|
||||
new torch::jit::Module(*module));
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
@ -213,7 +213,7 @@ Disable JIT for Debugging
|
||||
Python. Since TorchScript (scripting and tracing) are disabled with this flag,
|
||||
you can use tools like ``pdb`` to debug the model code.
|
||||
|
||||
Given an example script::
|
||||
Given an example
|
||||
|
||||
@torch.jit.script
|
||||
def scripted_fn(x : torch.Tensor):
|
||||
|
@ -120,8 +120,8 @@ usage might look like:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
SetExportModuleExtraFilesHook([](const script::Module&) {
|
||||
script::ExtraFilesMap files;
|
||||
SetExportModuleExtraFilesHook([](const Module&) {
|
||||
ExtraFilesMap files;
|
||||
files["producer_info.json"] = "{\"user\": \"" + getenv("USER") + "\"}";
|
||||
return files;
|
||||
});
|
||||
|
@ -8,7 +8,7 @@ Module
|
||||
|
||||
.. java:type:: public class Module
|
||||
|
||||
Java wrapper for torch::jit::script::Module.
|
||||
Java wrapper for torch::jit::Module.
|
||||
|
||||
Methods
|
||||
-------
|
||||
@ -18,7 +18,7 @@ destroy
|
||||
.. java:method:: public void destroy()
|
||||
:outertype: Module
|
||||
|
||||
Explicitly destroys the native torch::jit::script::Module. Calling this method is not required, as the native object will be destroyed when this object is garbage-collected. However, the timing of garbage collection is not guaranteed, so proactively calling \ ``destroy``\ can free memory more quickly. See \ :java:ref:`com.facebook.jni.HybridData.resetNative`\ .
|
||||
Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the native object will be destroyed when this object is garbage-collected. However, the timing of garbage collection is not guaranteed, so proactively calling \ ``destroy``\ can free memory more quickly. See \ :java:ref:`com.facebook.jni.HybridData.resetNative`\ .
|
||||
|
||||
forward
|
||||
^^^^^^^
|
||||
@ -40,7 +40,7 @@ load
|
||||
Loads a serialized TorchScript module from the specified path on the disk.
|
||||
|
||||
:param modelPath: path to file that contains the serialized TorchScript module.
|
||||
:return: new \ :java:ref:`org.pytorch.Module`\ object which owns torch::jit::script::Module.
|
||||
:return: new \ :java:ref:`org.pytorch.Module`\ object which owns torch::jit::Module.
|
||||
|
||||
runMethod
|
||||
^^^^^^^^^
|
||||
|
@ -7,7 +7,7 @@
|
||||
@end
|
||||
|
||||
@implementation TestAppTests {
|
||||
torch::jit::script::Module _module;
|
||||
torch::jit::Module _module;
|
||||
}
|
||||
|
||||
+ (void)setUp {
|
||||
|
@ -395,7 +395,7 @@ void testAliasAnalysis() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%opt : Tensor? = prim::Constant()
|
||||
@ -415,7 +415,7 @@ void testAliasAnalysis() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x : Tensor):
|
||||
%3 : int = prim::Constant[value=1]()
|
||||
@ -491,7 +491,7 @@ void testWriteTracking() {
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x: Tensor):
|
||||
%b : Tensor = aten::relu_(%x)
|
||||
@ -505,7 +505,7 @@ void testWriteTracking() {
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x: Tensor, %y : Tensor):
|
||||
%b : Tensor = aten::mul(%x, %y)
|
||||
@ -520,7 +520,7 @@ void testWriteTracking() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x: Tensor, %y : Tensor):
|
||||
%c1 : int = prim::Constant[value=1]()
|
||||
@ -540,7 +540,7 @@ void testContainerAliasing() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%inp: Tensor[]):
|
||||
%x : str = prim::Constant[value="a"]()
|
||||
@ -574,7 +574,7 @@ void testContainerAliasing() {
|
||||
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%x : str = prim::Constant[value="a"]()
|
||||
@ -601,7 +601,7 @@ void testContainerAliasing() {
|
||||
// Test input aliasing
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x: Tensor, %y: Tensor):
|
||||
%a : (Tensor) = prim::TupleConstruct(%x)
|
||||
@ -622,7 +622,7 @@ void testContainerAliasing() {
|
||||
// Test tuple that doesn't come from construct
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x : int,
|
||||
%y : Tensor,
|
||||
@ -654,7 +654,7 @@ graph(%x : int,
|
||||
// test nested types
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%a : Tensor = prim::MakeTestTensor()
|
||||
@ -681,7 +681,7 @@ graph():
|
||||
// simple example
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%0 : Tensor = prim::Constant()
|
||||
@ -711,7 +711,7 @@ graph():
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%x : str = prim::Constant[value="a"]()
|
||||
@ -737,7 +737,7 @@ graph():
|
||||
// Test list container aliasing
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%0 : int = prim::Constant[value=2]()
|
||||
@ -779,7 +779,7 @@ graph():
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%0 : int = prim::Constant[value=2]()
|
||||
@ -810,7 +810,7 @@ graph():
|
||||
// print across it.
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%35 : int = prim::Constant[value=1]()
|
||||
@ -849,7 +849,7 @@ graph():
|
||||
// print across it.
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%38 : int = prim::Constant[value=1]()
|
||||
@ -935,7 +935,7 @@ void testWildcards() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%ten_list : Tensor[], %int_list : int[], %opt_ten_list : Tensor[]?):
|
||||
%ten : Tensor = prim::Constant()
|
||||
@ -971,7 +971,7 @@ void testWildcards() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%ten_list : Tensor[], %ten_opt_list : Tensor?[]):
|
||||
%ten : Tensor = prim::Constant()
|
||||
@ -992,7 +992,7 @@ void testWildcards() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%float_3D : Float(*, *, *), %float_2D : Float(*, *)):
|
||||
return ()
|
||||
@ -1006,7 +1006,7 @@ void testWildcards() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%float_3D_list : Float(*, *, *)[], %float_2D_list : Float(*, *)[], %ten: Tensor):
|
||||
return ()
|
||||
|
@ -213,7 +213,7 @@ void testDifferentiateWithRequiresGrad() {
|
||||
%7 : Tensor = aten::add(%6, %1, %2)
|
||||
return (%4, %7))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::script::parseIR(graph_string, g.get());
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
auto a_var = autograd::make_variable(
|
||||
at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
|
||||
|
@ -9,8 +9,6 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::script;
|
||||
|
||||
static const auto classSrcs1 = R"JIT(
|
||||
class FooNestedTest:
|
||||
def __init__(self, y):
|
||||
|
@ -4,7 +4,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using namespace torch::jit::script;
|
||||
const auto testSource = R"JIT(
|
||||
class FooTest:
|
||||
def __init__(self, x):
|
||||
|
@ -5,8 +5,6 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::script;
|
||||
|
||||
void testClassTypeAddRemoveAttr() {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto cls = ClassType::create("foo.bar", cu, true);
|
||||
|
@ -14,7 +14,7 @@ namespace jit {
|
||||
void testConstantPooling() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%8 : int = prim::Constant[value=1]()
|
||||
@ -29,7 +29,7 @@ graph():
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%cond : Tensor):
|
||||
%a : str = prim::Constant[value="bcd"]()
|
||||
@ -53,7 +53,7 @@ graph(%cond : Tensor):
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%2 : int = prim::Constant[value=2]()
|
||||
|
@ -145,7 +145,7 @@ void testCustomOperatorAliasing() {
|
||||
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x: Tensor, %y: Tensor):
|
||||
%ret : Tensor = foo::aliasing(%x, %y)
|
||||
@ -172,7 +172,7 @@ graph(%x: Tensor, %y: Tensor):
|
||||
%ret : Tensor = foo::aliasing(%x, %y)
|
||||
return (%x)
|
||||
)IR";
|
||||
script::parseIR(text, graph.get());
|
||||
parseIR(text, graph.get());
|
||||
EliminateDeadCode(graph);
|
||||
|
||||
testing::FileCheck().run(text, *graph);
|
||||
|
@ -43,7 +43,7 @@ graph():
|
||||
-> (%50, %tot.3)
|
||||
return (%tot)
|
||||
)IR";
|
||||
script::parseIR(input, graph.get());
|
||||
parseIR(input, graph.get());
|
||||
EliminateDeadCode(graph);
|
||||
// Check that dead code elimin
|
||||
testing::FileCheck().run(input, *graph);
|
||||
|
@ -67,7 +67,7 @@ void testFusion() {
|
||||
%2 : Tensor = aten::mul(%0, %1)
|
||||
return (%2))IR";
|
||||
Graph graph;
|
||||
torch::jit::script::parseIR(graph_string, &graph);
|
||||
torch::jit::parseIR(graph_string, &graph);
|
||||
|
||||
auto a = at::rand({3, 4}, at::kCUDA);
|
||||
auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1);
|
||||
@ -100,7 +100,7 @@ void testFusion() {
|
||||
%14 : Tensor = aten::mul(%8, %13)
|
||||
return (%14, %12))IR";
|
||||
Graph graph;
|
||||
torch::jit::script::parseIR(graph_string, &graph);
|
||||
torch::jit::parseIR(graph_string, &graph);
|
||||
|
||||
graph.lint();
|
||||
|
||||
@ -164,7 +164,7 @@ void testFusion() {
|
||||
graph_string2};
|
||||
for (auto i = decltype(graph_strings.size()){0}; i < graph_strings.size(); ++i) {
|
||||
Graph g;
|
||||
torch::jit::script::parseIR(graph_strings[i], &g);
|
||||
torch::jit::parseIR(graph_strings[i], &g);
|
||||
|
||||
auto outputs = debugLaunchGraph(g, {a, b});
|
||||
ASSERT_EQ(outputs.size(), 2);
|
||||
@ -187,7 +187,7 @@ void testRegisterFusionCachesKernel() {
|
||||
%d0 : Float(2, 3, 4) = aten::mul(%c0, %0)
|
||||
return (%d0))IR";
|
||||
auto g0 = std::make_shared<Graph>();
|
||||
torch::jit::script::parseIR(graph0_string, g0.get());
|
||||
torch::jit::parseIR(graph0_string, g0.get());
|
||||
|
||||
const auto graph1_string = R"IR(
|
||||
graph(%0 : Float(2, 3, 4),
|
||||
@ -196,7 +196,7 @@ void testRegisterFusionCachesKernel() {
|
||||
%d1 : Float(2, 3, 4) = aten::mul(%c1, %0)
|
||||
return (%d1))IR";
|
||||
auto g1 = std::make_shared<Graph>();
|
||||
torch::jit::script::parseIR(graph1_string, g1.get());
|
||||
torch::jit::parseIR(graph1_string, g1.get());
|
||||
|
||||
auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
|
||||
const auto& nodes = graph->nodes();
|
||||
|
@ -21,7 +21,6 @@ def foo3(x):
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using namespace script;
|
||||
using namespace testing;
|
||||
|
||||
struct InlinerGuard {
|
||||
|
@ -11,8 +11,6 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::script;
|
||||
|
||||
static const std::vector<std::string> subMethodSrcs = {R"JIT(
|
||||
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
return x + y + 1
|
||||
|
@ -55,7 +55,7 @@ void testBlocks() {
|
||||
%12 : int = prim::Constant[value=1]()
|
||||
%13 : Tensor = aten::add(%5, %3, %12)
|
||||
return (%13))IR";
|
||||
torch::jit::script::parseIR(graph_string, g.get());
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
testing::FileCheck()
|
||||
@ -122,7 +122,7 @@ graph(%x : Tensor,
|
||||
|
||||
torch::jit::Graph g;
|
||||
std::unordered_map<std::string, torch::jit::Value*> name_to_value;
|
||||
torch::jit::script::parseIR(input_str, &g, name_to_value);
|
||||
torch::jit::parseIR(input_str, &g, name_to_value);
|
||||
|
||||
std::vector<std::string> value_names{"6", "7", "9", "10"};
|
||||
std::unordered_set<std::string> value_names_set(
|
||||
|
@ -17,7 +17,7 @@ namespace jit {
|
||||
*/
|
||||
static void checkRoundtrip(const std::string& s) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(s, &*graph);
|
||||
parseIR(s, &*graph);
|
||||
std::ostringstream ss;
|
||||
ss << *graph;
|
||||
std::string parsed = ss.str();
|
||||
@ -42,7 +42,7 @@ void testIRParser() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0 : Tensor, %1 : Tensor):
|
||||
%2 : Tensor = foo::add(%0, %1)
|
||||
@ -117,7 +117,7 @@ graph(%0 : Tensor,
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a):
|
||||
return (%a))IR",
|
||||
@ -127,7 +127,7 @@ graph(%a):
|
||||
{
|
||||
// Check that parser correctly handles values reusing the same name.
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x):
|
||||
%x = a::a(%x)
|
||||
@ -239,7 +239,7 @@ graph(%0 : Tensor,
|
||||
# CHECK: return
|
||||
return (%a))IR";
|
||||
|
||||
script::parseIR(text, &*graph);
|
||||
parseIR(text, &*graph);
|
||||
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get()));
|
||||
torch::jit::testing::FileCheck().run(text, *graph);
|
||||
}
|
||||
|
@ -7,8 +7,6 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::script;
|
||||
|
||||
void testIValue() {
|
||||
c10::List<int64_t> foo({3, 4, 5});
|
||||
ASSERT_EQ(foo.use_count(), 1);
|
||||
|
@ -12,7 +12,7 @@ namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testLiteInterpreterUpsampleNearest2d() {
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, input: Tensor, scale:float):
|
||||
return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
|
||||
@ -35,7 +35,7 @@ void testLiteInterpreterUpsampleNearest2d() {
|
||||
}
|
||||
|
||||
void testLiteInterpreterAdd() {
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
// TODO: support default param val, which was pushed in
|
||||
// function schema's checkAndNormalizeInputs()
|
||||
@ -75,7 +75,7 @@ void testLiteInterpreterConv() {
|
||||
|
||||
std::vector<torch::jit::IValue> inputs;
|
||||
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
|
||||
m.register_parameter("bias", torch::ones({20}), false);
|
||||
m.define(R"(
|
||||
@ -100,7 +100,7 @@ void testLiteInterpreterConv() {
|
||||
}
|
||||
|
||||
void testLiteInterpreterInline() {
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
def foo1(self, x):
|
||||
return x + 1
|
||||
@ -120,7 +120,7 @@ void testLiteInterpreterInline() {
|
||||
}
|
||||
|
||||
void testLiteInterpreterTuple() {
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
def foo(self, x):
|
||||
return (1, 2, x + 3)
|
||||
@ -138,7 +138,7 @@ void testLiteInterpreterTuple() {
|
||||
}
|
||||
|
||||
void testLiteInterpreterPrimOverload() {
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
def forward(self, x):
|
||||
result = [1, 2]
|
||||
@ -154,7 +154,7 @@ void testLiteInterpreterPrimOverload() {
|
||||
}
|
||||
|
||||
void testLiteInterpreterPrim() {
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
def forward(self, x):
|
||||
return int(x)
|
||||
@ -180,7 +180,7 @@ void testLiteInterpreterPrim() {
|
||||
}
|
||||
|
||||
void testLiteInterpreterLoadOrigJit() {
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
def forward(self, x):
|
||||
@ -193,7 +193,7 @@ void testLiteInterpreterLoadOrigJit() {
|
||||
}
|
||||
|
||||
void testLiteInterpreterWrongMethodName() {
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
def add(self, x):
|
||||
@ -210,7 +210,7 @@ void testLiteInterpreterWrongMethodName() {
|
||||
}
|
||||
|
||||
void testLiteInterpreterParams() {
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
|
||||
m.define(R"(
|
||||
def forward(self, x):
|
||||
@ -269,7 +269,7 @@ void testLiteInterpreterParams() {
|
||||
}
|
||||
|
||||
void testLiteInterpreterSetState() {
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
def __getstate__(self):
|
||||
|
@ -368,7 +368,7 @@ void testCustomFusion() {
|
||||
%3 : Tensor = aten::mul(%2, %0)
|
||||
return (%3))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::script::parseIR(graph_string, g.get());
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
torch::jit::overrideCanFuseOnCPU(true);
|
||||
CustomFuseGraph(
|
||||
@ -412,7 +412,7 @@ void testCustomFusionNestedBlocks() {
|
||||
%9 : Tensor = aten::add(%4, %2, %3)
|
||||
return (%4))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::script::parseIR(graph_string, g.get());
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
CustomFuseGraph(
|
||||
g,
|
||||
@ -489,7 +489,7 @@ void testEvalModeForLoadedModule() {
|
||||
if (isSandcastle())
|
||||
return; // The module file to load is not generated in Sandcastle
|
||||
std::string module_path = "dropout_model.pt";
|
||||
torch::jit::script::Module module = torch::jit::load(module_path);
|
||||
torch::jit::Module module = torch::jit::load(module_path);
|
||||
AT_ASSERT(module.attr("dropout").toModule().is_training());
|
||||
module.eval();
|
||||
AT_ASSERT(!module.attr("dropout").toModule().is_training());
|
||||
@ -973,7 +973,7 @@ void testNoneSchemaMatch() {
|
||||
}
|
||||
|
||||
void testModuleDefine() {
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(R"(
|
||||
def add_it(self, x, b : int = 4):
|
||||
@ -984,7 +984,7 @@ void testModuleDefine() {
|
||||
}
|
||||
|
||||
void testModuleConversion() {
|
||||
script::Module m("test");
|
||||
Module m("test");
|
||||
{
|
||||
// test cuda to cpu for params and buffers
|
||||
m.register_parameter("foo", torch::ones({}, at::kCUDA), false);
|
||||
@ -1016,7 +1016,7 @@ RegisterPass p(fakePass);
|
||||
|
||||
void testPassManagement() {
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a):
|
||||
return (%a))IR",
|
||||
|
@ -5,8 +5,6 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::script;
|
||||
|
||||
void testModuleClone() {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto parent = ClassType::create("parent", cu, true);
|
||||
|
@ -8,8 +8,6 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace script;
|
||||
|
||||
|
||||
void testPeepholeOptimize() {
|
||||
// test is / is not none optimization
|
||||
|
@ -10,7 +10,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using namespace script;
|
||||
|
||||
void testSaveExtraFilesHook() {
|
||||
// no secrets
|
||||
|
@ -23,7 +23,7 @@ void testSchemaMatching() {
|
||||
return 0;
|
||||
}, c10::AliasAnalysisKind::FROM_SCHEMA),
|
||||
});
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def test(self):
|
||||
a = (1.0, 2.0)
|
||||
@ -59,7 +59,7 @@ void testSchemaMatching() {
|
||||
return 0;
|
||||
}, AliasAnalysisKind::FROM_SCHEMA),
|
||||
});
|
||||
script::Module m("m");
|
||||
Module m("m");
|
||||
m.define(R"JIT(
|
||||
def test(self):
|
||||
a = (1.0, 2.0)
|
||||
|
@ -7,13 +7,13 @@ namespace jit {
|
||||
|
||||
void testTrivial1() {
|
||||
Graph graph, pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%a = a::aaa(%0)
|
||||
return (%a))IR",
|
||||
&graph);
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%x = a::aaa(%0)
|
||||
@ -46,7 +46,7 @@ void testTrivial2() {
|
||||
|
||||
void testTrivial3() {
|
||||
Graph graph, pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%a = a::a(%0)
|
||||
@ -54,7 +54,7 @@ graph(%0):
|
||||
%c = a::c(%a, %b)
|
||||
return (%c))IR",
|
||||
&graph);
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a, %b):
|
||||
%c = a::c(%a, %b)
|
||||
@ -92,7 +92,7 @@ void testTrivial4() {
|
||||
|
||||
void testLinear1() {
|
||||
Graph graph, pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%a = a::aaa(%0)
|
||||
@ -102,7 +102,7 @@ graph(%0):
|
||||
%a = a::aaa(%0)
|
||||
return (%d))IR",
|
||||
&graph);
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%x = b::bbb(%0)
|
||||
@ -161,7 +161,7 @@ void testLinear2() {
|
||||
*/
|
||||
void testDiamond1() {
|
||||
Graph graph, pattern1, pattern2;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%o = o::ooo(%0)
|
||||
@ -173,7 +173,7 @@ graph(%0):
|
||||
return (%e))IR",
|
||||
&graph);
|
||||
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%a = a::aaa(%0)
|
||||
@ -185,7 +185,7 @@ graph(%0):
|
||||
AT_ASSERT(!findPatternMatches(pattern1, graph).empty());
|
||||
|
||||
// Check that order of nodes inside the diamond does not affect the result
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%a = a::aaa(%0)
|
||||
@ -247,7 +247,7 @@ void testDiamond2() {
|
||||
|
||||
void testXPattern() {
|
||||
Graph graph, pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0, %1):
|
||||
%b = b::bbb(%0)
|
||||
@ -258,7 +258,7 @@ graph(%0, %1):
|
||||
%g = g::ggg(%e, %f)
|
||||
return (%g))IR",
|
||||
&graph);
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0, %1):
|
||||
%b = b::bbb(%0)
|
||||
@ -274,7 +274,7 @@ graph(%0, %1):
|
||||
|
||||
void testMultipleMatches() {
|
||||
Graph graph, pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%t0):
|
||||
%t1 = a::aaa(%t0)
|
||||
@ -283,7 +283,7 @@ graph(%t0):
|
||||
%t4 = a::aaa(%t3)
|
||||
return (%t4))IR",
|
||||
&graph);
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%t0):
|
||||
%t1 = a::aaa(%t0)
|
||||
@ -295,7 +295,7 @@ graph(%t0):
|
||||
|
||||
void testOverlappingMatches() {
|
||||
Graph graph, pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%t0):
|
||||
%t1 = a::aaa(%t0)
|
||||
@ -304,7 +304,7 @@ graph(%t0):
|
||||
%t4 = a::aaa(%t3)
|
||||
return (%t4))IR",
|
||||
&graph);
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%t0):
|
||||
%t1 = a::aaa(%t0)
|
||||
@ -317,7 +317,7 @@ graph(%t0):
|
||||
|
||||
void testMatchInBasicBlocks1() {
|
||||
Graph graph;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a, %b, %c):
|
||||
%d = aten::mul(%a, %b)
|
||||
@ -333,7 +333,7 @@ graph(%a, %b, %c):
|
||||
|
||||
// Ensure the matches don't cross basic block boundaries
|
||||
Graph pattern0;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x, %y):
|
||||
%z = aten::mul(%x, %y)
|
||||
@ -342,7 +342,7 @@ graph(%x, %y):
|
||||
AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3);
|
||||
|
||||
Graph pattern1;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x, %y):
|
||||
%z1 = aten::mul(%x, %y)
|
||||
@ -354,7 +354,7 @@ graph(%x, %y):
|
||||
|
||||
void testMatchInBasicBlocks2() {
|
||||
Graph graph;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a, %b):
|
||||
%x = my::mul(%a, %b)
|
||||
@ -367,7 +367,7 @@ graph(%a, %b):
|
||||
|
||||
// Check that we can match both mul ops
|
||||
Graph pattern0;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x, %y):
|
||||
%z = my::mul(%x, %y)
|
||||
@ -377,7 +377,7 @@ graph(%x, %y):
|
||||
|
||||
// Ensure the matches don't cross basic block boundaries
|
||||
Graph pattern1;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x, %y):
|
||||
%u = my::mul(%x, %y)
|
||||
@ -389,7 +389,7 @@ graph(%x, %y):
|
||||
|
||||
void testMatchesAttributes() {
|
||||
Graph graph;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%a = a::a[isattr=[1,2]](%0)
|
||||
@ -400,7 +400,7 @@ graph(%0):
|
||||
|
||||
{
|
||||
Graph pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a, %b):
|
||||
%c = a::c[myattr="qqq"](%a, %b)
|
||||
@ -410,7 +410,7 @@ graph(%a, %b):
|
||||
}
|
||||
{
|
||||
Graph pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a, %b):
|
||||
%c = a::c[myattr="zzz"](%a, %b)
|
||||
@ -420,7 +420,7 @@ graph(%a, %b):
|
||||
}
|
||||
{
|
||||
Graph pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%b = a::b[extraattr=10](%0)
|
||||
@ -430,7 +430,7 @@ graph(%0):
|
||||
}
|
||||
{
|
||||
Graph pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%b = a::b[intattr=10, floatattr=3.14](%0)
|
||||
@ -440,7 +440,7 @@ graph(%0):
|
||||
}
|
||||
{
|
||||
Graph pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%b = a::b[intattr=10, floatattr=3.14, strattr="rrr"](%0)
|
||||
@ -450,7 +450,7 @@ graph(%0):
|
||||
}
|
||||
{
|
||||
Graph pattern;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%a = a::a[isattr=[1,2]](%0)
|
||||
@ -463,14 +463,14 @@ graph(%0):
|
||||
|
||||
void testBadPattern() {
|
||||
Graph graph, pattern1, pattern2;
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%a = a::aaa(%0)
|
||||
return (%a))IR",
|
||||
&graph);
|
||||
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x):
|
||||
%y = my::node_with_subblock()
|
||||
@ -481,7 +481,7 @@ graph(%x):
|
||||
&pattern1);
|
||||
ASSERT_ANY_THROW(findPatternMatches(pattern1, graph));
|
||||
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%x):
|
||||
%y = my::op1(%x)
|
||||
|
@ -11,7 +11,7 @@ using namespace testing;
|
||||
void testFilterMatch() {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%a = a::aaa(%0)
|
||||
@ -27,7 +27,7 @@ graph(%a, %b):
|
||||
Graph pattern_graph;
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
pattern,
|
||||
&pattern_graph,
|
||||
vmap);
|
||||
@ -55,7 +55,7 @@ graph(%a, %b):
|
||||
|
||||
void testFilterNoMatch() {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%0):
|
||||
%a = a::aaa(%0)
|
||||
@ -71,7 +71,7 @@ graph(%a, %b):
|
||||
Graph pattern_graph;
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
|
||||
script::parseIR(
|
||||
parseIR(
|
||||
pattern,
|
||||
&pattern_graph,
|
||||
vmap);
|
||||
|
@ -84,7 +84,7 @@ std::shared_ptr<Graph> build_lstm() {
|
||||
%22 : Tensor = aten::mul(%14, %21)
|
||||
return (%22, %20))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::script::parseIR(graph_string, g.get());
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
|
||||
return g;
|
||||
|
@ -12,7 +12,7 @@
|
||||
namespace helpers {
|
||||
template <typename Predicate>
|
||||
void check_all_parameters(
|
||||
const torch::jit::script::Module& module,
|
||||
const torch::jit::Module& module,
|
||||
Predicate predicate) {
|
||||
for (at::Tensor parameter : module.parameters()) {
|
||||
AT_ASSERT(predicate(parameter));
|
||||
@ -79,7 +79,7 @@ void get_autograd_operator_from_registry_and_execute_in_nograd_mode() {
|
||||
|
||||
void load_serialized_module_with_custom_op_and_execute(
|
||||
const std::string& path_to_exported_script_module) {
|
||||
torch::jit::script::Module module =
|
||||
torch::jit::Module module =
|
||||
torch::jit::load(path_to_exported_script_module);
|
||||
std::vector<torch::jit::IValue> inputs;
|
||||
inputs.push_back(torch::ones(5));
|
||||
@ -90,7 +90,7 @@ void load_serialized_module_with_custom_op_and_execute(
|
||||
|
||||
void test_argument_checking_for_serialized_modules(
|
||||
const std::string& path_to_exported_script_module) {
|
||||
torch::jit::script::Module module =
|
||||
torch::jit::Module module =
|
||||
torch::jit::load(path_to_exported_script_module);
|
||||
|
||||
try {
|
||||
@ -124,7 +124,7 @@ void test_argument_checking_for_serialized_modules(
|
||||
}
|
||||
|
||||
void test_move_to_device(const std::string& path_to_exported_script_module) {
|
||||
torch::jit::script::Module module =
|
||||
torch::jit::Module module =
|
||||
torch::jit::load(path_to_exported_script_module);
|
||||
|
||||
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
||||
@ -145,7 +145,7 @@ void test_move_to_device(const std::string& path_to_exported_script_module) {
|
||||
}
|
||||
|
||||
void test_move_to_dtype(const std::string& path_to_exported_script_module) {
|
||||
torch::jit::script::Module module =
|
||||
torch::jit::Module module =
|
||||
torch::jit::load(path_to_exported_script_module);
|
||||
|
||||
module.to(torch::kInt);
|
||||
|
@ -24,7 +24,7 @@ struct MobileCallGuard {
|
||||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
|
||||
};
|
||||
|
||||
torch::jit::script::Module loadModel(const std::string& path) {
|
||||
torch::jit::Module loadModel(const std::string& path) {
|
||||
MobileCallGuard guard;
|
||||
auto module = torch::jit::load(path);
|
||||
module.eval();
|
||||
|
@ -138,7 +138,7 @@ TORCH_NN_MODULE_TEST_INIT = Template("""\n
|
||||
void ${module_variant_name}_test_init(
|
||||
const std::string& saved_module_path,
|
||||
const std::string& device) {
|
||||
torch::jit::script::Module m_init_by_python = torch::jit::load(saved_module_path);
|
||||
torch::jit::Module m_init_by_python = torch::jit::load(saved_module_path);
|
||||
|
||||
torch::manual_seed(2);
|
||||
${module_qualified_name} m_init_by_cpp${cpp_constructor_args};
|
||||
|
@ -30,7 +30,7 @@ namespace jit {
|
||||
/// )JIT");
|
||||
/// IValue output = module->run_method("relu_script", a, b);
|
||||
/// \endrst
|
||||
TORCH_API std::shared_ptr<script::CompilationUnit> compile(const std::string& source);
|
||||
TORCH_API std::shared_ptr<CompilationUnit> compile(const std::string& source);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -39,7 +39,7 @@ namespace torch {
|
||||
template <typename Value, typename... SaveToArgs>
|
||||
void save(const Value& value, SaveToArgs&&... args) {
|
||||
serialize::OutputArchive archive(
|
||||
std::make_shared<jit::script::CompilationUnit>());
|
||||
std::make_shared<jit::CompilationUnit>());
|
||||
archive << value;
|
||||
archive.save_to(std::forward<SaveToArgs>(args)...);
|
||||
}
|
||||
@ -65,7 +65,7 @@ void save(const Value& value, SaveToArgs&&... args) {
|
||||
template <typename... SaveToArgs>
|
||||
void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) {
|
||||
serialize::OutputArchive archive(
|
||||
std::make_shared<jit::script::CompilationUnit>());
|
||||
std::make_shared<jit::CompilationUnit>());
|
||||
for (size_t i = 0; i < tensor_vec.size(); i++) {
|
||||
auto& value = tensor_vec[i];
|
||||
archive.write(c10::to_string(i), value);
|
||||
|
@ -18,9 +18,7 @@ class Tensor;
|
||||
namespace torch {
|
||||
using at::Tensor;
|
||||
namespace jit {
|
||||
namespace script {
|
||||
struct Module;
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -108,7 +106,7 @@ class TORCH_API InputArchive final {
|
||||
}
|
||||
|
||||
private:
|
||||
jit::script::Module module_;
|
||||
jit::Module module_;
|
||||
std::string hierarchy_prefix_;
|
||||
};
|
||||
} // namespace serialize
|
||||
|
@ -15,9 +15,7 @@ class Tensor;
|
||||
namespace torch {
|
||||
using at::Tensor;
|
||||
namespace jit {
|
||||
namespace script {
|
||||
struct Module;
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -25,8 +23,8 @@ namespace torch {
|
||||
namespace serialize {
|
||||
class TORCH_API OutputArchive final {
|
||||
public:
|
||||
explicit OutputArchive(std::shared_ptr<jit::script::CompilationUnit> cu);
|
||||
explicit OutputArchive() : cu_(std::make_shared<jit::script::CompilationUnit>()), module_("__torch__.Module", cu_) {}
|
||||
explicit OutputArchive(std::shared_ptr<jit::CompilationUnit> cu);
|
||||
explicit OutputArchive() : cu_(std::make_shared<jit::CompilationUnit>()), module_("__torch__.Module", cu_) {}
|
||||
|
||||
// Move is allowed.
|
||||
OutputArchive(OutputArchive&&) = default;
|
||||
@ -36,7 +34,7 @@ class TORCH_API OutputArchive final {
|
||||
OutputArchive(OutputArchive&) = delete;
|
||||
OutputArchive& operator=(OutputArchive&) = delete;
|
||||
|
||||
std::shared_ptr<jit::script::CompilationUnit> compilation_unit() const {
|
||||
std::shared_ptr<jit::CompilationUnit> compilation_unit() const {
|
||||
return cu_;
|
||||
}
|
||||
|
||||
@ -75,8 +73,8 @@ class TORCH_API OutputArchive final {
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<jit::script::CompilationUnit> cu_;
|
||||
jit::script::Module module_;
|
||||
std::shared_ptr<jit::CompilationUnit> cu_;
|
||||
jit::Module module_;
|
||||
};
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
|
@ -9,12 +9,12 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
std::shared_ptr<script::CompilationUnit> compile(const std::string& source) {
|
||||
auto module = std::make_shared<script::CompilationUnit>();
|
||||
std::shared_ptr<CompilationUnit> compile(const std::string& source) {
|
||||
auto module = std::make_shared<CompilationUnit>();
|
||||
module->define(
|
||||
c10::nullopt,
|
||||
source,
|
||||
script::nativeResolver(),
|
||||
nativeResolver(),
|
||||
nullptr);
|
||||
return module;
|
||||
}
|
||||
|
@ -16,7 +16,7 @@
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
|
||||
InputArchive::InputArchive() : module_("Module", std::make_shared<jit::script::CompilationUnit>()) {}
|
||||
InputArchive::InputArchive() : module_("Module", std::make_shared<jit::CompilationUnit>()) {}
|
||||
|
||||
void InputArchive::read(const std::string& key, c10::IValue& ivalue) {
|
||||
ivalue = module_.attr(key);
|
||||
@ -161,7 +161,7 @@ std::vector<std::string> InputArchive::keys() {
|
||||
std::vector<std::string> all_keys;
|
||||
all_keys.reserve(module_.named_attributes(/*recurse=*/false).size());
|
||||
|
||||
for (const torch::jit::script::NameValue& s : module_.named_attributes(/*recurse=*/false)) {
|
||||
for (const torch::jit::NameValue& s : module_.named_attributes(/*recurse=*/false)) {
|
||||
all_keys.push_back(s.name);
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
OutputArchive::OutputArchive(std::shared_ptr<jit::script::CompilationUnit> cu)
|
||||
OutputArchive::OutputArchive(std::shared_ptr<jit::CompilationUnit> cu)
|
||||
: cu_(std::move(cu)),
|
||||
module_("__torch__.Module", cu_, /*shouldMangle=*/true) {}
|
||||
|
||||
|
@ -53,7 +53,7 @@ TypePtr tryInferTypeWithTypeHint(
|
||||
const py::object& type_hint) {
|
||||
// If the py::object to be contained by the RRef is a ScripModule, we enforce
|
||||
// users to specify its ModuleInterface type.
|
||||
if (auto module = jit::script::as_module(value)) {
|
||||
if (auto module = jit::as_module(value)) {
|
||||
TORCH_CHECK(
|
||||
!type_hint.is_none(),
|
||||
"The RRef being created contains a ScriptModule, "
|
||||
|
@ -27,8 +27,8 @@ namespace {
|
||||
|
||||
// PythonTypeResolver that inherits from Script::Resolver to
|
||||
// support resolving types together with ScriptTypeParser.
|
||||
struct PythonTypeResolver : public jit::script::Resolver {
|
||||
std::shared_ptr<jit::script::SugaredValue> resolveValue(
|
||||
struct PythonTypeResolver : public jit::Resolver {
|
||||
std::shared_ptr<jit::SugaredValue> resolveValue(
|
||||
const std::string& /* unused */,
|
||||
torch::jit::Function& /* unused */,
|
||||
const jit::SourceRange& /* unused */) override {
|
||||
@ -67,7 +67,7 @@ PythonRpcHandler::PythonRpcHandler() {
|
||||
pyHandleException_ = getFunction(module, "_handle_exception");
|
||||
pyGetQualifiedName_ = py::module::import("torch.jit").attr("_qualified_name");
|
||||
jitCompilationUnit_ = torch::jit::get_python_cu();
|
||||
typeParser_ = std::make_shared<jit::script::ScriptTypeParser>(
|
||||
typeParser_ = std::make_shared<jit::ScriptTypeParser>(
|
||||
std::make_shared<PythonTypeResolver>());
|
||||
}
|
||||
|
||||
@ -97,7 +97,7 @@ PythonRpcHandler& PythonRpcHandler::getInstance() {
|
||||
return *handler;
|
||||
}
|
||||
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> PythonRpcHandler::
|
||||
std::shared_ptr<torch::jit::CompilationUnit> PythonRpcHandler::
|
||||
jitCompilationUnit() {
|
||||
return jitCompilationUnit_;
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ class PYBIND11_EXPORT PythonRpcHandler {
|
||||
// PythonRpcHandler.
|
||||
void cleanup();
|
||||
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit();
|
||||
std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit();
|
||||
|
||||
// Parse the string to recover the jit_type, this is used for RRef python
|
||||
// pickling/unpickling type recovery. The type string inference rule is as
|
||||
@ -97,11 +97,11 @@ class PYBIND11_EXPORT PythonRpcHandler {
|
||||
// and imported in C++ (see get_python_cu() in
|
||||
// csrc/jit/python/pybind_utils.h). We import the compilation unit here only
|
||||
// once for less cost and thread safety.
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit_;
|
||||
std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit_;
|
||||
|
||||
// jit type parser to parse type_str back to TypePtr for RRef type
|
||||
// recovery when pickling and unpickling RRef
|
||||
std::shared_ptr<jit::script::ScriptTypeParser> typeParser_;
|
||||
std::shared_ptr<jit::ScriptTypeParser> typeParser_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
|
@ -24,7 +24,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
struct Def;
|
||||
struct ClassDef;
|
||||
@ -281,22 +280,25 @@ struct TORCH_API CompilationUnit {
|
||||
mutable size_t mangleIndex_ = 0;
|
||||
};
|
||||
|
||||
} // namespace script
|
||||
|
||||
// An owning pointer to a Function. Just a pair of a raw Function ptr and it's
|
||||
// owning CU. We need this because pybind requires a ref-counted way to refer to
|
||||
// Functions.
|
||||
struct StrongFunctionPtr {
|
||||
StrongFunctionPtr(
|
||||
std::shared_ptr<script::CompilationUnit> cu,
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
Function* function)
|
||||
: cu_(std::move(cu)), function_(function) {
|
||||
TORCH_INTERNAL_ASSERT(cu_);
|
||||
TORCH_INTERNAL_ASSERT(function_);
|
||||
}
|
||||
std::shared_ptr<script::CompilationUnit> cu_;
|
||||
std::shared_ptr<CompilationUnit> cu_;
|
||||
Function* function_;
|
||||
};
|
||||
|
||||
namespace script {
|
||||
// We once had a `script::` namespace that was deleted. This is for backcompat
|
||||
// of the public API; new code should not use this type alias.
|
||||
using CompilationUnit = ::torch::jit::CompilationUnit;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
|
||||
|
||||
@ -61,6 +60,11 @@ struct TORCH_API Method {
|
||||
Function* function_;
|
||||
};
|
||||
|
||||
} // namespace script
|
||||
namespace script {
|
||||
// We once had a `script::` namespace that was deleted. This is for backcompat
|
||||
// of the public API; new code should not use this type alias.
|
||||
using Method = ::torch::jit::Method;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -11,7 +11,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
static ObjectPtr create_module_object(
|
||||
c10::QualifiedName class_name,
|
||||
@ -389,14 +388,13 @@ void Module::dump(
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
namespace c10 {
|
||||
|
||||
torch::jit::script::Module IValue::toModule() const {
|
||||
return torch::jit::script::Module(toObject());
|
||||
torch::jit::Module IValue::toModule() const {
|
||||
return torch::jit::Module(toObject());
|
||||
}
|
||||
bool IValue::isModule() const {
|
||||
return isObject() && toObjectRef().type()->is_module();
|
||||
|
@ -33,7 +33,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
using ::c10::Argument;
|
||||
using ::c10::FunctionSchema;
|
||||
@ -554,6 +553,11 @@ struct NamedPolicy {
|
||||
|
||||
TORCH_API bool& getInlineEverythingMode();
|
||||
|
||||
} // namespace script
|
||||
namespace script {
|
||||
// We once had a `script::` namespace that was deleted. This is for backcompat
|
||||
// of the public API; new code should not use this type alias.
|
||||
using Module = ::torch::jit::Module;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
void Module::save(std::ostream& out, const ExtraFilesMap& extra_files) const {
|
||||
ExportModule(*this, out, extra_files, false);
|
||||
@ -26,6 +25,5 @@ void Module::_save_for_mobile(
|
||||
ExportModule(*this, filename, extra_files, true);
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -7,7 +7,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
Object::Object(
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
@ -35,10 +34,9 @@ void Object::define(const std::string& src, const ResolverPtr& resolver) {
|
||||
_ivalue()->compilation_unit()->define(
|
||||
*type()->name(),
|
||||
src,
|
||||
resolver ? resolver : script::nativeResolver(),
|
||||
resolver ? resolver : nativeResolver(),
|
||||
&self);
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
struct Resolver;
|
||||
using ResolverPtr = std::shared_ptr<Resolver>;
|
||||
@ -132,6 +131,10 @@ struct TORCH_API Object {
|
||||
mutable ObjectPtr _ivalue_;
|
||||
};
|
||||
|
||||
} // namespace script
|
||||
namespace script {
|
||||
// We once had a `script::` namespace that was deleted. This is for backcompat
|
||||
// of the public API; new code should not use this type alias.
|
||||
using Object = ::torch::jit::Object;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -285,7 +285,7 @@ The load process has the following steps:
|
||||
|
||||
1. Unpickle `constants.pkl`, which produces a tuple of all tensor constants
|
||||
referenced in code.
|
||||
2. Unpickle `data.pkl` into the top-level `script::Module` and return it.
|
||||
2. Unpickle `data.pkl` into the top-level `Module` and return it.
|
||||
|
||||
The unpickling process consists of a single call to unpickle the module
|
||||
object contained in `data.pkl`. The `Unpickler` is given a callback that lets it
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
auto scalar_operators_source = CodeTemplate(
|
||||
R"SCRIPT(
|
||||
@ -81,7 +80,7 @@ struct BuiltinFunctionRegistry {
|
||||
std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
|
||||
modules.emplace_back(cu);
|
||||
cu->define(
|
||||
c10::nullopt, source, script::nativeResolver(), /*self=*/nullptr);
|
||||
c10::nullopt, source, nativeResolver(), /*self=*/nullptr);
|
||||
for (auto& method : cu->get_functions()) {
|
||||
builtins_by_name_[Symbol::fromQualString(
|
||||
the_namespace + "::" + method->name())]
|
||||
@ -134,6 +133,5 @@ const std::vector<Function*>& getAllBuiltinFunctionsFor(
|
||||
return registry.getAllBuiltinFunctionsFor(name);
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -5,9 +5,7 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name);
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
|
||||
auto cu = get_python_cu();
|
||||
@ -303,6 +302,5 @@ ConcreteModuleType::getModulesPy() const {
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -9,7 +9,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
enum class IterableModuleKind { NONE, LIST, DICT };
|
||||
class ConcreteModuleType;
|
||||
@ -228,6 +227,5 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
|
||||
TypePtr jitType_;
|
||||
};
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -9,7 +9,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
// At the beginning of the pass the Graph has already undergone type checking,
|
||||
// and writes or reads to a variable are emitted as Loads and Stores in the
|
||||
@ -328,6 +327,5 @@ void ConvertToSSA(std::shared_ptr<Graph>& graph) {
|
||||
TransformExits(graph);
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -8,11 +8,9 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
// Convert a graph with Loads & Stores into SSA form
|
||||
TORCH_API void ConvertToSSA(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
// computes levenshtein edit distance between two words
|
||||
// returns maxEditDistance + 1 if the edit distance exceeds MaxEditDistance
|
||||
@ -52,6 +51,5 @@ size_t ComputeEditDistance(
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -5,13 +5,11 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
TORCH_API size_t ComputeEditDistance(
|
||||
const char* word1,
|
||||
const char* word2,
|
||||
size_t maxEditDistance);
|
||||
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
// Avoid storing objects with destructor in thread_local for mobile build.
|
||||
#ifndef C10_MOBILE
|
||||
@ -72,6 +71,5 @@ const char* ErrorReport::what() const noexcept {
|
||||
return the_message.c_str();
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
struct Call {
|
||||
std::string fn_name;
|
||||
@ -49,6 +48,5 @@ const ErrorReport& operator<<(const ErrorReport& e, const T& t) {
|
||||
return e;
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -22,7 +22,6 @@ using at::TypeKind;
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace script {
|
||||
namespace {
|
||||
struct SchemaParser {
|
||||
SchemaParser(const std::string& str)
|
||||
@ -290,10 +289,9 @@ struct SchemaParser {
|
||||
SchemaTypeParser type_parser;
|
||||
};
|
||||
} // namespace
|
||||
} // namespace script
|
||||
|
||||
C10_EXPORT either<OperatorName, FunctionSchema> parseSchemaOrName(const std::string& schemaOrName) {
|
||||
return script::SchemaParser(schemaOrName).parseDeclarations().at(0);
|
||||
return SchemaParser(schemaOrName).parseDeclarations().at(0);
|
||||
}
|
||||
|
||||
C10_EXPORT FunctionSchema parseSchema(const std::string& schema) {
|
||||
|
@ -30,7 +30,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
using FunctionTable = std::unordered_map<std::string, Function&>;
|
||||
using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
|
||||
@ -491,7 +490,7 @@ struct Environment {
|
||||
if (!retval) {
|
||||
if (auto type = resolver->resolveType(ident, range)) {
|
||||
if (auto tuple_type = type->cast<TupleType>()) {
|
||||
retval = std::make_shared<script::NamedTupleConstructor>(tuple_type);
|
||||
retval = std::make_shared<NamedTupleConstructor>(tuple_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -503,7 +502,7 @@ struct Environment {
|
||||
if (!retval) {
|
||||
if (auto type = resolver->resolveType(ident, range)) {
|
||||
if (auto class_type = type->cast<ClassType>()) {
|
||||
retval = std::make_shared<script::ClassValue>(class_type);
|
||||
retval = std::make_shared<ClassValue>(class_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3598,7 +3597,7 @@ std::vector<Function*> CompilationUnit::define(
|
||||
void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
|
||||
liftClosures(to_clean);
|
||||
inlineForkedClosures(to_clean);
|
||||
if (script::getInlineEverythingMode()) {
|
||||
if (getInlineEverythingMode()) {
|
||||
Inline(*to_clean);
|
||||
}
|
||||
// remove any uses of tuples that we inserted that are not needed
|
||||
@ -3669,6 +3668,5 @@ void CompilationUnit::define_interface(
|
||||
this->register_type(iface);
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -12,12 +12,10 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
TORCH_API void runCleanupPasses(std::shared_ptr<Graph>& to_clean);
|
||||
|
||||
TORCH_API bool meaningfulName(const std::string& name);
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -8,7 +8,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
static const std::unordered_map<int, int> binary_prec = {
|
||||
{TK_IF, 1},
|
||||
@ -101,6 +100,5 @@ C10_EXPORT SharedParserData& sharedParserData() {
|
||||
return data;
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
// single character tokens are just the character itself '+'
|
||||
// multi-character tokens need an entry here
|
||||
@ -180,7 +179,7 @@ struct CAFFE2_API SharedParserData {
|
||||
return false;
|
||||
const char* startptr = str.c_str() + start;
|
||||
char* endptr;
|
||||
torch::jit::script::strtod_c(startptr, &endptr);
|
||||
torch::jit::strtod_c(startptr, &endptr);
|
||||
*len = endptr - startptr;
|
||||
return *len > 0;
|
||||
}
|
||||
@ -515,6 +514,5 @@ struct Lexer {
|
||||
std::vector<Token> next_tokens;
|
||||
SharedParserData& shared;
|
||||
};
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
// Simple data structure for containing a type T in nested control blocks
|
||||
// Should only be used after initial compilation where type checking and
|
||||
@ -52,6 +51,5 @@ struct MiniEnvironment {
|
||||
std::unordered_map<std::string, T> table;
|
||||
};
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
inline bool isCharCount(char c, const std::string& str, size_t start, int len) {
|
||||
// count checks from [start, start + len)
|
||||
@ -87,6 +86,5 @@ inline std::string parseStringLiteral(
|
||||
return ret_str;
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -7,7 +7,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
Decl mergeTypesFromTypeComment(
|
||||
const Decl& decl,
|
||||
@ -735,6 +734,5 @@ Expr Parser::parseExp() {
|
||||
return pImpl->parseExp();
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
struct Decl;
|
||||
struct ParserImpl;
|
||||
@ -30,6 +29,5 @@ struct TORCH_API Parser {
|
||||
std::unique_ptr<ParserImpl> pImpl;
|
||||
};
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -2,8 +2,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|~";
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
struct Resolver;
|
||||
using ResolverPtr = std::shared_ptr<Resolver>;
|
||||
@ -65,6 +64,5 @@ struct NativeResolver : public Resolver {
|
||||
inline std::shared_ptr<NativeResolver> nativeResolver() {
|
||||
return std::make_shared<NativeResolver>();
|
||||
}
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
static inline TypePtr unwrapOptional(TypePtr opt_type) {
|
||||
if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
|
||||
@ -612,6 +611,5 @@ Value* emitBuiltinCall(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -7,7 +7,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
// try to match a list if inputs and keyword 'attributes' to this schema,
|
||||
// if it works return the flat list of positional inputs to the call
|
||||
@ -60,6 +59,5 @@ TORCH_API Value* tryConvertToType(
|
||||
const TypePtr& concrete_type,
|
||||
Value* value,
|
||||
bool allow_conversions);
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -31,7 +31,6 @@ using c10::VarType;
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
TypePtr SchemaTypeParser::parseBaseType() {
|
||||
static std::unordered_map<std::string, TypePtr> type_map = {
|
||||
@ -292,6 +291,5 @@ void SchemaTypeParser::parseList(
|
||||
L.expect(end);
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -7,7 +7,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
using TypePtr = c10::TypePtr;
|
||||
|
||||
@ -33,6 +32,5 @@ struct CAFFE2_API SchemaTypeParser {
|
||||
Lexer& L;
|
||||
size_t next_id = 0;
|
||||
};
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -4,7 +4,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
const std::unordered_map<std::string, TypePtr>& string_to_type_lut();
|
||||
namespace {
|
||||
|
||||
@ -355,6 +354,5 @@ c10::IValue ScriptTypeParser::parseClassConstant(const Assign& assign) {
|
||||
return *default_val.begin();
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
/**
|
||||
* class ScriptTypeParser
|
||||
@ -45,6 +44,5 @@ class TORCH_API ScriptTypeParser {
|
||||
|
||||
ResolverPtr resolver_ = nullptr;
|
||||
};
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
using namespace c10;
|
||||
const std::unordered_map<std::string, TypePtr>& string_to_type_lut() {
|
||||
static std::unordered_map<std::string, TypePtr> map = {
|
||||
@ -22,6 +21,5 @@ const std::unordered_map<std::string, TypePtr>& string_to_type_lut() {
|
||||
{"tuple", AnyTupleType::get()}};
|
||||
return map;
|
||||
}
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -76,7 +76,6 @@ double parse_inf_or_nan(const char *p, char **endptr)
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
#ifdef _MSC_VER
|
||||
C10_EXPORT double strtod_c(const char *nptr, char **endptr)
|
||||
@ -265,4 +264,3 @@ C10_EXPORT float strtof_c(const char *nptr, char **endptr)
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,11 +4,9 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
CAFFE2_API double strtod_c(const char *nptr, char **endptr);
|
||||
CAFFE2_API float strtof_c(const char *nptr, char **endptr);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
struct NoneValue : SugaredValue {
|
||||
NoneValue() = default;
|
||||
@ -623,6 +622,5 @@ std::shared_ptr<BuiltinFunction> BuiltinFunction::tryCreate(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -10,7 +10,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
using SugaredValuePtr = std::shared_ptr<SugaredValue>;
|
||||
|
||||
@ -365,7 +364,7 @@ struct FunctionValue : public SugaredValue {
|
||||
try {
|
||||
callee->ensure_defined();
|
||||
} catch (const RecursiveMethodCallError&) {
|
||||
throw script::ErrorReport(loc)
|
||||
throw ErrorReport(loc)
|
||||
<< " function '" << callee->name() << "' is called recursively. "
|
||||
<< "Recursive calls are not supported";
|
||||
}
|
||||
@ -428,7 +427,7 @@ struct MethodValue : public SugaredValue {
|
||||
try {
|
||||
method->ensure_defined();
|
||||
} catch (const RecursiveMethodCallError&) {
|
||||
throw script::ErrorReport(loc)
|
||||
throw ErrorReport(loc)
|
||||
<< " method '" << method->name() << "' is called recursively. "
|
||||
<< "Recursive calls are not supported";
|
||||
}
|
||||
@ -652,6 +651,5 @@ struct SimpleSelf : public Self {
|
||||
private:
|
||||
ClassTypePtr classType_;
|
||||
};
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -152,7 +152,7 @@ Value* TracingState::getValue(const IValue& var) {
|
||||
|
||||
// Find torchbind classes
|
||||
if (isCustomClass(var)) {
|
||||
auto obj = script::Object(var.toObject());
|
||||
auto obj = Object(var.toObject());
|
||||
auto qualname = obj.type()->name();
|
||||
auto custom_class_type = getCustomClass(qualname->qualifiedName());
|
||||
if (custom_class_type) {
|
||||
@ -313,14 +313,14 @@ static IValue addInput(const std::shared_ptr<TracingState> & state, const IValue
|
||||
static void gatherParametersAndBuffers(
|
||||
const std::shared_ptr<TracingState>& state,
|
||||
Value* self_value,
|
||||
const script::Module& self,
|
||||
const Module& self,
|
||||
const std::string& prefix) {
|
||||
Graph& g = *self_value->owningGraph();
|
||||
|
||||
state->setValue(self._ivalue(), self_value);
|
||||
|
||||
auto self_ty = self.type();
|
||||
for (const script::NameValue& s : self.named_attributes(/*recurse=*/false)) {
|
||||
for (const NameValue& s : self.named_attributes(/*recurse=*/false)) {
|
||||
auto qualname = prefix + "." + s.name;
|
||||
Value* trace_get_attr = g.insertNode(g.create(prim::TracedAttr))
|
||||
->s_(attr::scope, qualname)
|
||||
@ -335,7 +335,7 @@ static void gatherParametersAndBuffers(
|
||||
}
|
||||
if (self_ty->getAttribute(s.name)->is_module()) {
|
||||
gatherParametersAndBuffers(
|
||||
state, trace_get_attr, script::Module(s.value.toObject()), qualname);
|
||||
state, trace_get_attr, Module(s.value.toObject()), qualname);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -345,7 +345,7 @@ std::pair<std::shared_ptr<TracingState>, Stack> trace(
|
||||
const std::function<Stack(Stack)>& traced_fn,
|
||||
std::function<std::string(const Variable&)> var_name_lookup_fn,
|
||||
bool force_outplace,
|
||||
script::Module* self) {
|
||||
Module* self) {
|
||||
try {
|
||||
// Start tracing, treating 'inputs' as inputs to the trace, which can be
|
||||
// varied on subsequent invocations of the trace. Any other variables
|
||||
@ -387,7 +387,7 @@ std::pair<std::shared_ptr<TracingState>, Stack> trace(
|
||||
}
|
||||
setTracingState(nullptr);
|
||||
|
||||
if (script::getInlineEverythingMode()) {
|
||||
if (getInlineEverythingMode()) {
|
||||
Inline(*graph);
|
||||
}
|
||||
FixupTraceScopeBlocks(graph, self);
|
||||
|
@ -21,10 +21,7 @@ namespace jit {
|
||||
struct Node;
|
||||
struct Value;
|
||||
struct Graph;
|
||||
|
||||
namespace script {
|
||||
struct Module;
|
||||
}
|
||||
struct Module;
|
||||
|
||||
namespace tracer {
|
||||
|
||||
@ -210,7 +207,7 @@ TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> trace(
|
||||
const std::function<Stack(Stack)>& traced_fn,
|
||||
std::function<std::string(const Variable&)> var_name_lookup_fn,
|
||||
bool force_outplace = false,
|
||||
script::Module* self = nullptr);
|
||||
Module* self = nullptr);
|
||||
|
||||
TORCH_API void abandon();
|
||||
|
||||
|
@ -11,7 +11,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
// Tree's are used to represent all forms of TC IR, pre- and post- typechecking.
|
||||
// Rather than have a full class hierarchy for all TC statements,
|
||||
@ -221,6 +220,5 @@ static inline std::ostream& operator<<(std::ostream& out, const TreeRef& t) {
|
||||
return out << pretty_tree(t);
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -10,7 +10,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
// clang-format off
|
||||
// TreeView provides a statically-typed way to traverse the tree, which should
|
||||
@ -813,7 +812,7 @@ struct Const : public Expr {
|
||||
// We can't pass in nullptr as the dummy pointer gets dereferenced for
|
||||
// Android version of strtod_c().
|
||||
char* dummy;
|
||||
return torch::jit::script::strtod_c(
|
||||
return torch::jit::strtod_c(
|
||||
subtree(0)->stringValue().c_str(), &dummy);
|
||||
}
|
||||
const std::string& text() const {
|
||||
@ -1045,14 +1044,13 @@ struct Delete : public Stmt {
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
namespace std {
|
||||
|
||||
template <typename T>
|
||||
struct iterator_traits<torch::jit::script::ListIterator<T>>
|
||||
: std::iterator_traits<torch::jit::script::TreeList::const_iterator> {};
|
||||
struct iterator_traits<torch::jit::ListIterator<T>>
|
||||
: std::iterator_traits<torch::jit::TreeList::const_iterator> {};
|
||||
|
||||
} // namespace std
|
||||
|
@ -963,7 +963,7 @@ const Operator& Node::getOperator() const {
|
||||
if (maybe)
|
||||
return *maybe;
|
||||
|
||||
auto er = script::ErrorReport(sourceRange());
|
||||
auto er = ErrorReport(sourceRange());
|
||||
er << "Schema not found for node. File a bug report.\n";
|
||||
er << "Node: " << *this << "\n";
|
||||
er << "Input types:";
|
||||
@ -1491,7 +1491,7 @@ Value* Graph::insert(
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
const c10::optional<SourceRange>& range) {
|
||||
return script::emitBuiltinCall(
|
||||
return emitBuiltinCall(
|
||||
range.value_or(fakeRange()), *this, opname, args, kwargs);
|
||||
}
|
||||
|
||||
@ -1721,7 +1721,7 @@ Value* Graph::insertToList(Value* v, TypePtr type) {
|
||||
|
||||
Value* Graph::insertFunctionCall(
|
||||
Function* callee,
|
||||
const script::MatchedSchema& matched) {
|
||||
const MatchedSchema& matched) {
|
||||
std::string func_name = callee->name();
|
||||
Value* fn_constant = insertNode(create(prim::Constant))
|
||||
->s_(attr::name, func_name)
|
||||
@ -1737,7 +1737,7 @@ Value* Graph::insertFunctionCall(
|
||||
|
||||
Value* Graph::insertMethodCall(
|
||||
std::string method_name,
|
||||
const script::MatchedSchema& matched) {
|
||||
const MatchedSchema& matched) {
|
||||
Value* result = insertNode(create(prim::CallMethod, matched.inputs))
|
||||
->s_(attr::name, std::move(method_name))
|
||||
->output()
|
||||
|
@ -73,9 +73,7 @@ using namespace ::c10::aten;
|
||||
}
|
||||
|
||||
struct Function;
|
||||
namespace script {
|
||||
struct MatchedSchema;
|
||||
} // namespace script
|
||||
|
||||
// Graph represents one "function" of computation.
|
||||
// It uses a simple ownership model where the graph owns all the nodes inside
|
||||
@ -1134,10 +1132,10 @@ struct Graph {
|
||||
|
||||
TORCH_API Value* insertFunctionCall(
|
||||
Function* callee,
|
||||
const script::MatchedSchema& matched);
|
||||
const MatchedSchema& matched);
|
||||
TORCH_API Value* insertMethodCall(
|
||||
std::string method_name,
|
||||
const script::MatchedSchema& matched);
|
||||
const MatchedSchema& matched);
|
||||
|
||||
// Note: defined in python_ir.cpp and can be used only in python extension
|
||||
Node* createPythonOp(
|
||||
|
@ -9,7 +9,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
struct VarWithType;
|
||||
struct ParsedLiteral;
|
||||
@ -57,7 +56,7 @@ class IRParser {
|
||||
|
||||
Value* findValueInVMap(const std::string& name);
|
||||
|
||||
torch::jit::script::Lexer L;
|
||||
torch::jit::Lexer L;
|
||||
torch::jit::Graph* g = nullptr;
|
||||
std::unordered_map<std::string, Value*>& vmap;
|
||||
SchemaTypeParser type_parser;
|
||||
@ -86,7 +85,7 @@ void parseIR(
|
||||
const std::string& str,
|
||||
torch::jit::Graph* graph,
|
||||
std::unordered_map<std::string, Value*>& vmap) {
|
||||
torch::jit::script::IRParser p(str, graph, vmap);
|
||||
torch::jit::IRParser p(str, graph, vmap);
|
||||
p.parse();
|
||||
}
|
||||
|
||||
@ -473,6 +472,5 @@ Value* IRParser::findValueInVMap(const std::string& name) {
|
||||
return vmap.at(name);
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -12,8 +12,6 @@ namespace jit {
|
||||
struct Graph;
|
||||
struct Value;
|
||||
|
||||
namespace script {
|
||||
|
||||
// \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH.
|
||||
TORCH_API void parseIR(const std::string& str, torch::jit::Graph* graph);
|
||||
|
||||
@ -27,6 +25,5 @@ TORCH_API void parseIR(
|
||||
torch::jit::Graph* graph,
|
||||
std::unordered_map<std::string, Value*>& vmap);
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -125,14 +125,14 @@ class BytecodeDeserializer final {
|
||||
private:
|
||||
c10::IValue readArchive(const std::string& archive_name,
|
||||
std::shared_ptr<mobile::CompilationUnit> mcu);
|
||||
std::shared_ptr<script::CompilationUnit> compilation_unit_;
|
||||
std::shared_ptr<CompilationUnit> compilation_unit_;
|
||||
std::unordered_set<std::string> imported_libs_;
|
||||
std::unique_ptr<PyTorchStreamReader> reader_;
|
||||
c10::optional<at::Device> device_;
|
||||
};
|
||||
|
||||
BytecodeDeserializer::BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader)
|
||||
: compilation_unit_(std::make_shared<script::CompilationUnit>()), reader_(std::move(reader)) {}
|
||||
: compilation_unit_(std::make_shared<CompilationUnit>()), reader_(std::move(reader)) {}
|
||||
|
||||
mobile::Module BytecodeDeserializer::deserialize(c10::optional<at::Device> device) {
|
||||
device_ = device;
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user