[jit] kill script namespace (#34515)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34515

Once upon a time we thought this was necessary. In reality it is not, so
removing it.

For backcompat, our public interface (defined in `api/`) still has
typedefs to the old `script::` names.

There was only one collision: `Pass` as a `Stmt` and `Pass` as a graph
transform. I renamed one of them.

Test Plan: Imported from OSS

Differential Revision: D20353503

Pulled By: suo

fbshipit-source-id: 48bb911ce75120a8c9e0c6fb65262ef775dfba93
This commit is contained in:
Michael Suo
2020-03-11 23:29:34 -07:00
committed by Facebook GitHub Bot
parent cf8b728255
commit c235be42dd
152 changed files with 559 additions and 694 deletions

View File

@ -13,7 +13,7 @@ int main(int argc, char* argv[]) {
std::ifstream ifs(input_file_path); std::ifstream ifs(input_file_path);
std::stringstream buffer; std::stringstream buffer;
buffer << ifs.rdbuf(); buffer << ifs.rdbuf();
torch::jit::script::Module m("TestModule"); torch::jit::Module m("TestModule");
m.define(buffer.str()); m.define(buffer.str());
m.save(output_file_path); m.save(output_file_path);

View File

@ -58,7 +58,7 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> { class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
private: private:
friend HybridBase; friend HybridBase;
torch::jit::script::Module module_; torch::jit::Module module_;
public: public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;"; constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";

View File

@ -5,7 +5,7 @@ package org.pytorch;
import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate; import com.facebook.soloader.nativeloader.SystemDelegate;
/** Java wrapper for torch::jit::script::Module. */ /** Java wrapper for torch::jit::Module. */
public class Module { public class Module {
private INativePeer mNativePeer; private INativePeer mNativePeer;
@ -14,7 +14,7 @@ public class Module {
* Loads a serialized TorchScript module from the specified path on the disk. * Loads a serialized TorchScript module from the specified path on the disk.
* *
* @param modelPath path to file that contains the serialized TorchScript module. * @param modelPath path to file that contains the serialized TorchScript module.
* @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module. * @return new {@link org.pytorch.Module} object which owns torch::jit::Module.
*/ */
public static Module load(final String modelPath) { public static Module load(final String modelPath) {
if (!NativeLoader.isInitialized()) { if (!NativeLoader.isInitialized()) {
@ -49,7 +49,7 @@ public class Module {
} }
/** /**
* Explicitly destroys the native torch::jit::script::Module. Calling this method is not required, * Explicitly destroys the native torch::jit::Module. Calling this method is not required,
* as the native object will be destroyed when this object is garbage-collected. However, the * as the native object will be destroyed when this object is garbage-collected. However, the
* timing of garbage collection is not guaranteed, so proactively calling {@code destroy} can free * timing of garbage collection is not guaranteed, so proactively calling {@code destroy} can free
* memory more quickly. See {@link com.facebook.jni.HybridData#resetNative}. * memory more quickly. See {@link com.facebook.jni.HybridData#resetNative}.

View File

@ -22,7 +22,7 @@ TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
// A Function is a pure Graph with no implicit `self` object bound. // A Function is a pure Graph with no implicit `self` object bound.
// It contains schema information, and the executor that manages the // It contains schema information, and the executor that manages the
// execution of the function. script::Method is a wrapper around a // execution of the function. Method is a wrapper around a
// underlying Function that also provides a `self` object. // underlying Function that also provides a `self` object.
struct TORCH_API Function { struct TORCH_API Function {
virtual bool isGraphFunction() const = 0; virtual bool isGraphFunction() const = 0;

View File

@ -378,7 +378,7 @@ std::vector<std::pair<IValue, IValue>> iterationOrder(const c10::Dict<IValue, IV
} }
StrongTypePtr::StrongTypePtr( StrongTypePtr::StrongTypePtr(
std::shared_ptr<torch::jit::script::CompilationUnit> cu, std::shared_ptr<torch::jit::CompilationUnit> cu,
std::shared_ptr<Type> type) { std::shared_ptr<Type> type) {
cu_ = std::move(cu); cu_ = std::move(cu);
type_ = type; type_ = type;

View File

@ -11,10 +11,8 @@ namespace jit {
class CustomClassHolder : public c10::intrusive_ptr_target {}; class CustomClassHolder : public c10::intrusive_ptr_target {};
struct Function; struct Function;
namespace script {
struct CompilationUnit; struct CompilationUnit;
struct Module; struct Module;
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch
namespace c10 { namespace c10 {
@ -356,7 +354,7 @@ struct CAFFE2_API IValue final {
c10::intrusive_ptr<ivalue::Object> toObject() const & ; c10::intrusive_ptr<ivalue::Object> toObject() const & ;
const ivalue::Object& toObjectRef() const; const ivalue::Object& toObjectRef() const;
torch::jit::script::Module toModule() const; torch::jit::Module toModule() const;
bool isModule() const; bool isModule() const;
// PyObject // PyObject
@ -692,10 +690,10 @@ private:
// guaranteed to stay alive as long as we hold this object. // guaranteed to stay alive as long as we hold this object.
struct TORCH_API StrongTypePtr { struct TORCH_API StrongTypePtr {
StrongTypePtr( StrongTypePtr(
std::shared_ptr<torch::jit::script::CompilationUnit> cu, std::shared_ptr<torch::jit::CompilationUnit> cu,
std::shared_ptr<Type> type); std::shared_ptr<Type> type);
std::shared_ptr<torch::jit::script::CompilationUnit> cu_; std::shared_ptr<torch::jit::CompilationUnit> cu_;
std::shared_ptr<Type> type_; std::shared_ptr<Type> type_;
}; };

View File

@ -15,9 +15,7 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
struct Function; struct Function;
namespace script {
struct CompilationUnit; struct CompilationUnit;
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch
namespace c10 { namespace c10 {
@ -406,7 +404,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
} }
std::shared_ptr<ClassType> type() const; std::shared_ptr<ClassType> type() const;
std::shared_ptr<torch::jit::script::CompilationUnit> compilation_unit() { std::shared_ptr<torch::jit::CompilationUnit> compilation_unit() {
return type_.cu_; return type_.cu_;
} }
@ -868,7 +866,7 @@ IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
auto res = getCustomClassType<inputType>(); auto res = getCustomClassType<inputType>();
auto retObject = ivalue::Object::create( auto retObject = ivalue::Object::create(
StrongTypePtr( StrongTypePtr(
std::shared_ptr<torch::jit::script::CompilationUnit>(), std::shared_ptr<torch::jit::CompilationUnit>(),
std::move(res)), std::move(res)),
1); 1);
auto objPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(std::move(x)); auto objPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(std::move(x));

View File

@ -17,9 +17,7 @@
struct ClassType; struct ClassType;
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
struct CompilationUnit; struct CompilationUnit;
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch
@ -1491,7 +1489,7 @@ CAFFE2_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type);
struct ClassType; struct ClassType;
using ClassTypePtr = std::shared_ptr<ClassType>; using ClassTypePtr = std::shared_ptr<ClassType>;
using ::torch::jit::script::CompilationUnit; using ::torch::jit::CompilationUnit;
// This represents a class in TorchScript. // This represents a class in TorchScript.
struct CAFFE2_API ClassType : public NamedType { struct CAFFE2_API ClassType : public NamedType {
@ -1801,7 +1799,7 @@ struct CAFFE2_API ClassType : public NamedType {
struct InterfaceType; struct InterfaceType;
using InterfaceTypePtr = std::shared_ptr<InterfaceType>; using InterfaceTypePtr = std::shared_ptr<InterfaceType>;
using ::torch::jit::script::CompilationUnit; using ::torch::jit::CompilationUnit;
// Interfaces are a list of abstract methods that a class might meet. // Interfaces are a list of abstract methods that a class might meet.
// If a class provides those methods, it implicitly meets the interface. // If a class provides those methods, it implicitly meets the interface.

View File

@ -24,7 +24,7 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
void dump_opnames(const script::Module& m, std::unordered_set<std::string>& opnames) { void dump_opnames(const Module& m, std::unordered_set<std::string>& opnames) {
auto methods = m.get_methods(); auto methods = m.get_methods();
for (const auto& method : methods) { for (const auto& method : methods) {
const auto& func = method.function(); const auto& func = method.function();

View File

@ -7,8 +7,8 @@
namespace caffe2 { namespace caffe2 {
using torch::jit::IValue; using torch::jit::IValue;
using torch::jit::script::Method; using torch::jit::Method;
using torch::jit::script::Module; using torch::jit::Module;
namespace { namespace {
class ScriptModuleSerializer : public BlobSerializerBase { class ScriptModuleSerializer : public BlobSerializerBase {
@ -31,7 +31,7 @@ class ScriptModuleSerializer : public BlobSerializerBase {
// the more efficient serialization version (if we ever get to that point) // the more efficient serialization version (if we ever get to that point)
BlobProto blob_proto; BlobProto blob_proto;
blob_proto.set_name(name); blob_proto.set_name(name);
blob_proto.set_type("torch::jit::script::Module"); blob_proto.set_type("torch::jit::Module");
blob_proto.set_content(ss.str()); blob_proto.set_content(ss.str());
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
} }
@ -134,7 +134,7 @@ REGISTER_BLOB_SERIALIZER(
// NB: the first argument to REGISTER_BLOB_DESERIALIZER macro doesn't really // NB: the first argument to REGISTER_BLOB_DESERIALIZER macro doesn't really
// need to be a real type, it just get converted to string // need to be a real type, it just get converted to string
REGISTER_BLOB_DESERIALIZER( REGISTER_BLOB_DESERIALIZER(
torch::jit::script::Module, torch::jit::Module,
ScriptModuleDeserializer); ScriptModuleDeserializer);
OPERATOR_SCHEMA(ScriptModule) OPERATOR_SCHEMA(ScriptModule)

View File

@ -94,12 +94,12 @@ REGISTER_BLOB_FETCHER((TypeMeta::Id<string>()), StringFetcher);
class ScriptModuleFetcher : public BlobFetcherBase { class ScriptModuleFetcher : public BlobFetcherBase {
public: public:
pybind11::object Fetch(const Blob& blob) override { pybind11::object Fetch(const Blob& blob) override {
return py::cast(*blob.Get<std::unique_ptr<torch::jit::script::Module>>()); return py::cast(*blob.Get<std::unique_ptr<torch::jit::Module>>());
} }
}; };
REGISTER_BLOB_FETCHER( REGISTER_BLOB_FETCHER(
(TypeMeta::Id<std::unique_ptr<torch::jit::script::Module>>()), (TypeMeta::Id<std::unique_ptr<torch::jit::Module>>()),
caffe2::python::ScriptModuleFetcher); caffe2::python::ScriptModuleFetcher);
#endif #endif
@ -247,9 +247,9 @@ bool feedBlob(
return true; return true;
} }
#ifdef FBCODE_CAFFE2 #ifdef FBCODE_CAFFE2
if (auto module = torch::jit::script::as_module(arg)) { if (auto module = torch::jit::as_module(arg)) {
blob->GetMutable<std::unique_ptr<torch::jit::script::Module>>()->reset( blob->GetMutable<std::unique_ptr<torch::jit::Module>>()->reset(
new torch::jit::script::Module(*module)); new torch::jit::Module(*module));
return true; return true;
} }
#endif #endif

View File

@ -213,7 +213,7 @@ Disable JIT for Debugging
Python. Since TorchScript (scripting and tracing) are disabled with this flag, Python. Since TorchScript (scripting and tracing) are disabled with this flag,
you can use tools like ``pdb`` to debug the model code. you can use tools like ``pdb`` to debug the model code.
Given an example script:: Given an example
@torch.jit.script @torch.jit.script
def scripted_fn(x : torch.Tensor): def scripted_fn(x : torch.Tensor):

View File

@ -120,8 +120,8 @@ usage might look like:
.. code-block:: cpp .. code-block:: cpp
SetExportModuleExtraFilesHook([](const script::Module&) { SetExportModuleExtraFilesHook([](const Module&) {
script::ExtraFilesMap files; ExtraFilesMap files;
files["producer_info.json"] = "{\"user\": \"" + getenv("USER") + "\"}"; files["producer_info.json"] = "{\"user\": \"" + getenv("USER") + "\"}";
return files; return files;
}); });

View File

@ -8,7 +8,7 @@ Module
.. java:type:: public class Module .. java:type:: public class Module
Java wrapper for torch::jit::script::Module. Java wrapper for torch::jit::Module.
Methods Methods
------- -------
@ -18,7 +18,7 @@ destroy
.. java:method:: public void destroy() .. java:method:: public void destroy()
:outertype: Module :outertype: Module
Explicitly destroys the native torch::jit::script::Module. Calling this method is not required, as the native object will be destroyed when this object is garbage-collected. However, the timing of garbage collection is not guaranteed, so proactively calling \ ``destroy``\ can free memory more quickly. See \ :java:ref:`com.facebook.jni.HybridData.resetNative`\ . Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the native object will be destroyed when this object is garbage-collected. However, the timing of garbage collection is not guaranteed, so proactively calling \ ``destroy``\ can free memory more quickly. See \ :java:ref:`com.facebook.jni.HybridData.resetNative`\ .
forward forward
^^^^^^^ ^^^^^^^
@ -40,7 +40,7 @@ load
Loads a serialized TorchScript module from the specified path on the disk. Loads a serialized TorchScript module from the specified path on the disk.
:param modelPath: path to file that contains the serialized TorchScript module. :param modelPath: path to file that contains the serialized TorchScript module.
:return: new \ :java:ref:`org.pytorch.Module`\ object which owns torch::jit::script::Module. :return: new \ :java:ref:`org.pytorch.Module`\ object which owns torch::jit::Module.
runMethod runMethod
^^^^^^^^^ ^^^^^^^^^

View File

@ -7,7 +7,7 @@
@end @end
@implementation TestAppTests { @implementation TestAppTests {
torch::jit::script::Module _module; torch::jit::Module _module;
} }
+ (void)setUp { + (void)setUp {

View File

@ -395,7 +395,7 @@ void testAliasAnalysis() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(): graph():
%opt : Tensor? = prim::Constant() %opt : Tensor? = prim::Constant()
@ -415,7 +415,7 @@ void testAliasAnalysis() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x : Tensor): graph(%x : Tensor):
%3 : int = prim::Constant[value=1]() %3 : int = prim::Constant[value=1]()
@ -491,7 +491,7 @@ void testWriteTracking() {
} }
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x: Tensor): graph(%x: Tensor):
%b : Tensor = aten::relu_(%x) %b : Tensor = aten::relu_(%x)
@ -505,7 +505,7 @@ void testWriteTracking() {
} }
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x: Tensor, %y : Tensor): graph(%x: Tensor, %y : Tensor):
%b : Tensor = aten::mul(%x, %y) %b : Tensor = aten::mul(%x, %y)
@ -520,7 +520,7 @@ void testWriteTracking() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x: Tensor, %y : Tensor): graph(%x: Tensor, %y : Tensor):
%c1 : int = prim::Constant[value=1]() %c1 : int = prim::Constant[value=1]()
@ -540,7 +540,7 @@ void testContainerAliasing() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%inp: Tensor[]): graph(%inp: Tensor[]):
%x : str = prim::Constant[value="a"]() %x : str = prim::Constant[value="a"]()
@ -574,7 +574,7 @@ void testContainerAliasing() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(): graph():
%x : str = prim::Constant[value="a"]() %x : str = prim::Constant[value="a"]()
@ -601,7 +601,7 @@ void testContainerAliasing() {
// Test input aliasing // Test input aliasing
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x: Tensor, %y: Tensor): graph(%x: Tensor, %y: Tensor):
%a : (Tensor) = prim::TupleConstruct(%x) %a : (Tensor) = prim::TupleConstruct(%x)
@ -622,7 +622,7 @@ void testContainerAliasing() {
// Test tuple that doesn't come from construct // Test tuple that doesn't come from construct
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x : int, graph(%x : int,
%y : Tensor, %y : Tensor,
@ -654,7 +654,7 @@ graph(%x : int,
// test nested types // test nested types
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(): graph():
%a : Tensor = prim::MakeTestTensor() %a : Tensor = prim::MakeTestTensor()
@ -681,7 +681,7 @@ graph():
// simple example // simple example
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(): graph():
%0 : Tensor = prim::Constant() %0 : Tensor = prim::Constant()
@ -711,7 +711,7 @@ graph():
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(): graph():
%x : str = prim::Constant[value="a"]() %x : str = prim::Constant[value="a"]()
@ -737,7 +737,7 @@ graph():
// Test list container aliasing // Test list container aliasing
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(): graph():
%0 : int = prim::Constant[value=2]() %0 : int = prim::Constant[value=2]()
@ -779,7 +779,7 @@ graph():
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(): graph():
%0 : int = prim::Constant[value=2]() %0 : int = prim::Constant[value=2]()
@ -810,7 +810,7 @@ graph():
// print across it. // print across it.
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(): graph():
%35 : int = prim::Constant[value=1]() %35 : int = prim::Constant[value=1]()
@ -849,7 +849,7 @@ graph():
// print across it. // print across it.
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(): graph():
%38 : int = prim::Constant[value=1]() %38 : int = prim::Constant[value=1]()
@ -935,7 +935,7 @@ void testWildcards() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%ten_list : Tensor[], %int_list : int[], %opt_ten_list : Tensor[]?): graph(%ten_list : Tensor[], %int_list : int[], %opt_ten_list : Tensor[]?):
%ten : Tensor = prim::Constant() %ten : Tensor = prim::Constant()
@ -971,7 +971,7 @@ void testWildcards() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%ten_list : Tensor[], %ten_opt_list : Tensor?[]): graph(%ten_list : Tensor[], %ten_opt_list : Tensor?[]):
%ten : Tensor = prim::Constant() %ten : Tensor = prim::Constant()
@ -992,7 +992,7 @@ void testWildcards() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%float_3D : Float(*, *, *), %float_2D : Float(*, *)): graph(%float_3D : Float(*, *, *), %float_2D : Float(*, *)):
return () return ()
@ -1006,7 +1006,7 @@ void testWildcards() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%float_3D_list : Float(*, *, *)[], %float_2D_list : Float(*, *)[], %ten: Tensor): graph(%float_3D_list : Float(*, *, *)[], %float_2D_list : Float(*, *)[], %ten: Tensor):
return () return ()

View File

@ -213,7 +213,7 @@ void testDifferentiateWithRequiresGrad() {
%7 : Tensor = aten::add(%6, %1, %2) %7 : Tensor = aten::add(%6, %1, %2)
return (%4, %7))IR"; return (%4, %7))IR";
auto g = std::make_shared<Graph>(); auto g = std::make_shared<Graph>();
torch::jit::script::parseIR(graph_string, g.get()); torch::jit::parseIR(graph_string, g.get());
auto a_var = autograd::make_variable( auto a_var = autograd::make_variable(
at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true); at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);

View File

@ -9,8 +9,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
using namespace torch::jit::script;
static const auto classSrcs1 = R"JIT( static const auto classSrcs1 = R"JIT(
class FooNestedTest: class FooNestedTest:
def __init__(self, y): def __init__(self, y):

View File

@ -4,7 +4,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
using namespace torch::jit::script;
const auto testSource = R"JIT( const auto testSource = R"JIT(
class FooTest: class FooTest:
def __init__(self, x): def __init__(self, x):

View File

@ -5,8 +5,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
using namespace torch::jit::script;
void testClassTypeAddRemoveAttr() { void testClassTypeAddRemoveAttr() {
auto cu = std::make_shared<CompilationUnit>(); auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true); auto cls = ClassType::create("foo.bar", cu, true);

View File

@ -14,7 +14,7 @@ namespace jit {
void testConstantPooling() { void testConstantPooling() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(): graph():
%8 : int = prim::Constant[value=1]() %8 : int = prim::Constant[value=1]()
@ -29,7 +29,7 @@ graph():
} }
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(%cond : Tensor): graph(%cond : Tensor):
%a : str = prim::Constant[value="bcd"]() %a : str = prim::Constant[value="bcd"]()
@ -53,7 +53,7 @@ graph(%cond : Tensor):
} }
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(): graph():
%2 : int = prim::Constant[value=2]() %2 : int = prim::Constant[value=2]()

View File

@ -145,7 +145,7 @@ void testCustomOperatorAliasing() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x: Tensor, %y: Tensor): graph(%x: Tensor, %y: Tensor):
%ret : Tensor = foo::aliasing(%x, %y) %ret : Tensor = foo::aliasing(%x, %y)
@ -172,7 +172,7 @@ graph(%x: Tensor, %y: Tensor):
%ret : Tensor = foo::aliasing(%x, %y) %ret : Tensor = foo::aliasing(%x, %y)
return (%x) return (%x)
)IR"; )IR";
script::parseIR(text, graph.get()); parseIR(text, graph.get());
EliminateDeadCode(graph); EliminateDeadCode(graph);
testing::FileCheck().run(text, *graph); testing::FileCheck().run(text, *graph);

View File

@ -43,7 +43,7 @@ graph():
-> (%50, %tot.3) -> (%50, %tot.3)
return (%tot) return (%tot)
)IR"; )IR";
script::parseIR(input, graph.get()); parseIR(input, graph.get());
EliminateDeadCode(graph); EliminateDeadCode(graph);
// Check that dead code elimin // Check that dead code elimin
testing::FileCheck().run(input, *graph); testing::FileCheck().run(input, *graph);

View File

@ -67,7 +67,7 @@ void testFusion() {
%2 : Tensor = aten::mul(%0, %1) %2 : Tensor = aten::mul(%0, %1)
return (%2))IR"; return (%2))IR";
Graph graph; Graph graph;
torch::jit::script::parseIR(graph_string, &graph); torch::jit::parseIR(graph_string, &graph);
auto a = at::rand({3, 4}, at::kCUDA); auto a = at::rand({3, 4}, at::kCUDA);
auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1); auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1);
@ -100,7 +100,7 @@ void testFusion() {
%14 : Tensor = aten::mul(%8, %13) %14 : Tensor = aten::mul(%8, %13)
return (%14, %12))IR"; return (%14, %12))IR";
Graph graph; Graph graph;
torch::jit::script::parseIR(graph_string, &graph); torch::jit::parseIR(graph_string, &graph);
graph.lint(); graph.lint();
@ -164,7 +164,7 @@ void testFusion() {
graph_string2}; graph_string2};
for (auto i = decltype(graph_strings.size()){0}; i < graph_strings.size(); ++i) { for (auto i = decltype(graph_strings.size()){0}; i < graph_strings.size(); ++i) {
Graph g; Graph g;
torch::jit::script::parseIR(graph_strings[i], &g); torch::jit::parseIR(graph_strings[i], &g);
auto outputs = debugLaunchGraph(g, {a, b}); auto outputs = debugLaunchGraph(g, {a, b});
ASSERT_EQ(outputs.size(), 2); ASSERT_EQ(outputs.size(), 2);
@ -187,7 +187,7 @@ void testRegisterFusionCachesKernel() {
%d0 : Float(2, 3, 4) = aten::mul(%c0, %0) %d0 : Float(2, 3, 4) = aten::mul(%c0, %0)
return (%d0))IR"; return (%d0))IR";
auto g0 = std::make_shared<Graph>(); auto g0 = std::make_shared<Graph>();
torch::jit::script::parseIR(graph0_string, g0.get()); torch::jit::parseIR(graph0_string, g0.get());
const auto graph1_string = R"IR( const auto graph1_string = R"IR(
graph(%0 : Float(2, 3, 4), graph(%0 : Float(2, 3, 4),
@ -196,7 +196,7 @@ void testRegisterFusionCachesKernel() {
%d1 : Float(2, 3, 4) = aten::mul(%c1, %0) %d1 : Float(2, 3, 4) = aten::mul(%c1, %0)
return (%d1))IR"; return (%d1))IR";
auto g1 = std::make_shared<Graph>(); auto g1 = std::make_shared<Graph>();
torch::jit::script::parseIR(graph1_string, g1.get()); torch::jit::parseIR(graph1_string, g1.get());
auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) { auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
const auto& nodes = graph->nodes(); const auto& nodes = graph->nodes();

View File

@ -21,7 +21,6 @@ def foo3(x):
namespace torch { namespace torch {
namespace jit { namespace jit {
using namespace script;
using namespace testing; using namespace testing;
struct InlinerGuard { struct InlinerGuard {

View File

@ -11,8 +11,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
using namespace torch::jit::script;
static const std::vector<std::string> subMethodSrcs = {R"JIT( static const std::vector<std::string> subMethodSrcs = {R"JIT(
def one(self, x: Tensor, y: Tensor) -> Tensor: def one(self, x: Tensor, y: Tensor) -> Tensor:
return x + y + 1 return x + y + 1

View File

@ -55,7 +55,7 @@ void testBlocks() {
%12 : int = prim::Constant[value=1]() %12 : int = prim::Constant[value=1]()
%13 : Tensor = aten::add(%5, %3, %12) %13 : Tensor = aten::add(%5, %3, %12)
return (%13))IR"; return (%13))IR";
torch::jit::script::parseIR(graph_string, g.get()); torch::jit::parseIR(graph_string, g.get());
g->lint(); g->lint();
testing::FileCheck() testing::FileCheck()
@ -122,7 +122,7 @@ graph(%x : Tensor,
torch::jit::Graph g; torch::jit::Graph g;
std::unordered_map<std::string, torch::jit::Value*> name_to_value; std::unordered_map<std::string, torch::jit::Value*> name_to_value;
torch::jit::script::parseIR(input_str, &g, name_to_value); torch::jit::parseIR(input_str, &g, name_to_value);
std::vector<std::string> value_names{"6", "7", "9", "10"}; std::vector<std::string> value_names{"6", "7", "9", "10"};
std::unordered_set<std::string> value_names_set( std::unordered_set<std::string> value_names_set(

View File

@ -17,7 +17,7 @@ namespace jit {
*/ */
static void checkRoundtrip(const std::string& s) { static void checkRoundtrip(const std::string& s) {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR(s, &*graph); parseIR(s, &*graph);
std::ostringstream ss; std::ostringstream ss;
ss << *graph; ss << *graph;
std::string parsed = ss.str(); std::string parsed = ss.str();
@ -42,7 +42,7 @@ void testIRParser() {
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0 : Tensor, %1 : Tensor): graph(%0 : Tensor, %1 : Tensor):
%2 : Tensor = foo::add(%0, %1) %2 : Tensor = foo::add(%0, %1)
@ -117,7 +117,7 @@ graph(%0 : Tensor,
} }
{ {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(%a): graph(%a):
return (%a))IR", return (%a))IR",
@ -127,7 +127,7 @@ graph(%a):
{ {
// Check that parser correctly handles values reusing the same name. // Check that parser correctly handles values reusing the same name.
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x): graph(%x):
%x = a::a(%x) %x = a::a(%x)
@ -239,7 +239,7 @@ graph(%0 : Tensor,
# CHECK: return # CHECK: return
return (%a))IR"; return (%a))IR";
script::parseIR(text, &*graph); parseIR(text, &*graph);
AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get())); AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get()));
torch::jit::testing::FileCheck().run(text, *graph); torch::jit::testing::FileCheck().run(text, *graph);
} }

View File

@ -7,8 +7,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
using namespace torch::jit::script;
void testIValue() { void testIValue() {
c10::List<int64_t> foo({3, 4, 5}); c10::List<int64_t> foo({3, 4, 5});
ASSERT_EQ(foo.use_count(), 1); ASSERT_EQ(foo.use_count(), 1);

View File

@ -12,7 +12,7 @@ namespace torch {
namespace jit { namespace jit {
void testLiteInterpreterUpsampleNearest2d() { void testLiteInterpreterUpsampleNearest2d() {
script::Module m("m"); Module m("m");
m.define(R"( m.define(R"(
def forward(self, input: Tensor, scale:float): def forward(self, input: Tensor, scale:float):
return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
@ -35,7 +35,7 @@ void testLiteInterpreterUpsampleNearest2d() {
} }
void testLiteInterpreterAdd() { void testLiteInterpreterAdd() {
script::Module m("m"); Module m("m");
m.register_parameter("foo", torch::ones({}), false); m.register_parameter("foo", torch::ones({}), false);
// TODO: support default param val, which was pushed in // TODO: support default param val, which was pushed in
// function schema's checkAndNormalizeInputs() // function schema's checkAndNormalizeInputs()
@ -75,7 +75,7 @@ void testLiteInterpreterConv() {
std::vector<torch::jit::IValue> inputs; std::vector<torch::jit::IValue> inputs;
script::Module m("m"); Module m("m");
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
m.register_parameter("bias", torch::ones({20}), false); m.register_parameter("bias", torch::ones({20}), false);
m.define(R"( m.define(R"(
@ -100,7 +100,7 @@ void testLiteInterpreterConv() {
} }
void testLiteInterpreterInline() { void testLiteInterpreterInline() {
script::Module m("m"); Module m("m");
m.define(R"JIT( m.define(R"JIT(
def foo1(self, x): def foo1(self, x):
return x + 1 return x + 1
@ -120,7 +120,7 @@ void testLiteInterpreterInline() {
} }
void testLiteInterpreterTuple() { void testLiteInterpreterTuple() {
script::Module m("m"); Module m("m");
m.define(R"JIT( m.define(R"JIT(
def foo(self, x): def foo(self, x):
return (1, 2, x + 3) return (1, 2, x + 3)
@ -138,7 +138,7 @@ void testLiteInterpreterTuple() {
} }
void testLiteInterpreterPrimOverload() { void testLiteInterpreterPrimOverload() {
script::Module m("m"); Module m("m");
m.define(R"JIT( m.define(R"JIT(
def forward(self, x): def forward(self, x):
result = [1, 2] result = [1, 2]
@ -154,7 +154,7 @@ void testLiteInterpreterPrimOverload() {
} }
void testLiteInterpreterPrim() { void testLiteInterpreterPrim() {
script::Module m("m"); Module m("m");
m.define(R"JIT( m.define(R"JIT(
def forward(self, x): def forward(self, x):
return int(x) return int(x)
@ -180,7 +180,7 @@ void testLiteInterpreterPrim() {
} }
void testLiteInterpreterLoadOrigJit() { void testLiteInterpreterLoadOrigJit() {
script::Module m("m"); Module m("m");
m.register_parameter("foo", torch::ones({}), false); m.register_parameter("foo", torch::ones({}), false);
m.define(R"( m.define(R"(
def forward(self, x): def forward(self, x):
@ -193,7 +193,7 @@ void testLiteInterpreterLoadOrigJit() {
} }
void testLiteInterpreterWrongMethodName() { void testLiteInterpreterWrongMethodName() {
script::Module m("m"); Module m("m");
m.register_parameter("foo", torch::ones({}), false); m.register_parameter("foo", torch::ones({}), false);
m.define(R"( m.define(R"(
def add(self, x): def add(self, x):
@ -210,7 +210,7 @@ void testLiteInterpreterWrongMethodName() {
} }
void testLiteInterpreterParams() { void testLiteInterpreterParams() {
script::Module m("m"); Module m("m");
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false); m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
m.define(R"( m.define(R"(
def forward(self, x): def forward(self, x):
@ -269,7 +269,7 @@ void testLiteInterpreterParams() {
} }
void testLiteInterpreterSetState() { void testLiteInterpreterSetState() {
script::Module m("m"); Module m("m");
m.register_parameter("foo", torch::ones({}), false); m.register_parameter("foo", torch::ones({}), false);
m.define(R"( m.define(R"(
def __getstate__(self): def __getstate__(self):

View File

@ -368,7 +368,7 @@ void testCustomFusion() {
%3 : Tensor = aten::mul(%2, %0) %3 : Tensor = aten::mul(%2, %0)
return (%3))IR"; return (%3))IR";
auto g = std::make_shared<Graph>(); auto g = std::make_shared<Graph>();
torch::jit::script::parseIR(graph_string, g.get()); torch::jit::parseIR(graph_string, g.get());
torch::jit::overrideCanFuseOnCPU(true); torch::jit::overrideCanFuseOnCPU(true);
CustomFuseGraph( CustomFuseGraph(
@ -412,7 +412,7 @@ void testCustomFusionNestedBlocks() {
%9 : Tensor = aten::add(%4, %2, %3) %9 : Tensor = aten::add(%4, %2, %3)
return (%4))IR"; return (%4))IR";
auto g = std::make_shared<Graph>(); auto g = std::make_shared<Graph>();
torch::jit::script::parseIR(graph_string, g.get()); torch::jit::parseIR(graph_string, g.get());
CustomFuseGraph( CustomFuseGraph(
g, g,
@ -489,7 +489,7 @@ void testEvalModeForLoadedModule() {
if (isSandcastle()) if (isSandcastle())
return; // The module file to load is not generated in Sandcastle return; // The module file to load is not generated in Sandcastle
std::string module_path = "dropout_model.pt"; std::string module_path = "dropout_model.pt";
torch::jit::script::Module module = torch::jit::load(module_path); torch::jit::Module module = torch::jit::load(module_path);
AT_ASSERT(module.attr("dropout").toModule().is_training()); AT_ASSERT(module.attr("dropout").toModule().is_training());
module.eval(); module.eval();
AT_ASSERT(!module.attr("dropout").toModule().is_training()); AT_ASSERT(!module.attr("dropout").toModule().is_training());
@ -973,7 +973,7 @@ void testNoneSchemaMatch() {
} }
void testModuleDefine() { void testModuleDefine() {
script::Module m("m"); Module m("m");
m.register_parameter("foo", torch::ones({}), false); m.register_parameter("foo", torch::ones({}), false);
m.define(R"( m.define(R"(
def add_it(self, x, b : int = 4): def add_it(self, x, b : int = 4):
@ -984,7 +984,7 @@ void testModuleDefine() {
} }
void testModuleConversion() { void testModuleConversion() {
script::Module m("test"); Module m("test");
{ {
// test cuda to cpu for params and buffers // test cuda to cpu for params and buffers
m.register_parameter("foo", torch::ones({}, at::kCUDA), false); m.register_parameter("foo", torch::ones({}, at::kCUDA), false);
@ -1016,7 +1016,7 @@ RegisterPass p(fakePass);
void testPassManagement() { void testPassManagement() {
std::shared_ptr<Graph> graph = std::make_shared<Graph>(); std::shared_ptr<Graph> graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(%a): graph(%a):
return (%a))IR", return (%a))IR",

View File

@ -5,8 +5,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
using namespace torch::jit::script;
void testModuleClone() { void testModuleClone() {
auto cu = std::make_shared<CompilationUnit>(); auto cu = std::make_shared<CompilationUnit>();
auto parent = ClassType::create("parent", cu, true); auto parent = ClassType::create("parent", cu, true);

View File

@ -8,8 +8,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
using namespace script;
void testPeepholeOptimize() { void testPeepholeOptimize() {
// test is / is not none optimization // test is / is not none optimization

View File

@ -10,7 +10,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
using namespace script;
void testSaveExtraFilesHook() { void testSaveExtraFilesHook() {
// no secrets // no secrets

View File

@ -23,7 +23,7 @@ void testSchemaMatching() {
return 0; return 0;
}, c10::AliasAnalysisKind::FROM_SCHEMA), }, c10::AliasAnalysisKind::FROM_SCHEMA),
}); });
script::Module m("m"); Module m("m");
m.define(R"( m.define(R"(
def test(self): def test(self):
a = (1.0, 2.0) a = (1.0, 2.0)
@ -59,7 +59,7 @@ void testSchemaMatching() {
return 0; return 0;
}, AliasAnalysisKind::FROM_SCHEMA), }, AliasAnalysisKind::FROM_SCHEMA),
}); });
script::Module m("m"); Module m("m");
m.define(R"JIT( m.define(R"JIT(
def test(self): def test(self):
a = (1.0, 2.0) a = (1.0, 2.0)

View File

@ -7,13 +7,13 @@ namespace jit {
void testTrivial1() { void testTrivial1() {
Graph graph, pattern; Graph graph, pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%a = a::aaa(%0) %a = a::aaa(%0)
return (%a))IR", return (%a))IR",
&graph); &graph);
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%x = a::aaa(%0) %x = a::aaa(%0)
@ -46,7 +46,7 @@ void testTrivial2() {
void testTrivial3() { void testTrivial3() {
Graph graph, pattern; Graph graph, pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%a = a::a(%0) %a = a::a(%0)
@ -54,7 +54,7 @@ graph(%0):
%c = a::c(%a, %b) %c = a::c(%a, %b)
return (%c))IR", return (%c))IR",
&graph); &graph);
script::parseIR( parseIR(
R"IR( R"IR(
graph(%a, %b): graph(%a, %b):
%c = a::c(%a, %b) %c = a::c(%a, %b)
@ -92,7 +92,7 @@ void testTrivial4() {
void testLinear1() { void testLinear1() {
Graph graph, pattern; Graph graph, pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%a = a::aaa(%0) %a = a::aaa(%0)
@ -102,7 +102,7 @@ graph(%0):
%a = a::aaa(%0) %a = a::aaa(%0)
return (%d))IR", return (%d))IR",
&graph); &graph);
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%x = b::bbb(%0) %x = b::bbb(%0)
@ -161,7 +161,7 @@ void testLinear2() {
*/ */
void testDiamond1() { void testDiamond1() {
Graph graph, pattern1, pattern2; Graph graph, pattern1, pattern2;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%o = o::ooo(%0) %o = o::ooo(%0)
@ -173,7 +173,7 @@ graph(%0):
return (%e))IR", return (%e))IR",
&graph); &graph);
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%a = a::aaa(%0) %a = a::aaa(%0)
@ -185,7 +185,7 @@ graph(%0):
AT_ASSERT(!findPatternMatches(pattern1, graph).empty()); AT_ASSERT(!findPatternMatches(pattern1, graph).empty());
// Check that order of nodes inside the diamond does not affect the result // Check that order of nodes inside the diamond does not affect the result
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%a = a::aaa(%0) %a = a::aaa(%0)
@ -247,7 +247,7 @@ void testDiamond2() {
void testXPattern() { void testXPattern() {
Graph graph, pattern; Graph graph, pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0, %1): graph(%0, %1):
%b = b::bbb(%0) %b = b::bbb(%0)
@ -258,7 +258,7 @@ graph(%0, %1):
%g = g::ggg(%e, %f) %g = g::ggg(%e, %f)
return (%g))IR", return (%g))IR",
&graph); &graph);
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0, %1): graph(%0, %1):
%b = b::bbb(%0) %b = b::bbb(%0)
@ -274,7 +274,7 @@ graph(%0, %1):
void testMultipleMatches() { void testMultipleMatches() {
Graph graph, pattern; Graph graph, pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%t0): graph(%t0):
%t1 = a::aaa(%t0) %t1 = a::aaa(%t0)
@ -283,7 +283,7 @@ graph(%t0):
%t4 = a::aaa(%t3) %t4 = a::aaa(%t3)
return (%t4))IR", return (%t4))IR",
&graph); &graph);
script::parseIR( parseIR(
R"IR( R"IR(
graph(%t0): graph(%t0):
%t1 = a::aaa(%t0) %t1 = a::aaa(%t0)
@ -295,7 +295,7 @@ graph(%t0):
void testOverlappingMatches() { void testOverlappingMatches() {
Graph graph, pattern; Graph graph, pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%t0): graph(%t0):
%t1 = a::aaa(%t0) %t1 = a::aaa(%t0)
@ -304,7 +304,7 @@ graph(%t0):
%t4 = a::aaa(%t3) %t4 = a::aaa(%t3)
return (%t4))IR", return (%t4))IR",
&graph); &graph);
script::parseIR( parseIR(
R"IR( R"IR(
graph(%t0): graph(%t0):
%t1 = a::aaa(%t0) %t1 = a::aaa(%t0)
@ -317,7 +317,7 @@ graph(%t0):
void testMatchInBasicBlocks1() { void testMatchInBasicBlocks1() {
Graph graph; Graph graph;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%a, %b, %c): graph(%a, %b, %c):
%d = aten::mul(%a, %b) %d = aten::mul(%a, %b)
@ -333,7 +333,7 @@ graph(%a, %b, %c):
// Ensure the matches don't cross basic block boundaries // Ensure the matches don't cross basic block boundaries
Graph pattern0; Graph pattern0;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x, %y): graph(%x, %y):
%z = aten::mul(%x, %y) %z = aten::mul(%x, %y)
@ -342,7 +342,7 @@ graph(%x, %y):
AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3); AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3);
Graph pattern1; Graph pattern1;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x, %y): graph(%x, %y):
%z1 = aten::mul(%x, %y) %z1 = aten::mul(%x, %y)
@ -354,7 +354,7 @@ graph(%x, %y):
void testMatchInBasicBlocks2() { void testMatchInBasicBlocks2() {
Graph graph; Graph graph;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%a, %b): graph(%a, %b):
%x = my::mul(%a, %b) %x = my::mul(%a, %b)
@ -367,7 +367,7 @@ graph(%a, %b):
// Check that we can match both mul ops // Check that we can match both mul ops
Graph pattern0; Graph pattern0;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x, %y): graph(%x, %y):
%z = my::mul(%x, %y) %z = my::mul(%x, %y)
@ -377,7 +377,7 @@ graph(%x, %y):
// Ensure the matches don't cross basic block boundaries // Ensure the matches don't cross basic block boundaries
Graph pattern1; Graph pattern1;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x, %y): graph(%x, %y):
%u = my::mul(%x, %y) %u = my::mul(%x, %y)
@ -389,7 +389,7 @@ graph(%x, %y):
void testMatchesAttributes() { void testMatchesAttributes() {
Graph graph; Graph graph;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%a = a::a[isattr=[1,2]](%0) %a = a::a[isattr=[1,2]](%0)
@ -400,7 +400,7 @@ graph(%0):
{ {
Graph pattern; Graph pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%a, %b): graph(%a, %b):
%c = a::c[myattr="qqq"](%a, %b) %c = a::c[myattr="qqq"](%a, %b)
@ -410,7 +410,7 @@ graph(%a, %b):
} }
{ {
Graph pattern; Graph pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%a, %b): graph(%a, %b):
%c = a::c[myattr="zzz"](%a, %b) %c = a::c[myattr="zzz"](%a, %b)
@ -420,7 +420,7 @@ graph(%a, %b):
} }
{ {
Graph pattern; Graph pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%b = a::b[extraattr=10](%0) %b = a::b[extraattr=10](%0)
@ -430,7 +430,7 @@ graph(%0):
} }
{ {
Graph pattern; Graph pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%b = a::b[intattr=10, floatattr=3.14](%0) %b = a::b[intattr=10, floatattr=3.14](%0)
@ -440,7 +440,7 @@ graph(%0):
} }
{ {
Graph pattern; Graph pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%b = a::b[intattr=10, floatattr=3.14, strattr="rrr"](%0) %b = a::b[intattr=10, floatattr=3.14, strattr="rrr"](%0)
@ -450,7 +450,7 @@ graph(%0):
} }
{ {
Graph pattern; Graph pattern;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%a = a::a[isattr=[1,2]](%0) %a = a::a[isattr=[1,2]](%0)
@ -463,14 +463,14 @@ graph(%0):
void testBadPattern() { void testBadPattern() {
Graph graph, pattern1, pattern2; Graph graph, pattern1, pattern2;
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%a = a::aaa(%0) %a = a::aaa(%0)
return (%a))IR", return (%a))IR",
&graph); &graph);
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x): graph(%x):
%y = my::node_with_subblock() %y = my::node_with_subblock()
@ -481,7 +481,7 @@ graph(%x):
&pattern1); &pattern1);
ASSERT_ANY_THROW(findPatternMatches(pattern1, graph)); ASSERT_ANY_THROW(findPatternMatches(pattern1, graph));
script::parseIR( parseIR(
R"IR( R"IR(
graph(%x): graph(%x):
%y = my::op1(%x) %y = my::op1(%x)

View File

@ -11,7 +11,7 @@ using namespace testing;
void testFilterMatch() { void testFilterMatch() {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%a = a::aaa(%0) %a = a::aaa(%0)
@ -27,7 +27,7 @@ graph(%a, %b):
Graph pattern_graph; Graph pattern_graph;
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
pattern, pattern,
&pattern_graph, &pattern_graph,
vmap); vmap);
@ -55,7 +55,7 @@ graph(%a, %b):
void testFilterNoMatch() { void testFilterNoMatch() {
auto graph = std::make_shared<Graph>(); auto graph = std::make_shared<Graph>();
script::parseIR( parseIR(
R"IR( R"IR(
graph(%0): graph(%0):
%a = a::aaa(%0) %a = a::aaa(%0)
@ -71,7 +71,7 @@ graph(%a, %b):
Graph pattern_graph; Graph pattern_graph;
std::unordered_map<std::string, Value*> vmap; std::unordered_map<std::string, Value*> vmap;
script::parseIR( parseIR(
pattern, pattern,
&pattern_graph, &pattern_graph,
vmap); vmap);

View File

@ -84,7 +84,7 @@ std::shared_ptr<Graph> build_lstm() {
%22 : Tensor = aten::mul(%14, %21) %22 : Tensor = aten::mul(%14, %21)
return (%22, %20))IR"; return (%22, %20))IR";
auto g = std::make_shared<Graph>(); auto g = std::make_shared<Graph>();
torch::jit::script::parseIR(graph_string, g.get()); torch::jit::parseIR(graph_string, g.get());
g->lint(); g->lint();
return g; return g;

View File

@ -12,7 +12,7 @@
namespace helpers { namespace helpers {
template <typename Predicate> template <typename Predicate>
void check_all_parameters( void check_all_parameters(
const torch::jit::script::Module& module, const torch::jit::Module& module,
Predicate predicate) { Predicate predicate) {
for (at::Tensor parameter : module.parameters()) { for (at::Tensor parameter : module.parameters()) {
AT_ASSERT(predicate(parameter)); AT_ASSERT(predicate(parameter));
@ -79,7 +79,7 @@ void get_autograd_operator_from_registry_and_execute_in_nograd_mode() {
void load_serialized_module_with_custom_op_and_execute( void load_serialized_module_with_custom_op_and_execute(
const std::string& path_to_exported_script_module) { const std::string& path_to_exported_script_module) {
torch::jit::script::Module module = torch::jit::Module module =
torch::jit::load(path_to_exported_script_module); torch::jit::load(path_to_exported_script_module);
std::vector<torch::jit::IValue> inputs; std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones(5)); inputs.push_back(torch::ones(5));
@ -90,7 +90,7 @@ void load_serialized_module_with_custom_op_and_execute(
void test_argument_checking_for_serialized_modules( void test_argument_checking_for_serialized_modules(
const std::string& path_to_exported_script_module) { const std::string& path_to_exported_script_module) {
torch::jit::script::Module module = torch::jit::Module module =
torch::jit::load(path_to_exported_script_module); torch::jit::load(path_to_exported_script_module);
try { try {
@ -124,7 +124,7 @@ void test_argument_checking_for_serialized_modules(
} }
void test_move_to_device(const std::string& path_to_exported_script_module) { void test_move_to_device(const std::string& path_to_exported_script_module) {
torch::jit::script::Module module = torch::jit::Module module =
torch::jit::load(path_to_exported_script_module); torch::jit::load(path_to_exported_script_module);
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) { helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
@ -145,7 +145,7 @@ void test_move_to_device(const std::string& path_to_exported_script_module) {
} }
void test_move_to_dtype(const std::string& path_to_exported_script_module) { void test_move_to_dtype(const std::string& path_to_exported_script_module) {
torch::jit::script::Module module = torch::jit::Module module =
torch::jit::load(path_to_exported_script_module); torch::jit::load(path_to_exported_script_module);
module.to(torch::kInt); module.to(torch::kInt);

View File

@ -24,7 +24,7 @@ struct MobileCallGuard {
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false}; torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
}; };
torch::jit::script::Module loadModel(const std::string& path) { torch::jit::Module loadModel(const std::string& path) {
MobileCallGuard guard; MobileCallGuard guard;
auto module = torch::jit::load(path); auto module = torch::jit::load(path);
module.eval(); module.eval();

View File

@ -138,7 +138,7 @@ TORCH_NN_MODULE_TEST_INIT = Template("""\n
void ${module_variant_name}_test_init( void ${module_variant_name}_test_init(
const std::string& saved_module_path, const std::string& saved_module_path,
const std::string& device) { const std::string& device) {
torch::jit::script::Module m_init_by_python = torch::jit::load(saved_module_path); torch::jit::Module m_init_by_python = torch::jit::load(saved_module_path);
torch::manual_seed(2); torch::manual_seed(2);
${module_qualified_name} m_init_by_cpp${cpp_constructor_args}; ${module_qualified_name} m_init_by_cpp${cpp_constructor_args};

View File

@ -30,7 +30,7 @@ namespace jit {
/// )JIT"); /// )JIT");
/// IValue output = module->run_method("relu_script", a, b); /// IValue output = module->run_method("relu_script", a, b);
/// \endrst /// \endrst
TORCH_API std::shared_ptr<script::CompilationUnit> compile(const std::string& source); TORCH_API std::shared_ptr<CompilationUnit> compile(const std::string& source);
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -39,7 +39,7 @@ namespace torch {
template <typename Value, typename... SaveToArgs> template <typename Value, typename... SaveToArgs>
void save(const Value& value, SaveToArgs&&... args) { void save(const Value& value, SaveToArgs&&... args) {
serialize::OutputArchive archive( serialize::OutputArchive archive(
std::make_shared<jit::script::CompilationUnit>()); std::make_shared<jit::CompilationUnit>());
archive << value; archive << value;
archive.save_to(std::forward<SaveToArgs>(args)...); archive.save_to(std::forward<SaveToArgs>(args)...);
} }
@ -65,7 +65,7 @@ void save(const Value& value, SaveToArgs&&... args) {
template <typename... SaveToArgs> template <typename... SaveToArgs>
void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) { void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) {
serialize::OutputArchive archive( serialize::OutputArchive archive(
std::make_shared<jit::script::CompilationUnit>()); std::make_shared<jit::CompilationUnit>());
for (size_t i = 0; i < tensor_vec.size(); i++) { for (size_t i = 0; i < tensor_vec.size(); i++) {
auto& value = tensor_vec[i]; auto& value = tensor_vec[i];
archive.write(c10::to_string(i), value); archive.write(c10::to_string(i), value);

View File

@ -18,9 +18,7 @@ class Tensor;
namespace torch { namespace torch {
using at::Tensor; using at::Tensor;
namespace jit { namespace jit {
namespace script {
struct Module; struct Module;
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch
@ -108,7 +106,7 @@ class TORCH_API InputArchive final {
} }
private: private:
jit::script::Module module_; jit::Module module_;
std::string hierarchy_prefix_; std::string hierarchy_prefix_;
}; };
} // namespace serialize } // namespace serialize

View File

@ -15,9 +15,7 @@ class Tensor;
namespace torch { namespace torch {
using at::Tensor; using at::Tensor;
namespace jit { namespace jit {
namespace script {
struct Module; struct Module;
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch
@ -25,8 +23,8 @@ namespace torch {
namespace serialize { namespace serialize {
class TORCH_API OutputArchive final { class TORCH_API OutputArchive final {
public: public:
explicit OutputArchive(std::shared_ptr<jit::script::CompilationUnit> cu); explicit OutputArchive(std::shared_ptr<jit::CompilationUnit> cu);
explicit OutputArchive() : cu_(std::make_shared<jit::script::CompilationUnit>()), module_("__torch__.Module", cu_) {} explicit OutputArchive() : cu_(std::make_shared<jit::CompilationUnit>()), module_("__torch__.Module", cu_) {}
// Move is allowed. // Move is allowed.
OutputArchive(OutputArchive&&) = default; OutputArchive(OutputArchive&&) = default;
@ -36,7 +34,7 @@ class TORCH_API OutputArchive final {
OutputArchive(OutputArchive&) = delete; OutputArchive(OutputArchive&) = delete;
OutputArchive& operator=(OutputArchive&) = delete; OutputArchive& operator=(OutputArchive&) = delete;
std::shared_ptr<jit::script::CompilationUnit> compilation_unit() const { std::shared_ptr<jit::CompilationUnit> compilation_unit() const {
return cu_; return cu_;
} }
@ -75,8 +73,8 @@ class TORCH_API OutputArchive final {
} }
private: private:
std::shared_ptr<jit::script::CompilationUnit> cu_; std::shared_ptr<jit::CompilationUnit> cu_;
jit::script::Module module_; jit::Module module_;
}; };
} // namespace serialize } // namespace serialize
} // namespace torch } // namespace torch

View File

@ -9,12 +9,12 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
std::shared_ptr<script::CompilationUnit> compile(const std::string& source) { std::shared_ptr<CompilationUnit> compile(const std::string& source) {
auto module = std::make_shared<script::CompilationUnit>(); auto module = std::make_shared<CompilationUnit>();
module->define( module->define(
c10::nullopt, c10::nullopt,
source, source,
script::nativeResolver(), nativeResolver(),
nullptr); nullptr);
return module; return module;
} }

View File

@ -16,7 +16,7 @@
namespace torch { namespace torch {
namespace serialize { namespace serialize {
InputArchive::InputArchive() : module_("Module", std::make_shared<jit::script::CompilationUnit>()) {} InputArchive::InputArchive() : module_("Module", std::make_shared<jit::CompilationUnit>()) {}
void InputArchive::read(const std::string& key, c10::IValue& ivalue) { void InputArchive::read(const std::string& key, c10::IValue& ivalue) {
ivalue = module_.attr(key); ivalue = module_.attr(key);
@ -161,7 +161,7 @@ std::vector<std::string> InputArchive::keys() {
std::vector<std::string> all_keys; std::vector<std::string> all_keys;
all_keys.reserve(module_.named_attributes(/*recurse=*/false).size()); all_keys.reserve(module_.named_attributes(/*recurse=*/false).size());
for (const torch::jit::script::NameValue& s : module_.named_attributes(/*recurse=*/false)) { for (const torch::jit::NameValue& s : module_.named_attributes(/*recurse=*/false)) {
all_keys.push_back(s.name); all_keys.push_back(s.name);
} }

View File

@ -14,7 +14,7 @@
namespace torch { namespace torch {
namespace serialize { namespace serialize {
OutputArchive::OutputArchive(std::shared_ptr<jit::script::CompilationUnit> cu) OutputArchive::OutputArchive(std::shared_ptr<jit::CompilationUnit> cu)
: cu_(std::move(cu)), : cu_(std::move(cu)),
module_("__torch__.Module", cu_, /*shouldMangle=*/true) {} module_("__torch__.Module", cu_, /*shouldMangle=*/true) {}

View File

@ -53,7 +53,7 @@ TypePtr tryInferTypeWithTypeHint(
const py::object& type_hint) { const py::object& type_hint) {
// If the py::object to be contained by the RRef is a ScripModule, we enforce // If the py::object to be contained by the RRef is a ScripModule, we enforce
// users to specify its ModuleInterface type. // users to specify its ModuleInterface type.
if (auto module = jit::script::as_module(value)) { if (auto module = jit::as_module(value)) {
TORCH_CHECK( TORCH_CHECK(
!type_hint.is_none(), !type_hint.is_none(),
"The RRef being created contains a ScriptModule, " "The RRef being created contains a ScriptModule, "

View File

@ -27,8 +27,8 @@ namespace {
// PythonTypeResolver that inherits from Script::Resolver to // PythonTypeResolver that inherits from Script::Resolver to
// support resolving types together with ScriptTypeParser. // support resolving types together with ScriptTypeParser.
struct PythonTypeResolver : public jit::script::Resolver { struct PythonTypeResolver : public jit::Resolver {
std::shared_ptr<jit::script::SugaredValue> resolveValue( std::shared_ptr<jit::SugaredValue> resolveValue(
const std::string& /* unused */, const std::string& /* unused */,
torch::jit::Function& /* unused */, torch::jit::Function& /* unused */,
const jit::SourceRange& /* unused */) override { const jit::SourceRange& /* unused */) override {
@ -67,7 +67,7 @@ PythonRpcHandler::PythonRpcHandler() {
pyHandleException_ = getFunction(module, "_handle_exception"); pyHandleException_ = getFunction(module, "_handle_exception");
pyGetQualifiedName_ = py::module::import("torch.jit").attr("_qualified_name"); pyGetQualifiedName_ = py::module::import("torch.jit").attr("_qualified_name");
jitCompilationUnit_ = torch::jit::get_python_cu(); jitCompilationUnit_ = torch::jit::get_python_cu();
typeParser_ = std::make_shared<jit::script::ScriptTypeParser>( typeParser_ = std::make_shared<jit::ScriptTypeParser>(
std::make_shared<PythonTypeResolver>()); std::make_shared<PythonTypeResolver>());
} }
@ -97,7 +97,7 @@ PythonRpcHandler& PythonRpcHandler::getInstance() {
return *handler; return *handler;
} }
std::shared_ptr<torch::jit::script::CompilationUnit> PythonRpcHandler:: std::shared_ptr<torch::jit::CompilationUnit> PythonRpcHandler::
jitCompilationUnit() { jitCompilationUnit() {
return jitCompilationUnit_; return jitCompilationUnit_;
} }

View File

@ -55,7 +55,7 @@ class PYBIND11_EXPORT PythonRpcHandler {
// PythonRpcHandler. // PythonRpcHandler.
void cleanup(); void cleanup();
std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit(); std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit();
// Parse the string to recover the jit_type, this is used for RRef python // Parse the string to recover the jit_type, this is used for RRef python
// pickling/unpickling type recovery. The type string inference rule is as // pickling/unpickling type recovery. The type string inference rule is as
@ -97,11 +97,11 @@ class PYBIND11_EXPORT PythonRpcHandler {
// and imported in C++ (see get_python_cu() in // and imported in C++ (see get_python_cu() in
// csrc/jit/python/pybind_utils.h). We import the compilation unit here only // csrc/jit/python/pybind_utils.h). We import the compilation unit here only
// once for less cost and thread safety. // once for less cost and thread safety.
std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit_; std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit_;
// jit type parser to parse type_str back to TypePtr for RRef type // jit type parser to parse type_str back to TypePtr for RRef type
// recovery when pickling and unpickling RRef // recovery when pickling and unpickling RRef
std::shared_ptr<jit::script::ScriptTypeParser> typeParser_; std::shared_ptr<jit::ScriptTypeParser> typeParser_;
}; };
} // namespace rpc } // namespace rpc

View File

@ -24,7 +24,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
struct Def; struct Def;
struct ClassDef; struct ClassDef;
@ -281,22 +280,25 @@ struct TORCH_API CompilationUnit {
mutable size_t mangleIndex_ = 0; mutable size_t mangleIndex_ = 0;
}; };
} // namespace script
// An owning pointer to a Function. Just a pair of a raw Function ptr and it's // An owning pointer to a Function. Just a pair of a raw Function ptr and it's
// owning CU. We need this because pybind requires a ref-counted way to refer to // owning CU. We need this because pybind requires a ref-counted way to refer to
// Functions. // Functions.
struct StrongFunctionPtr { struct StrongFunctionPtr {
StrongFunctionPtr( StrongFunctionPtr(
std::shared_ptr<script::CompilationUnit> cu, std::shared_ptr<CompilationUnit> cu,
Function* function) Function* function)
: cu_(std::move(cu)), function_(function) { : cu_(std::move(cu)), function_(function) {
TORCH_INTERNAL_ASSERT(cu_); TORCH_INTERNAL_ASSERT(cu_);
TORCH_INTERNAL_ASSERT(function_); TORCH_INTERNAL_ASSERT(function_);
} }
std::shared_ptr<script::CompilationUnit> cu_; std::shared_ptr<CompilationUnit> cu_;
Function* function_; Function* function_;
}; };
namespace script {
// We once had a `script::` namespace that was deleted. This is for backcompat
// of the public API; new code should not use this type alias.
using CompilationUnit = ::torch::jit::CompilationUnit;
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -6,7 +6,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>; using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
@ -61,6 +60,11 @@ struct TORCH_API Method {
Function* function_; Function* function_;
}; };
} // namespace script namespace script {
// We once had a `script::` namespace that was deleted. This is for backcompat
// of the public API; new code should not use this type alias.
using Method = ::torch::jit::Method;
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -11,7 +11,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
static ObjectPtr create_module_object( static ObjectPtr create_module_object(
c10::QualifiedName class_name, c10::QualifiedName class_name,
@ -389,14 +388,13 @@ void Module::dump(
<< std::endl; << std::endl;
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch
namespace c10 { namespace c10 {
torch::jit::script::Module IValue::toModule() const { torch::jit::Module IValue::toModule() const {
return torch::jit::script::Module(toObject()); return torch::jit::Module(toObject());
} }
bool IValue::isModule() const { bool IValue::isModule() const {
return isObject() && toObjectRef().type()->is_module(); return isObject() && toObjectRef().type()->is_module();

View File

@ -33,7 +33,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
using ::c10::Argument; using ::c10::Argument;
using ::c10::FunctionSchema; using ::c10::FunctionSchema;
@ -554,6 +553,11 @@ struct NamedPolicy {
TORCH_API bool& getInlineEverythingMode(); TORCH_API bool& getInlineEverythingMode();
} // namespace script namespace script {
// We once had a `script::` namespace that was deleted. This is for backcompat
// of the public API; new code should not use this type alias.
using Module = ::torch::jit::Module;
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -3,7 +3,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
void Module::save(std::ostream& out, const ExtraFilesMap& extra_files) const { void Module::save(std::ostream& out, const ExtraFilesMap& extra_files) const {
ExportModule(*this, out, extra_files, false); ExportModule(*this, out, extra_files, false);
@ -26,6 +25,5 @@ void Module::_save_for_mobile(
ExportModule(*this, filename, extra_files, true); ExportModule(*this, filename, extra_files, true);
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -7,7 +7,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
Object::Object( Object::Object(
std::shared_ptr<CompilationUnit> cu, std::shared_ptr<CompilationUnit> cu,
@ -35,10 +34,9 @@ void Object::define(const std::string& src, const ResolverPtr& resolver) {
_ivalue()->compilation_unit()->define( _ivalue()->compilation_unit()->define(
*type()->name(), *type()->name(),
src, src,
resolver ? resolver : script::nativeResolver(), resolver ? resolver : nativeResolver(),
&self); &self);
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -6,7 +6,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
struct Resolver; struct Resolver;
using ResolverPtr = std::shared_ptr<Resolver>; using ResolverPtr = std::shared_ptr<Resolver>;
@ -132,6 +131,10 @@ struct TORCH_API Object {
mutable ObjectPtr _ivalue_; mutable ObjectPtr _ivalue_;
}; };
} // namespace script namespace script {
// We once had a `script::` namespace that was deleted. This is for backcompat
// of the public API; new code should not use this type alias.
using Object = ::torch::jit::Object;
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -285,7 +285,7 @@ The load process has the following steps:
1. Unpickle `constants.pkl`, which produces a tuple of all tensor constants 1. Unpickle `constants.pkl`, which produces a tuple of all tensor constants
referenced in code. referenced in code.
2. Unpickle `data.pkl` into the top-level `script::Module` and return it. 2. Unpickle `data.pkl` into the top-level `Module` and return it.
The unpickling process consists of a single call to unpickle the module The unpickling process consists of a single call to unpickle the module
object contained in `data.pkl`. The `Unpickler` is given a callback that lets it object contained in `data.pkl`. The `Unpickler` is given a callback that lets it

View File

@ -5,7 +5,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
auto scalar_operators_source = CodeTemplate( auto scalar_operators_source = CodeTemplate(
R"SCRIPT( R"SCRIPT(
@ -81,7 +80,7 @@ struct BuiltinFunctionRegistry {
std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>(); std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
modules.emplace_back(cu); modules.emplace_back(cu);
cu->define( cu->define(
c10::nullopt, source, script::nativeResolver(), /*self=*/nullptr); c10::nullopt, source, nativeResolver(), /*self=*/nullptr);
for (auto& method : cu->get_functions()) { for (auto& method : cu->get_functions()) {
builtins_by_name_[Symbol::fromQualString( builtins_by_name_[Symbol::fromQualString(
the_namespace + "::" + method->name())] the_namespace + "::" + method->name())]
@ -134,6 +133,5 @@ const std::vector<Function*>& getAllBuiltinFunctionsFor(
return registry.getAllBuiltinFunctionsFor(name); return registry.getAllBuiltinFunctionsFor(name);
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -5,9 +5,7 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name); TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name);
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -3,7 +3,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const { ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
auto cu = get_python_cu(); auto cu = get_python_cu();
@ -303,6 +302,5 @@ ConcreteModuleType::getModulesPy() const {
return ret; return ret;
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -9,7 +9,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
enum class IterableModuleKind { NONE, LIST, DICT }; enum class IterableModuleKind { NONE, LIST, DICT };
class ConcreteModuleType; class ConcreteModuleType;
@ -228,6 +227,5 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
TypePtr jitType_; TypePtr jitType_;
}; };
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -9,7 +9,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
// At the beginning of the pass the Graph has already undergone type checking, // At the beginning of the pass the Graph has already undergone type checking,
// and writes or reads to a variable are emitted as Loads and Stores in the // and writes or reads to a variable are emitted as Loads and Stores in the
@ -328,6 +327,5 @@ void ConvertToSSA(std::shared_ptr<Graph>& graph) {
TransformExits(graph); TransformExits(graph);
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -8,11 +8,9 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
// Convert a graph with Loads & Stores into SSA form // Convert a graph with Loads & Stores into SSA form
TORCH_API void ConvertToSSA(std::shared_ptr<Graph>& graph); TORCH_API void ConvertToSSA(std::shared_ptr<Graph>& graph);
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -5,7 +5,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
// computes levenshtein edit distance between two words // computes levenshtein edit distance between two words
// returns maxEditDistance + 1 if the edit distance exceeds MaxEditDistance // returns maxEditDistance + 1 if the edit distance exceeds MaxEditDistance
@ -52,6 +51,5 @@ size_t ComputeEditDistance(
return result; return result;
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -5,13 +5,11 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
TORCH_API size_t ComputeEditDistance( TORCH_API size_t ComputeEditDistance(
const char* word1, const char* word1,
const char* word2, const char* word2,
size_t maxEditDistance); size_t maxEditDistance);
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -5,7 +5,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
// Avoid storing objects with destructor in thread_local for mobile build. // Avoid storing objects with destructor in thread_local for mobile build.
#ifndef C10_MOBILE #ifndef C10_MOBILE
@ -72,6 +71,5 @@ const char* ErrorReport::what() const noexcept {
return the_message.c_str(); return the_message.c_str();
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -5,7 +5,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
struct Call { struct Call {
std::string fn_name; std::string fn_name;
@ -49,6 +48,5 @@ const ErrorReport& operator<<(const ErrorReport& e, const T& t) {
return e; return e;
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -22,7 +22,6 @@ using at::TypeKind;
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
namespace { namespace {
struct SchemaParser { struct SchemaParser {
SchemaParser(const std::string& str) SchemaParser(const std::string& str)
@ -290,10 +289,9 @@ struct SchemaParser {
SchemaTypeParser type_parser; SchemaTypeParser type_parser;
}; };
} // namespace } // namespace
} // namespace script
C10_EXPORT either<OperatorName, FunctionSchema> parseSchemaOrName(const std::string& schemaOrName) { C10_EXPORT either<OperatorName, FunctionSchema> parseSchemaOrName(const std::string& schemaOrName) {
return script::SchemaParser(schemaOrName).parseDeclarations().at(0); return SchemaParser(schemaOrName).parseDeclarations().at(0);
} }
C10_EXPORT FunctionSchema parseSchema(const std::string& schema) { C10_EXPORT FunctionSchema parseSchema(const std::string& schema) {

View File

@ -30,7 +30,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
using FunctionTable = std::unordered_map<std::string, Function&>; using FunctionTable = std::unordered_map<std::string, Function&>;
using ValueTable = std::unordered_map<std::string, SugaredValuePtr>; using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
@ -491,7 +490,7 @@ struct Environment {
if (!retval) { if (!retval) {
if (auto type = resolver->resolveType(ident, range)) { if (auto type = resolver->resolveType(ident, range)) {
if (auto tuple_type = type->cast<TupleType>()) { if (auto tuple_type = type->cast<TupleType>()) {
retval = std::make_shared<script::NamedTupleConstructor>(tuple_type); retval = std::make_shared<NamedTupleConstructor>(tuple_type);
} }
} }
} }
@ -503,7 +502,7 @@ struct Environment {
if (!retval) { if (!retval) {
if (auto type = resolver->resolveType(ident, range)) { if (auto type = resolver->resolveType(ident, range)) {
if (auto class_type = type->cast<ClassType>()) { if (auto class_type = type->cast<ClassType>()) {
retval = std::make_shared<script::ClassValue>(class_type); retval = std::make_shared<ClassValue>(class_type);
} }
} }
} }
@ -3598,7 +3597,7 @@ std::vector<Function*> CompilationUnit::define(
void runCleanupPasses(std::shared_ptr<Graph>& to_clean) { void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
liftClosures(to_clean); liftClosures(to_clean);
inlineForkedClosures(to_clean); inlineForkedClosures(to_clean);
if (script::getInlineEverythingMode()) { if (getInlineEverythingMode()) {
Inline(*to_clean); Inline(*to_clean);
} }
// remove any uses of tuples that we inserted that are not needed // remove any uses of tuples that we inserted that are not needed
@ -3669,6 +3668,5 @@ void CompilationUnit::define_interface(
this->register_type(iface); this->register_type(iface);
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -12,12 +12,10 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
TORCH_API void runCleanupPasses(std::shared_ptr<Graph>& to_clean); TORCH_API void runCleanupPasses(std::shared_ptr<Graph>& to_clean);
TORCH_API bool meaningfulName(const std::string& name); TORCH_API bool meaningfulName(const std::string& name);
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -8,7 +8,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
static const std::unordered_map<int, int> binary_prec = { static const std::unordered_map<int, int> binary_prec = {
{TK_IF, 1}, {TK_IF, 1},
@ -101,6 +100,5 @@ C10_EXPORT SharedParserData& sharedParserData() {
return data; return data;
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -16,7 +16,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
// single character tokens are just the character itself '+' // single character tokens are just the character itself '+'
// multi-character tokens need an entry here // multi-character tokens need an entry here
@ -180,7 +179,7 @@ struct CAFFE2_API SharedParserData {
return false; return false;
const char* startptr = str.c_str() + start; const char* startptr = str.c_str() + start;
char* endptr; char* endptr;
torch::jit::script::strtod_c(startptr, &endptr); torch::jit::strtod_c(startptr, &endptr);
*len = endptr - startptr; *len = endptr - startptr;
return *len > 0; return *len > 0;
} }
@ -515,6 +514,5 @@ struct Lexer {
std::vector<Token> next_tokens; std::vector<Token> next_tokens;
SharedParserData& shared; SharedParserData& shared;
}; };
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -5,7 +5,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
// Simple data structure for containing a type T in nested control blocks // Simple data structure for containing a type T in nested control blocks
// Should only be used after initial compilation where type checking and // Should only be used after initial compilation where type checking and
@ -52,6 +51,5 @@ struct MiniEnvironment {
std::unordered_map<std::string, T> table; std::unordered_map<std::string, T> table;
}; };
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -5,7 +5,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
inline bool isCharCount(char c, const std::string& str, size_t start, int len) { inline bool isCharCount(char c, const std::string& str, size_t start, int len) {
// count checks from [start, start + len) // count checks from [start, start + len)
@ -87,6 +86,5 @@ inline std::string parseStringLiteral(
return ret_str; return ret_str;
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -7,7 +7,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
Decl mergeTypesFromTypeComment( Decl mergeTypesFromTypeComment(
const Decl& decl, const Decl& decl,
@ -735,6 +734,5 @@ Expr Parser::parseExp() {
return pImpl->parseExp(); return pImpl->parseExp();
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -6,7 +6,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
struct Decl; struct Decl;
struct ParserImpl; struct ParserImpl;
@ -30,6 +29,5 @@ struct TORCH_API Parser {
std::unique_ptr<ParserImpl> pImpl; std::unique_ptr<ParserImpl> pImpl;
}; };
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -2,8 +2,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|~"; static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|~";
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -6,7 +6,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
struct Resolver; struct Resolver;
using ResolverPtr = std::shared_ptr<Resolver>; using ResolverPtr = std::shared_ptr<Resolver>;
@ -65,6 +64,5 @@ struct NativeResolver : public Resolver {
inline std::shared_ptr<NativeResolver> nativeResolver() { inline std::shared_ptr<NativeResolver> nativeResolver() {
return std::make_shared<NativeResolver>(); return std::make_shared<NativeResolver>();
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -6,7 +6,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
static inline TypePtr unwrapOptional(TypePtr opt_type) { static inline TypePtr unwrapOptional(TypePtr opt_type) {
if (auto unwrap_list_type = opt_type->cast<OptionalType>()) { if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
@ -612,6 +611,5 @@ Value* emitBuiltinCall(
} }
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -7,7 +7,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
// try to match a list if inputs and keyword 'attributes' to this schema, // try to match a list if inputs and keyword 'attributes' to this schema,
// if it works return the flat list of positional inputs to the call // if it works return the flat list of positional inputs to the call
@ -60,6 +59,5 @@ TORCH_API Value* tryConvertToType(
const TypePtr& concrete_type, const TypePtr& concrete_type,
Value* value, Value* value,
bool allow_conversions); bool allow_conversions);
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -31,7 +31,6 @@ using c10::VarType;
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
TypePtr SchemaTypeParser::parseBaseType() { TypePtr SchemaTypeParser::parseBaseType() {
static std::unordered_map<std::string, TypePtr> type_map = { static std::unordered_map<std::string, TypePtr> type_map = {
@ -292,6 +291,5 @@ void SchemaTypeParser::parseList(
L.expect(end); L.expect(end);
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -7,7 +7,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
using TypePtr = c10::TypePtr; using TypePtr = c10::TypePtr;
@ -33,6 +32,5 @@ struct CAFFE2_API SchemaTypeParser {
Lexer& L; Lexer& L;
size_t next_id = 0; size_t next_id = 0;
}; };
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -4,7 +4,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
const std::unordered_map<std::string, TypePtr>& string_to_type_lut(); const std::unordered_map<std::string, TypePtr>& string_to_type_lut();
namespace { namespace {
@ -355,6 +354,5 @@ c10::IValue ScriptTypeParser::parseClassConstant(const Assign& assign) {
return *default_val.begin(); return *default_val.begin();
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -6,7 +6,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
/** /**
* class ScriptTypeParser * class ScriptTypeParser
@ -45,6 +44,5 @@ class TORCH_API ScriptTypeParser {
ResolverPtr resolver_ = nullptr; ResolverPtr resolver_ = nullptr;
}; };
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -2,7 +2,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
using namespace c10; using namespace c10;
const std::unordered_map<std::string, TypePtr>& string_to_type_lut() { const std::unordered_map<std::string, TypePtr>& string_to_type_lut() {
static std::unordered_map<std::string, TypePtr> map = { static std::unordered_map<std::string, TypePtr> map = {
@ -22,6 +21,5 @@ const std::unordered_map<std::string, TypePtr>& string_to_type_lut() {
{"tuple", AnyTupleType::get()}}; {"tuple", AnyTupleType::get()}};
return map; return map;
} }
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -76,7 +76,6 @@ double parse_inf_or_nan(const char *p, char **endptr)
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
#ifdef _MSC_VER #ifdef _MSC_VER
C10_EXPORT double strtod_c(const char *nptr, char **endptr) C10_EXPORT double strtod_c(const char *nptr, char **endptr)
@ -265,4 +264,3 @@ C10_EXPORT float strtof_c(const char *nptr, char **endptr)
} }
} }
}

View File

@ -4,11 +4,9 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
CAFFE2_API double strtod_c(const char *nptr, char **endptr); CAFFE2_API double strtod_c(const char *nptr, char **endptr);
CAFFE2_API float strtof_c(const char *nptr, char **endptr); CAFFE2_API float strtof_c(const char *nptr, char **endptr);
} }
} }
}

View File

@ -6,7 +6,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
struct NoneValue : SugaredValue { struct NoneValue : SugaredValue {
NoneValue() = default; NoneValue() = default;
@ -623,6 +622,5 @@ std::shared_ptr<BuiltinFunction> BuiltinFunction::tryCreate(
return nullptr; return nullptr;
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -10,7 +10,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
using SugaredValuePtr = std::shared_ptr<SugaredValue>; using SugaredValuePtr = std::shared_ptr<SugaredValue>;
@ -365,7 +364,7 @@ struct FunctionValue : public SugaredValue {
try { try {
callee->ensure_defined(); callee->ensure_defined();
} catch (const RecursiveMethodCallError&) { } catch (const RecursiveMethodCallError&) {
throw script::ErrorReport(loc) throw ErrorReport(loc)
<< " function '" << callee->name() << "' is called recursively. " << " function '" << callee->name() << "' is called recursively. "
<< "Recursive calls are not supported"; << "Recursive calls are not supported";
} }
@ -428,7 +427,7 @@ struct MethodValue : public SugaredValue {
try { try {
method->ensure_defined(); method->ensure_defined();
} catch (const RecursiveMethodCallError&) { } catch (const RecursiveMethodCallError&) {
throw script::ErrorReport(loc) throw ErrorReport(loc)
<< " method '" << method->name() << "' is called recursively. " << " method '" << method->name() << "' is called recursively. "
<< "Recursive calls are not supported"; << "Recursive calls are not supported";
} }
@ -652,6 +651,5 @@ struct SimpleSelf : public Self {
private: private:
ClassTypePtr classType_; ClassTypePtr classType_;
}; };
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -152,7 +152,7 @@ Value* TracingState::getValue(const IValue& var) {
// Find torchbind classes // Find torchbind classes
if (isCustomClass(var)) { if (isCustomClass(var)) {
auto obj = script::Object(var.toObject()); auto obj = Object(var.toObject());
auto qualname = obj.type()->name(); auto qualname = obj.type()->name();
auto custom_class_type = getCustomClass(qualname->qualifiedName()); auto custom_class_type = getCustomClass(qualname->qualifiedName());
if (custom_class_type) { if (custom_class_type) {
@ -313,14 +313,14 @@ static IValue addInput(const std::shared_ptr<TracingState> & state, const IValue
static void gatherParametersAndBuffers( static void gatherParametersAndBuffers(
const std::shared_ptr<TracingState>& state, const std::shared_ptr<TracingState>& state,
Value* self_value, Value* self_value,
const script::Module& self, const Module& self,
const std::string& prefix) { const std::string& prefix) {
Graph& g = *self_value->owningGraph(); Graph& g = *self_value->owningGraph();
state->setValue(self._ivalue(), self_value); state->setValue(self._ivalue(), self_value);
auto self_ty = self.type(); auto self_ty = self.type();
for (const script::NameValue& s : self.named_attributes(/*recurse=*/false)) { for (const NameValue& s : self.named_attributes(/*recurse=*/false)) {
auto qualname = prefix + "." + s.name; auto qualname = prefix + "." + s.name;
Value* trace_get_attr = g.insertNode(g.create(prim::TracedAttr)) Value* trace_get_attr = g.insertNode(g.create(prim::TracedAttr))
->s_(attr::scope, qualname) ->s_(attr::scope, qualname)
@ -335,7 +335,7 @@ static void gatherParametersAndBuffers(
} }
if (self_ty->getAttribute(s.name)->is_module()) { if (self_ty->getAttribute(s.name)->is_module()) {
gatherParametersAndBuffers( gatherParametersAndBuffers(
state, trace_get_attr, script::Module(s.value.toObject()), qualname); state, trace_get_attr, Module(s.value.toObject()), qualname);
} }
} }
} }
@ -345,7 +345,7 @@ std::pair<std::shared_ptr<TracingState>, Stack> trace(
const std::function<Stack(Stack)>& traced_fn, const std::function<Stack(Stack)>& traced_fn,
std::function<std::string(const Variable&)> var_name_lookup_fn, std::function<std::string(const Variable&)> var_name_lookup_fn,
bool force_outplace, bool force_outplace,
script::Module* self) { Module* self) {
try { try {
// Start tracing, treating 'inputs' as inputs to the trace, which can be // Start tracing, treating 'inputs' as inputs to the trace, which can be
// varied on subsequent invocations of the trace. Any other variables // varied on subsequent invocations of the trace. Any other variables
@ -387,7 +387,7 @@ std::pair<std::shared_ptr<TracingState>, Stack> trace(
} }
setTracingState(nullptr); setTracingState(nullptr);
if (script::getInlineEverythingMode()) { if (getInlineEverythingMode()) {
Inline(*graph); Inline(*graph);
} }
FixupTraceScopeBlocks(graph, self); FixupTraceScopeBlocks(graph, self);

View File

@ -21,10 +21,7 @@ namespace jit {
struct Node; struct Node;
struct Value; struct Value;
struct Graph; struct Graph;
struct Module;
namespace script {
struct Module;
}
namespace tracer { namespace tracer {
@ -210,7 +207,7 @@ TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> trace(
const std::function<Stack(Stack)>& traced_fn, const std::function<Stack(Stack)>& traced_fn,
std::function<std::string(const Variable&)> var_name_lookup_fn, std::function<std::string(const Variable&)> var_name_lookup_fn,
bool force_outplace = false, bool force_outplace = false,
script::Module* self = nullptr); Module* self = nullptr);
TORCH_API void abandon(); TORCH_API void abandon();

View File

@ -11,7 +11,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
// Tree's are used to represent all forms of TC IR, pre- and post- typechecking. // Tree's are used to represent all forms of TC IR, pre- and post- typechecking.
// Rather than have a full class hierarchy for all TC statements, // Rather than have a full class hierarchy for all TC statements,
@ -221,6 +220,5 @@ static inline std::ostream& operator<<(std::ostream& out, const TreeRef& t) {
return out << pretty_tree(t); return out << pretty_tree(t);
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -10,7 +10,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
// clang-format off // clang-format off
// TreeView provides a statically-typed way to traverse the tree, which should // TreeView provides a statically-typed way to traverse the tree, which should
@ -813,7 +812,7 @@ struct Const : public Expr {
// We can't pass in nullptr as the dummy pointer gets dereferenced for // We can't pass in nullptr as the dummy pointer gets dereferenced for
// Android version of strtod_c(). // Android version of strtod_c().
char* dummy; char* dummy;
return torch::jit::script::strtod_c( return torch::jit::strtod_c(
subtree(0)->stringValue().c_str(), &dummy); subtree(0)->stringValue().c_str(), &dummy);
} }
const std::string& text() const { const std::string& text() const {
@ -1045,14 +1044,13 @@ struct Delete : public Stmt {
} }
}; };
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch
namespace std { namespace std {
template <typename T> template <typename T>
struct iterator_traits<torch::jit::script::ListIterator<T>> struct iterator_traits<torch::jit::ListIterator<T>>
: std::iterator_traits<torch::jit::script::TreeList::const_iterator> {}; : std::iterator_traits<torch::jit::TreeList::const_iterator> {};
} // namespace std } // namespace std

View File

@ -963,7 +963,7 @@ const Operator& Node::getOperator() const {
if (maybe) if (maybe)
return *maybe; return *maybe;
auto er = script::ErrorReport(sourceRange()); auto er = ErrorReport(sourceRange());
er << "Schema not found for node. File a bug report.\n"; er << "Schema not found for node. File a bug report.\n";
er << "Node: " << *this << "\n"; er << "Node: " << *this << "\n";
er << "Input types:"; er << "Input types:";
@ -1491,7 +1491,7 @@ Value* Graph::insert(
at::ArrayRef<NamedValue> args, at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs, at::ArrayRef<NamedValue> kwargs,
const c10::optional<SourceRange>& range) { const c10::optional<SourceRange>& range) {
return script::emitBuiltinCall( return emitBuiltinCall(
range.value_or(fakeRange()), *this, opname, args, kwargs); range.value_or(fakeRange()), *this, opname, args, kwargs);
} }
@ -1721,7 +1721,7 @@ Value* Graph::insertToList(Value* v, TypePtr type) {
Value* Graph::insertFunctionCall( Value* Graph::insertFunctionCall(
Function* callee, Function* callee,
const script::MatchedSchema& matched) { const MatchedSchema& matched) {
std::string func_name = callee->name(); std::string func_name = callee->name();
Value* fn_constant = insertNode(create(prim::Constant)) Value* fn_constant = insertNode(create(prim::Constant))
->s_(attr::name, func_name) ->s_(attr::name, func_name)
@ -1737,7 +1737,7 @@ Value* Graph::insertFunctionCall(
Value* Graph::insertMethodCall( Value* Graph::insertMethodCall(
std::string method_name, std::string method_name,
const script::MatchedSchema& matched) { const MatchedSchema& matched) {
Value* result = insertNode(create(prim::CallMethod, matched.inputs)) Value* result = insertNode(create(prim::CallMethod, matched.inputs))
->s_(attr::name, std::move(method_name)) ->s_(attr::name, std::move(method_name))
->output() ->output()

View File

@ -73,9 +73,7 @@ using namespace ::c10::aten;
} }
struct Function; struct Function;
namespace script {
struct MatchedSchema; struct MatchedSchema;
} // namespace script
// Graph represents one "function" of computation. // Graph represents one "function" of computation.
// It uses a simple ownership model where the graph owns all the nodes inside // It uses a simple ownership model where the graph owns all the nodes inside
@ -1134,10 +1132,10 @@ struct Graph {
TORCH_API Value* insertFunctionCall( TORCH_API Value* insertFunctionCall(
Function* callee, Function* callee,
const script::MatchedSchema& matched); const MatchedSchema& matched);
TORCH_API Value* insertMethodCall( TORCH_API Value* insertMethodCall(
std::string method_name, std::string method_name,
const script::MatchedSchema& matched); const MatchedSchema& matched);
// Note: defined in python_ir.cpp and can be used only in python extension // Note: defined in python_ir.cpp and can be used only in python extension
Node* createPythonOp( Node* createPythonOp(

View File

@ -9,7 +9,6 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace script {
struct VarWithType; struct VarWithType;
struct ParsedLiteral; struct ParsedLiteral;
@ -57,7 +56,7 @@ class IRParser {
Value* findValueInVMap(const std::string& name); Value* findValueInVMap(const std::string& name);
torch::jit::script::Lexer L; torch::jit::Lexer L;
torch::jit::Graph* g = nullptr; torch::jit::Graph* g = nullptr;
std::unordered_map<std::string, Value*>& vmap; std::unordered_map<std::string, Value*>& vmap;
SchemaTypeParser type_parser; SchemaTypeParser type_parser;
@ -86,7 +85,7 @@ void parseIR(
const std::string& str, const std::string& str,
torch::jit::Graph* graph, torch::jit::Graph* graph,
std::unordered_map<std::string, Value*>& vmap) { std::unordered_map<std::string, Value*>& vmap) {
torch::jit::script::IRParser p(str, graph, vmap); torch::jit::IRParser p(str, graph, vmap);
p.parse(); p.parse();
} }
@ -473,6 +472,5 @@ Value* IRParser::findValueInVMap(const std::string& name) {
return vmap.at(name); return vmap.at(name);
} }
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -12,8 +12,6 @@ namespace jit {
struct Graph; struct Graph;
struct Value; struct Value;
namespace script {
// \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH. // \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH.
TORCH_API void parseIR(const std::string& str, torch::jit::Graph* graph); TORCH_API void parseIR(const std::string& str, torch::jit::Graph* graph);
@ -27,6 +25,5 @@ TORCH_API void parseIR(
torch::jit::Graph* graph, torch::jit::Graph* graph,
std::unordered_map<std::string, Value*>& vmap); std::unordered_map<std::string, Value*>& vmap);
} // namespace script
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -125,14 +125,14 @@ class BytecodeDeserializer final {
private: private:
c10::IValue readArchive(const std::string& archive_name, c10::IValue readArchive(const std::string& archive_name,
std::shared_ptr<mobile::CompilationUnit> mcu); std::shared_ptr<mobile::CompilationUnit> mcu);
std::shared_ptr<script::CompilationUnit> compilation_unit_; std::shared_ptr<CompilationUnit> compilation_unit_;
std::unordered_set<std::string> imported_libs_; std::unordered_set<std::string> imported_libs_;
std::unique_ptr<PyTorchStreamReader> reader_; std::unique_ptr<PyTorchStreamReader> reader_;
c10::optional<at::Device> device_; c10::optional<at::Device> device_;
}; };
BytecodeDeserializer::BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader) BytecodeDeserializer::BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader)
: compilation_unit_(std::make_shared<script::CompilationUnit>()), reader_(std::move(reader)) {} : compilation_unit_(std::make_shared<CompilationUnit>()), reader_(std::move(reader)) {}
mobile::Module BytecodeDeserializer::deserialize(c10::optional<at::Device> device) { mobile::Module BytecodeDeserializer::deserialize(c10::optional<at::Device> device) {
device_ = device; device_ = device;

Some files were not shown because too many files have changed in this diff Show More