mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Follows #130509 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130674 Approved by: https://github.com/Skylion007
		
			
				
	
	
		
			922 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			922 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#include <test/cpp/jit/test_utils.h>
 | 
						|
 | 
						|
#include <gtest/gtest.h>
 | 
						|
 | 
						|
#include <c10/core/TensorOptions.h>
 | 
						|
#include <torch/csrc/autograd/generated/variable_factories.h>
 | 
						|
#include <torch/csrc/jit/api/module.h>
 | 
						|
#include <torch/csrc/jit/frontend/resolver.h>
 | 
						|
#include <torch/csrc/jit/mobile/compatibility/backport.h>
 | 
						|
#include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
 | 
						|
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
 | 
						|
#include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
 | 
						|
#include <torch/csrc/jit/mobile/import.h>
 | 
						|
#include <torch/csrc/jit/mobile/interpreter.h>
 | 
						|
#include <torch/csrc/jit/mobile/module.h>
 | 
						|
#include <torch/csrc/jit/mobile/parse_bytecode.h>
 | 
						|
#include <torch/csrc/jit/mobile/parse_operators.h>
 | 
						|
#include <torch/csrc/jit/serialization/export.h>
 | 
						|
#include <torch/csrc/jit/serialization/export_bytecode.h>
 | 
						|
#include <torch/csrc/jit/serialization/import.h>
 | 
						|
#include <torch/custom_class.h>
 | 
						|
#include <torch/torch.h>
 | 
						|
 | 
						|
#include <unordered_set>
 | 
						|
 | 
						|
// Tests go in torch::jit
 | 
						|
namespace torch {
 | 
						|
namespace jit {
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, UpsampleNearest2d) {
 | 
						|
  Module m("m");
 | 
						|
  m.define(R"(
 | 
						|
    def forward(self, input: Tensor, scale:float):
 | 
						|
      return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
 | 
						|
  )");
 | 
						|
 | 
						|
  std::vector<IValue> inputs;
 | 
						|
  inputs.emplace_back(torch::rand({1, 3, 128, 128}));
 | 
						|
  inputs.emplace_back(at::Scalar(2.0));
 | 
						|
  auto ref = m.forward(inputs);
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  IValue res;
 | 
						|
  res = bc.forward(inputs);
 | 
						|
 | 
						|
  auto resd = res.toTensor();
 | 
						|
  auto refd = ref.toTensor();
 | 
						|
  ASSERT_TRUE(resd.equal(refd));
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, CheckAttrAccess) {
 | 
						|
  Module m("m");
 | 
						|
  m.register_attribute("mobile_optimized", BoolType::get(), true);
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
 | 
						|
 | 
						|
  AT_ASSERT(mobile_optimized);
 | 
						|
  m.setattr("mobile_optimized", false);
 | 
						|
  bc = jitModuleToMobile(m, options);
 | 
						|
  mobile_optimized = bc.attr("mobile_optimized", false).toBool();
 | 
						|
  AT_ASSERT(!mobile_optimized);
 | 
						|
}
 | 
						|
 | 
						|
TEST(
 | 
						|
    LiteInterpreterDirectTest,
 | 
						|
    MethodInvocation) { // NOLINT (use =delete in gtest)
 | 
						|
  const std::vector<std::string> test_programs{
 | 
						|
      // test invoking a method with default parameter
 | 
						|
      R"(
 | 
						|
      def test_func(self, x, b : int = 4):
 | 
						|
        return self.foo + x + b
 | 
						|
      )",
 | 
						|
      // inner method call with default parameter (gets inlined)
 | 
						|
      R"(
 | 
						|
      def add_with_default_arg(self, x, b : int = 4):
 | 
						|
        return self.foo + x + b
 | 
						|
      def test_func(self, x):
 | 
						|
        return self.add_with_default_arg(x)  # invoke method w/ default arg
 | 
						|
      )",
 | 
						|
      // simple method call
 | 
						|
      R"(
 | 
						|
      def test_func(self, x):
 | 
						|
        b = 4
 | 
						|
        return self.foo + x + b
 | 
						|
      )",
 | 
						|
  };
 | 
						|
  for (const auto& test_program : test_programs) {
 | 
						|
    Module m("m");
 | 
						|
    m.register_parameter("foo", torch::ones({}), false);
 | 
						|
    m.define(test_program);
 | 
						|
 | 
						|
    const int fortyTwo = 42; // (keep linter happy)
 | 
						|
    auto minput = fortyTwo * torch::ones({});
 | 
						|
    auto ref = m.run_method("test_func", minput);
 | 
						|
 | 
						|
    CompilationOptions options;
 | 
						|
    mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
    const auto& test_func = bc.get_method("test_func");
 | 
						|
    std::cerr << "hello " << std::endl;
 | 
						|
    IValue res;
 | 
						|
    for (int i = 0; i < 3; ++i) {
 | 
						|
      res = test_func({minput});
 | 
						|
    }
 | 
						|
    std::cerr << "hello 3" << std::endl;
 | 
						|
 | 
						|
    auto resd = res.toTensor().item<float>();
 | 
						|
    auto refd = ref.toTensor().item<float>();
 | 
						|
    AT_ASSERT(resd == refd);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, Conv) {
 | 
						|
  auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
 | 
						|
  if (s && strcmp(s, "1") == 0)
 | 
						|
    return;
 | 
						|
 | 
						|
  std::vector<torch::jit::IValue> inputs;
 | 
						|
 | 
						|
  Module m("m");
 | 
						|
  m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
 | 
						|
  m.register_parameter("bias", torch::ones({20}), false);
 | 
						|
  m.define(R"(
 | 
						|
    def forward(self, input):
 | 
						|
      return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
 | 
						|
  )");
 | 
						|
 | 
						|
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
 | 
						|
  inputs.push_back(torch::ones({1, 1, 28, 28}));
 | 
						|
 | 
						|
  auto outputref = m.forward(inputs).toTensor();
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  IValue res;
 | 
						|
  for (int i = 0; i < 3; ++i) {
 | 
						|
    res = bc.get_method("forward")(inputs);
 | 
						|
  }
 | 
						|
  auto output = res.toTensor();
 | 
						|
  AT_ASSERT(outputref.dim() == output.dim());
 | 
						|
  AT_ASSERT(
 | 
						|
      outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, Inline) {
 | 
						|
  Module m("m");
 | 
						|
  m.define(R"JIT(
 | 
						|
  def foo1(self, x):
 | 
						|
      return x + 1
 | 
						|
 | 
						|
  def foo2(self, x):
 | 
						|
      return self.foo1(x) + 2
 | 
						|
 | 
						|
  def foo3(self, x):
 | 
						|
      return self.foo2(x) + 3
 | 
						|
  )JIT");
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  std::vector<torch::jit::IValue> inputs({torch::ones({})});
 | 
						|
  auto output = bc.get_method("foo3")(inputs);
 | 
						|
  AT_ASSERT(output.toTensor().item<float>() == 7.0);
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, Tuple) {
 | 
						|
  Module m("m");
 | 
						|
  m.define(R"JIT(
 | 
						|
  def foo(self, x):
 | 
						|
      return (1, 2, x + 3)
 | 
						|
 | 
						|
  def forward(self, x):
 | 
						|
      tuple = self.foo(x)
 | 
						|
      return tuple
 | 
						|
  )JIT");
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  std::vector<torch::jit::IValue> inputs({torch::ones({})});
 | 
						|
  auto output = bc.get_method("forward")(inputs);
 | 
						|
  AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, Dict) {
 | 
						|
  Module m("m");
 | 
						|
  m.define(R"JIT(
 | 
						|
  def foo(self, x):
 | 
						|
      return {"result": x + 1}
 | 
						|
 | 
						|
  def forward(self, x):
 | 
						|
      d = self.foo(x)
 | 
						|
      return d
 | 
						|
  )JIT");
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  std::vector<torch::jit::IValue> inputs({torch::ones({})});
 | 
						|
  auto output = bc.get_method("forward")(inputs);
 | 
						|
  AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, Prim) {
 | 
						|
  Module m("m");
 | 
						|
  m.define(R"JIT(
 | 
						|
        def forward(self, x):
 | 
						|
            return int(x)
 | 
						|
  )JIT");
 | 
						|
 | 
						|
  std::vector<IValue> inputs;
 | 
						|
  auto minput = 3.5 * torch::ones({});
 | 
						|
  inputs.emplace_back(minput);
 | 
						|
  auto ref = m.run_method("forward", minput);
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
 | 
						|
  IValue res;
 | 
						|
  for (int i = 0; i < 3; ++i) {
 | 
						|
    // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
 | 
						|
    auto bcinputs = inputs;
 | 
						|
    res = bc.get_method("forward")(bcinputs);
 | 
						|
  }
 | 
						|
 | 
						|
  auto resi = res.toInt();
 | 
						|
  auto refi = ref.toInt();
 | 
						|
  AT_ASSERT(resi == refi);
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, PrimScalar) {
 | 
						|
  Module m("m");
 | 
						|
  m.define(R"JIT(
 | 
						|
        def forward(self, x):
 | 
						|
            return int(x.item())
 | 
						|
  )JIT");
 | 
						|
 | 
						|
  std::vector<IValue> inputs;
 | 
						|
  auto minput = 3.5 * torch::ones({});
 | 
						|
  inputs.emplace_back(minput);
 | 
						|
  auto ref = m.run_method("forward", minput);
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  IValue res;
 | 
						|
  for (int i = 0; i < 3; ++i) {
 | 
						|
    // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
 | 
						|
    auto bcinputs = inputs;
 | 
						|
    res = bc.get_method("forward")(bcinputs);
 | 
						|
  }
 | 
						|
 | 
						|
  auto resi = res.toInt();
 | 
						|
  auto refi = ref.toInt();
 | 
						|
  AT_ASSERT(resi == refi);
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, WrongMethodName) {
 | 
						|
  Module m("m");
 | 
						|
  m.register_parameter("foo", torch::ones({}), false);
 | 
						|
  m.define(R"(
 | 
						|
    def add(self, x):
 | 
						|
      b = 4
 | 
						|
      return self.foo + x + b
 | 
						|
  )");
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  std::vector<IValue> inputs;
 | 
						|
  auto minput = 5 * torch::ones({});
 | 
						|
  inputs.emplace_back(minput);
 | 
						|
  ASSERT_THROWS_WITH_MESSAGE(
 | 
						|
      bc.get_method("forward")(inputs), "is not defined");
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, SetState) {
 | 
						|
  Module m("m");
 | 
						|
  m.register_parameter("foo", torch::ones({}), false);
 | 
						|
  m.define(R"(
 | 
						|
    def __getstate__(self):
 | 
						|
      return self.foo
 | 
						|
    def __setstate__(self, a):
 | 
						|
      self.foo = a
 | 
						|
    def forward(self, x):
 | 
						|
      b = 4
 | 
						|
      return self.foo + x + b
 | 
						|
  )");
 | 
						|
 | 
						|
  std::vector<IValue> inputs;
 | 
						|
  auto minput = 5 * torch::ones({});
 | 
						|
  inputs.emplace_back(minput);
 | 
						|
 | 
						|
  std::stringstream ms;
 | 
						|
  m.save(ms);
 | 
						|
  auto loaded_m = load(ms);
 | 
						|
  auto ref = loaded_m.run_method("forward", minput);
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  IValue res;
 | 
						|
  for (int i = 0; i < 3; ++i) {
 | 
						|
    // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
 | 
						|
    auto bcinputs = inputs;
 | 
						|
    res = bc.get_method("forward")(bcinputs);
 | 
						|
  }
 | 
						|
 | 
						|
  auto resd = res.toTensor().item<float>();
 | 
						|
  auto refd = ref.toTensor().item<float>();
 | 
						|
  AT_ASSERT(resd == refd);
 | 
						|
}
 | 
						|
 | 
						|
class TorchBindLiteInterpreterDirectTestStruct
 | 
						|
    : public torch::jit::CustomClassHolder {
 | 
						|
 public:
 | 
						|
  std::string get(at::Tensor t) {
 | 
						|
    std::stringstream ss;
 | 
						|
    ss << "Hello! Your tensor has ";
 | 
						|
    ss << t.numel();
 | 
						|
    ss << " elements!";
 | 
						|
    return ss.str();
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
namespace {
 | 
						|
struct ClassNamespaceValue : public SugaredValue {
 | 
						|
  explicit ClassNamespaceValue(c10::QualifiedName name)
 | 
						|
      : basename_(std::move(name)) {}
 | 
						|
 | 
						|
  std::shared_ptr<SugaredValue> attr(
 | 
						|
      const SourceRange&,
 | 
						|
      GraphFunction&,
 | 
						|
      const std::string& name) override {
 | 
						|
    const auto fullName = c10::QualifiedName(basename_, name);
 | 
						|
 | 
						|
    // Check to see if it is a custom class.
 | 
						|
    if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
 | 
						|
      return std::make_shared<ClassValue>(custom_class);
 | 
						|
    }
 | 
						|
 | 
						|
    // If it's not a custom class, assume it's another namespace
 | 
						|
    // NOLINTNEXTLINE(performance-move-const-arg)
 | 
						|
    return std::make_shared<ClassNamespaceValue>(fullName);
 | 
						|
  }
 | 
						|
 | 
						|
  std::string kind() const override {
 | 
						|
    return "Class Namespace";
 | 
						|
  }
 | 
						|
 | 
						|
 private:
 | 
						|
  c10::QualifiedName basename_;
 | 
						|
};
 | 
						|
 | 
						|
struct TestModuleResolver : public Resolver {
 | 
						|
  std::shared_ptr<SugaredValue> resolveValue(
 | 
						|
      const std::string& name,
 | 
						|
      GraphFunction&,
 | 
						|
      const SourceRange&) override {
 | 
						|
    if (name == "torch") {
 | 
						|
      return std::make_shared<BuiltinModule>("aten");
 | 
						|
    } else if (name == "__torch__") {
 | 
						|
      return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
 | 
						|
    }
 | 
						|
 | 
						|
    return nullptr;
 | 
						|
  }
 | 
						|
 | 
						|
  TypePtr resolveType(const std::string&, const SourceRange&) override {
 | 
						|
    return nullptr;
 | 
						|
  }
 | 
						|
};
 | 
						|
} // namespace
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, BuiltinFunction) {
 | 
						|
  script::Module m("m");
 | 
						|
  auto custom_class_obj =
 | 
						|
      make_custom_class<TorchBindLiteInterpreterDirectTestStruct>();
 | 
						|
  m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
 | 
						|
  m.define(R"(
 | 
						|
    def forward(self, x) -> str:
 | 
						|
      return self.my_obj.get(x)
 | 
						|
  )");
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  auto res =
 | 
						|
      bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
 | 
						|
  // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
 | 
						|
  auto str = res.toStringRef();
 | 
						|
  std::string expected = "Hello! Your tensor has 12 elements!";
 | 
						|
  AT_ASSERT(str == expected);
 | 
						|
}
 | 
						|
 | 
						|
#if !defined FB_XPLAT_BUILD
 | 
						|
TEST(LiteInterpreterDirectTest, GetRuntimeByteCodeVersion) {
 | 
						|
  auto runtime_bytecode_version = _get_runtime_bytecode_version();
 | 
						|
  AT_ASSERT(
 | 
						|
      runtime_bytecode_version ==
 | 
						|
      caffe2::serialize::kMaxSupportedBytecodeVersion);
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, GetRuntimeOperatorsVersion) {
 | 
						|
  auto runtime_operators_version = _get_runtime_operators_min_max_versions();
 | 
						|
  AT_ASSERT(
 | 
						|
      runtime_operators_version.first ==
 | 
						|
          caffe2::serialize::kMinSupportedFileFormatVersion &&
 | 
						|
      runtime_operators_version.second ==
 | 
						|
          caffe2::serialize::kMaxSupportedFileFormatVersion);
 | 
						|
}
 | 
						|
 | 
						|
/**
 | 
						|
 * The test below is disarmed for FB internal xplat builds since
 | 
						|
 * BUCK requires us to pass in the script_module_v4.ptl file in
 | 
						|
 * as a resource dependency of the build rule for this file, and
 | 
						|
 * we would need to access it via the C++ Resources API instead
 | 
						|
 * of directly reading from disk (which is what the open source
 | 
						|
 * build/run does).
 | 
						|
 */
 | 
						|
TEST(LiteInterpreterDirectTest, GetByteCodeVersion) {
 | 
						|
  std::string filePath(__FILE__);
 | 
						|
  auto test_model_file_v4 =
 | 
						|
      filePath.substr(0, filePath.find_last_of("/\\") + 1);
 | 
						|
  test_model_file_v4.append("script_module_v4.ptl");
 | 
						|
 | 
						|
  auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
 | 
						|
  AT_ASSERT(version_v4 == 4);
 | 
						|
}
 | 
						|
 | 
						|
#endif // !defined(FB_XPLAT_BUILD)
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, GetRuntimeOpsAndInfo) {
 | 
						|
  auto runtime_ops = _get_runtime_ops_and_info();
 | 
						|
  // Ballpark estimate of the minimal number of ops; just used to
 | 
						|
  // verify API returns a reasonably large number.
 | 
						|
  AT_ASSERT(runtime_ops.size() > 2900);
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, Eval) {
 | 
						|
  std::vector<torch::jit::IValue> inputs;
 | 
						|
 | 
						|
  Module m("m");
 | 
						|
  m.define(R"(
 | 
						|
    def __init__(self, x):
 | 
						|
      self.training = True
 | 
						|
 | 
						|
    def forward(self, input):
 | 
						|
      return torch.dropout(input, 1.0, self.training)
 | 
						|
  )");
 | 
						|
 | 
						|
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
 | 
						|
  inputs.push_back(torch::ones({1, 1, 28, 28}));
 | 
						|
  m.eval();
 | 
						|
  auto outputref = m.forward(inputs).toTensor();
 | 
						|
 | 
						|
  // save m in training mode to make sure that mobile eval() will correctly
 | 
						|
  // change back to eval mode
 | 
						|
  m.train();
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  bc.eval();
 | 
						|
  IValue res;
 | 
						|
  for (int i = 0; i < 3; ++i) {
 | 
						|
    res = bc.get_method("forward")(inputs);
 | 
						|
  }
 | 
						|
  auto output = res.toTensor();
 | 
						|
  AT_ASSERT(outputref.dim() == output.dim());
 | 
						|
  AT_ASSERT(
 | 
						|
      outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, FindWrongMethodName) {
 | 
						|
  Module m("m");
 | 
						|
  m.register_parameter("foo", torch::ones({}), false);
 | 
						|
  m.define(R"(
 | 
						|
    def add(self, x):
 | 
						|
      b = 4
 | 
						|
      return self.foo + x + b
 | 
						|
  )");
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  ASSERT_TRUE(bc.find_method("forward") == std::nullopt);
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, FindAndRunMethod) {
 | 
						|
  Module m("m");
 | 
						|
  m.register_parameter("foo", torch::ones({}), false);
 | 
						|
  m.define(R"(
 | 
						|
    def add_it(self, x):
 | 
						|
      b = 4
 | 
						|
      return self.foo + x + b
 | 
						|
  )");
 | 
						|
 | 
						|
  std::vector<IValue> inputs;
 | 
						|
  auto minput = 5 * torch::ones({});
 | 
						|
  inputs.emplace_back(minput);
 | 
						|
  auto ref = m.get_method("add_it")(inputs);
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  IValue res;
 | 
						|
  for (int i = 0; i < 3; ++i) {
 | 
						|
    auto bcinputs = inputs;
 | 
						|
    auto method = bc.find_method("add_it");
 | 
						|
    AT_ASSERT(method != std::nullopt);
 | 
						|
    res = (*method)(std::move(bcinputs));
 | 
						|
  }
 | 
						|
 | 
						|
  auto resd = res.toTensor().item<float>();
 | 
						|
  auto refd = ref.toTensor().item<float>();
 | 
						|
  AT_ASSERT(resd == refd);
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, RunMethodVariadic) {
 | 
						|
  Module m("m");
 | 
						|
  m.register_parameter("foo", torch::ones({}), false);
 | 
						|
  m.define(R"(
 | 
						|
    def add_three(self, x, y):
 | 
						|
      return self.foo + x + y
 | 
						|
  )");
 | 
						|
 | 
						|
  std::vector<IValue> inputs;
 | 
						|
  auto inputx = 5 * torch::ones({});
 | 
						|
  auto inputy = 4 * torch::ones({});
 | 
						|
  auto ref = m.run_method("add_three", inputx, inputy);
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  IValue res = bc.run_method("add_three", inputx, inputy);
 | 
						|
 | 
						|
  auto resd = res.toTensor().item<float>();
 | 
						|
  auto refd = ref.toTensor().item<float>();
 | 
						|
  AT_ASSERT(resd == refd);
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, DuplicateSetState) {
 | 
						|
  Module m("M");
 | 
						|
  m.register_parameter("foo", torch::ones({}), false);
 | 
						|
  m.define(R"(
 | 
						|
    def __getstate__(self):
 | 
						|
      return self.foo + self.foo
 | 
						|
    def __setstate__(self, a):
 | 
						|
      self.foo = a
 | 
						|
    def forward(self, x):
 | 
						|
      b = 4
 | 
						|
      return self.foo + x + b
 | 
						|
  )");
 | 
						|
 | 
						|
  Module b("B");
 | 
						|
  b.register_module("M0", m);
 | 
						|
  b.register_module("M1", m);
 | 
						|
  b.define(R"(
 | 
						|
    def forward(self, x):
 | 
						|
      return self.M0.forward(x) + self.M1.forward(x)
 | 
						|
  )");
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  const auto methods = bc.get_methods();
 | 
						|
  const size_t expected_n = 3;
 | 
						|
  ASSERT_EQ(methods.size(), expected_n);
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, OpNameExportFetchRootOperators) {
 | 
						|
  torch::jit::Module m("m");
 | 
						|
  m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
 | 
						|
  m.register_parameter("bias", torch::ones({20}), false);
 | 
						|
  m.define(R"(
 | 
						|
    def forward(self, input):
 | 
						|
      x1 = torch.zeros(2, 2)
 | 
						|
      x2 = torch.empty_like(torch.empty(2, 2))
 | 
						|
      x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
 | 
						|
      return (x1, x2, x3)
 | 
						|
  )");
 | 
						|
  m.eval();
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module ptl_model = jitModuleToMobile(m, options);
 | 
						|
  std::set<std::string> operator_names =
 | 
						|
      torch::jit::mobile::_export_operator_list(ptl_model);
 | 
						|
  std::set<std::string> expected_operator_names = {
 | 
						|
      "aten::_convolution",
 | 
						|
      "aten::empty.memory_format",
 | 
						|
      "aten::empty_like",
 | 
						|
      "aten::zeros",
 | 
						|
  };
 | 
						|
  EXPECT_EQ(operator_names, expected_operator_names)
 | 
						|
      << "Expected the root operator lists to be the same";
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, DefaultArgsConv) {
 | 
						|
  auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
 | 
						|
  if (s && strcmp(s, "1") == 0)
 | 
						|
    return;
 | 
						|
 | 
						|
  std::vector<torch::jit::IValue> inputs;
 | 
						|
 | 
						|
  Module m("m");
 | 
						|
  m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
 | 
						|
  m.register_parameter("bias", torch::ones({20}), false);
 | 
						|
  m.define(R"(
 | 
						|
    def forward(self, input):
 | 
						|
      return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
 | 
						|
  )");
 | 
						|
 | 
						|
  inputs.emplace_back(torch::ones({1, 1, 28, 28}));
 | 
						|
 | 
						|
  auto outputref = m.forward(inputs).toTensor();
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  IValue res;
 | 
						|
  for (int i = 0; i < 1; ++i) {
 | 
						|
    res = bc.get_method("forward")(inputs);
 | 
						|
  }
 | 
						|
  auto output = res.toTensor();
 | 
						|
  AT_ASSERT(outputref.dim() == output.dim());
 | 
						|
  AT_ASSERT(output.equal(outputref));
 | 
						|
}
 | 
						|
 | 
						|
namespace {
 | 
						|
void testLiteModuleCompareResultTensors(
 | 
						|
    Module& m,
 | 
						|
    const std::vector<torch::jit::IValue>& inputs,
 | 
						|
    const std::string& method_name = "forward") {
 | 
						|
  auto outputref = m.get_method(method_name)(inputs).toTensor();
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  IValue res;
 | 
						|
  for (int i = 0; i < 3; ++i) {
 | 
						|
    res = bc.get_method(method_name)(inputs);
 | 
						|
  }
 | 
						|
  auto output = res.toTensor();
 | 
						|
  AT_ASSERT(outputref.dim() == output.dim());
 | 
						|
  AT_ASSERT(output.equal(outputref));
 | 
						|
}
 | 
						|
 | 
						|
void testDefaultArgsPinv2(int num_args) {
 | 
						|
  Module m("m");
 | 
						|
  if (num_args == 1) {
 | 
						|
    m.define(R"(
 | 
						|
      def forward(self, input):
 | 
						|
        return torch.linalg_pinv(input)
 | 
						|
    )");
 | 
						|
  } else if (num_args == 2) {
 | 
						|
    m.define(R"(
 | 
						|
      def forward(self, input):
 | 
						|
        return torch.linalg_pinv(input, 1e-5)
 | 
						|
    )");
 | 
						|
  } else if (num_args == 3) {
 | 
						|
    m.define(R"(
 | 
						|
      def forward(self, input):
 | 
						|
        return torch.linalg_pinv(input, 1e-5, True)
 | 
						|
    )");
 | 
						|
  }
 | 
						|
 | 
						|
  std::vector<torch::jit::IValue> inputs;
 | 
						|
  const int N = 28;
 | 
						|
  auto input = torch::range(1, N * N, 1);
 | 
						|
  input[0] = 1; // a more stable matrix
 | 
						|
  input = input.view({N, N});
 | 
						|
  inputs.emplace_back(input);
 | 
						|
  testLiteModuleCompareResultTensors(m, inputs);
 | 
						|
}
 | 
						|
} // namespace
 | 
						|
 | 
						|
#if !defined FB_XPLAT_BUILD
 | 
						|
TEST(LiteInterpreterDirectTest, DefaultArgsPinv) {
 | 
						|
  // Test with different number of specified arguments.
 | 
						|
  // Arguments not specified take default value.
 | 
						|
  for (int num_args = 1; num_args <= 3; ++num_args) {
 | 
						|
    testDefaultArgsPinv2(num_args);
 | 
						|
  }
 | 
						|
 | 
						|
  //  bytecode with one specified argument:
 | 
						|
  //  (6,
 | 
						|
  //      ('__torch__.m.forward',
 | 
						|
  //          (('instructions',
 | 
						|
  //              (('STOREN', 1, 2),
 | 
						|
  //                  ('DROPR', 1, 0),
 | 
						|
  //                  ('MOVE', 2, 0),
 | 
						|
  //                  ('OP', 0, 0),
 | 
						|
  //                  ('RET', 0, 0))),
 | 
						|
  //              ('operators', (('aten::linalg_pinv', '', 1),)),
 | 
						|
  //              ('constants', (False, 1e-15)), # default constants are not
 | 
						|
  //              used
 | 
						|
  //              ('types', ()),
 | 
						|
  //              ('register_size', 2)),
 | 
						|
  //          (('arguments',
 | 
						|
  //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
 | 
						|
  //              None)),
 | 
						|
  //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
 | 
						|
  //                  None)))),
 | 
						|
  //              ('returns',
 | 
						|
  //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
 | 
						|
  //                  None)),)))))
 | 
						|
 | 
						|
  //  bytecode with 2 specified argument:
 | 
						|
  //  (6,
 | 
						|
  //      ('__torch__.m.forward',
 | 
						|
  //          (('instructions',
 | 
						|
  //              (('STOREN', 1, 2),
 | 
						|
  //                  ('DROPR', 1, 0),
 | 
						|
  //                  ('MOVE', 2, 0),
 | 
						|
  //                  ('LOADC', 1, 0), # added LOADC for specified argument
 | 
						|
  //                  ('OP', 0, 0),
 | 
						|
  //                  ('RET', 0, 0))),
 | 
						|
  //              ('operators', (('aten::linalg_pinv', '', 2),)),
 | 
						|
  //              ('constants', (False, 1e-05)), # updated constant table
 | 
						|
  //              ('types', ()),
 | 
						|
  //              ('register_size', 2)),
 | 
						|
  //          (('arguments',
 | 
						|
  //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
 | 
						|
  //              None)),
 | 
						|
  //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
 | 
						|
  //                  None)))),
 | 
						|
  //              ('returns',
 | 
						|
  //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
 | 
						|
  //                  None)),)))))
 | 
						|
 | 
						|
  //  bytecode with 3 specified arguments:
 | 
						|
  //  (6,
 | 
						|
  //      ('__torch__.m.forward',
 | 
						|
  //          (('instructions',
 | 
						|
  //              (('STOREN', 1, 2),
 | 
						|
  //                  ('DROPR', 1, 0),
 | 
						|
  //                  ('MOVE', 2, 0),
 | 
						|
  //                  ('LOADC', 1, 0),
 | 
						|
  //                  ('LOADC', 0, 0),
 | 
						|
  //                  ('OP', 0, 0),
 | 
						|
  //                  ('RET', 0, 0))),
 | 
						|
  //              ('operators', (('aten::linalg_pinv', '', 3),)),
 | 
						|
  //              ('constants', (True, 1e-05)),
 | 
						|
  //              ('types', ()),
 | 
						|
  //              ('register_size', 2)),
 | 
						|
  //          (('arguments',
 | 
						|
  //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
 | 
						|
  //              None)),
 | 
						|
  //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
 | 
						|
  //                  None)))),
 | 
						|
  //              ('returns',
 | 
						|
  //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
 | 
						|
  //                  None)),)))))
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, DefaultArgsTensorinvSpecifyDefault) {
 | 
						|
  // The second argument is specified, but the value is the same as the default
 | 
						|
  // value. It's treated as "not specified" since the value can be fetched from
 | 
						|
  // schema.
 | 
						|
  Module m("m");
 | 
						|
  m.define(R"(
 | 
						|
    def forward(self, input):
 | 
						|
      return torch.linalg_tensorinv(input, 2)
 | 
						|
  )");
 | 
						|
  torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
 | 
						|
  auto arg_nums = code.op_to_num_specified_args();
 | 
						|
  ASSERT_EQ(arg_nums.size(), 1);
 | 
						|
  ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
 | 
						|
  std::vector<torch::jit::IValue> inputs;
 | 
						|
  const int N = 4;
 | 
						|
  auto input = torch::rand({N, N, N, N});
 | 
						|
  inputs.emplace_back(input);
 | 
						|
  testLiteModuleCompareResultTensors(m, inputs);
 | 
						|
}
 | 
						|
 | 
						|
void testDefaultArgsPinvWithOutArg2(int num_args) {
 | 
						|
  Module m("m");
 | 
						|
  if (num_args == 1) {
 | 
						|
    m.define(R"(
 | 
						|
      def forward(self, input):
 | 
						|
        return torch.linalg_pinv(input, out=input)
 | 
						|
    )");
 | 
						|
  } else if (num_args == 2) {
 | 
						|
    m.define(R"(
 | 
						|
      def forward(self, input):
 | 
						|
        return torch.linalg_pinv(input, 1e-5, out=input)
 | 
						|
    )");
 | 
						|
  } else if (num_args == 3) {
 | 
						|
    m.define(R"(
 | 
						|
      def forward(self, input):
 | 
						|
        return torch.linalg_pinv(input, 1e-5, True, out=input)
 | 
						|
    )");
 | 
						|
  }
 | 
						|
 | 
						|
  const int N = 28;
 | 
						|
  auto input = torch::range(1, N * N, 1);
 | 
						|
  input[0] = 10000; // a more stable matrix
 | 
						|
  input = input.view({N, N});
 | 
						|
  auto ref = m.run_method("forward", input);
 | 
						|
  TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
 | 
						|
  TORCH_CHECK(input.equal(ref.toTensor()));
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, DefaultArgsPinvWithOutArg) {
 | 
						|
  // Test with different number of specified arguments + out arg.
 | 
						|
  // Arguments not specified take default value.
 | 
						|
  for (int num_args = 1; num_args <= 3; ++num_args) {
 | 
						|
    testDefaultArgsPinvWithOutArg2(num_args);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, DefaultArgsWithOutArg) {
 | 
						|
  Module m("m");
 | 
						|
  m.define(R"(
 | 
						|
    def forward(self, x, h):
 | 
						|
      torch.add(x, h, out=x)
 | 
						|
  )");
 | 
						|
 | 
						|
  std::vector<IValue> inputs;
 | 
						|
  auto input_x = 2 * torch::ones({});
 | 
						|
  auto input_h = torch::ones({});
 | 
						|
  auto ref = m.run_method("forward", input_x, input_h);
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  mobile::Module bc = jitModuleToMobile(m, options);
 | 
						|
  bc.run_method("forward", input_x, input_h);
 | 
						|
  AT_ASSERT(input_x.equal(4 * torch::ones({})));
 | 
						|
}
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
 | 
						|
  Module a("A");
 | 
						|
  a.define(R"(
 | 
						|
    def bar(self, x, y):
 | 
						|
      return x + y
 | 
						|
  )");
 | 
						|
  Module b("B");
 | 
						|
  b.register_module("A0", a);
 | 
						|
  b.define(R"(
 | 
						|
    def foo(self, x, y):
 | 
						|
      return self.A0.bar(x, y) + 2
 | 
						|
  )");
 | 
						|
  Module c("C");
 | 
						|
  c.register_module("B0", b);
 | 
						|
  c.define(R"(
 | 
						|
    def forward(self, x, y):
 | 
						|
      return self.B0.foo(x, y) + 3
 | 
						|
  )");
 | 
						|
 | 
						|
  std::vector<IValue> inputs;
 | 
						|
  inputs.emplace_back(torch::rand({2, 4}));
 | 
						|
  inputs.emplace_back(torch::rand({13, 9}));
 | 
						|
 | 
						|
  CompilationOptions options;
 | 
						|
  auto lite_m = jitModuleToMobile(c, options);
 | 
						|
  std::string error_pattern = R"(
 | 
						|
  Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add
 | 
						|
Traceback of TorchScript (most recent call last):
 | 
						|
  File "<string>", line 3, in <unknown>
 | 
						|
 | 
						|
    def forward(self, x, y):
 | 
						|
      return self.B0.foo(x, y) + 3
 | 
						|
             ~~~~~~~~~~~ <--- HERE
 | 
						|
 | 
						|
  File "<string>", line 3, in foo
 | 
						|
 | 
						|
    def foo(self, x, y):
 | 
						|
      return self.A0.bar(x, y) + 2
 | 
						|
             ~~~~~~~~~~~ <--- HERE
 | 
						|
 | 
						|
  File "<string>", line 3, in bar
 | 
						|
 | 
						|
    def bar(self, x, y):
 | 
						|
      return x + y
 | 
						|
             ~~~~~ <--- HERE
 | 
						|
  )";
 | 
						|
  ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern);
 | 
						|
}
 | 
						|
#endif // !defined(FB_XPLAT_BUILD)
 | 
						|
 | 
						|
namespace {
 | 
						|
static auto reg =
 | 
						|
    torch::class_<TorchBindLiteInterpreterDirectTestStruct>(
 | 
						|
        "_TorchScriptTesting",
 | 
						|
        "_LiteInterpreterDirectTest")
 | 
						|
        .def(torch::init<>())
 | 
						|
        .def("get", &TorchBindLiteInterpreterDirectTestStruct::get)
 | 
						|
        .def_pickle(
 | 
						|
            // __getattr__
 | 
						|
            [](const c10::intrusive_ptr<
 | 
						|
                TorchBindLiteInterpreterDirectTestStruct>&) -> int64_t {
 | 
						|
              return 0;
 | 
						|
            },
 | 
						|
            // __setattr__
 | 
						|
            [](int64_t) {
 | 
						|
              return c10::make_intrusive<
 | 
						|
                  TorchBindLiteInterpreterDirectTestStruct>();
 | 
						|
            });
 | 
						|
 | 
						|
} // namespace
 | 
						|
 | 
						|
TEST(LiteInterpreterDirectTest, OperatorCacheDifferentiatesDefaultArgs) {
 | 
						|
  // Create 3 methods:
 | 
						|
  //
 | 
						|
  // 1. forward() returns a tensor with dtype=torch.int64 (4)
 | 
						|
  // 2. forward2() returns a tensor with dtype=torch.float32 (6)
 | 
						|
  // 3. forward3() returns a tensor with dtype=torch.float32 but
 | 
						|
  //    the dtype is inferred by the input tensor's dtype
 | 
						|
  //
 | 
						|
  // If caching works correctly, then the result from the full-jit
 | 
						|
  // module and the lite module will be the same. Otherwise, it
 | 
						|
  // will be different if we don't correctly ignore the cache
 | 
						|
  // entry for an operator that has a different number of
 | 
						|
  // arguments.
 | 
						|
  Module m("m");
 | 
						|
  m.define(R"(
 | 
						|
    def forward(self):
 | 
						|
      ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
 | 
						|
      return ret1.fill_(25)
 | 
						|
  )");
 | 
						|
  m.define(R"(
 | 
						|
    def forward2(self):
 | 
						|
      ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
 | 
						|
      return ret1.fill_(32.0)
 | 
						|
  )");
 | 
						|
  m.define(R"(
 | 
						|
    def forward3(self):
 | 
						|
      ret1 = torch.new_empty(torch.zeros(10), [10])
 | 
						|
      return ret1.fill_(12.0)
 | 
						|
  )");
 | 
						|
 | 
						|
  std::vector<torch::jit::IValue> inputs;
 | 
						|
  testLiteModuleCompareResultTensors(m, inputs, "forward");
 | 
						|
  testLiteModuleCompareResultTensors(m, inputs, "forward2");
 | 
						|
  testLiteModuleCompareResultTensors(m, inputs, "forward3");
 | 
						|
}
 | 
						|
 | 
						|
} // namespace jit
 | 
						|
} // namespace torch
 |