mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
It enables some C++ warnings since the code base is fairly clean. Meanwhile, Wextra-semi is disabled on CUDA generated code since there is no way to fix them without the cooperation of CUDA team. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142332 Approved by: https://github.com/albanD, https://github.com/eqy
414 lines
13 KiB
C++
414 lines
13 KiB
C++
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
//
|
|
// This source code is licensed under the BSD-style license found in the
|
|
// LICENSE file in the root directory of this source tree.
|
|
|
|
#include "test_utils.h"
|
|
|
|
#include <ATen/core/ivalue.h>
|
|
#include <gtest/gtest.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
|
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
|
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
#include <torch/csrc/jit/runtime/static/memory_planner.h>
|
|
#include <torch/csrc/jit/runtime/static/passes.h>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#else
|
|
#include <ATen/ops/allclose.h>
|
|
#endif
|
|
|
|
#include <memory>
|
|
#include <unordered_map>
|
|
|
|
using namespace torch::jit;
|
|
using namespace torch;
|
|
using c10::IValue;
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace test {
|
|
|
|
namespace {
|
|
|
|
class GraphExecutorWrapper {
|
|
public:
|
|
GraphExecutorWrapper() = default;
|
|
|
|
explicit GraphExecutorWrapper(const std::shared_ptr<Graph>& graph)
|
|
: graph_exec_(graph, "") {}
|
|
|
|
c10::IValue operator()(const std::vector<c10::IValue>& args) {
|
|
Stack stack(args);
|
|
graph_exec_.run(stack);
|
|
|
|
if (stack.size() == 1) {
|
|
return stack[0];
|
|
}
|
|
return c10::ivalue::Tuple::create(stack);
|
|
}
|
|
|
|
private:
|
|
GraphExecutor graph_exec_;
|
|
};
|
|
|
|
// Test scripts passed to testStaticRuntime can either be IR or JIT.
|
|
// The logic for running the script and producing a corresponding StaticModule
|
|
// is a bit different for each case. This logic is encapsulated within concrete
|
|
// implementations of this class, and testStaticRuntime is only aware of this
|
|
// interface.
|
|
class StaticRuntimeTestContext {
|
|
public:
|
|
virtual ~StaticRuntimeTestContext() = default;
|
|
|
|
virtual IValue getExpected(const std::vector<IValue>& args) = 0;
|
|
virtual StaticModule makeStaticModule(
|
|
const StaticModuleOptions& opt) const = 0;
|
|
};
|
|
|
|
class ModuleStaticRuntimeTestContext : public StaticRuntimeTestContext {
|
|
public:
|
|
explicit ModuleStaticRuntimeTestContext(const std::string& source_jit)
|
|
: module_("module") {
|
|
module_.define(source_jit);
|
|
}
|
|
|
|
IValue getExpected(const std::vector<IValue>& args) override {
|
|
return module_.forward(args);
|
|
}
|
|
|
|
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
|
|
return torch::jit::StaticModule(
|
|
module_, /* is_frozen */ false, opt, /* sample_inputs */ {});
|
|
}
|
|
|
|
private:
|
|
Module module_;
|
|
};
|
|
|
|
class GraphStaticRuntimeContext : public StaticRuntimeTestContext {
|
|
public:
|
|
explicit GraphStaticRuntimeContext(const std::string& source_ir) {
|
|
graph_ = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(source_ir, graph_.get(), vmap);
|
|
|
|
graph_exec_ = GraphExecutorWrapper(graph_);
|
|
}
|
|
|
|
IValue getExpected(const std::vector<IValue>& args) override {
|
|
return graph_exec_(args);
|
|
}
|
|
|
|
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
|
|
return StaticModule(graph_, opt, /* sample_inputs */ {});
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<Graph> graph_;
|
|
GraphExecutorWrapper graph_exec_;
|
|
};
|
|
|
|
std::unique_ptr<StaticRuntimeTestContext> makeTestContext(
|
|
const std::string& source) {
|
|
try {
|
|
return std::make_unique<ModuleStaticRuntimeTestContext>(source);
|
|
// Could not parse as TorchScript, assume it's IR
|
|
} catch (const std::runtime_error&) {
|
|
return std::make_unique<GraphStaticRuntimeContext>(source);
|
|
}
|
|
}
|
|
|
|
void compareTensorLists(
|
|
const std::vector<IValue>& l, /* expects */
|
|
const std::vector<IValue>& r, /* values */
|
|
const bool use_allclose,
|
|
const bool use_equalnan) {
|
|
EXPECT_TRUE(l.size() == r.size());
|
|
for (auto i : c10::irange(l.size())) {
|
|
ASSERT_TRUE(l[i].isTensor());
|
|
ASSERT_TRUE(r[i].isTensor());
|
|
VLOG(2) << "expect " << i << ": \n" << l[i] << std::endl;
|
|
VLOG(2) << "output " << i << ": \n" << r[i] << std::endl;
|
|
if (!l[i].toTensor().defined()) {
|
|
EXPECT_TRUE(!r[i].toTensor().defined());
|
|
} else {
|
|
if (use_allclose) {
|
|
EXPECT_TRUE(at::allclose(
|
|
l[i].toTensor(),
|
|
r[i].toTensor(),
|
|
/*rtol*/ 1e-05,
|
|
/*atol*/ 1e-08,
|
|
use_equalnan));
|
|
} else {
|
|
EXPECT_TRUE(l[i].toTensor().equal(r[i].toTensor()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void compareResults(
|
|
const IValue& expect,
|
|
const IValue& actual,
|
|
const bool use_allclose,
|
|
const bool use_equalnan) {
|
|
if (expect.isTensor()) {
|
|
VLOG(2) << "expect " << expect.toTensor() << std::endl;
|
|
VLOG(2) << "output " << actual.toTensor() << std::endl;
|
|
EXPECT_TRUE(actual.isTensor());
|
|
if (use_allclose) {
|
|
EXPECT_TRUE(at::allclose(
|
|
expect.toTensor(),
|
|
actual.toTensor(),
|
|
/*rtol*/ 1e-05,
|
|
/*atol*/ 1e-08,
|
|
use_equalnan));
|
|
} else {
|
|
EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
|
|
}
|
|
return;
|
|
} else if (expect.isTuple()) {
|
|
EXPECT_TRUE(actual.isTuple());
|
|
auto lhs = expect.toTupleRef().elements();
|
|
auto rhs = actual.toTupleRef().elements();
|
|
ASSERT_TRUE(lhs.size() == rhs.size());
|
|
for (size_t i = 0; i < lhs.size(); i++) {
|
|
compareResults(lhs[i], rhs[i]);
|
|
}
|
|
} else if (expect.isList()) {
|
|
EXPECT_TRUE(actual.isList());
|
|
auto lhs = expect.toList();
|
|
auto rhs = actual.toList();
|
|
ASSERT_TRUE(lhs.size() == rhs.size());
|
|
for (size_t i = 0; i < lhs.size(); i++) {
|
|
compareResults(lhs[i], rhs[i]);
|
|
}
|
|
} else if (expect.isGenericDict()) {
|
|
EXPECT_TRUE(actual.isGenericDict());
|
|
auto lhs = expect.toGenericDict();
|
|
auto rhs = actual.toGenericDict();
|
|
EXPECT_TRUE(lhs.size() == rhs.size());
|
|
for (auto& lh : lhs) {
|
|
auto f = rhs.find(lh.key());
|
|
ASSERT_FALSE(f == rhs.end());
|
|
compareResults(lh.value(), f->value());
|
|
}
|
|
} else {
|
|
// fall back to the default comparison impl in IValue
|
|
EXPECT_TRUE(expect == actual);
|
|
}
|
|
}
|
|
|
|
at::Tensor getTensor(const at::IValue& ival) {
|
|
if (ival.isTensor()) {
|
|
return ival.toTensor();
|
|
} else if (ival.isTensorList()) {
|
|
auto tensor_vec = ival.toTensorVector();
|
|
TORCH_CHECK(tensor_vec.size() == 1);
|
|
return tensor_vec[0];
|
|
} else if (ival.isTuple()) {
|
|
auto tuple = ival.toTuple();
|
|
auto ivalue_vec = tuple->elements();
|
|
TORCH_CHECK(ivalue_vec.size() == 1);
|
|
return ivalue_vec[0].toTensor();
|
|
} else {
|
|
CAFFE_THROW("Unknown input IValue");
|
|
}
|
|
}
|
|
|
|
Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind) {
|
|
return smodule.findNodeWithKindForTesting(kind);
|
|
}
|
|
|
|
Node* getNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind) {
|
|
const auto symbol = c10::Symbol::fromQualString(kind);
|
|
DepthFirstGraphNodeIterator it(graph);
|
|
for (auto* node = it.next(); node != nullptr; node = it.next()) {
|
|
if (node->kind() == symbol) {
|
|
return node;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind) {
|
|
return getNodeWithKind(smodule, kind) != nullptr;
|
|
}
|
|
|
|
bool hasNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind) {
|
|
return getNodeWithKind(graph, kind) != nullptr;
|
|
}
|
|
|
|
std::shared_ptr<Graph> getGraphFromScript(const std::string& jit_script) {
|
|
script::Module module("module");
|
|
module.define(jit_script);
|
|
|
|
Method method = module.get_method("forward");
|
|
return module.get_method("forward").graph();
|
|
}
|
|
|
|
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir) {
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(ir, graph.get(), vmap);
|
|
return graph;
|
|
}
|
|
|
|
void compareResultsWithJIT(
|
|
StaticRuntime& runtime,
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::vector<c10::IValue>& args,
|
|
const bool use_allclose,
|
|
const bool use_equalnan) {
|
|
GraphExecutorWrapper graph_exec(graph);
|
|
auto expected = graph_exec(args);
|
|
auto actual = runtime(args, {});
|
|
runtime.check_for_memory_leak();
|
|
compareResults(expected, actual, use_allclose, use_equalnan);
|
|
}
|
|
|
|
void testStaticRuntime(
|
|
const std::string& source,
|
|
const std::vector<IValue>& args,
|
|
const std::vector<IValue>& args2,
|
|
const bool use_allclose,
|
|
const bool use_equalnan,
|
|
const bool check_resize) {
|
|
auto test_context = makeTestContext(source);
|
|
|
|
std::vector<IValue> args_tensors, args_copy;
|
|
for (const auto& ival : args) {
|
|
if (ival.isTensor()) {
|
|
args_tensors.emplace_back(ival);
|
|
const at::Tensor& t = ival.toTensor();
|
|
args_copy.emplace_back(t.clone());
|
|
}
|
|
}
|
|
|
|
auto expect = test_context->getExpected(args);
|
|
|
|
for (bool enable_out_variant : {true, false}) {
|
|
for (bool manage_output_tensors : {true, false}) {
|
|
for (bool enable_tensorexpr_fusion : {true, false}) {
|
|
if (!enable_out_variant && manage_output_tensors) {
|
|
continue;
|
|
}
|
|
// run static runtime three times
|
|
// 1st run: collect allocation profiles (args)
|
|
// 2nd run: exercise memory planner and resizing with args2
|
|
// 3rd run: run with args again
|
|
StaticModuleOptions opts;
|
|
opts.enable_out_variant = enable_out_variant;
|
|
opts.optimize_memory = enable_out_variant;
|
|
opts.manage_output_tensors = manage_output_tensors;
|
|
opts.enable_tensorexpr_fusion = enable_tensorexpr_fusion;
|
|
|
|
auto smodule = test_context->makeStaticModule(opts);
|
|
StaticRuntime runtime(smodule);
|
|
auto actual = runtime(args, {});
|
|
if (actual.isTensor()) {
|
|
EXPECT_GE(smodule.num_nodes(), 2)
|
|
<< "If we only have one node, the output of the op we are testing is "
|
|
<< "not being managed by the memory planner! A failure here "
|
|
<< "can typically be fixed by clone()ing the output of the test script.";
|
|
}
|
|
runtime.check_for_memory_leak();
|
|
// first run
|
|
VLOG(2) << "enable_out_variant: " << enable_out_variant;
|
|
VLOG(2) << "manage_output_tensors: " << manage_output_tensors;
|
|
VLOG(2) << "enable_tensorexpr_fusion: " << enable_tensorexpr_fusion;
|
|
VLOG(2) << "args: " << args;
|
|
VLOG(2) << "args2: " << args2;
|
|
VLOG(2) << "expect: " << expect;
|
|
VLOG(2) << "actual: " << actual;
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
VLOG(2) << "first run comparison done";
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
|
|
if (!args2.empty()) {
|
|
auto* memory_planner = runtime.get_memory_planner();
|
|
size_t managed_bytes =
|
|
memory_planner ? memory_planner->total_managed() : 0;
|
|
|
|
// Run static runtime again with inputs of a different shape.
|
|
expect = test_context->getExpected(args2);
|
|
actual = runtime(args2, {});
|
|
runtime.check_for_memory_leak();
|
|
VLOG(2) << "comparing with args2";
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
VLOG(2) << "second run comparison done";
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
|
|
size_t new_managed_bytes =
|
|
memory_planner ? memory_planner->total_managed() : 0;
|
|
if (check_resize) {
|
|
EXPECT_GE(new_managed_bytes, managed_bytes);
|
|
}
|
|
|
|
// Run static runtime again with an input of the shape observed during
|
|
// the profile run.
|
|
expect = test_context->getExpected(args);
|
|
actual = runtime(args, {});
|
|
runtime.check_for_memory_leak();
|
|
// third run
|
|
VLOG(2) << "comparing third run";
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
VLOG(2) << "third run comparison done";
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
} else {
|
|
// run static runtime again to exercise the memory planner
|
|
// and allocate managed tensors.
|
|
actual = runtime(args, {});
|
|
runtime.check_for_memory_leak();
|
|
VLOG(2) << "comparing second run with same args";
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
VLOG(2) << "second run comparison done";
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
// third run to use the allocated managed tensors.
|
|
actual = runtime(args, {});
|
|
runtime.check_for_memory_leak();
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// make sure inputs were not modified
|
|
VLOG(2) << "Printing out input tensors";
|
|
compareTensorLists(args_tensors, args_copy, use_allclose, use_equalnan);
|
|
}
|
|
|
|
bool hasProcessedNodeWithName(
|
|
torch::jit::StaticModule& smodule,
|
|
const char* name) {
|
|
return smodule.findNodeWithKindForTesting(name) != nullptr;
|
|
}
|
|
|
|
} // namespace test
|
|
} // namespace jit
|
|
} // namespace torch
|