mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67934 This reduces the memory requirements of ProcessedNode: by allocating outputs sequentially into a shared array and supporting at most 2**16 - 1 values (current models seem to have 10-20x less than that), we only need to store the 2-byte offset into that array and 2-byte number of outputs in ProcessedNode. ghstack-source-id: 143429113 Test Plan: Patched d1jang's diff to measure memory turnover around SR startup. Previous diff, CMF local: ``` I1104 12:19:39.900211 597593 PyTorchStaticRuntimePredictor.cpp:82] memory turnover after creating an instance of StaticRuntime: 427120 ``` This diff, CMF local: ``` I1105 12:17:36.459688 866763 PyTorchStaticRuntimePredictor.cpp:82] memory turnover after creating an instance of StaticRuntime: 354208 72912 bytes (17%) savings ``` Perf looks neutral; see next diff (D32216573) test plan for details. Reviewed By: hlu1 Differential Revision: D32190751 fbshipit-source-id: 30c1e2caa9460f0d83b2d9bb24c68ccfcef757cc
336 lines
10 KiB
C++
336 lines
10 KiB
C++
// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
|
|
|
|
#include "test_utils.h"
|
|
|
|
#include <ATen/core/ivalue.h>
|
|
#include <gtest/gtest.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
|
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
#include <torch/csrc/jit/runtime/static/memory_planner.h>
|
|
#include <torch/csrc/jit/runtime/static/passes.h>
|
|
#include <memory>
|
|
#include <unordered_map>
|
|
|
|
using namespace torch::jit;
|
|
using namespace torch;
|
|
using c10::IValue;
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace test {
|
|
|
|
namespace {
|
|
|
|
// Test scripts passed to testStaticRuntime can either be IR or JIT.
|
|
// The logic for running the script and producing a corresponding StaticModule
|
|
// is a bit different for each case. This logic is encapsulated within concrete
|
|
// implementations of this class, and testStaticRuntime is only aware of this
|
|
// interface.
|
|
class StaticRuntimeTestContext {
|
|
public:
|
|
virtual ~StaticRuntimeTestContext() = default;
|
|
|
|
virtual IValue getExpected(const std::vector<IValue>& args) = 0;
|
|
virtual StaticModule makeStaticModule(
|
|
const StaticModuleOptions& opt) const = 0;
|
|
};
|
|
|
|
class ModuleStaticRuntimeTestContext : public StaticRuntimeTestContext {
|
|
public:
|
|
explicit ModuleStaticRuntimeTestContext(const std::string& source_jit)
|
|
: module_("module") {
|
|
module_.define(source_jit);
|
|
}
|
|
|
|
IValue getExpected(const std::vector<IValue>& args) override {
|
|
return module_.forward(args);
|
|
}
|
|
|
|
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
|
|
return torch::jit::StaticModule(module_, /* is_frozen */ false, opt);
|
|
}
|
|
|
|
private:
|
|
Module module_;
|
|
};
|
|
|
|
class GraphStaticRuntimeContext : public StaticRuntimeTestContext {
|
|
public:
|
|
explicit GraphStaticRuntimeContext(const std::string& source_ir) {
|
|
graph_ = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(source_ir, graph_.get(), vmap);
|
|
|
|
graph_exec_ = GraphExecutor(graph_, "");
|
|
}
|
|
|
|
IValue getExpected(const std::vector<IValue>& args) override {
|
|
Stack stack(args);
|
|
graph_exec_.run(stack);
|
|
|
|
if (stack.size() == 1) {
|
|
return stack[0];
|
|
}
|
|
return c10::ivalue::Tuple::create(stack);
|
|
}
|
|
|
|
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
|
|
return StaticModule(graph_, opt);
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<Graph> graph_;
|
|
GraphExecutor graph_exec_;
|
|
};
|
|
|
|
std::unique_ptr<StaticRuntimeTestContext> makeTestContext(
|
|
const std::string& source) {
|
|
try {
|
|
return std::make_unique<ModuleStaticRuntimeTestContext>(source);
|
|
// Could not parse as TorchScript, assume it's IR
|
|
} catch (const std::runtime_error&) {
|
|
return std::make_unique<GraphStaticRuntimeContext>(source);
|
|
}
|
|
}
|
|
|
|
void compareTensorLists(
|
|
const std::vector<IValue>& l, /* expects */
|
|
const std::vector<IValue>& r, /* values */
|
|
const bool use_allclose,
|
|
const bool use_equalnan) {
|
|
EXPECT_TRUE(l.size() == r.size());
|
|
for (int i = 0; i < l.size(); ++i) {
|
|
ASSERT_TRUE(l[i].isTensor());
|
|
ASSERT_TRUE(r[i].isTensor());
|
|
VLOG(2) << "expect " << i << ": \n" << l[i] << std::endl;
|
|
VLOG(2) << "output " << i << ": \n" << r[i] << std::endl;
|
|
if (!l[i].toTensor().defined()) {
|
|
EXPECT_TRUE(!r[i].toTensor().defined());
|
|
} else {
|
|
if (use_allclose) {
|
|
EXPECT_TRUE(at::allclose(
|
|
l[i].toTensor(),
|
|
r[i].toTensor(),
|
|
/*rtol*/ 1e-05,
|
|
/*atol*/ 1e-08,
|
|
use_equalnan));
|
|
} else {
|
|
EXPECT_TRUE(l[i].toTensor().equal(r[i].toTensor()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void compareResults(
|
|
const IValue& expect,
|
|
const IValue& actual,
|
|
const bool use_allclose = false,
|
|
const bool use_equalnan = false) {
|
|
if (expect.isTensor()) {
|
|
VLOG(2) << "expect " << expect.toTensor() << std::endl;
|
|
VLOG(2) << "output " << actual.toTensor() << std::endl;
|
|
EXPECT_TRUE(actual.isTensor());
|
|
if (use_allclose) {
|
|
EXPECT_TRUE(at::allclose(
|
|
expect.toTensor(),
|
|
actual.toTensor(),
|
|
/*rtol*/ 1e-05,
|
|
/*atol*/ 1e-08,
|
|
use_equalnan));
|
|
} else {
|
|
EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
|
|
}
|
|
return;
|
|
} else if (expect.isTuple()) {
|
|
EXPECT_TRUE(actual.isTuple());
|
|
auto lhs = expect.toTupleRef().elements();
|
|
auto rhs = actual.toTupleRef().elements();
|
|
EXPECT_TRUE(lhs.size() == rhs.size());
|
|
for (size_t i = 0; i < lhs.size(); i++) {
|
|
compareResults(lhs[i], rhs[i]);
|
|
}
|
|
} else if (expect.isList()) {
|
|
EXPECT_TRUE(actual.isList());
|
|
auto lhs = expect.toList();
|
|
auto rhs = actual.toList();
|
|
EXPECT_TRUE(lhs.size() == rhs.size());
|
|
for (size_t i = 0; i < lhs.size(); i++) {
|
|
compareResults(lhs[i], rhs[i]);
|
|
}
|
|
} else if (expect.isGenericDict()) {
|
|
EXPECT_TRUE(actual.isGenericDict());
|
|
auto lhs = expect.toGenericDict();
|
|
auto rhs = actual.toGenericDict();
|
|
EXPECT_TRUE(lhs.size() == rhs.size());
|
|
for (auto& lh : lhs) {
|
|
auto f = rhs.find(lh.key());
|
|
EXPECT_FALSE(f == rhs.end());
|
|
compareResults(lh.value(), f->value());
|
|
}
|
|
} else {
|
|
// fall back to the default comparison impl in IValue
|
|
EXPECT_TRUE(expect == actual);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
at::Tensor getTensor(const at::IValue& ival) {
|
|
if (ival.isTensor()) {
|
|
return ival.toTensor();
|
|
} else if (ival.isTensorList()) {
|
|
auto tensor_vec = ival.toTensorVector();
|
|
TORCH_CHECK(tensor_vec.size() == 1);
|
|
return tensor_vec[0];
|
|
} else if (ival.isTuple()) {
|
|
auto tuple = ival.toTuple();
|
|
auto ivalue_vec = tuple->elements();
|
|
TORCH_CHECK(ivalue_vec.size() == 1);
|
|
return ivalue_vec[0].toTensor();
|
|
} else {
|
|
CAFFE_THROW("Unknown input IValue");
|
|
}
|
|
}
|
|
|
|
Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind) {
|
|
return smodule.findNodeWithKindForTesting(kind);
|
|
}
|
|
|
|
bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind) {
|
|
return getNodeWithKind(smodule, kind) != nullptr;
|
|
}
|
|
|
|
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir) {
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(ir, graph.get(), vmap);
|
|
return graph;
|
|
}
|
|
|
|
void testStaticRuntime(
|
|
const std::string& source,
|
|
const std::vector<IValue>& args,
|
|
const std::vector<IValue>& args2,
|
|
const bool use_allclose,
|
|
const bool use_equalnan,
|
|
const bool check_resize) {
|
|
auto test_context = makeTestContext(source);
|
|
|
|
std::vector<IValue> args_tensors, args_copy;
|
|
for (const auto& ival : args) {
|
|
if (ival.isTensor()) {
|
|
args_tensors.emplace_back(ival);
|
|
const at::Tensor& t = ival.toTensor();
|
|
args_copy.emplace_back(t.clone());
|
|
}
|
|
}
|
|
|
|
auto expect = test_context->getExpected(args);
|
|
|
|
for (bool enable_out_variant : {true, false}) {
|
|
for (bool manage_output_tensors : {true, false}) {
|
|
if (!enable_out_variant && manage_output_tensors) {
|
|
continue;
|
|
}
|
|
// run static runtime three times
|
|
// 1st run: collect allocation profiles (args)
|
|
// 2nd run: exercise memory planner and resizing with args2
|
|
// 3rd run: run with args again
|
|
StaticModuleOptions opts{
|
|
.cleanup_activations = true,
|
|
.enable_out_variant = enable_out_variant,
|
|
.optimize_memory = enable_out_variant,
|
|
.manage_output_tensors = manage_output_tensors};
|
|
auto smodule = test_context->makeStaticModule(opts);
|
|
StaticRuntime runtime(smodule);
|
|
auto actual = runtime(args, {});
|
|
if (actual.isTensor()) {
|
|
EXPECT_GE(smodule.num_nodes(), 2)
|
|
<< "If we only have one node, the output of the op we are testing is "
|
|
<< "not being managed by the memory planner! A failure here "
|
|
<< "can typically be fixed by clone()ing the output of the test script.";
|
|
}
|
|
runtime.check_for_memory_leak();
|
|
// first run
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
|
|
if (!args2.empty()) {
|
|
auto* memory_planner = runtime.get_memory_planner();
|
|
size_t managed_bytes =
|
|
memory_planner ? memory_planner->total_managed() : 0;
|
|
|
|
// Run static runtime again with inputs of a different shape.
|
|
expect = test_context->getExpected(args2);
|
|
actual = runtime(args2, {});
|
|
runtime.check_for_memory_leak();
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
|
|
size_t new_managed_bytes =
|
|
memory_planner ? memory_planner->total_managed() : 0;
|
|
if (check_resize && new_managed_bytes > 0) {
|
|
VLOG(1) << "managed_bytes: " << managed_bytes
|
|
<< ", new_managed_bytes: " << new_managed_bytes;
|
|
EXPECT_TRUE(new_managed_bytes > managed_bytes);
|
|
}
|
|
|
|
// Run static runtime again with an input of the shape observed during
|
|
// the profile run.
|
|
expect = test_context->getExpected(args);
|
|
actual = runtime(args, {});
|
|
runtime.check_for_memory_leak();
|
|
// third run
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
} else {
|
|
// run static runtime again to exercise the memory planner
|
|
// and allocate managed tensors.
|
|
actual = runtime(args, {});
|
|
runtime.check_for_memory_leak();
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
// third run to use the allocated managed tensors.
|
|
actual = runtime(args, {});
|
|
runtime.check_for_memory_leak();
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// make sure inputs were not modified
|
|
VLOG(2) << "Printing out input tensors";
|
|
compareTensorLists(args_tensors, args_copy, use_allclose, use_equalnan);
|
|
}
|
|
|
|
bool hasProcessedNodeWithName(
|
|
torch::jit::StaticModule& smodule,
|
|
const char* name) {
|
|
return smodule.findNodeWithKindForTesting(name) != nullptr;
|
|
}
|
|
|
|
} // namespace test
|
|
} // namespace jit
|
|
} // namespace torch
|