mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
ProfilingGraphExecutor works like this: 1. do some unrelated JIT optimizations 2. Add profiling nodes to collect JIT information like tensor dtypes and shapes 3. Do some more unrelated JIT optimizations 4. Remove the profiling nodes and extract the tensor info, and then use the JIT tensor info to do optimizations. This PR is intended to fix a bug in Step 4, where the profiling nodes were removed. It was previously assumed that all the things that were profiled were either Tensors or Optional[Tensor]s - otherwise, step 2 would not have introduced a profiling node. However, we saw a case where step 3 would remove replace Optional[Tensor] inputs with `None` inputs (e.g. if a conditional that returned a Tensor or a None could be statically known to only follow the `None` branch). To fix this, we essentially just modify the RemoveProfileNodesAndSpecializeTypes assert so that it accepts Tensors, Optional[Tensor]s, or None (the new part). Note that this issue is probably somewhat uncommon (maybe why we didn't see it for the first 4 years that this code existed). I expect that, typically, any time that step 3 would convert `Optional[Tensor] -> None`, step 1 would have already done that. So it's difficult to reproduce in an end-to-end TorchScript workload. Differential Revision: [D81068172](https://our.internmc.facebook.com/intern/diff/D81068172) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161538 Approved by: https://github.com/nmacchioni
42 lines
1006 B
C++
42 lines
1006 B
C++
#include <gtest/gtest.h>
|
|
|
|
#include <test/cpp/jit/test_utils.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
|
|
|
#include <iostream>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
TEST(TETest, RemoveProfiling) {
|
|
auto g = std::make_shared<Graph>();
|
|
const auto graph_string = R"IR(
|
|
graph(%a : Tensor,
|
|
%b : bool):
|
|
%1 : None = prim::Constant()
|
|
%2 : Tensor? = prim::If(%b)
|
|
block0():
|
|
%3 : Tensor? = prim::profile[profiled_type=Tensor, seen_none=0](%1)
|
|
-> (%3)
|
|
block1():
|
|
%4 : Tensor = prim::profile[profiled_type=Tensor, seen_none=0](%a)
|
|
-> (%4)
|
|
return (%2))IR";
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
RemoveProfileNodesAndSpecializeTypes(g);
|
|
g->lint();
|
|
|
|
testing::FileCheck()
|
|
.check("prim::Constant")
|
|
->check("prim::If")
|
|
->check("block")
|
|
->check("block")
|
|
->check("return")
|
|
->run(*g);
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|