mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] kill script namespace (#34515)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34515 Once upon a time we thought this was necessary. In reality it is not, so removing it. For backcompat, our public interface (defined in `api/`) still has typedefs to the old `script::` names. There was only one collision: `Pass` as a `Stmt` and `Pass` as a graph transform. I renamed one of them. Test Plan: Imported from OSS Differential Revision: D20353503 Pulled By: suo fbshipit-source-id: 48bb911ce75120a8c9e0c6fb65262ef775dfba93
This commit is contained in:
committed by
Facebook GitHub Bot
parent
cf8b728255
commit
c235be42dd
@ -13,7 +13,7 @@ int main(int argc, char* argv[]) {
|
|||||||
std::ifstream ifs(input_file_path);
|
std::ifstream ifs(input_file_path);
|
||||||
std::stringstream buffer;
|
std::stringstream buffer;
|
||||||
buffer << ifs.rdbuf();
|
buffer << ifs.rdbuf();
|
||||||
torch::jit::script::Module m("TestModule");
|
torch::jit::Module m("TestModule");
|
||||||
|
|
||||||
m.define(buffer.str());
|
m.define(buffer.str());
|
||||||
m.save(output_file_path);
|
m.save(output_file_path);
|
||||||
|
@ -58,7 +58,7 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
|
|||||||
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||||
private:
|
private:
|
||||||
friend HybridBase;
|
friend HybridBase;
|
||||||
torch::jit::script::Module module_;
|
torch::jit::Module module_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";
|
constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";
|
||||||
|
@ -5,7 +5,7 @@ package org.pytorch;
|
|||||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||||
|
|
||||||
/** Java wrapper for torch::jit::script::Module. */
|
/** Java wrapper for torch::jit::Module. */
|
||||||
public class Module {
|
public class Module {
|
||||||
|
|
||||||
private INativePeer mNativePeer;
|
private INativePeer mNativePeer;
|
||||||
@ -14,7 +14,7 @@ public class Module {
|
|||||||
* Loads a serialized TorchScript module from the specified path on the disk.
|
* Loads a serialized TorchScript module from the specified path on the disk.
|
||||||
*
|
*
|
||||||
* @param modelPath path to file that contains the serialized TorchScript module.
|
* @param modelPath path to file that contains the serialized TorchScript module.
|
||||||
* @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module.
|
* @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
|
||||||
*/
|
*/
|
||||||
public static Module load(final String modelPath) {
|
public static Module load(final String modelPath) {
|
||||||
if (!NativeLoader.isInitialized()) {
|
if (!NativeLoader.isInitialized()) {
|
||||||
@ -49,7 +49,7 @@ public class Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Explicitly destroys the native torch::jit::script::Module. Calling this method is not required,
|
* Explicitly destroys the native torch::jit::Module. Calling this method is not required,
|
||||||
* as the native object will be destroyed when this object is garbage-collected. However, the
|
* as the native object will be destroyed when this object is garbage-collected. However, the
|
||||||
* timing of garbage collection is not guaranteed, so proactively calling {@code destroy} can free
|
* timing of garbage collection is not guaranteed, so proactively calling {@code destroy} can free
|
||||||
* memory more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
|
* memory more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
|
||||||
|
@ -22,7 +22,7 @@ TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
|
|||||||
|
|
||||||
// A Function is a pure Graph with no implicit `self` object bound.
|
// A Function is a pure Graph with no implicit `self` object bound.
|
||||||
// It contains schema information, and the executor that manages the
|
// It contains schema information, and the executor that manages the
|
||||||
// execution of the function. script::Method is a wrapper around a
|
// execution of the function. Method is a wrapper around a
|
||||||
// underlying Function that also provides a `self` object.
|
// underlying Function that also provides a `self` object.
|
||||||
struct TORCH_API Function {
|
struct TORCH_API Function {
|
||||||
virtual bool isGraphFunction() const = 0;
|
virtual bool isGraphFunction() const = 0;
|
||||||
|
@ -378,7 +378,7 @@ std::vector<std::pair<IValue, IValue>> iterationOrder(const c10::Dict<IValue, IV
|
|||||||
}
|
}
|
||||||
|
|
||||||
StrongTypePtr::StrongTypePtr(
|
StrongTypePtr::StrongTypePtr(
|
||||||
std::shared_ptr<torch::jit::script::CompilationUnit> cu,
|
std::shared_ptr<torch::jit::CompilationUnit> cu,
|
||||||
std::shared_ptr<Type> type) {
|
std::shared_ptr<Type> type) {
|
||||||
cu_ = std::move(cu);
|
cu_ = std::move(cu);
|
||||||
type_ = type;
|
type_ = type;
|
||||||
|
@ -11,10 +11,8 @@ namespace jit {
|
|||||||
class CustomClassHolder : public c10::intrusive_ptr_target {};
|
class CustomClassHolder : public c10::intrusive_ptr_target {};
|
||||||
|
|
||||||
struct Function;
|
struct Function;
|
||||||
namespace script {
|
|
||||||
struct CompilationUnit;
|
struct CompilationUnit;
|
||||||
struct Module;
|
struct Module;
|
||||||
}
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
@ -356,7 +354,7 @@ struct CAFFE2_API IValue final {
|
|||||||
c10::intrusive_ptr<ivalue::Object> toObject() const & ;
|
c10::intrusive_ptr<ivalue::Object> toObject() const & ;
|
||||||
const ivalue::Object& toObjectRef() const;
|
const ivalue::Object& toObjectRef() const;
|
||||||
|
|
||||||
torch::jit::script::Module toModule() const;
|
torch::jit::Module toModule() const;
|
||||||
bool isModule() const;
|
bool isModule() const;
|
||||||
|
|
||||||
// PyObject
|
// PyObject
|
||||||
@ -692,10 +690,10 @@ private:
|
|||||||
// guaranteed to stay alive as long as we hold this object.
|
// guaranteed to stay alive as long as we hold this object.
|
||||||
struct TORCH_API StrongTypePtr {
|
struct TORCH_API StrongTypePtr {
|
||||||
StrongTypePtr(
|
StrongTypePtr(
|
||||||
std::shared_ptr<torch::jit::script::CompilationUnit> cu,
|
std::shared_ptr<torch::jit::CompilationUnit> cu,
|
||||||
std::shared_ptr<Type> type);
|
std::shared_ptr<Type> type);
|
||||||
|
|
||||||
std::shared_ptr<torch::jit::script::CompilationUnit> cu_;
|
std::shared_ptr<torch::jit::CompilationUnit> cu_;
|
||||||
std::shared_ptr<Type> type_;
|
std::shared_ptr<Type> type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -15,9 +15,7 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
struct Function;
|
struct Function;
|
||||||
namespace script {
|
|
||||||
struct CompilationUnit;
|
struct CompilationUnit;
|
||||||
}
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
@ -406,7 +404,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
|
|||||||
}
|
}
|
||||||
std::shared_ptr<ClassType> type() const;
|
std::shared_ptr<ClassType> type() const;
|
||||||
|
|
||||||
std::shared_ptr<torch::jit::script::CompilationUnit> compilation_unit() {
|
std::shared_ptr<torch::jit::CompilationUnit> compilation_unit() {
|
||||||
return type_.cu_;
|
return type_.cu_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -868,7 +866,7 @@ IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
|
|||||||
auto res = getCustomClassType<inputType>();
|
auto res = getCustomClassType<inputType>();
|
||||||
auto retObject = ivalue::Object::create(
|
auto retObject = ivalue::Object::create(
|
||||||
StrongTypePtr(
|
StrongTypePtr(
|
||||||
std::shared_ptr<torch::jit::script::CompilationUnit>(),
|
std::shared_ptr<torch::jit::CompilationUnit>(),
|
||||||
std::move(res)),
|
std::move(res)),
|
||||||
1);
|
1);
|
||||||
auto objPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(std::move(x));
|
auto objPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(std::move(x));
|
||||||
|
@ -17,9 +17,7 @@
|
|||||||
struct ClassType;
|
struct ClassType;
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
struct CompilationUnit;
|
struct CompilationUnit;
|
||||||
}
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
||||||
@ -1491,7 +1489,7 @@ CAFFE2_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type);
|
|||||||
|
|
||||||
struct ClassType;
|
struct ClassType;
|
||||||
using ClassTypePtr = std::shared_ptr<ClassType>;
|
using ClassTypePtr = std::shared_ptr<ClassType>;
|
||||||
using ::torch::jit::script::CompilationUnit;
|
using ::torch::jit::CompilationUnit;
|
||||||
|
|
||||||
// This represents a class in TorchScript.
|
// This represents a class in TorchScript.
|
||||||
struct CAFFE2_API ClassType : public NamedType {
|
struct CAFFE2_API ClassType : public NamedType {
|
||||||
@ -1801,7 +1799,7 @@ struct CAFFE2_API ClassType : public NamedType {
|
|||||||
|
|
||||||
struct InterfaceType;
|
struct InterfaceType;
|
||||||
using InterfaceTypePtr = std::shared_ptr<InterfaceType>;
|
using InterfaceTypePtr = std::shared_ptr<InterfaceType>;
|
||||||
using ::torch::jit::script::CompilationUnit;
|
using ::torch::jit::CompilationUnit;
|
||||||
|
|
||||||
// Interfaces are a list of abstract methods that a class might meet.
|
// Interfaces are a list of abstract methods that a class might meet.
|
||||||
// If a class provides those methods, it implicitly meets the interface.
|
// If a class provides those methods, it implicitly meets the interface.
|
||||||
|
@ -24,7 +24,7 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
void dump_opnames(const script::Module& m, std::unordered_set<std::string>& opnames) {
|
void dump_opnames(const Module& m, std::unordered_set<std::string>& opnames) {
|
||||||
auto methods = m.get_methods();
|
auto methods = m.get_methods();
|
||||||
for (const auto& method : methods) {
|
for (const auto& method : methods) {
|
||||||
const auto& func = method.function();
|
const auto& func = method.function();
|
||||||
|
@ -7,8 +7,8 @@
|
|||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
|
|
||||||
using torch::jit::IValue;
|
using torch::jit::IValue;
|
||||||
using torch::jit::script::Method;
|
using torch::jit::Method;
|
||||||
using torch::jit::script::Module;
|
using torch::jit::Module;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ScriptModuleSerializer : public BlobSerializerBase {
|
class ScriptModuleSerializer : public BlobSerializerBase {
|
||||||
@ -31,7 +31,7 @@ class ScriptModuleSerializer : public BlobSerializerBase {
|
|||||||
// the more efficient serialization version (if we ever get to that point)
|
// the more efficient serialization version (if we ever get to that point)
|
||||||
BlobProto blob_proto;
|
BlobProto blob_proto;
|
||||||
blob_proto.set_name(name);
|
blob_proto.set_name(name);
|
||||||
blob_proto.set_type("torch::jit::script::Module");
|
blob_proto.set_type("torch::jit::Module");
|
||||||
blob_proto.set_content(ss.str());
|
blob_proto.set_content(ss.str());
|
||||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||||
}
|
}
|
||||||
@ -134,7 +134,7 @@ REGISTER_BLOB_SERIALIZER(
|
|||||||
// NB: the first argument to REGISTER_BLOB_DESERIALIZER macro doesn't really
|
// NB: the first argument to REGISTER_BLOB_DESERIALIZER macro doesn't really
|
||||||
// need to be a real type, it just get converted to string
|
// need to be a real type, it just get converted to string
|
||||||
REGISTER_BLOB_DESERIALIZER(
|
REGISTER_BLOB_DESERIALIZER(
|
||||||
torch::jit::script::Module,
|
torch::jit::Module,
|
||||||
ScriptModuleDeserializer);
|
ScriptModuleDeserializer);
|
||||||
|
|
||||||
OPERATOR_SCHEMA(ScriptModule)
|
OPERATOR_SCHEMA(ScriptModule)
|
||||||
|
@ -94,12 +94,12 @@ REGISTER_BLOB_FETCHER((TypeMeta::Id<string>()), StringFetcher);
|
|||||||
class ScriptModuleFetcher : public BlobFetcherBase {
|
class ScriptModuleFetcher : public BlobFetcherBase {
|
||||||
public:
|
public:
|
||||||
pybind11::object Fetch(const Blob& blob) override {
|
pybind11::object Fetch(const Blob& blob) override {
|
||||||
return py::cast(*blob.Get<std::unique_ptr<torch::jit::script::Module>>());
|
return py::cast(*blob.Get<std::unique_ptr<torch::jit::Module>>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_BLOB_FETCHER(
|
REGISTER_BLOB_FETCHER(
|
||||||
(TypeMeta::Id<std::unique_ptr<torch::jit::script::Module>>()),
|
(TypeMeta::Id<std::unique_ptr<torch::jit::Module>>()),
|
||||||
caffe2::python::ScriptModuleFetcher);
|
caffe2::python::ScriptModuleFetcher);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -247,9 +247,9 @@ bool feedBlob(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#ifdef FBCODE_CAFFE2
|
#ifdef FBCODE_CAFFE2
|
||||||
if (auto module = torch::jit::script::as_module(arg)) {
|
if (auto module = torch::jit::as_module(arg)) {
|
||||||
blob->GetMutable<std::unique_ptr<torch::jit::script::Module>>()->reset(
|
blob->GetMutable<std::unique_ptr<torch::jit::Module>>()->reset(
|
||||||
new torch::jit::script::Module(*module));
|
new torch::jit::Module(*module));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -213,7 +213,7 @@ Disable JIT for Debugging
|
|||||||
Python. Since TorchScript (scripting and tracing) are disabled with this flag,
|
Python. Since TorchScript (scripting and tracing) are disabled with this flag,
|
||||||
you can use tools like ``pdb`` to debug the model code.
|
you can use tools like ``pdb`` to debug the model code.
|
||||||
|
|
||||||
Given an example script::
|
Given an example
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def scripted_fn(x : torch.Tensor):
|
def scripted_fn(x : torch.Tensor):
|
||||||
|
@ -120,8 +120,8 @@ usage might look like:
|
|||||||
|
|
||||||
.. code-block:: cpp
|
.. code-block:: cpp
|
||||||
|
|
||||||
SetExportModuleExtraFilesHook([](const script::Module&) {
|
SetExportModuleExtraFilesHook([](const Module&) {
|
||||||
script::ExtraFilesMap files;
|
ExtraFilesMap files;
|
||||||
files["producer_info.json"] = "{\"user\": \"" + getenv("USER") + "\"}";
|
files["producer_info.json"] = "{\"user\": \"" + getenv("USER") + "\"}";
|
||||||
return files;
|
return files;
|
||||||
});
|
});
|
||||||
|
@ -8,7 +8,7 @@ Module
|
|||||||
|
|
||||||
.. java:type:: public class Module
|
.. java:type:: public class Module
|
||||||
|
|
||||||
Java wrapper for torch::jit::script::Module.
|
Java wrapper for torch::jit::Module.
|
||||||
|
|
||||||
Methods
|
Methods
|
||||||
-------
|
-------
|
||||||
@ -18,7 +18,7 @@ destroy
|
|||||||
.. java:method:: public void destroy()
|
.. java:method:: public void destroy()
|
||||||
:outertype: Module
|
:outertype: Module
|
||||||
|
|
||||||
Explicitly destroys the native torch::jit::script::Module. Calling this method is not required, as the native object will be destroyed when this object is garbage-collected. However, the timing of garbage collection is not guaranteed, so proactively calling \ ``destroy``\ can free memory more quickly. See \ :java:ref:`com.facebook.jni.HybridData.resetNative`\ .
|
Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the native object will be destroyed when this object is garbage-collected. However, the timing of garbage collection is not guaranteed, so proactively calling \ ``destroy``\ can free memory more quickly. See \ :java:ref:`com.facebook.jni.HybridData.resetNative`\ .
|
||||||
|
|
||||||
forward
|
forward
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
@ -40,7 +40,7 @@ load
|
|||||||
Loads a serialized TorchScript module from the specified path on the disk.
|
Loads a serialized TorchScript module from the specified path on the disk.
|
||||||
|
|
||||||
:param modelPath: path to file that contains the serialized TorchScript module.
|
:param modelPath: path to file that contains the serialized TorchScript module.
|
||||||
:return: new \ :java:ref:`org.pytorch.Module`\ object which owns torch::jit::script::Module.
|
:return: new \ :java:ref:`org.pytorch.Module`\ object which owns torch::jit::Module.
|
||||||
|
|
||||||
runMethod
|
runMethod
|
||||||
^^^^^^^^^
|
^^^^^^^^^
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
@end
|
@end
|
||||||
|
|
||||||
@implementation TestAppTests {
|
@implementation TestAppTests {
|
||||||
torch::jit::script::Module _module;
|
torch::jit::Module _module;
|
||||||
}
|
}
|
||||||
|
|
||||||
+ (void)setUp {
|
+ (void)setUp {
|
||||||
|
@ -395,7 +395,7 @@ void testAliasAnalysis() {
|
|||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph():
|
graph():
|
||||||
%opt : Tensor? = prim::Constant()
|
%opt : Tensor? = prim::Constant()
|
||||||
@ -415,7 +415,7 @@ void testAliasAnalysis() {
|
|||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x : Tensor):
|
graph(%x : Tensor):
|
||||||
%3 : int = prim::Constant[value=1]()
|
%3 : int = prim::Constant[value=1]()
|
||||||
@ -491,7 +491,7 @@ void testWriteTracking() {
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x: Tensor):
|
graph(%x: Tensor):
|
||||||
%b : Tensor = aten::relu_(%x)
|
%b : Tensor = aten::relu_(%x)
|
||||||
@ -505,7 +505,7 @@ void testWriteTracking() {
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x: Tensor, %y : Tensor):
|
graph(%x: Tensor, %y : Tensor):
|
||||||
%b : Tensor = aten::mul(%x, %y)
|
%b : Tensor = aten::mul(%x, %y)
|
||||||
@ -520,7 +520,7 @@ void testWriteTracking() {
|
|||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x: Tensor, %y : Tensor):
|
graph(%x: Tensor, %y : Tensor):
|
||||||
%c1 : int = prim::Constant[value=1]()
|
%c1 : int = prim::Constant[value=1]()
|
||||||
@ -540,7 +540,7 @@ void testContainerAliasing() {
|
|||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%inp: Tensor[]):
|
graph(%inp: Tensor[]):
|
||||||
%x : str = prim::Constant[value="a"]()
|
%x : str = prim::Constant[value="a"]()
|
||||||
@ -574,7 +574,7 @@ void testContainerAliasing() {
|
|||||||
|
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph():
|
graph():
|
||||||
%x : str = prim::Constant[value="a"]()
|
%x : str = prim::Constant[value="a"]()
|
||||||
@ -601,7 +601,7 @@ void testContainerAliasing() {
|
|||||||
// Test input aliasing
|
// Test input aliasing
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x: Tensor, %y: Tensor):
|
graph(%x: Tensor, %y: Tensor):
|
||||||
%a : (Tensor) = prim::TupleConstruct(%x)
|
%a : (Tensor) = prim::TupleConstruct(%x)
|
||||||
@ -622,7 +622,7 @@ void testContainerAliasing() {
|
|||||||
// Test tuple that doesn't come from construct
|
// Test tuple that doesn't come from construct
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x : int,
|
graph(%x : int,
|
||||||
%y : Tensor,
|
%y : Tensor,
|
||||||
@ -654,7 +654,7 @@ graph(%x : int,
|
|||||||
// test nested types
|
// test nested types
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph():
|
graph():
|
||||||
%a : Tensor = prim::MakeTestTensor()
|
%a : Tensor = prim::MakeTestTensor()
|
||||||
@ -681,7 +681,7 @@ graph():
|
|||||||
// simple example
|
// simple example
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph():
|
graph():
|
||||||
%0 : Tensor = prim::Constant()
|
%0 : Tensor = prim::Constant()
|
||||||
@ -711,7 +711,7 @@ graph():
|
|||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph():
|
graph():
|
||||||
%x : str = prim::Constant[value="a"]()
|
%x : str = prim::Constant[value="a"]()
|
||||||
@ -737,7 +737,7 @@ graph():
|
|||||||
// Test list container aliasing
|
// Test list container aliasing
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph():
|
graph():
|
||||||
%0 : int = prim::Constant[value=2]()
|
%0 : int = prim::Constant[value=2]()
|
||||||
@ -779,7 +779,7 @@ graph():
|
|||||||
|
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph():
|
graph():
|
||||||
%0 : int = prim::Constant[value=2]()
|
%0 : int = prim::Constant[value=2]()
|
||||||
@ -810,7 +810,7 @@ graph():
|
|||||||
// print across it.
|
// print across it.
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph():
|
graph():
|
||||||
%35 : int = prim::Constant[value=1]()
|
%35 : int = prim::Constant[value=1]()
|
||||||
@ -849,7 +849,7 @@ graph():
|
|||||||
// print across it.
|
// print across it.
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph():
|
graph():
|
||||||
%38 : int = prim::Constant[value=1]()
|
%38 : int = prim::Constant[value=1]()
|
||||||
@ -935,7 +935,7 @@ void testWildcards() {
|
|||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%ten_list : Tensor[], %int_list : int[], %opt_ten_list : Tensor[]?):
|
graph(%ten_list : Tensor[], %int_list : int[], %opt_ten_list : Tensor[]?):
|
||||||
%ten : Tensor = prim::Constant()
|
%ten : Tensor = prim::Constant()
|
||||||
@ -971,7 +971,7 @@ void testWildcards() {
|
|||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%ten_list : Tensor[], %ten_opt_list : Tensor?[]):
|
graph(%ten_list : Tensor[], %ten_opt_list : Tensor?[]):
|
||||||
%ten : Tensor = prim::Constant()
|
%ten : Tensor = prim::Constant()
|
||||||
@ -992,7 +992,7 @@ void testWildcards() {
|
|||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%float_3D : Float(*, *, *), %float_2D : Float(*, *)):
|
graph(%float_3D : Float(*, *, *), %float_2D : Float(*, *)):
|
||||||
return ()
|
return ()
|
||||||
@ -1006,7 +1006,7 @@ void testWildcards() {
|
|||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%float_3D_list : Float(*, *, *)[], %float_2D_list : Float(*, *)[], %ten: Tensor):
|
graph(%float_3D_list : Float(*, *, *)[], %float_2D_list : Float(*, *)[], %ten: Tensor):
|
||||||
return ()
|
return ()
|
||||||
|
@ -213,7 +213,7 @@ void testDifferentiateWithRequiresGrad() {
|
|||||||
%7 : Tensor = aten::add(%6, %1, %2)
|
%7 : Tensor = aten::add(%6, %1, %2)
|
||||||
return (%4, %7))IR";
|
return (%4, %7))IR";
|
||||||
auto g = std::make_shared<Graph>();
|
auto g = std::make_shared<Graph>();
|
||||||
torch::jit::script::parseIR(graph_string, g.get());
|
torch::jit::parseIR(graph_string, g.get());
|
||||||
|
|
||||||
auto a_var = autograd::make_variable(
|
auto a_var = autograd::make_variable(
|
||||||
at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
|
at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
|
||||||
|
@ -9,8 +9,6 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
using namespace torch::jit::script;
|
|
||||||
|
|
||||||
static const auto classSrcs1 = R"JIT(
|
static const auto classSrcs1 = R"JIT(
|
||||||
class FooNestedTest:
|
class FooNestedTest:
|
||||||
def __init__(self, y):
|
def __init__(self, y):
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
using namespace torch::jit::script;
|
|
||||||
const auto testSource = R"JIT(
|
const auto testSource = R"JIT(
|
||||||
class FooTest:
|
class FooTest:
|
||||||
def __init__(self, x):
|
def __init__(self, x):
|
||||||
|
@ -5,8 +5,6 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
using namespace torch::jit::script;
|
|
||||||
|
|
||||||
void testClassTypeAddRemoveAttr() {
|
void testClassTypeAddRemoveAttr() {
|
||||||
auto cu = std::make_shared<CompilationUnit>();
|
auto cu = std::make_shared<CompilationUnit>();
|
||||||
auto cls = ClassType::create("foo.bar", cu, true);
|
auto cls = ClassType::create("foo.bar", cu, true);
|
||||||
|
@ -14,7 +14,7 @@ namespace jit {
|
|||||||
void testConstantPooling() {
|
void testConstantPooling() {
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph():
|
graph():
|
||||||
%8 : int = prim::Constant[value=1]()
|
%8 : int = prim::Constant[value=1]()
|
||||||
@ -29,7 +29,7 @@ graph():
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%cond : Tensor):
|
graph(%cond : Tensor):
|
||||||
%a : str = prim::Constant[value="bcd"]()
|
%a : str = prim::Constant[value="bcd"]()
|
||||||
@ -53,7 +53,7 @@ graph(%cond : Tensor):
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph():
|
graph():
|
||||||
%2 : int = prim::Constant[value=2]()
|
%2 : int = prim::Constant[value=2]()
|
||||||
|
@ -145,7 +145,7 @@ void testCustomOperatorAliasing() {
|
|||||||
|
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x: Tensor, %y: Tensor):
|
graph(%x: Tensor, %y: Tensor):
|
||||||
%ret : Tensor = foo::aliasing(%x, %y)
|
%ret : Tensor = foo::aliasing(%x, %y)
|
||||||
@ -172,7 +172,7 @@ graph(%x: Tensor, %y: Tensor):
|
|||||||
%ret : Tensor = foo::aliasing(%x, %y)
|
%ret : Tensor = foo::aliasing(%x, %y)
|
||||||
return (%x)
|
return (%x)
|
||||||
)IR";
|
)IR";
|
||||||
script::parseIR(text, graph.get());
|
parseIR(text, graph.get());
|
||||||
EliminateDeadCode(graph);
|
EliminateDeadCode(graph);
|
||||||
|
|
||||||
testing::FileCheck().run(text, *graph);
|
testing::FileCheck().run(text, *graph);
|
||||||
|
@ -43,7 +43,7 @@ graph():
|
|||||||
-> (%50, %tot.3)
|
-> (%50, %tot.3)
|
||||||
return (%tot)
|
return (%tot)
|
||||||
)IR";
|
)IR";
|
||||||
script::parseIR(input, graph.get());
|
parseIR(input, graph.get());
|
||||||
EliminateDeadCode(graph);
|
EliminateDeadCode(graph);
|
||||||
// Check that dead code elimin
|
// Check that dead code elimin
|
||||||
testing::FileCheck().run(input, *graph);
|
testing::FileCheck().run(input, *graph);
|
||||||
|
@ -67,7 +67,7 @@ void testFusion() {
|
|||||||
%2 : Tensor = aten::mul(%0, %1)
|
%2 : Tensor = aten::mul(%0, %1)
|
||||||
return (%2))IR";
|
return (%2))IR";
|
||||||
Graph graph;
|
Graph graph;
|
||||||
torch::jit::script::parseIR(graph_string, &graph);
|
torch::jit::parseIR(graph_string, &graph);
|
||||||
|
|
||||||
auto a = at::rand({3, 4}, at::kCUDA);
|
auto a = at::rand({3, 4}, at::kCUDA);
|
||||||
auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1);
|
auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1);
|
||||||
@ -100,7 +100,7 @@ void testFusion() {
|
|||||||
%14 : Tensor = aten::mul(%8, %13)
|
%14 : Tensor = aten::mul(%8, %13)
|
||||||
return (%14, %12))IR";
|
return (%14, %12))IR";
|
||||||
Graph graph;
|
Graph graph;
|
||||||
torch::jit::script::parseIR(graph_string, &graph);
|
torch::jit::parseIR(graph_string, &graph);
|
||||||
|
|
||||||
graph.lint();
|
graph.lint();
|
||||||
|
|
||||||
@ -164,7 +164,7 @@ void testFusion() {
|
|||||||
graph_string2};
|
graph_string2};
|
||||||
for (auto i = decltype(graph_strings.size()){0}; i < graph_strings.size(); ++i) {
|
for (auto i = decltype(graph_strings.size()){0}; i < graph_strings.size(); ++i) {
|
||||||
Graph g;
|
Graph g;
|
||||||
torch::jit::script::parseIR(graph_strings[i], &g);
|
torch::jit::parseIR(graph_strings[i], &g);
|
||||||
|
|
||||||
auto outputs = debugLaunchGraph(g, {a, b});
|
auto outputs = debugLaunchGraph(g, {a, b});
|
||||||
ASSERT_EQ(outputs.size(), 2);
|
ASSERT_EQ(outputs.size(), 2);
|
||||||
@ -187,7 +187,7 @@ void testRegisterFusionCachesKernel() {
|
|||||||
%d0 : Float(2, 3, 4) = aten::mul(%c0, %0)
|
%d0 : Float(2, 3, 4) = aten::mul(%c0, %0)
|
||||||
return (%d0))IR";
|
return (%d0))IR";
|
||||||
auto g0 = std::make_shared<Graph>();
|
auto g0 = std::make_shared<Graph>();
|
||||||
torch::jit::script::parseIR(graph0_string, g0.get());
|
torch::jit::parseIR(graph0_string, g0.get());
|
||||||
|
|
||||||
const auto graph1_string = R"IR(
|
const auto graph1_string = R"IR(
|
||||||
graph(%0 : Float(2, 3, 4),
|
graph(%0 : Float(2, 3, 4),
|
||||||
@ -196,7 +196,7 @@ void testRegisterFusionCachesKernel() {
|
|||||||
%d1 : Float(2, 3, 4) = aten::mul(%c1, %0)
|
%d1 : Float(2, 3, 4) = aten::mul(%c1, %0)
|
||||||
return (%d1))IR";
|
return (%d1))IR";
|
||||||
auto g1 = std::make_shared<Graph>();
|
auto g1 = std::make_shared<Graph>();
|
||||||
torch::jit::script::parseIR(graph1_string, g1.get());
|
torch::jit::parseIR(graph1_string, g1.get());
|
||||||
|
|
||||||
auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
|
auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
|
||||||
const auto& nodes = graph->nodes();
|
const auto& nodes = graph->nodes();
|
||||||
|
@ -21,7 +21,6 @@ def foo3(x):
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
using namespace script;
|
|
||||||
using namespace testing;
|
using namespace testing;
|
||||||
|
|
||||||
struct InlinerGuard {
|
struct InlinerGuard {
|
||||||
|
@ -11,8 +11,6 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
using namespace torch::jit::script;
|
|
||||||
|
|
||||||
static const std::vector<std::string> subMethodSrcs = {R"JIT(
|
static const std::vector<std::string> subMethodSrcs = {R"JIT(
|
||||||
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
||||||
return x + y + 1
|
return x + y + 1
|
||||||
|
@ -55,7 +55,7 @@ void testBlocks() {
|
|||||||
%12 : int = prim::Constant[value=1]()
|
%12 : int = prim::Constant[value=1]()
|
||||||
%13 : Tensor = aten::add(%5, %3, %12)
|
%13 : Tensor = aten::add(%5, %3, %12)
|
||||||
return (%13))IR";
|
return (%13))IR";
|
||||||
torch::jit::script::parseIR(graph_string, g.get());
|
torch::jit::parseIR(graph_string, g.get());
|
||||||
|
|
||||||
g->lint();
|
g->lint();
|
||||||
testing::FileCheck()
|
testing::FileCheck()
|
||||||
@ -122,7 +122,7 @@ graph(%x : Tensor,
|
|||||||
|
|
||||||
torch::jit::Graph g;
|
torch::jit::Graph g;
|
||||||
std::unordered_map<std::string, torch::jit::Value*> name_to_value;
|
std::unordered_map<std::string, torch::jit::Value*> name_to_value;
|
||||||
torch::jit::script::parseIR(input_str, &g, name_to_value);
|
torch::jit::parseIR(input_str, &g, name_to_value);
|
||||||
|
|
||||||
std::vector<std::string> value_names{"6", "7", "9", "10"};
|
std::vector<std::string> value_names{"6", "7", "9", "10"};
|
||||||
std::unordered_set<std::string> value_names_set(
|
std::unordered_set<std::string> value_names_set(
|
||||||
|
@ -17,7 +17,7 @@ namespace jit {
|
|||||||
*/
|
*/
|
||||||
static void checkRoundtrip(const std::string& s) {
|
static void checkRoundtrip(const std::string& s) {
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(s, &*graph);
|
parseIR(s, &*graph);
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
ss << *graph;
|
ss << *graph;
|
||||||
std::string parsed = ss.str();
|
std::string parsed = ss.str();
|
||||||
@ -42,7 +42,7 @@ void testIRParser() {
|
|||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0 : Tensor, %1 : Tensor):
|
graph(%0 : Tensor, %1 : Tensor):
|
||||||
%2 : Tensor = foo::add(%0, %1)
|
%2 : Tensor = foo::add(%0, %1)
|
||||||
@ -117,7 +117,7 @@ graph(%0 : Tensor,
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%a):
|
graph(%a):
|
||||||
return (%a))IR",
|
return (%a))IR",
|
||||||
@ -127,7 +127,7 @@ graph(%a):
|
|||||||
{
|
{
|
||||||
// Check that parser correctly handles values reusing the same name.
|
// Check that parser correctly handles values reusing the same name.
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x):
|
graph(%x):
|
||||||
%x = a::a(%x)
|
%x = a::a(%x)
|
||||||
@ -239,7 +239,7 @@ graph(%0 : Tensor,
|
|||||||
# CHECK: return
|
# CHECK: return
|
||||||
return (%a))IR";
|
return (%a))IR";
|
||||||
|
|
||||||
script::parseIR(text, &*graph);
|
parseIR(text, &*graph);
|
||||||
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get()));
|
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get()));
|
||||||
torch::jit::testing::FileCheck().run(text, *graph);
|
torch::jit::testing::FileCheck().run(text, *graph);
|
||||||
}
|
}
|
||||||
|
@ -7,8 +7,6 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
using namespace torch::jit::script;
|
|
||||||
|
|
||||||
void testIValue() {
|
void testIValue() {
|
||||||
c10::List<int64_t> foo({3, 4, 5});
|
c10::List<int64_t> foo({3, 4, 5});
|
||||||
ASSERT_EQ(foo.use_count(), 1);
|
ASSERT_EQ(foo.use_count(), 1);
|
||||||
|
@ -12,7 +12,7 @@ namespace torch {
|
|||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
void testLiteInterpreterUpsampleNearest2d() {
|
void testLiteInterpreterUpsampleNearest2d() {
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.define(R"(
|
m.define(R"(
|
||||||
def forward(self, input: Tensor, scale:float):
|
def forward(self, input: Tensor, scale:float):
|
||||||
return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
|
return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
|
||||||
@ -35,7 +35,7 @@ void testLiteInterpreterUpsampleNearest2d() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void testLiteInterpreterAdd() {
|
void testLiteInterpreterAdd() {
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.register_parameter("foo", torch::ones({}), false);
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
// TODO: support default param val, which was pushed in
|
// TODO: support default param val, which was pushed in
|
||||||
// function schema's checkAndNormalizeInputs()
|
// function schema's checkAndNormalizeInputs()
|
||||||
@ -75,7 +75,7 @@ void testLiteInterpreterConv() {
|
|||||||
|
|
||||||
std::vector<torch::jit::IValue> inputs;
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
|
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
|
||||||
m.register_parameter("bias", torch::ones({20}), false);
|
m.register_parameter("bias", torch::ones({20}), false);
|
||||||
m.define(R"(
|
m.define(R"(
|
||||||
@ -100,7 +100,7 @@ void testLiteInterpreterConv() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void testLiteInterpreterInline() {
|
void testLiteInterpreterInline() {
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.define(R"JIT(
|
m.define(R"JIT(
|
||||||
def foo1(self, x):
|
def foo1(self, x):
|
||||||
return x + 1
|
return x + 1
|
||||||
@ -120,7 +120,7 @@ void testLiteInterpreterInline() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void testLiteInterpreterTuple() {
|
void testLiteInterpreterTuple() {
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.define(R"JIT(
|
m.define(R"JIT(
|
||||||
def foo(self, x):
|
def foo(self, x):
|
||||||
return (1, 2, x + 3)
|
return (1, 2, x + 3)
|
||||||
@ -138,7 +138,7 @@ void testLiteInterpreterTuple() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void testLiteInterpreterPrimOverload() {
|
void testLiteInterpreterPrimOverload() {
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.define(R"JIT(
|
m.define(R"JIT(
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
result = [1, 2]
|
result = [1, 2]
|
||||||
@ -154,7 +154,7 @@ void testLiteInterpreterPrimOverload() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void testLiteInterpreterPrim() {
|
void testLiteInterpreterPrim() {
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.define(R"JIT(
|
m.define(R"JIT(
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return int(x)
|
return int(x)
|
||||||
@ -180,7 +180,7 @@ void testLiteInterpreterPrim() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void testLiteInterpreterLoadOrigJit() {
|
void testLiteInterpreterLoadOrigJit() {
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.register_parameter("foo", torch::ones({}), false);
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
m.define(R"(
|
m.define(R"(
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -193,7 +193,7 @@ void testLiteInterpreterLoadOrigJit() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void testLiteInterpreterWrongMethodName() {
|
void testLiteInterpreterWrongMethodName() {
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.register_parameter("foo", torch::ones({}), false);
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
m.define(R"(
|
m.define(R"(
|
||||||
def add(self, x):
|
def add(self, x):
|
||||||
@ -210,7 +210,7 @@ void testLiteInterpreterWrongMethodName() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void testLiteInterpreterParams() {
|
void testLiteInterpreterParams() {
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
|
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
|
||||||
m.define(R"(
|
m.define(R"(
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -269,7 +269,7 @@ void testLiteInterpreterParams() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void testLiteInterpreterSetState() {
|
void testLiteInterpreterSetState() {
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.register_parameter("foo", torch::ones({}), false);
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
m.define(R"(
|
m.define(R"(
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
|
@ -368,7 +368,7 @@ void testCustomFusion() {
|
|||||||
%3 : Tensor = aten::mul(%2, %0)
|
%3 : Tensor = aten::mul(%2, %0)
|
||||||
return (%3))IR";
|
return (%3))IR";
|
||||||
auto g = std::make_shared<Graph>();
|
auto g = std::make_shared<Graph>();
|
||||||
torch::jit::script::parseIR(graph_string, g.get());
|
torch::jit::parseIR(graph_string, g.get());
|
||||||
|
|
||||||
torch::jit::overrideCanFuseOnCPU(true);
|
torch::jit::overrideCanFuseOnCPU(true);
|
||||||
CustomFuseGraph(
|
CustomFuseGraph(
|
||||||
@ -412,7 +412,7 @@ void testCustomFusionNestedBlocks() {
|
|||||||
%9 : Tensor = aten::add(%4, %2, %3)
|
%9 : Tensor = aten::add(%4, %2, %3)
|
||||||
return (%4))IR";
|
return (%4))IR";
|
||||||
auto g = std::make_shared<Graph>();
|
auto g = std::make_shared<Graph>();
|
||||||
torch::jit::script::parseIR(graph_string, g.get());
|
torch::jit::parseIR(graph_string, g.get());
|
||||||
|
|
||||||
CustomFuseGraph(
|
CustomFuseGraph(
|
||||||
g,
|
g,
|
||||||
@ -489,7 +489,7 @@ void testEvalModeForLoadedModule() {
|
|||||||
if (isSandcastle())
|
if (isSandcastle())
|
||||||
return; // The module file to load is not generated in Sandcastle
|
return; // The module file to load is not generated in Sandcastle
|
||||||
std::string module_path = "dropout_model.pt";
|
std::string module_path = "dropout_model.pt";
|
||||||
torch::jit::script::Module module = torch::jit::load(module_path);
|
torch::jit::Module module = torch::jit::load(module_path);
|
||||||
AT_ASSERT(module.attr("dropout").toModule().is_training());
|
AT_ASSERT(module.attr("dropout").toModule().is_training());
|
||||||
module.eval();
|
module.eval();
|
||||||
AT_ASSERT(!module.attr("dropout").toModule().is_training());
|
AT_ASSERT(!module.attr("dropout").toModule().is_training());
|
||||||
@ -973,7 +973,7 @@ void testNoneSchemaMatch() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void testModuleDefine() {
|
void testModuleDefine() {
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.register_parameter("foo", torch::ones({}), false);
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
m.define(R"(
|
m.define(R"(
|
||||||
def add_it(self, x, b : int = 4):
|
def add_it(self, x, b : int = 4):
|
||||||
@ -984,7 +984,7 @@ void testModuleDefine() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void testModuleConversion() {
|
void testModuleConversion() {
|
||||||
script::Module m("test");
|
Module m("test");
|
||||||
{
|
{
|
||||||
// test cuda to cpu for params and buffers
|
// test cuda to cpu for params and buffers
|
||||||
m.register_parameter("foo", torch::ones({}, at::kCUDA), false);
|
m.register_parameter("foo", torch::ones({}, at::kCUDA), false);
|
||||||
@ -1016,7 +1016,7 @@ RegisterPass p(fakePass);
|
|||||||
|
|
||||||
void testPassManagement() {
|
void testPassManagement() {
|
||||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%a):
|
graph(%a):
|
||||||
return (%a))IR",
|
return (%a))IR",
|
||||||
|
@ -5,8 +5,6 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
using namespace torch::jit::script;
|
|
||||||
|
|
||||||
void testModuleClone() {
|
void testModuleClone() {
|
||||||
auto cu = std::make_shared<CompilationUnit>();
|
auto cu = std::make_shared<CompilationUnit>();
|
||||||
auto parent = ClassType::create("parent", cu, true);
|
auto parent = ClassType::create("parent", cu, true);
|
||||||
|
@ -8,8 +8,6 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
using namespace script;
|
|
||||||
|
|
||||||
|
|
||||||
void testPeepholeOptimize() {
|
void testPeepholeOptimize() {
|
||||||
// test is / is not none optimization
|
// test is / is not none optimization
|
||||||
|
@ -10,7 +10,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
using namespace script;
|
|
||||||
|
|
||||||
void testSaveExtraFilesHook() {
|
void testSaveExtraFilesHook() {
|
||||||
// no secrets
|
// no secrets
|
||||||
|
@ -23,7 +23,7 @@ void testSchemaMatching() {
|
|||||||
return 0;
|
return 0;
|
||||||
}, c10::AliasAnalysisKind::FROM_SCHEMA),
|
}, c10::AliasAnalysisKind::FROM_SCHEMA),
|
||||||
});
|
});
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.define(R"(
|
m.define(R"(
|
||||||
def test(self):
|
def test(self):
|
||||||
a = (1.0, 2.0)
|
a = (1.0, 2.0)
|
||||||
@ -59,7 +59,7 @@ void testSchemaMatching() {
|
|||||||
return 0;
|
return 0;
|
||||||
}, AliasAnalysisKind::FROM_SCHEMA),
|
}, AliasAnalysisKind::FROM_SCHEMA),
|
||||||
});
|
});
|
||||||
script::Module m("m");
|
Module m("m");
|
||||||
m.define(R"JIT(
|
m.define(R"JIT(
|
||||||
def test(self):
|
def test(self):
|
||||||
a = (1.0, 2.0)
|
a = (1.0, 2.0)
|
||||||
|
@ -7,13 +7,13 @@ namespace jit {
|
|||||||
|
|
||||||
void testTrivial1() {
|
void testTrivial1() {
|
||||||
Graph graph, pattern;
|
Graph graph, pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%a = a::aaa(%0)
|
%a = a::aaa(%0)
|
||||||
return (%a))IR",
|
return (%a))IR",
|
||||||
&graph);
|
&graph);
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%x = a::aaa(%0)
|
%x = a::aaa(%0)
|
||||||
@ -46,7 +46,7 @@ void testTrivial2() {
|
|||||||
|
|
||||||
void testTrivial3() {
|
void testTrivial3() {
|
||||||
Graph graph, pattern;
|
Graph graph, pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%a = a::a(%0)
|
%a = a::a(%0)
|
||||||
@ -54,7 +54,7 @@ graph(%0):
|
|||||||
%c = a::c(%a, %b)
|
%c = a::c(%a, %b)
|
||||||
return (%c))IR",
|
return (%c))IR",
|
||||||
&graph);
|
&graph);
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%a, %b):
|
graph(%a, %b):
|
||||||
%c = a::c(%a, %b)
|
%c = a::c(%a, %b)
|
||||||
@ -92,7 +92,7 @@ void testTrivial4() {
|
|||||||
|
|
||||||
void testLinear1() {
|
void testLinear1() {
|
||||||
Graph graph, pattern;
|
Graph graph, pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%a = a::aaa(%0)
|
%a = a::aaa(%0)
|
||||||
@ -102,7 +102,7 @@ graph(%0):
|
|||||||
%a = a::aaa(%0)
|
%a = a::aaa(%0)
|
||||||
return (%d))IR",
|
return (%d))IR",
|
||||||
&graph);
|
&graph);
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%x = b::bbb(%0)
|
%x = b::bbb(%0)
|
||||||
@ -161,7 +161,7 @@ void testLinear2() {
|
|||||||
*/
|
*/
|
||||||
void testDiamond1() {
|
void testDiamond1() {
|
||||||
Graph graph, pattern1, pattern2;
|
Graph graph, pattern1, pattern2;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%o = o::ooo(%0)
|
%o = o::ooo(%0)
|
||||||
@ -173,7 +173,7 @@ graph(%0):
|
|||||||
return (%e))IR",
|
return (%e))IR",
|
||||||
&graph);
|
&graph);
|
||||||
|
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%a = a::aaa(%0)
|
%a = a::aaa(%0)
|
||||||
@ -185,7 +185,7 @@ graph(%0):
|
|||||||
AT_ASSERT(!findPatternMatches(pattern1, graph).empty());
|
AT_ASSERT(!findPatternMatches(pattern1, graph).empty());
|
||||||
|
|
||||||
// Check that order of nodes inside the diamond does not affect the result
|
// Check that order of nodes inside the diamond does not affect the result
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%a = a::aaa(%0)
|
%a = a::aaa(%0)
|
||||||
@ -247,7 +247,7 @@ void testDiamond2() {
|
|||||||
|
|
||||||
void testXPattern() {
|
void testXPattern() {
|
||||||
Graph graph, pattern;
|
Graph graph, pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0, %1):
|
graph(%0, %1):
|
||||||
%b = b::bbb(%0)
|
%b = b::bbb(%0)
|
||||||
@ -258,7 +258,7 @@ graph(%0, %1):
|
|||||||
%g = g::ggg(%e, %f)
|
%g = g::ggg(%e, %f)
|
||||||
return (%g))IR",
|
return (%g))IR",
|
||||||
&graph);
|
&graph);
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0, %1):
|
graph(%0, %1):
|
||||||
%b = b::bbb(%0)
|
%b = b::bbb(%0)
|
||||||
@ -274,7 +274,7 @@ graph(%0, %1):
|
|||||||
|
|
||||||
void testMultipleMatches() {
|
void testMultipleMatches() {
|
||||||
Graph graph, pattern;
|
Graph graph, pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%t0):
|
graph(%t0):
|
||||||
%t1 = a::aaa(%t0)
|
%t1 = a::aaa(%t0)
|
||||||
@ -283,7 +283,7 @@ graph(%t0):
|
|||||||
%t4 = a::aaa(%t3)
|
%t4 = a::aaa(%t3)
|
||||||
return (%t4))IR",
|
return (%t4))IR",
|
||||||
&graph);
|
&graph);
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%t0):
|
graph(%t0):
|
||||||
%t1 = a::aaa(%t0)
|
%t1 = a::aaa(%t0)
|
||||||
@ -295,7 +295,7 @@ graph(%t0):
|
|||||||
|
|
||||||
void testOverlappingMatches() {
|
void testOverlappingMatches() {
|
||||||
Graph graph, pattern;
|
Graph graph, pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%t0):
|
graph(%t0):
|
||||||
%t1 = a::aaa(%t0)
|
%t1 = a::aaa(%t0)
|
||||||
@ -304,7 +304,7 @@ graph(%t0):
|
|||||||
%t4 = a::aaa(%t3)
|
%t4 = a::aaa(%t3)
|
||||||
return (%t4))IR",
|
return (%t4))IR",
|
||||||
&graph);
|
&graph);
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%t0):
|
graph(%t0):
|
||||||
%t1 = a::aaa(%t0)
|
%t1 = a::aaa(%t0)
|
||||||
@ -317,7 +317,7 @@ graph(%t0):
|
|||||||
|
|
||||||
void testMatchInBasicBlocks1() {
|
void testMatchInBasicBlocks1() {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%a, %b, %c):
|
graph(%a, %b, %c):
|
||||||
%d = aten::mul(%a, %b)
|
%d = aten::mul(%a, %b)
|
||||||
@ -333,7 +333,7 @@ graph(%a, %b, %c):
|
|||||||
|
|
||||||
// Ensure the matches don't cross basic block boundaries
|
// Ensure the matches don't cross basic block boundaries
|
||||||
Graph pattern0;
|
Graph pattern0;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x, %y):
|
graph(%x, %y):
|
||||||
%z = aten::mul(%x, %y)
|
%z = aten::mul(%x, %y)
|
||||||
@ -342,7 +342,7 @@ graph(%x, %y):
|
|||||||
AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3);
|
AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3);
|
||||||
|
|
||||||
Graph pattern1;
|
Graph pattern1;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x, %y):
|
graph(%x, %y):
|
||||||
%z1 = aten::mul(%x, %y)
|
%z1 = aten::mul(%x, %y)
|
||||||
@ -354,7 +354,7 @@ graph(%x, %y):
|
|||||||
|
|
||||||
void testMatchInBasicBlocks2() {
|
void testMatchInBasicBlocks2() {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%a, %b):
|
graph(%a, %b):
|
||||||
%x = my::mul(%a, %b)
|
%x = my::mul(%a, %b)
|
||||||
@ -367,7 +367,7 @@ graph(%a, %b):
|
|||||||
|
|
||||||
// Check that we can match both mul ops
|
// Check that we can match both mul ops
|
||||||
Graph pattern0;
|
Graph pattern0;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x, %y):
|
graph(%x, %y):
|
||||||
%z = my::mul(%x, %y)
|
%z = my::mul(%x, %y)
|
||||||
@ -377,7 +377,7 @@ graph(%x, %y):
|
|||||||
|
|
||||||
// Ensure the matches don't cross basic block boundaries
|
// Ensure the matches don't cross basic block boundaries
|
||||||
Graph pattern1;
|
Graph pattern1;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x, %y):
|
graph(%x, %y):
|
||||||
%u = my::mul(%x, %y)
|
%u = my::mul(%x, %y)
|
||||||
@ -389,7 +389,7 @@ graph(%x, %y):
|
|||||||
|
|
||||||
void testMatchesAttributes() {
|
void testMatchesAttributes() {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%a = a::a[isattr=[1,2]](%0)
|
%a = a::a[isattr=[1,2]](%0)
|
||||||
@ -400,7 +400,7 @@ graph(%0):
|
|||||||
|
|
||||||
{
|
{
|
||||||
Graph pattern;
|
Graph pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%a, %b):
|
graph(%a, %b):
|
||||||
%c = a::c[myattr="qqq"](%a, %b)
|
%c = a::c[myattr="qqq"](%a, %b)
|
||||||
@ -410,7 +410,7 @@ graph(%a, %b):
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
Graph pattern;
|
Graph pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%a, %b):
|
graph(%a, %b):
|
||||||
%c = a::c[myattr="zzz"](%a, %b)
|
%c = a::c[myattr="zzz"](%a, %b)
|
||||||
@ -420,7 +420,7 @@ graph(%a, %b):
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
Graph pattern;
|
Graph pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%b = a::b[extraattr=10](%0)
|
%b = a::b[extraattr=10](%0)
|
||||||
@ -430,7 +430,7 @@ graph(%0):
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
Graph pattern;
|
Graph pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%b = a::b[intattr=10, floatattr=3.14](%0)
|
%b = a::b[intattr=10, floatattr=3.14](%0)
|
||||||
@ -440,7 +440,7 @@ graph(%0):
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
Graph pattern;
|
Graph pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%b = a::b[intattr=10, floatattr=3.14, strattr="rrr"](%0)
|
%b = a::b[intattr=10, floatattr=3.14, strattr="rrr"](%0)
|
||||||
@ -450,7 +450,7 @@ graph(%0):
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
Graph pattern;
|
Graph pattern;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%a = a::a[isattr=[1,2]](%0)
|
%a = a::a[isattr=[1,2]](%0)
|
||||||
@ -463,14 +463,14 @@ graph(%0):
|
|||||||
|
|
||||||
void testBadPattern() {
|
void testBadPattern() {
|
||||||
Graph graph, pattern1, pattern2;
|
Graph graph, pattern1, pattern2;
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%a = a::aaa(%0)
|
%a = a::aaa(%0)
|
||||||
return (%a))IR",
|
return (%a))IR",
|
||||||
&graph);
|
&graph);
|
||||||
|
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x):
|
graph(%x):
|
||||||
%y = my::node_with_subblock()
|
%y = my::node_with_subblock()
|
||||||
@ -481,7 +481,7 @@ graph(%x):
|
|||||||
&pattern1);
|
&pattern1);
|
||||||
ASSERT_ANY_THROW(findPatternMatches(pattern1, graph));
|
ASSERT_ANY_THROW(findPatternMatches(pattern1, graph));
|
||||||
|
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%x):
|
graph(%x):
|
||||||
%y = my::op1(%x)
|
%y = my::op1(%x)
|
||||||
|
@ -11,7 +11,7 @@ using namespace testing;
|
|||||||
void testFilterMatch() {
|
void testFilterMatch() {
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
|
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%a = a::aaa(%0)
|
%a = a::aaa(%0)
|
||||||
@ -27,7 +27,7 @@ graph(%a, %b):
|
|||||||
Graph pattern_graph;
|
Graph pattern_graph;
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
|
|
||||||
script::parseIR(
|
parseIR(
|
||||||
pattern,
|
pattern,
|
||||||
&pattern_graph,
|
&pattern_graph,
|
||||||
vmap);
|
vmap);
|
||||||
@ -55,7 +55,7 @@ graph(%a, %b):
|
|||||||
|
|
||||||
void testFilterNoMatch() {
|
void testFilterNoMatch() {
|
||||||
auto graph = std::make_shared<Graph>();
|
auto graph = std::make_shared<Graph>();
|
||||||
script::parseIR(
|
parseIR(
|
||||||
R"IR(
|
R"IR(
|
||||||
graph(%0):
|
graph(%0):
|
||||||
%a = a::aaa(%0)
|
%a = a::aaa(%0)
|
||||||
@ -71,7 +71,7 @@ graph(%a, %b):
|
|||||||
Graph pattern_graph;
|
Graph pattern_graph;
|
||||||
std::unordered_map<std::string, Value*> vmap;
|
std::unordered_map<std::string, Value*> vmap;
|
||||||
|
|
||||||
script::parseIR(
|
parseIR(
|
||||||
pattern,
|
pattern,
|
||||||
&pattern_graph,
|
&pattern_graph,
|
||||||
vmap);
|
vmap);
|
||||||
|
@ -84,7 +84,7 @@ std::shared_ptr<Graph> build_lstm() {
|
|||||||
%22 : Tensor = aten::mul(%14, %21)
|
%22 : Tensor = aten::mul(%14, %21)
|
||||||
return (%22, %20))IR";
|
return (%22, %20))IR";
|
||||||
auto g = std::make_shared<Graph>();
|
auto g = std::make_shared<Graph>();
|
||||||
torch::jit::script::parseIR(graph_string, g.get());
|
torch::jit::parseIR(graph_string, g.get());
|
||||||
g->lint();
|
g->lint();
|
||||||
|
|
||||||
return g;
|
return g;
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
namespace helpers {
|
namespace helpers {
|
||||||
template <typename Predicate>
|
template <typename Predicate>
|
||||||
void check_all_parameters(
|
void check_all_parameters(
|
||||||
const torch::jit::script::Module& module,
|
const torch::jit::Module& module,
|
||||||
Predicate predicate) {
|
Predicate predicate) {
|
||||||
for (at::Tensor parameter : module.parameters()) {
|
for (at::Tensor parameter : module.parameters()) {
|
||||||
AT_ASSERT(predicate(parameter));
|
AT_ASSERT(predicate(parameter));
|
||||||
@ -79,7 +79,7 @@ void get_autograd_operator_from_registry_and_execute_in_nograd_mode() {
|
|||||||
|
|
||||||
void load_serialized_module_with_custom_op_and_execute(
|
void load_serialized_module_with_custom_op_and_execute(
|
||||||
const std::string& path_to_exported_script_module) {
|
const std::string& path_to_exported_script_module) {
|
||||||
torch::jit::script::Module module =
|
torch::jit::Module module =
|
||||||
torch::jit::load(path_to_exported_script_module);
|
torch::jit::load(path_to_exported_script_module);
|
||||||
std::vector<torch::jit::IValue> inputs;
|
std::vector<torch::jit::IValue> inputs;
|
||||||
inputs.push_back(torch::ones(5));
|
inputs.push_back(torch::ones(5));
|
||||||
@ -90,7 +90,7 @@ void load_serialized_module_with_custom_op_and_execute(
|
|||||||
|
|
||||||
void test_argument_checking_for_serialized_modules(
|
void test_argument_checking_for_serialized_modules(
|
||||||
const std::string& path_to_exported_script_module) {
|
const std::string& path_to_exported_script_module) {
|
||||||
torch::jit::script::Module module =
|
torch::jit::Module module =
|
||||||
torch::jit::load(path_to_exported_script_module);
|
torch::jit::load(path_to_exported_script_module);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -124,7 +124,7 @@ void test_argument_checking_for_serialized_modules(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void test_move_to_device(const std::string& path_to_exported_script_module) {
|
void test_move_to_device(const std::string& path_to_exported_script_module) {
|
||||||
torch::jit::script::Module module =
|
torch::jit::Module module =
|
||||||
torch::jit::load(path_to_exported_script_module);
|
torch::jit::load(path_to_exported_script_module);
|
||||||
|
|
||||||
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
||||||
@ -145,7 +145,7 @@ void test_move_to_device(const std::string& path_to_exported_script_module) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void test_move_to_dtype(const std::string& path_to_exported_script_module) {
|
void test_move_to_dtype(const std::string& path_to_exported_script_module) {
|
||||||
torch::jit::script::Module module =
|
torch::jit::Module module =
|
||||||
torch::jit::load(path_to_exported_script_module);
|
torch::jit::load(path_to_exported_script_module);
|
||||||
|
|
||||||
module.to(torch::kInt);
|
module.to(torch::kInt);
|
||||||
|
@ -24,7 +24,7 @@ struct MobileCallGuard {
|
|||||||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
|
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
torch::jit::script::Module loadModel(const std::string& path) {
|
torch::jit::Module loadModel(const std::string& path) {
|
||||||
MobileCallGuard guard;
|
MobileCallGuard guard;
|
||||||
auto module = torch::jit::load(path);
|
auto module = torch::jit::load(path);
|
||||||
module.eval();
|
module.eval();
|
||||||
|
@ -138,7 +138,7 @@ TORCH_NN_MODULE_TEST_INIT = Template("""\n
|
|||||||
void ${module_variant_name}_test_init(
|
void ${module_variant_name}_test_init(
|
||||||
const std::string& saved_module_path,
|
const std::string& saved_module_path,
|
||||||
const std::string& device) {
|
const std::string& device) {
|
||||||
torch::jit::script::Module m_init_by_python = torch::jit::load(saved_module_path);
|
torch::jit::Module m_init_by_python = torch::jit::load(saved_module_path);
|
||||||
|
|
||||||
torch::manual_seed(2);
|
torch::manual_seed(2);
|
||||||
${module_qualified_name} m_init_by_cpp${cpp_constructor_args};
|
${module_qualified_name} m_init_by_cpp${cpp_constructor_args};
|
||||||
|
@ -30,7 +30,7 @@ namespace jit {
|
|||||||
/// )JIT");
|
/// )JIT");
|
||||||
/// IValue output = module->run_method("relu_script", a, b);
|
/// IValue output = module->run_method("relu_script", a, b);
|
||||||
/// \endrst
|
/// \endrst
|
||||||
TORCH_API std::shared_ptr<script::CompilationUnit> compile(const std::string& source);
|
TORCH_API std::shared_ptr<CompilationUnit> compile(const std::string& source);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -39,7 +39,7 @@ namespace torch {
|
|||||||
template <typename Value, typename... SaveToArgs>
|
template <typename Value, typename... SaveToArgs>
|
||||||
void save(const Value& value, SaveToArgs&&... args) {
|
void save(const Value& value, SaveToArgs&&... args) {
|
||||||
serialize::OutputArchive archive(
|
serialize::OutputArchive archive(
|
||||||
std::make_shared<jit::script::CompilationUnit>());
|
std::make_shared<jit::CompilationUnit>());
|
||||||
archive << value;
|
archive << value;
|
||||||
archive.save_to(std::forward<SaveToArgs>(args)...);
|
archive.save_to(std::forward<SaveToArgs>(args)...);
|
||||||
}
|
}
|
||||||
@ -65,7 +65,7 @@ void save(const Value& value, SaveToArgs&&... args) {
|
|||||||
template <typename... SaveToArgs>
|
template <typename... SaveToArgs>
|
||||||
void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) {
|
void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) {
|
||||||
serialize::OutputArchive archive(
|
serialize::OutputArchive archive(
|
||||||
std::make_shared<jit::script::CompilationUnit>());
|
std::make_shared<jit::CompilationUnit>());
|
||||||
for (size_t i = 0; i < tensor_vec.size(); i++) {
|
for (size_t i = 0; i < tensor_vec.size(); i++) {
|
||||||
auto& value = tensor_vec[i];
|
auto& value = tensor_vec[i];
|
||||||
archive.write(c10::to_string(i), value);
|
archive.write(c10::to_string(i), value);
|
||||||
|
@ -18,9 +18,7 @@ class Tensor;
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
using at::Tensor;
|
using at::Tensor;
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
struct Module;
|
struct Module;
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
||||||
@ -108,7 +106,7 @@ class TORCH_API InputArchive final {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
jit::script::Module module_;
|
jit::Module module_;
|
||||||
std::string hierarchy_prefix_;
|
std::string hierarchy_prefix_;
|
||||||
};
|
};
|
||||||
} // namespace serialize
|
} // namespace serialize
|
||||||
|
@ -15,9 +15,7 @@ class Tensor;
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
using at::Tensor;
|
using at::Tensor;
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
struct Module;
|
struct Module;
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
||||||
@ -25,8 +23,8 @@ namespace torch {
|
|||||||
namespace serialize {
|
namespace serialize {
|
||||||
class TORCH_API OutputArchive final {
|
class TORCH_API OutputArchive final {
|
||||||
public:
|
public:
|
||||||
explicit OutputArchive(std::shared_ptr<jit::script::CompilationUnit> cu);
|
explicit OutputArchive(std::shared_ptr<jit::CompilationUnit> cu);
|
||||||
explicit OutputArchive() : cu_(std::make_shared<jit::script::CompilationUnit>()), module_("__torch__.Module", cu_) {}
|
explicit OutputArchive() : cu_(std::make_shared<jit::CompilationUnit>()), module_("__torch__.Module", cu_) {}
|
||||||
|
|
||||||
// Move is allowed.
|
// Move is allowed.
|
||||||
OutputArchive(OutputArchive&&) = default;
|
OutputArchive(OutputArchive&&) = default;
|
||||||
@ -36,7 +34,7 @@ class TORCH_API OutputArchive final {
|
|||||||
OutputArchive(OutputArchive&) = delete;
|
OutputArchive(OutputArchive&) = delete;
|
||||||
OutputArchive& operator=(OutputArchive&) = delete;
|
OutputArchive& operator=(OutputArchive&) = delete;
|
||||||
|
|
||||||
std::shared_ptr<jit::script::CompilationUnit> compilation_unit() const {
|
std::shared_ptr<jit::CompilationUnit> compilation_unit() const {
|
||||||
return cu_;
|
return cu_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,8 +73,8 @@ class TORCH_API OutputArchive final {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<jit::script::CompilationUnit> cu_;
|
std::shared_ptr<jit::CompilationUnit> cu_;
|
||||||
jit::script::Module module_;
|
jit::Module module_;
|
||||||
};
|
};
|
||||||
} // namespace serialize
|
} // namespace serialize
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -9,12 +9,12 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
std::shared_ptr<script::CompilationUnit> compile(const std::string& source) {
|
std::shared_ptr<CompilationUnit> compile(const std::string& source) {
|
||||||
auto module = std::make_shared<script::CompilationUnit>();
|
auto module = std::make_shared<CompilationUnit>();
|
||||||
module->define(
|
module->define(
|
||||||
c10::nullopt,
|
c10::nullopt,
|
||||||
source,
|
source,
|
||||||
script::nativeResolver(),
|
nativeResolver(),
|
||||||
nullptr);
|
nullptr);
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace serialize {
|
namespace serialize {
|
||||||
|
|
||||||
InputArchive::InputArchive() : module_("Module", std::make_shared<jit::script::CompilationUnit>()) {}
|
InputArchive::InputArchive() : module_("Module", std::make_shared<jit::CompilationUnit>()) {}
|
||||||
|
|
||||||
void InputArchive::read(const std::string& key, c10::IValue& ivalue) {
|
void InputArchive::read(const std::string& key, c10::IValue& ivalue) {
|
||||||
ivalue = module_.attr(key);
|
ivalue = module_.attr(key);
|
||||||
@ -161,7 +161,7 @@ std::vector<std::string> InputArchive::keys() {
|
|||||||
std::vector<std::string> all_keys;
|
std::vector<std::string> all_keys;
|
||||||
all_keys.reserve(module_.named_attributes(/*recurse=*/false).size());
|
all_keys.reserve(module_.named_attributes(/*recurse=*/false).size());
|
||||||
|
|
||||||
for (const torch::jit::script::NameValue& s : module_.named_attributes(/*recurse=*/false)) {
|
for (const torch::jit::NameValue& s : module_.named_attributes(/*recurse=*/false)) {
|
||||||
all_keys.push_back(s.name);
|
all_keys.push_back(s.name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace serialize {
|
namespace serialize {
|
||||||
OutputArchive::OutputArchive(std::shared_ptr<jit::script::CompilationUnit> cu)
|
OutputArchive::OutputArchive(std::shared_ptr<jit::CompilationUnit> cu)
|
||||||
: cu_(std::move(cu)),
|
: cu_(std::move(cu)),
|
||||||
module_("__torch__.Module", cu_, /*shouldMangle=*/true) {}
|
module_("__torch__.Module", cu_, /*shouldMangle=*/true) {}
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ TypePtr tryInferTypeWithTypeHint(
|
|||||||
const py::object& type_hint) {
|
const py::object& type_hint) {
|
||||||
// If the py::object to be contained by the RRef is a ScripModule, we enforce
|
// If the py::object to be contained by the RRef is a ScripModule, we enforce
|
||||||
// users to specify its ModuleInterface type.
|
// users to specify its ModuleInterface type.
|
||||||
if (auto module = jit::script::as_module(value)) {
|
if (auto module = jit::as_module(value)) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!type_hint.is_none(),
|
!type_hint.is_none(),
|
||||||
"The RRef being created contains a ScriptModule, "
|
"The RRef being created contains a ScriptModule, "
|
||||||
|
@ -27,8 +27,8 @@ namespace {
|
|||||||
|
|
||||||
// PythonTypeResolver that inherits from Script::Resolver to
|
// PythonTypeResolver that inherits from Script::Resolver to
|
||||||
// support resolving types together with ScriptTypeParser.
|
// support resolving types together with ScriptTypeParser.
|
||||||
struct PythonTypeResolver : public jit::script::Resolver {
|
struct PythonTypeResolver : public jit::Resolver {
|
||||||
std::shared_ptr<jit::script::SugaredValue> resolveValue(
|
std::shared_ptr<jit::SugaredValue> resolveValue(
|
||||||
const std::string& /* unused */,
|
const std::string& /* unused */,
|
||||||
torch::jit::Function& /* unused */,
|
torch::jit::Function& /* unused */,
|
||||||
const jit::SourceRange& /* unused */) override {
|
const jit::SourceRange& /* unused */) override {
|
||||||
@ -67,7 +67,7 @@ PythonRpcHandler::PythonRpcHandler() {
|
|||||||
pyHandleException_ = getFunction(module, "_handle_exception");
|
pyHandleException_ = getFunction(module, "_handle_exception");
|
||||||
pyGetQualifiedName_ = py::module::import("torch.jit").attr("_qualified_name");
|
pyGetQualifiedName_ = py::module::import("torch.jit").attr("_qualified_name");
|
||||||
jitCompilationUnit_ = torch::jit::get_python_cu();
|
jitCompilationUnit_ = torch::jit::get_python_cu();
|
||||||
typeParser_ = std::make_shared<jit::script::ScriptTypeParser>(
|
typeParser_ = std::make_shared<jit::ScriptTypeParser>(
|
||||||
std::make_shared<PythonTypeResolver>());
|
std::make_shared<PythonTypeResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,7 +97,7 @@ PythonRpcHandler& PythonRpcHandler::getInstance() {
|
|||||||
return *handler;
|
return *handler;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<torch::jit::script::CompilationUnit> PythonRpcHandler::
|
std::shared_ptr<torch::jit::CompilationUnit> PythonRpcHandler::
|
||||||
jitCompilationUnit() {
|
jitCompilationUnit() {
|
||||||
return jitCompilationUnit_;
|
return jitCompilationUnit_;
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ class PYBIND11_EXPORT PythonRpcHandler {
|
|||||||
// PythonRpcHandler.
|
// PythonRpcHandler.
|
||||||
void cleanup();
|
void cleanup();
|
||||||
|
|
||||||
std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit();
|
std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit();
|
||||||
|
|
||||||
// Parse the string to recover the jit_type, this is used for RRef python
|
// Parse the string to recover the jit_type, this is used for RRef python
|
||||||
// pickling/unpickling type recovery. The type string inference rule is as
|
// pickling/unpickling type recovery. The type string inference rule is as
|
||||||
@ -97,11 +97,11 @@ class PYBIND11_EXPORT PythonRpcHandler {
|
|||||||
// and imported in C++ (see get_python_cu() in
|
// and imported in C++ (see get_python_cu() in
|
||||||
// csrc/jit/python/pybind_utils.h). We import the compilation unit here only
|
// csrc/jit/python/pybind_utils.h). We import the compilation unit here only
|
||||||
// once for less cost and thread safety.
|
// once for less cost and thread safety.
|
||||||
std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit_;
|
std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit_;
|
||||||
|
|
||||||
// jit type parser to parse type_str back to TypePtr for RRef type
|
// jit type parser to parse type_str back to TypePtr for RRef type
|
||||||
// recovery when pickling and unpickling RRef
|
// recovery when pickling and unpickling RRef
|
||||||
std::shared_ptr<jit::script::ScriptTypeParser> typeParser_;
|
std::shared_ptr<jit::ScriptTypeParser> typeParser_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace rpc
|
} // namespace rpc
|
||||||
|
@ -24,7 +24,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
struct Def;
|
struct Def;
|
||||||
struct ClassDef;
|
struct ClassDef;
|
||||||
@ -281,22 +280,25 @@ struct TORCH_API CompilationUnit {
|
|||||||
mutable size_t mangleIndex_ = 0;
|
mutable size_t mangleIndex_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
|
|
||||||
// An owning pointer to a Function. Just a pair of a raw Function ptr and it's
|
// An owning pointer to a Function. Just a pair of a raw Function ptr and it's
|
||||||
// owning CU. We need this because pybind requires a ref-counted way to refer to
|
// owning CU. We need this because pybind requires a ref-counted way to refer to
|
||||||
// Functions.
|
// Functions.
|
||||||
struct StrongFunctionPtr {
|
struct StrongFunctionPtr {
|
||||||
StrongFunctionPtr(
|
StrongFunctionPtr(
|
||||||
std::shared_ptr<script::CompilationUnit> cu,
|
std::shared_ptr<CompilationUnit> cu,
|
||||||
Function* function)
|
Function* function)
|
||||||
: cu_(std::move(cu)), function_(function) {
|
: cu_(std::move(cu)), function_(function) {
|
||||||
TORCH_INTERNAL_ASSERT(cu_);
|
TORCH_INTERNAL_ASSERT(cu_);
|
||||||
TORCH_INTERNAL_ASSERT(function_);
|
TORCH_INTERNAL_ASSERT(function_);
|
||||||
}
|
}
|
||||||
std::shared_ptr<script::CompilationUnit> cu_;
|
std::shared_ptr<CompilationUnit> cu_;
|
||||||
Function* function_;
|
Function* function_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace script {
|
||||||
|
// We once had a `script::` namespace that was deleted. This is for backcompat
|
||||||
|
// of the public API; new code should not use this type alias.
|
||||||
|
using CompilationUnit = ::torch::jit::CompilationUnit;
|
||||||
|
}
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
|
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
|
||||||
|
|
||||||
@ -61,6 +60,11 @@ struct TORCH_API Method {
|
|||||||
Function* function_;
|
Function* function_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace script
|
namespace script {
|
||||||
|
// We once had a `script::` namespace that was deleted. This is for backcompat
|
||||||
|
// of the public API; new code should not use this type alias.
|
||||||
|
using Method = ::torch::jit::Method;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
static ObjectPtr create_module_object(
|
static ObjectPtr create_module_object(
|
||||||
c10::QualifiedName class_name,
|
c10::QualifiedName class_name,
|
||||||
@ -389,14 +388,13 @@ void Module::dump(
|
|||||||
<< std::endl;
|
<< std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
torch::jit::script::Module IValue::toModule() const {
|
torch::jit::Module IValue::toModule() const {
|
||||||
return torch::jit::script::Module(toObject());
|
return torch::jit::Module(toObject());
|
||||||
}
|
}
|
||||||
bool IValue::isModule() const {
|
bool IValue::isModule() const {
|
||||||
return isObject() && toObjectRef().type()->is_module();
|
return isObject() && toObjectRef().type()->is_module();
|
||||||
|
@ -33,7 +33,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
using ::c10::Argument;
|
using ::c10::Argument;
|
||||||
using ::c10::FunctionSchema;
|
using ::c10::FunctionSchema;
|
||||||
@ -554,6 +553,11 @@ struct NamedPolicy {
|
|||||||
|
|
||||||
TORCH_API bool& getInlineEverythingMode();
|
TORCH_API bool& getInlineEverythingMode();
|
||||||
|
|
||||||
} // namespace script
|
namespace script {
|
||||||
|
// We once had a `script::` namespace that was deleted. This is for backcompat
|
||||||
|
// of the public API; new code should not use this type alias.
|
||||||
|
using Module = ::torch::jit::Module;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
void Module::save(std::ostream& out, const ExtraFilesMap& extra_files) const {
|
void Module::save(std::ostream& out, const ExtraFilesMap& extra_files) const {
|
||||||
ExportModule(*this, out, extra_files, false);
|
ExportModule(*this, out, extra_files, false);
|
||||||
@ -26,6 +25,5 @@ void Module::_save_for_mobile(
|
|||||||
ExportModule(*this, filename, extra_files, true);
|
ExportModule(*this, filename, extra_files, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
Object::Object(
|
Object::Object(
|
||||||
std::shared_ptr<CompilationUnit> cu,
|
std::shared_ptr<CompilationUnit> cu,
|
||||||
@ -35,10 +34,9 @@ void Object::define(const std::string& src, const ResolverPtr& resolver) {
|
|||||||
_ivalue()->compilation_unit()->define(
|
_ivalue()->compilation_unit()->define(
|
||||||
*type()->name(),
|
*type()->name(),
|
||||||
src,
|
src,
|
||||||
resolver ? resolver : script::nativeResolver(),
|
resolver ? resolver : nativeResolver(),
|
||||||
&self);
|
&self);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
struct Resolver;
|
struct Resolver;
|
||||||
using ResolverPtr = std::shared_ptr<Resolver>;
|
using ResolverPtr = std::shared_ptr<Resolver>;
|
||||||
@ -132,6 +131,10 @@ struct TORCH_API Object {
|
|||||||
mutable ObjectPtr _ivalue_;
|
mutable ObjectPtr _ivalue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace script
|
namespace script {
|
||||||
|
// We once had a `script::` namespace that was deleted. This is for backcompat
|
||||||
|
// of the public API; new code should not use this type alias.
|
||||||
|
using Object = ::torch::jit::Object;
|
||||||
|
}
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -285,7 +285,7 @@ The load process has the following steps:
|
|||||||
|
|
||||||
1. Unpickle `constants.pkl`, which produces a tuple of all tensor constants
|
1. Unpickle `constants.pkl`, which produces a tuple of all tensor constants
|
||||||
referenced in code.
|
referenced in code.
|
||||||
2. Unpickle `data.pkl` into the top-level `script::Module` and return it.
|
2. Unpickle `data.pkl` into the top-level `Module` and return it.
|
||||||
|
|
||||||
The unpickling process consists of a single call to unpickle the module
|
The unpickling process consists of a single call to unpickle the module
|
||||||
object contained in `data.pkl`. The `Unpickler` is given a callback that lets it
|
object contained in `data.pkl`. The `Unpickler` is given a callback that lets it
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
auto scalar_operators_source = CodeTemplate(
|
auto scalar_operators_source = CodeTemplate(
|
||||||
R"SCRIPT(
|
R"SCRIPT(
|
||||||
@ -81,7 +80,7 @@ struct BuiltinFunctionRegistry {
|
|||||||
std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
|
std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
|
||||||
modules.emplace_back(cu);
|
modules.emplace_back(cu);
|
||||||
cu->define(
|
cu->define(
|
||||||
c10::nullopt, source, script::nativeResolver(), /*self=*/nullptr);
|
c10::nullopt, source, nativeResolver(), /*self=*/nullptr);
|
||||||
for (auto& method : cu->get_functions()) {
|
for (auto& method : cu->get_functions()) {
|
||||||
builtins_by_name_[Symbol::fromQualString(
|
builtins_by_name_[Symbol::fromQualString(
|
||||||
the_namespace + "::" + method->name())]
|
the_namespace + "::" + method->name())]
|
||||||
@ -134,6 +133,5 @@ const std::vector<Function*>& getAllBuiltinFunctionsFor(
|
|||||||
return registry.getAllBuiltinFunctionsFor(name);
|
return registry.getAllBuiltinFunctionsFor(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -5,9 +5,7 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name);
|
TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name);
|
||||||
}
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
|
ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
|
||||||
auto cu = get_python_cu();
|
auto cu = get_python_cu();
|
||||||
@ -303,6 +302,5 @@ ConcreteModuleType::getModulesPy() const {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -9,7 +9,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
enum class IterableModuleKind { NONE, LIST, DICT };
|
enum class IterableModuleKind { NONE, LIST, DICT };
|
||||||
class ConcreteModuleType;
|
class ConcreteModuleType;
|
||||||
@ -228,6 +227,5 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
|
|||||||
TypePtr jitType_;
|
TypePtr jitType_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -9,7 +9,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
// At the beginning of the pass the Graph has already undergone type checking,
|
// At the beginning of the pass the Graph has already undergone type checking,
|
||||||
// and writes or reads to a variable are emitted as Loads and Stores in the
|
// and writes or reads to a variable are emitted as Loads and Stores in the
|
||||||
@ -328,6 +327,5 @@ void ConvertToSSA(std::shared_ptr<Graph>& graph) {
|
|||||||
TransformExits(graph);
|
TransformExits(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -8,11 +8,9 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
// Convert a graph with Loads & Stores into SSA form
|
// Convert a graph with Loads & Stores into SSA form
|
||||||
TORCH_API void ConvertToSSA(std::shared_ptr<Graph>& graph);
|
TORCH_API void ConvertToSSA(std::shared_ptr<Graph>& graph);
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
// computes levenshtein edit distance between two words
|
// computes levenshtein edit distance between two words
|
||||||
// returns maxEditDistance + 1 if the edit distance exceeds MaxEditDistance
|
// returns maxEditDistance + 1 if the edit distance exceeds MaxEditDistance
|
||||||
@ -52,6 +51,5 @@ size_t ComputeEditDistance(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -5,13 +5,11 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
TORCH_API size_t ComputeEditDistance(
|
TORCH_API size_t ComputeEditDistance(
|
||||||
const char* word1,
|
const char* word1,
|
||||||
const char* word2,
|
const char* word2,
|
||||||
size_t maxEditDistance);
|
size_t maxEditDistance);
|
||||||
|
|
||||||
}
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
// Avoid storing objects with destructor in thread_local for mobile build.
|
// Avoid storing objects with destructor in thread_local for mobile build.
|
||||||
#ifndef C10_MOBILE
|
#ifndef C10_MOBILE
|
||||||
@ -72,6 +71,5 @@ const char* ErrorReport::what() const noexcept {
|
|||||||
return the_message.c_str();
|
return the_message.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
struct Call {
|
struct Call {
|
||||||
std::string fn_name;
|
std::string fn_name;
|
||||||
@ -49,6 +48,5 @@ const ErrorReport& operator<<(const ErrorReport& e, const T& t) {
|
|||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -22,7 +22,6 @@ using at::TypeKind;
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
namespace script {
|
|
||||||
namespace {
|
namespace {
|
||||||
struct SchemaParser {
|
struct SchemaParser {
|
||||||
SchemaParser(const std::string& str)
|
SchemaParser(const std::string& str)
|
||||||
@ -290,10 +289,9 @@ struct SchemaParser {
|
|||||||
SchemaTypeParser type_parser;
|
SchemaTypeParser type_parser;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace script
|
|
||||||
|
|
||||||
C10_EXPORT either<OperatorName, FunctionSchema> parseSchemaOrName(const std::string& schemaOrName) {
|
C10_EXPORT either<OperatorName, FunctionSchema> parseSchemaOrName(const std::string& schemaOrName) {
|
||||||
return script::SchemaParser(schemaOrName).parseDeclarations().at(0);
|
return SchemaParser(schemaOrName).parseDeclarations().at(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
C10_EXPORT FunctionSchema parseSchema(const std::string& schema) {
|
C10_EXPORT FunctionSchema parseSchema(const std::string& schema) {
|
||||||
|
@ -30,7 +30,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
using FunctionTable = std::unordered_map<std::string, Function&>;
|
using FunctionTable = std::unordered_map<std::string, Function&>;
|
||||||
using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
|
using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
|
||||||
@ -491,7 +490,7 @@ struct Environment {
|
|||||||
if (!retval) {
|
if (!retval) {
|
||||||
if (auto type = resolver->resolveType(ident, range)) {
|
if (auto type = resolver->resolveType(ident, range)) {
|
||||||
if (auto tuple_type = type->cast<TupleType>()) {
|
if (auto tuple_type = type->cast<TupleType>()) {
|
||||||
retval = std::make_shared<script::NamedTupleConstructor>(tuple_type);
|
retval = std::make_shared<NamedTupleConstructor>(tuple_type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -503,7 +502,7 @@ struct Environment {
|
|||||||
if (!retval) {
|
if (!retval) {
|
||||||
if (auto type = resolver->resolveType(ident, range)) {
|
if (auto type = resolver->resolveType(ident, range)) {
|
||||||
if (auto class_type = type->cast<ClassType>()) {
|
if (auto class_type = type->cast<ClassType>()) {
|
||||||
retval = std::make_shared<script::ClassValue>(class_type);
|
retval = std::make_shared<ClassValue>(class_type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3598,7 +3597,7 @@ std::vector<Function*> CompilationUnit::define(
|
|||||||
void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
|
void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
|
||||||
liftClosures(to_clean);
|
liftClosures(to_clean);
|
||||||
inlineForkedClosures(to_clean);
|
inlineForkedClosures(to_clean);
|
||||||
if (script::getInlineEverythingMode()) {
|
if (getInlineEverythingMode()) {
|
||||||
Inline(*to_clean);
|
Inline(*to_clean);
|
||||||
}
|
}
|
||||||
// remove any uses of tuples that we inserted that are not needed
|
// remove any uses of tuples that we inserted that are not needed
|
||||||
@ -3669,6 +3668,5 @@ void CompilationUnit::define_interface(
|
|||||||
this->register_type(iface);
|
this->register_type(iface);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -12,12 +12,10 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
TORCH_API void runCleanupPasses(std::shared_ptr<Graph>& to_clean);
|
TORCH_API void runCleanupPasses(std::shared_ptr<Graph>& to_clean);
|
||||||
|
|
||||||
TORCH_API bool meaningfulName(const std::string& name);
|
TORCH_API bool meaningfulName(const std::string& name);
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
static const std::unordered_map<int, int> binary_prec = {
|
static const std::unordered_map<int, int> binary_prec = {
|
||||||
{TK_IF, 1},
|
{TK_IF, 1},
|
||||||
@ -101,6 +100,5 @@ C10_EXPORT SharedParserData& sharedParserData() {
|
|||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
// single character tokens are just the character itself '+'
|
// single character tokens are just the character itself '+'
|
||||||
// multi-character tokens need an entry here
|
// multi-character tokens need an entry here
|
||||||
@ -180,7 +179,7 @@ struct CAFFE2_API SharedParserData {
|
|||||||
return false;
|
return false;
|
||||||
const char* startptr = str.c_str() + start;
|
const char* startptr = str.c_str() + start;
|
||||||
char* endptr;
|
char* endptr;
|
||||||
torch::jit::script::strtod_c(startptr, &endptr);
|
torch::jit::strtod_c(startptr, &endptr);
|
||||||
*len = endptr - startptr;
|
*len = endptr - startptr;
|
||||||
return *len > 0;
|
return *len > 0;
|
||||||
}
|
}
|
||||||
@ -515,6 +514,5 @@ struct Lexer {
|
|||||||
std::vector<Token> next_tokens;
|
std::vector<Token> next_tokens;
|
||||||
SharedParserData& shared;
|
SharedParserData& shared;
|
||||||
};
|
};
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
// Simple data structure for containing a type T in nested control blocks
|
// Simple data structure for containing a type T in nested control blocks
|
||||||
// Should only be used after initial compilation where type checking and
|
// Should only be used after initial compilation where type checking and
|
||||||
@ -52,6 +51,5 @@ struct MiniEnvironment {
|
|||||||
std::unordered_map<std::string, T> table;
|
std::unordered_map<std::string, T> table;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
inline bool isCharCount(char c, const std::string& str, size_t start, int len) {
|
inline bool isCharCount(char c, const std::string& str, size_t start, int len) {
|
||||||
// count checks from [start, start + len)
|
// count checks from [start, start + len)
|
||||||
@ -87,6 +86,5 @@ inline std::string parseStringLiteral(
|
|||||||
return ret_str;
|
return ret_str;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
Decl mergeTypesFromTypeComment(
|
Decl mergeTypesFromTypeComment(
|
||||||
const Decl& decl,
|
const Decl& decl,
|
||||||
@ -735,6 +734,5 @@ Expr Parser::parseExp() {
|
|||||||
return pImpl->parseExp();
|
return pImpl->parseExp();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
struct Decl;
|
struct Decl;
|
||||||
struct ParserImpl;
|
struct ParserImpl;
|
||||||
@ -30,6 +29,5 @@ struct TORCH_API Parser {
|
|||||||
std::unique_ptr<ParserImpl> pImpl;
|
std::unique_ptr<ParserImpl> pImpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|~";
|
static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|~";
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
struct Resolver;
|
struct Resolver;
|
||||||
using ResolverPtr = std::shared_ptr<Resolver>;
|
using ResolverPtr = std::shared_ptr<Resolver>;
|
||||||
@ -65,6 +64,5 @@ struct NativeResolver : public Resolver {
|
|||||||
inline std::shared_ptr<NativeResolver> nativeResolver() {
|
inline std::shared_ptr<NativeResolver> nativeResolver() {
|
||||||
return std::make_shared<NativeResolver>();
|
return std::make_shared<NativeResolver>();
|
||||||
}
|
}
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
static inline TypePtr unwrapOptional(TypePtr opt_type) {
|
static inline TypePtr unwrapOptional(TypePtr opt_type) {
|
||||||
if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
|
if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
|
||||||
@ -612,6 +611,5 @@ Value* emitBuiltinCall(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
// try to match a list if inputs and keyword 'attributes' to this schema,
|
// try to match a list if inputs and keyword 'attributes' to this schema,
|
||||||
// if it works return the flat list of positional inputs to the call
|
// if it works return the flat list of positional inputs to the call
|
||||||
@ -60,6 +59,5 @@ TORCH_API Value* tryConvertToType(
|
|||||||
const TypePtr& concrete_type,
|
const TypePtr& concrete_type,
|
||||||
Value* value,
|
Value* value,
|
||||||
bool allow_conversions);
|
bool allow_conversions);
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -31,7 +31,6 @@ using c10::VarType;
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
TypePtr SchemaTypeParser::parseBaseType() {
|
TypePtr SchemaTypeParser::parseBaseType() {
|
||||||
static std::unordered_map<std::string, TypePtr> type_map = {
|
static std::unordered_map<std::string, TypePtr> type_map = {
|
||||||
@ -292,6 +291,5 @@ void SchemaTypeParser::parseList(
|
|||||||
L.expect(end);
|
L.expect(end);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
using TypePtr = c10::TypePtr;
|
using TypePtr = c10::TypePtr;
|
||||||
|
|
||||||
@ -33,6 +32,5 @@ struct CAFFE2_API SchemaTypeParser {
|
|||||||
Lexer& L;
|
Lexer& L;
|
||||||
size_t next_id = 0;
|
size_t next_id = 0;
|
||||||
};
|
};
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
const std::unordered_map<std::string, TypePtr>& string_to_type_lut();
|
const std::unordered_map<std::string, TypePtr>& string_to_type_lut();
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -355,6 +354,5 @@ c10::IValue ScriptTypeParser::parseClassConstant(const Assign& assign) {
|
|||||||
return *default_val.begin();
|
return *default_val.begin();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* class ScriptTypeParser
|
* class ScriptTypeParser
|
||||||
@ -45,6 +44,5 @@ class TORCH_API ScriptTypeParser {
|
|||||||
|
|
||||||
ResolverPtr resolver_ = nullptr;
|
ResolverPtr resolver_ = nullptr;
|
||||||
};
|
};
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
using namespace c10;
|
using namespace c10;
|
||||||
const std::unordered_map<std::string, TypePtr>& string_to_type_lut() {
|
const std::unordered_map<std::string, TypePtr>& string_to_type_lut() {
|
||||||
static std::unordered_map<std::string, TypePtr> map = {
|
static std::unordered_map<std::string, TypePtr> map = {
|
||||||
@ -22,6 +21,5 @@ const std::unordered_map<std::string, TypePtr>& string_to_type_lut() {
|
|||||||
{"tuple", AnyTupleType::get()}};
|
{"tuple", AnyTupleType::get()}};
|
||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -76,7 +76,6 @@ double parse_inf_or_nan(const char *p, char **endptr)
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
C10_EXPORT double strtod_c(const char *nptr, char **endptr)
|
C10_EXPORT double strtod_c(const char *nptr, char **endptr)
|
||||||
@ -265,4 +264,3 @@ C10_EXPORT float strtof_c(const char *nptr, char **endptr)
|
|||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
@ -4,11 +4,9 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
CAFFE2_API double strtod_c(const char *nptr, char **endptr);
|
CAFFE2_API double strtod_c(const char *nptr, char **endptr);
|
||||||
CAFFE2_API float strtof_c(const char *nptr, char **endptr);
|
CAFFE2_API float strtof_c(const char *nptr, char **endptr);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
struct NoneValue : SugaredValue {
|
struct NoneValue : SugaredValue {
|
||||||
NoneValue() = default;
|
NoneValue() = default;
|
||||||
@ -623,6 +622,5 @@ std::shared_ptr<BuiltinFunction> BuiltinFunction::tryCreate(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -10,7 +10,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
using SugaredValuePtr = std::shared_ptr<SugaredValue>;
|
using SugaredValuePtr = std::shared_ptr<SugaredValue>;
|
||||||
|
|
||||||
@ -365,7 +364,7 @@ struct FunctionValue : public SugaredValue {
|
|||||||
try {
|
try {
|
||||||
callee->ensure_defined();
|
callee->ensure_defined();
|
||||||
} catch (const RecursiveMethodCallError&) {
|
} catch (const RecursiveMethodCallError&) {
|
||||||
throw script::ErrorReport(loc)
|
throw ErrorReport(loc)
|
||||||
<< " function '" << callee->name() << "' is called recursively. "
|
<< " function '" << callee->name() << "' is called recursively. "
|
||||||
<< "Recursive calls are not supported";
|
<< "Recursive calls are not supported";
|
||||||
}
|
}
|
||||||
@ -428,7 +427,7 @@ struct MethodValue : public SugaredValue {
|
|||||||
try {
|
try {
|
||||||
method->ensure_defined();
|
method->ensure_defined();
|
||||||
} catch (const RecursiveMethodCallError&) {
|
} catch (const RecursiveMethodCallError&) {
|
||||||
throw script::ErrorReport(loc)
|
throw ErrorReport(loc)
|
||||||
<< " method '" << method->name() << "' is called recursively. "
|
<< " method '" << method->name() << "' is called recursively. "
|
||||||
<< "Recursive calls are not supported";
|
<< "Recursive calls are not supported";
|
||||||
}
|
}
|
||||||
@ -652,6 +651,5 @@ struct SimpleSelf : public Self {
|
|||||||
private:
|
private:
|
||||||
ClassTypePtr classType_;
|
ClassTypePtr classType_;
|
||||||
};
|
};
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -152,7 +152,7 @@ Value* TracingState::getValue(const IValue& var) {
|
|||||||
|
|
||||||
// Find torchbind classes
|
// Find torchbind classes
|
||||||
if (isCustomClass(var)) {
|
if (isCustomClass(var)) {
|
||||||
auto obj = script::Object(var.toObject());
|
auto obj = Object(var.toObject());
|
||||||
auto qualname = obj.type()->name();
|
auto qualname = obj.type()->name();
|
||||||
auto custom_class_type = getCustomClass(qualname->qualifiedName());
|
auto custom_class_type = getCustomClass(qualname->qualifiedName());
|
||||||
if (custom_class_type) {
|
if (custom_class_type) {
|
||||||
@ -313,14 +313,14 @@ static IValue addInput(const std::shared_ptr<TracingState> & state, const IValue
|
|||||||
static void gatherParametersAndBuffers(
|
static void gatherParametersAndBuffers(
|
||||||
const std::shared_ptr<TracingState>& state,
|
const std::shared_ptr<TracingState>& state,
|
||||||
Value* self_value,
|
Value* self_value,
|
||||||
const script::Module& self,
|
const Module& self,
|
||||||
const std::string& prefix) {
|
const std::string& prefix) {
|
||||||
Graph& g = *self_value->owningGraph();
|
Graph& g = *self_value->owningGraph();
|
||||||
|
|
||||||
state->setValue(self._ivalue(), self_value);
|
state->setValue(self._ivalue(), self_value);
|
||||||
|
|
||||||
auto self_ty = self.type();
|
auto self_ty = self.type();
|
||||||
for (const script::NameValue& s : self.named_attributes(/*recurse=*/false)) {
|
for (const NameValue& s : self.named_attributes(/*recurse=*/false)) {
|
||||||
auto qualname = prefix + "." + s.name;
|
auto qualname = prefix + "." + s.name;
|
||||||
Value* trace_get_attr = g.insertNode(g.create(prim::TracedAttr))
|
Value* trace_get_attr = g.insertNode(g.create(prim::TracedAttr))
|
||||||
->s_(attr::scope, qualname)
|
->s_(attr::scope, qualname)
|
||||||
@ -335,7 +335,7 @@ static void gatherParametersAndBuffers(
|
|||||||
}
|
}
|
||||||
if (self_ty->getAttribute(s.name)->is_module()) {
|
if (self_ty->getAttribute(s.name)->is_module()) {
|
||||||
gatherParametersAndBuffers(
|
gatherParametersAndBuffers(
|
||||||
state, trace_get_attr, script::Module(s.value.toObject()), qualname);
|
state, trace_get_attr, Module(s.value.toObject()), qualname);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -345,7 +345,7 @@ std::pair<std::shared_ptr<TracingState>, Stack> trace(
|
|||||||
const std::function<Stack(Stack)>& traced_fn,
|
const std::function<Stack(Stack)>& traced_fn,
|
||||||
std::function<std::string(const Variable&)> var_name_lookup_fn,
|
std::function<std::string(const Variable&)> var_name_lookup_fn,
|
||||||
bool force_outplace,
|
bool force_outplace,
|
||||||
script::Module* self) {
|
Module* self) {
|
||||||
try {
|
try {
|
||||||
// Start tracing, treating 'inputs' as inputs to the trace, which can be
|
// Start tracing, treating 'inputs' as inputs to the trace, which can be
|
||||||
// varied on subsequent invocations of the trace. Any other variables
|
// varied on subsequent invocations of the trace. Any other variables
|
||||||
@ -387,7 +387,7 @@ std::pair<std::shared_ptr<TracingState>, Stack> trace(
|
|||||||
}
|
}
|
||||||
setTracingState(nullptr);
|
setTracingState(nullptr);
|
||||||
|
|
||||||
if (script::getInlineEverythingMode()) {
|
if (getInlineEverythingMode()) {
|
||||||
Inline(*graph);
|
Inline(*graph);
|
||||||
}
|
}
|
||||||
FixupTraceScopeBlocks(graph, self);
|
FixupTraceScopeBlocks(graph, self);
|
||||||
|
@ -21,10 +21,7 @@ namespace jit {
|
|||||||
struct Node;
|
struct Node;
|
||||||
struct Value;
|
struct Value;
|
||||||
struct Graph;
|
struct Graph;
|
||||||
|
struct Module;
|
||||||
namespace script {
|
|
||||||
struct Module;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace tracer {
|
namespace tracer {
|
||||||
|
|
||||||
@ -210,7 +207,7 @@ TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> trace(
|
|||||||
const std::function<Stack(Stack)>& traced_fn,
|
const std::function<Stack(Stack)>& traced_fn,
|
||||||
std::function<std::string(const Variable&)> var_name_lookup_fn,
|
std::function<std::string(const Variable&)> var_name_lookup_fn,
|
||||||
bool force_outplace = false,
|
bool force_outplace = false,
|
||||||
script::Module* self = nullptr);
|
Module* self = nullptr);
|
||||||
|
|
||||||
TORCH_API void abandon();
|
TORCH_API void abandon();
|
||||||
|
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
// Tree's are used to represent all forms of TC IR, pre- and post- typechecking.
|
// Tree's are used to represent all forms of TC IR, pre- and post- typechecking.
|
||||||
// Rather than have a full class hierarchy for all TC statements,
|
// Rather than have a full class hierarchy for all TC statements,
|
||||||
@ -221,6 +220,5 @@ static inline std::ostream& operator<<(std::ostream& out, const TreeRef& t) {
|
|||||||
return out << pretty_tree(t);
|
return out << pretty_tree(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -10,7 +10,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
// TreeView provides a statically-typed way to traverse the tree, which should
|
// TreeView provides a statically-typed way to traverse the tree, which should
|
||||||
@ -813,7 +812,7 @@ struct Const : public Expr {
|
|||||||
// We can't pass in nullptr as the dummy pointer gets dereferenced for
|
// We can't pass in nullptr as the dummy pointer gets dereferenced for
|
||||||
// Android version of strtod_c().
|
// Android version of strtod_c().
|
||||||
char* dummy;
|
char* dummy;
|
||||||
return torch::jit::script::strtod_c(
|
return torch::jit::strtod_c(
|
||||||
subtree(0)->stringValue().c_str(), &dummy);
|
subtree(0)->stringValue().c_str(), &dummy);
|
||||||
}
|
}
|
||||||
const std::string& text() const {
|
const std::string& text() const {
|
||||||
@ -1045,14 +1044,13 @@ struct Delete : public Stmt {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
||||||
namespace std {
|
namespace std {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct iterator_traits<torch::jit::script::ListIterator<T>>
|
struct iterator_traits<torch::jit::ListIterator<T>>
|
||||||
: std::iterator_traits<torch::jit::script::TreeList::const_iterator> {};
|
: std::iterator_traits<torch::jit::TreeList::const_iterator> {};
|
||||||
|
|
||||||
} // namespace std
|
} // namespace std
|
||||||
|
@ -963,7 +963,7 @@ const Operator& Node::getOperator() const {
|
|||||||
if (maybe)
|
if (maybe)
|
||||||
return *maybe;
|
return *maybe;
|
||||||
|
|
||||||
auto er = script::ErrorReport(sourceRange());
|
auto er = ErrorReport(sourceRange());
|
||||||
er << "Schema not found for node. File a bug report.\n";
|
er << "Schema not found for node. File a bug report.\n";
|
||||||
er << "Node: " << *this << "\n";
|
er << "Node: " << *this << "\n";
|
||||||
er << "Input types:";
|
er << "Input types:";
|
||||||
@ -1491,7 +1491,7 @@ Value* Graph::insert(
|
|||||||
at::ArrayRef<NamedValue> args,
|
at::ArrayRef<NamedValue> args,
|
||||||
at::ArrayRef<NamedValue> kwargs,
|
at::ArrayRef<NamedValue> kwargs,
|
||||||
const c10::optional<SourceRange>& range) {
|
const c10::optional<SourceRange>& range) {
|
||||||
return script::emitBuiltinCall(
|
return emitBuiltinCall(
|
||||||
range.value_or(fakeRange()), *this, opname, args, kwargs);
|
range.value_or(fakeRange()), *this, opname, args, kwargs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1721,7 +1721,7 @@ Value* Graph::insertToList(Value* v, TypePtr type) {
|
|||||||
|
|
||||||
Value* Graph::insertFunctionCall(
|
Value* Graph::insertFunctionCall(
|
||||||
Function* callee,
|
Function* callee,
|
||||||
const script::MatchedSchema& matched) {
|
const MatchedSchema& matched) {
|
||||||
std::string func_name = callee->name();
|
std::string func_name = callee->name();
|
||||||
Value* fn_constant = insertNode(create(prim::Constant))
|
Value* fn_constant = insertNode(create(prim::Constant))
|
||||||
->s_(attr::name, func_name)
|
->s_(attr::name, func_name)
|
||||||
@ -1737,7 +1737,7 @@ Value* Graph::insertFunctionCall(
|
|||||||
|
|
||||||
Value* Graph::insertMethodCall(
|
Value* Graph::insertMethodCall(
|
||||||
std::string method_name,
|
std::string method_name,
|
||||||
const script::MatchedSchema& matched) {
|
const MatchedSchema& matched) {
|
||||||
Value* result = insertNode(create(prim::CallMethod, matched.inputs))
|
Value* result = insertNode(create(prim::CallMethod, matched.inputs))
|
||||||
->s_(attr::name, std::move(method_name))
|
->s_(attr::name, std::move(method_name))
|
||||||
->output()
|
->output()
|
||||||
|
@ -73,9 +73,7 @@ using namespace ::c10::aten;
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Function;
|
struct Function;
|
||||||
namespace script {
|
|
||||||
struct MatchedSchema;
|
struct MatchedSchema;
|
||||||
} // namespace script
|
|
||||||
|
|
||||||
// Graph represents one "function" of computation.
|
// Graph represents one "function" of computation.
|
||||||
// It uses a simple ownership model where the graph owns all the nodes inside
|
// It uses a simple ownership model where the graph owns all the nodes inside
|
||||||
@ -1134,10 +1132,10 @@ struct Graph {
|
|||||||
|
|
||||||
TORCH_API Value* insertFunctionCall(
|
TORCH_API Value* insertFunctionCall(
|
||||||
Function* callee,
|
Function* callee,
|
||||||
const script::MatchedSchema& matched);
|
const MatchedSchema& matched);
|
||||||
TORCH_API Value* insertMethodCall(
|
TORCH_API Value* insertMethodCall(
|
||||||
std::string method_name,
|
std::string method_name,
|
||||||
const script::MatchedSchema& matched);
|
const MatchedSchema& matched);
|
||||||
|
|
||||||
// Note: defined in python_ir.cpp and can be used only in python extension
|
// Note: defined in python_ir.cpp and can be used only in python extension
|
||||||
Node* createPythonOp(
|
Node* createPythonOp(
|
||||||
|
@ -9,7 +9,6 @@
|
|||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
|
||||||
|
|
||||||
struct VarWithType;
|
struct VarWithType;
|
||||||
struct ParsedLiteral;
|
struct ParsedLiteral;
|
||||||
@ -57,7 +56,7 @@ class IRParser {
|
|||||||
|
|
||||||
Value* findValueInVMap(const std::string& name);
|
Value* findValueInVMap(const std::string& name);
|
||||||
|
|
||||||
torch::jit::script::Lexer L;
|
torch::jit::Lexer L;
|
||||||
torch::jit::Graph* g = nullptr;
|
torch::jit::Graph* g = nullptr;
|
||||||
std::unordered_map<std::string, Value*>& vmap;
|
std::unordered_map<std::string, Value*>& vmap;
|
||||||
SchemaTypeParser type_parser;
|
SchemaTypeParser type_parser;
|
||||||
@ -86,7 +85,7 @@ void parseIR(
|
|||||||
const std::string& str,
|
const std::string& str,
|
||||||
torch::jit::Graph* graph,
|
torch::jit::Graph* graph,
|
||||||
std::unordered_map<std::string, Value*>& vmap) {
|
std::unordered_map<std::string, Value*>& vmap) {
|
||||||
torch::jit::script::IRParser p(str, graph, vmap);
|
torch::jit::IRParser p(str, graph, vmap);
|
||||||
p.parse();
|
p.parse();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -473,6 +472,5 @@ Value* IRParser::findValueInVMap(const std::string& name) {
|
|||||||
return vmap.at(name);
|
return vmap.at(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -12,8 +12,6 @@ namespace jit {
|
|||||||
struct Graph;
|
struct Graph;
|
||||||
struct Value;
|
struct Value;
|
||||||
|
|
||||||
namespace script {
|
|
||||||
|
|
||||||
// \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH.
|
// \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH.
|
||||||
TORCH_API void parseIR(const std::string& str, torch::jit::Graph* graph);
|
TORCH_API void parseIR(const std::string& str, torch::jit::Graph* graph);
|
||||||
|
|
||||||
@ -27,6 +25,5 @@ TORCH_API void parseIR(
|
|||||||
torch::jit::Graph* graph,
|
torch::jit::Graph* graph,
|
||||||
std::unordered_map<std::string, Value*>& vmap);
|
std::unordered_map<std::string, Value*>& vmap);
|
||||||
|
|
||||||
} // namespace script
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -125,14 +125,14 @@ class BytecodeDeserializer final {
|
|||||||
private:
|
private:
|
||||||
c10::IValue readArchive(const std::string& archive_name,
|
c10::IValue readArchive(const std::string& archive_name,
|
||||||
std::shared_ptr<mobile::CompilationUnit> mcu);
|
std::shared_ptr<mobile::CompilationUnit> mcu);
|
||||||
std::shared_ptr<script::CompilationUnit> compilation_unit_;
|
std::shared_ptr<CompilationUnit> compilation_unit_;
|
||||||
std::unordered_set<std::string> imported_libs_;
|
std::unordered_set<std::string> imported_libs_;
|
||||||
std::unique_ptr<PyTorchStreamReader> reader_;
|
std::unique_ptr<PyTorchStreamReader> reader_;
|
||||||
c10::optional<at::Device> device_;
|
c10::optional<at::Device> device_;
|
||||||
};
|
};
|
||||||
|
|
||||||
BytecodeDeserializer::BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader)
|
BytecodeDeserializer::BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader)
|
||||||
: compilation_unit_(std::make_shared<script::CompilationUnit>()), reader_(std::move(reader)) {}
|
: compilation_unit_(std::make_shared<CompilationUnit>()), reader_(std::move(reader)) {}
|
||||||
|
|
||||||
mobile::Module BytecodeDeserializer::deserialize(c10::optional<at::Device> device) {
|
mobile::Module BytecodeDeserializer::deserialize(c10::optional<at::Device> device) {
|
||||||
device_ = device;
|
device_ = device;
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user