mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Test is flaky and sometimes hangs in CI Here's an example of the failure: https://github.com/pytorch/pytorch/actions/runs/16946153494/job/48027937663 ``` 2025-08-13T20:54:00.1223688Z ==================================== RERUNS ==================================== 2025-08-13T20:54:00.1224156Z ___________________________ RecordDebugHandles.Basic ___________________________ 2025-08-13T20:54:00.1224682Z [gw2] linux -- Python 3.13.5 /opt/conda/envs/py_3.13/bin/python3.13 2025-08-13T20:54:00.1225568Z Internal Error: calling /opt/conda/envs/py_3.13/lib/python3.13/site-packages/torch/bin/test_jit for test RecordDebugHandles.Basic failed (returncode=-6): 2025-08-13T20:54:00.1226430Z CUDA not available. Disabling CUDA and MultiCUDA tests 2025-08-13T20:54:00.1226988Z Note: Google Test filter = RecordDebugHandles.Basic-*_CUDA:*_MultiCUDA 2025-08-13T20:54:00.1227450Z [==========] Running 1 test from 1 test suite. 2025-08-13T20:54:00.1227792Z [----------] Global test environment set-up. 2025-08-13T20:54:00.1228145Z [----------] 1 test from RecordDebugHandles 2025-08-13T20:54:00.1228492Z [ RUN ] RecordDebugHandles.Basic 2025-08-13T20:54:00.1228822Z [ OK ] RecordDebugHandles.Basic (1 ms) 2025-08-13T20:54:00.1229204Z [----------] 1 test from RecordDebugHandles (1 ms total) 2025-08-13T20:54:00.1229501Z 2025-08-13T20:54:00.1229666Z [----------] Global test environment tear-down 2025-08-13T20:54:00.1230033Z [==========] 1 test from 1 test suite ran. (1 ms total) 2025-08-13T20:54:00.1230355Z [ PASSED ] 1 test. 2025-08-13T20:54:00.1230727Z terminate called after throwing an instance of 'std::system_error' 2025-08-13T20:54:00.1231154Z what(): Invalid argument 2025-08-13T20:54:00.1231416Z unknown file:0: C++ failure 2025-08-13T20:54:00.1231788Z ------------------------------ Captured c++ call ------------------------------- 2025-08-13T20:54:00.1232262Z CUDA not available. Disabling CUDA and MultiCUDA tests 2025-08-13T20:54:00.1232745Z Note: Google Test filter = RecordDebugHandles.Basic-*_CUDA:*_MultiCUDA 2025-08-13T20:54:00.1233199Z [==========] Running 1 test from 1 test suite. 2025-08-13T20:54:00.1233557Z [----------] Global test environment set-up. 2025-08-13T20:54:00.1233915Z [----------] 1 test from RecordDebugHandles 2025-08-13T20:54:00.1234247Z [ RUN ] RecordDebugHandles.Basic 2025-08-13T20:54:00.1234590Z [ OK ] RecordDebugHandles.Basic (1 ms) 2025-08-13T20:54:00.1235020Z [----------] 1 test from RecordDebugHandles (1 ms total) 2025-08-13T20:54:00.1235304Z 2025-08-13T20:54:00.1235431Z [----------] Global test environment tear-down 2025-08-13T20:54:00.1235793Z [==========] 1 test from 1 test suite ran. (1 ms total) 2025-08-13T20:54:00.1236126Z [ PASSED ] 1 test. 2025-08-13T20:54:00.1236481Z terminate called after throwing an instance of 'std::system_error' 2025-08-13T20:54:00.1236906Z what(): Invalid argument 2025-08-13T20:54:00.1237287Z ___________________________ RecordDebugHandles.Basic ___________________________ 2025-08-13T20:54:00.1237800Z [gw2] linux -- Python 3.13.5 /opt/conda/envs/py_3.13/bin/python3.13 2025-08-13T20:54:00.1238686Z Internal Error: calling /opt/conda/envs/py_3.13/lib/python3.13/site-packages/torch/bin/test_jit for test RecordDebugHandles.Basic failed (returncode=-6): 2025-08-13T20:54:00.1239551Z CUDA not available. Disabling CUDA and MultiCUDA tests 2025-08-13T20:54:00.1240048Z Note: Google Test filter = RecordDebugHandles.Basic-*_CUDA:*_MultiCUDA 2025-08-13T20:54:00.1240495Z [==========] Running 1 test from 1 test suite. 2025-08-13T20:54:00.1240848Z [----------] Global test environment set-up. 2025-08-13T20:54:00.1241199Z [----------] 1 test from RecordDebugHandles 2025-08-13T20:54:00.1241542Z [ RUN ] RecordDebugHandles.Basic 2025-08-13T20:54:00.1241871Z [ OK ] RecordDebugHandles.Basic (1 ms) 2025-08-13T20:54:00.1242249Z [----------] 1 test from RecordDebugHandles (1 ms total) 2025-08-13T20:54:00.1242503Z 2025-08-13T20:54:00.1242641Z [----------] Global test environment tear-down 2025-08-13T20:54:00.1242993Z [==========] 1 test from 1 test suite ran. (19 ms total) 2025-08-13T20:54:00.1243329Z [ PASSED ] 1 test. 2025-08-13T20:54:00.1243697Z terminate called after throwing an instance of 'std::system_error' 2025-08-13T20:54:00.1244113Z what(): Invalid argument 2025-08-13T20:54:00.1244392Z unknown file:0: C++ failure 2025-08-13T20:54:00.1244759Z ------------------------------ Captured c++ call ------------------------------- 2025-08-13T20:54:00.1245235Z CUDA not available. Disabling CUDA and MultiCUDA tests 2025-08-13T20:54:00.1283768Z ============== 1 failed, 568 passed, 2 rerun in 115.57s (0:01:55) ============== ``` Here's an example of the hang: https://github.com/pytorch/pytorch/actions/runs/16942186826/job/48015238944 Logs aren't super helpful other than stating that it took a long time. Usually this file takes <2min to run ``` 2025-08-13T18:43:24.6586481Z [gw0] [ 97%] PASSED [1.4119s] ../../../../../opt/conda/envs/py_3.13/lib/python3.13/site-packages/torch/bin/test_jit::PyTorch/LiteInterpreterDynamicTypeTestFixture::Conformance/8 2025-08-13T18:43:24.6587278Z [gw1] [ 97%] PASSED [1.4866s] ../../../../../opt/conda/envs/py_3.13/lib/python3.13/site-packages/torch/bin/test_jit::PyTorch/LiteInterpreterDynamicTypeTestFixture::Conformance/9 Command took >30min, returning 124 2025-08-13T18:43:24.6587288Z 2025-08-13T18:43:24.6587632Z FINISHED PRINTING LOG FILE of cpp/test_jit 1/1 (test/test-reports/cpp.test_jit_1.1_c259e5a152845991_.log) 2025-08-13T18:43:24.6587639Z ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160577 Approved by: https://github.com/huydhn
3186 lines
99 KiB
C++
3186 lines
99 KiB
C++
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/Parallel.h>
|
|
#include <ATen/core/interned_strings.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <ATen/core/jit_type_base.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
#include <torch/csrc/jit/passes/remove_mutation.h>
|
|
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
|
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
|
|
|
#include <torch/csrc/autograd/engine.h>
|
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
|
#include <torch/csrc/autograd/profiler.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/jit/api/function_impl.h>
|
|
#include <torch/csrc/jit/api/module.h>
|
|
#include <torch/csrc/jit/codegen/fuser/interface.h>
|
|
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
|
#include <torch/csrc/jit/frontend/tracer.h>
|
|
#include <torch/csrc/jit/ir/alias_analysis.h>
|
|
#include <torch/csrc/jit/ir/attributes.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/ir/scope.h>
|
|
#include <torch/csrc/jit/ir/type_hashing.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/bailout_graph.h>
|
|
#include <torch/csrc/jit/passes/canonicalize.h>
|
|
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
|
|
#include <torch/csrc/jit/passes/constant_propagation.h>
|
|
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/graph_fuser.h>
|
|
#include <torch/csrc/jit/passes/guard_elimination.h>
|
|
#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
|
|
#include <torch/csrc/jit/passes/insert_guards.h>
|
|
#include <torch/csrc/jit/passes/liveness.h>
|
|
#include <torch/csrc/jit/passes/loop_unrolling.h>
|
|
#include <torch/csrc/jit/passes/lower_grad_of.h>
|
|
#include <torch/csrc/jit/passes/lower_tuples.h>
|
|
#include <torch/csrc/jit/passes/pass_manager.h>
|
|
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
|
|
#include <torch/csrc/jit/passes/restore_mutation.h>
|
|
#include <torch/csrc/jit/passes/shape_analysis.h>
|
|
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
|
|
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
|
#include <torch/csrc/jit/runtime/argument_spec.h>
|
|
#include <torch/csrc/jit/runtime/autodiff.h>
|
|
#include <torch/csrc/jit/runtime/custom_operator.h>
|
|
#include <torch/csrc/jit/runtime/decomposition_registry.h>
|
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
|
#include <torch/csrc/jit/runtime/interpreter.h>
|
|
#include <torch/csrc/jit/runtime/jit_trace.h>
|
|
#include <torch/csrc/jit/runtime/profiling_record.h>
|
|
#include <torch/csrc/jit/runtime/symbolic_script.h>
|
|
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
|
|
#include <torch/csrc/jit/serialization/import.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
#include <torch/jit.h>
|
|
#include <torch/script.h>
|
|
|
|
#include <onnx/onnx_pb.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/ThreadLocalDebugInfo.h>
|
|
|
|
#include <torch/csrc/jit/passes/freeze_module.h>
|
|
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
|
#include <algorithm>
|
|
#include <cstddef>
|
|
#include <functional>
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <set>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <tuple>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
|
|
return c10::AliasAnalysisKind::FROM_SCHEMA;
|
|
}
|
|
|
|
template <typename T>
|
|
std::ostream& operator<<(std::ostream& out, const std::vector<T>& list) {
|
|
size_t i = 0;
|
|
out << "{";
|
|
for (auto&& e : list) {
|
|
if (i++ > 0)
|
|
out << ", ";
|
|
out << e;
|
|
}
|
|
out << "}";
|
|
return out;
|
|
}
|
|
|
|
TEST(InternedStringsTest, Basic) {
|
|
ASSERT_EQ(prim::Param, Symbol::prim("Param"));
|
|
ASSERT_EQ(prim::Return, Symbol::prim("Return"));
|
|
ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return"));
|
|
ASSERT_EQ(prim::Return.toQualString(), std::string("prim::Return"));
|
|
Symbol newsym = Symbol::aten("__NEW_SYMBOL");
|
|
size_t symstart = newsym;
|
|
ASSERT_EQ(newsym.toQualString(), std::string("aten::__NEW_SYMBOL"));
|
|
// TODO: This test is a bit too close to the implementation details.
|
|
ASSERT_EQ(Symbol::aten("What"), symstart + 1);
|
|
ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
|
|
ASSERT_EQ(Symbol::aten("What"), symstart + 1);
|
|
ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
|
|
ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2"));
|
|
}
|
|
|
|
TEST(FromQualStringTest, Basic) {
|
|
ASSERT_EQ(Symbol::fromQualString("prim::Param"), Symbol::prim("Param"));
|
|
ASSERT_EQ(Symbol::fromQualString("aten::mm"), Symbol::aten("mm"));
|
|
ASSERT_EQ(Symbol::fromQualString("onnx::LSTM"), Symbol::onnx("LSTM"));
|
|
ASSERT_EQ(Symbol::fromQualString("attr::value"), Symbol::attr("value"));
|
|
ASSERT_EQ(Symbol::fromQualString("scope::"), Symbol::scope(""));
|
|
ASSERT_EQ(Symbol::fromQualString("::").toUnqualString(), std::string(""));
|
|
ASSERT_EQ(
|
|
Symbol::fromQualString("::").ns().toQualString(),
|
|
std::string("namespaces::"));
|
|
ASSERT_EQ(
|
|
Symbol::fromQualString("new_ns::param").toUnqualString(),
|
|
std::string("param"));
|
|
ASSERT_EQ(
|
|
Symbol::fromQualString("new_ns::param").ns().toUnqualString(),
|
|
std::string("new_ns"));
|
|
ASSERT_EQ(
|
|
Symbol::fromQualString("new_ns::param").ns(),
|
|
Symbol::fromQualString("namespaces::new_ns"));
|
|
|
|
auto bad_inputs = {"scope", ":", ""};
|
|
for (auto input : bad_inputs) {
|
|
try {
|
|
Symbol::fromQualString(input);
|
|
ASSERT_TRUE(0);
|
|
} catch (const std::exception& c) {
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(THNNConvTest, Basic) {
|
|
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
|
|
std::vector<int64_t> kernel_size = {3, 5};
|
|
std::vector<int64_t> stride = {1, 2};
|
|
std::vector<int64_t> padding = {2, 1};
|
|
constexpr int out_channels = 5;
|
|
|
|
// make inputs
|
|
at::Tensor input = torch::randn(input_size);
|
|
at::Tensor weight = torch::randn(
|
|
{out_channels, input_size[1], kernel_size[0], kernel_size[1]});
|
|
at::Tensor bias = torch::randn({out_channels});
|
|
|
|
// run forward eagerly
|
|
at::Tensor output = at::_slow_conv2d_forward(
|
|
input, weight, kernel_size, bias, stride, padding);
|
|
|
|
// make grad_outputs
|
|
at::Tensor grad_output =
|
|
torch::randn_like(output, at::MemoryFormat::Preserve);
|
|
|
|
// run backward eagerly
|
|
auto [grad_input, grad_weight, grad_bias] = at::_slow_conv2d_backward(
|
|
grad_output,
|
|
input,
|
|
weight,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
{true, true, true});
|
|
|
|
// make JIT graph
|
|
auto graph = std::make_shared<Graph>();
|
|
auto ksz_val = graph->insertConstant(kernel_size);
|
|
auto kst_val = graph->insertConstant(stride);
|
|
auto pad_val = graph->insertConstant(padding);
|
|
|
|
auto inputg = graph->addInput("self");
|
|
auto weightg = graph->addInput("weight");
|
|
auto biasg = graph->addInput("bias");
|
|
|
|
Value* conv = graph->insert(
|
|
aten::_slow_conv2d_forward,
|
|
{inputg, weightg, ksz_val, biasg, kst_val, pad_val});
|
|
auto outputs = conv->node()->outputs();
|
|
for (auto output : outputs) {
|
|
graph->registerOutput(output);
|
|
}
|
|
LowerAllTuples(graph);
|
|
graph->lint();
|
|
|
|
// differentiate JIT graph
|
|
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
|
|
ConstantPropagation(graph);
|
|
auto grad_spec = differentiate(graph);
|
|
LowerGradOf(*grad_spec.df);
|
|
|
|
// prepare JIT inputs / gradients
|
|
tensor_list tensors_in;
|
|
tensors_in.push_back(input);
|
|
tensors_in.push_back(weight);
|
|
tensors_in.push_back(bias);
|
|
|
|
tensor_list tensor_grads_in;
|
|
tensor_grads_in.push_back(grad_output);
|
|
|
|
// Get outputs from the interpreter
|
|
auto [tensors_out, tensor_grads_out] =
|
|
runGradient(grad_spec, tensors_in, tensor_grads_in);
|
|
|
|
// prepare expected structs
|
|
tensor_list expected_tensors_out, expected_tensor_grads_out;
|
|
expected_tensors_out.push_back(output);
|
|
expected_tensor_grads_out.push_back(grad_input);
|
|
expected_tensor_grads_out.push_back(grad_weight);
|
|
expected_tensor_grads_out.push_back(grad_bias);
|
|
|
|
// Compare results
|
|
assertAllClose(tensors_out, expected_tensors_out);
|
|
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
|
|
}
|
|
|
|
TEST(ATenNativeBatchNormTest, Basic) {
|
|
// aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor
|
|
// running_mean, Tensor running_var, bool training, float momentum, float eps)
|
|
// -> (Tensor, Tensor, Tensor)
|
|
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
|
|
bool training = true;
|
|
float momentum = 0.9;
|
|
float eps = 1e-5;
|
|
|
|
// make inputs
|
|
at::Tensor input = torch::randn(input_size);
|
|
at::Tensor weight = torch::randn({input_size[1]});
|
|
at::Tensor bias = torch::randn({input_size[1]});
|
|
at::Tensor running_mean = torch::randn({input_size[1]});
|
|
at::Tensor running_var = torch::randn({input_size[1]});
|
|
|
|
// running_mean and running_var are changed in-place, so clone and send them
|
|
at::Tensor running_mean_eager = running_mean.clone();
|
|
at::Tensor running_var_eager = running_var.clone();
|
|
at::Tensor running_mean_jit = running_mean.clone();
|
|
at::Tensor running_var_jit = running_var.clone();
|
|
|
|
// run forward eagerly
|
|
auto [output, savemean, saveinvstd] = at::native_batch_norm(
|
|
input,
|
|
weight,
|
|
bias,
|
|
running_mean_eager,
|
|
running_var_eager,
|
|
training,
|
|
momentum,
|
|
eps);
|
|
|
|
// make grad_outputs
|
|
at::Tensor grad_output =
|
|
torch::randn_like(output, at::MemoryFormat::Preserve);
|
|
at::Tensor grad_savemean =
|
|
torch::zeros_like(savemean, at::MemoryFormat::Preserve);
|
|
at::Tensor grad_saveinvstd =
|
|
torch::zeros_like(saveinvstd, at::MemoryFormat::Preserve);
|
|
|
|
// run backward eagerly
|
|
// aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor
|
|
// weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor
|
|
// save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor,
|
|
// Tensor, Tensor)
|
|
auto [grad_input, grad_weight, grad_bias] = at::native_batch_norm_backward(
|
|
grad_output,
|
|
input,
|
|
weight,
|
|
running_mean_eager,
|
|
running_var_eager,
|
|
savemean,
|
|
saveinvstd,
|
|
training,
|
|
eps,
|
|
{true, true, true});
|
|
|
|
// make JIT graph
|
|
auto graph = std::make_shared<Graph>();
|
|
auto training_val = graph->insertConstant(IValue(training));
|
|
auto momentum_val = graph->insertConstant(IValue(momentum));
|
|
auto eps_val = graph->insertConstant(IValue(eps));
|
|
|
|
auto inputg = graph->addInput("self");
|
|
auto weightg = graph->addInput("weight");
|
|
auto biasg = graph->addInput("bias");
|
|
auto running_meang = graph->addInput("running_mean");
|
|
auto running_varg = graph->addInput("running_var");
|
|
|
|
Value* bn = graph->insert(
|
|
aten::native_batch_norm,
|
|
{inputg,
|
|
weightg,
|
|
biasg,
|
|
running_meang,
|
|
running_varg,
|
|
training_val,
|
|
momentum_val,
|
|
eps_val});
|
|
auto outputs = bn->node()->outputs();
|
|
for (auto output : outputs) {
|
|
graph->registerOutput(output);
|
|
}
|
|
LowerAllTuples(graph);
|
|
graph->lint();
|
|
|
|
// differentiate JIT graph
|
|
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
|
|
ConstantPropagation(graph);
|
|
auto grad_spec = differentiate(graph);
|
|
LowerGradOf(*grad_spec.df);
|
|
|
|
// prepare JIT inputs / gradients
|
|
tensor_list tensors_in;
|
|
tensors_in.push_back(input);
|
|
tensors_in.push_back(weight);
|
|
tensors_in.push_back(bias);
|
|
tensors_in.push_back(running_mean_jit);
|
|
tensors_in.push_back(running_var_jit);
|
|
|
|
tensor_list tensor_grads_in;
|
|
tensor_grads_in.push_back(grad_output);
|
|
tensor_grads_in.push_back(grad_savemean);
|
|
tensor_grads_in.push_back(grad_saveinvstd);
|
|
|
|
// Get outputs from the interpreter
|
|
auto [tensors_out, tensor_grads_out] =
|
|
runGradient(grad_spec, tensors_in, tensor_grads_in);
|
|
|
|
// prepare expected structs
|
|
tensor_list expected_tensors_out, expected_tensor_grads_out;
|
|
expected_tensors_out.push_back(output);
|
|
expected_tensors_out.push_back(savemean);
|
|
expected_tensors_out.push_back(saveinvstd);
|
|
expected_tensors_out.push_back(running_mean_eager);
|
|
expected_tensors_out.push_back(running_var_eager);
|
|
expected_tensor_grads_out.push_back(grad_input);
|
|
expected_tensor_grads_out.push_back(grad_weight);
|
|
expected_tensor_grads_out.push_back(grad_bias);
|
|
|
|
tensors_out.push_back(running_mean_jit);
|
|
tensors_out.push_back(running_var_jit);
|
|
|
|
// Compare results
|
|
assertAllClose(tensors_out, expected_tensors_out);
|
|
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
|
|
}
|
|
|
|
TEST(CustomFusionTest, Basic) {
|
|
#if defined(FBCODE_CAFFE2)
|
|
return;
|
|
#endif
|
|
|
|
auto graph_string = R"IR(
|
|
graph(%0 : Float(2, 3, 4),
|
|
%1 : Float(2, 3, 4)):
|
|
%2 : Tensor = aten::mul(%0, %1)
|
|
%3 : Tensor = aten::mul(%2, %0)
|
|
return (%3))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
torch::jit::overrideCanFuseOnCPU(true);
|
|
CustomFuseGraph(
|
|
g,
|
|
[](Node* n) { return n->kind() != prim::Param; },
|
|
Symbol::fromQualString("prim::FusionGroup"));
|
|
torch::jit::overrideCanFuseOnCPU(false);
|
|
|
|
const auto& nodes = g->nodes();
|
|
auto fusion_group =
|
|
std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
|
|
return node->kind() == Symbol::fromQualString("prim::FusionGroup");
|
|
});
|
|
AT_ASSERT(fusion_group != nodes.end());
|
|
|
|
auto subgraph = fusion_group->g(attr::Subgraph);
|
|
auto hits = 0;
|
|
// two multiplications
|
|
for (const auto& n : subgraph->nodes()) {
|
|
(void)n;
|
|
hits++;
|
|
}
|
|
AT_ASSERT(hits == 2);
|
|
}
|
|
|
|
TEST(CustomFusionTest, NestedBlocks) {
|
|
#if defined(FBCODE_CAFFE2)
|
|
return;
|
|
#endif
|
|
|
|
auto graph_string = R"IR(
|
|
graph(%0 : Float(2, 3, 4),
|
|
%1 : Float(2, 3, 4),
|
|
%2 : Float(2, 3, 4)):
|
|
%3 : int = prim::Constant[value=1]()
|
|
%4 : Tensor = prim::If(%2)
|
|
block0():
|
|
%5 : Tensor = aten::mul(%0, %2)
|
|
%6 : Tensor = aten::mul(%5, %1)
|
|
-> (%6)
|
|
block1():
|
|
%7 : Tensor = aten::add(%0, %2, %3)
|
|
%8 : Tensor = aten::add(%7, %1, %3)
|
|
-> (%8)
|
|
%9 : Tensor = aten::add(%4, %2, %3)
|
|
return (%4))IR";
|
|
auto g = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, g.get());
|
|
|
|
CustomFuseGraph(
|
|
g,
|
|
[](Node* n) { return n->kind() == aten::mul; },
|
|
Symbol::fromQualString("prim::FusionGroup"));
|
|
|
|
// Could be done in more efficient ways, but this is only a test.
|
|
std::function<bool(const Block*, Symbol)> dfs = [&](const Block* b,
|
|
Symbol s) {
|
|
for (auto node : b->nodes()) {
|
|
if (node->kind() == s)
|
|
return true;
|
|
for (auto nested_b : node->blocks())
|
|
if (dfs(nested_b, s))
|
|
return true;
|
|
}
|
|
return false;
|
|
};
|
|
|
|
AT_ASSERT(dfs(g->block(), Symbol::fromQualString("prim::FusionGroup")));
|
|
}
|
|
|
|
static const auto cf_examples = R"JIT(
|
|
def if_test(a, b):
|
|
# FIXME: use 0 instead of a.
|
|
# c = 0
|
|
c = a
|
|
if bool(a < b):
|
|
c = b
|
|
else:
|
|
c = a
|
|
return c
|
|
def if_one(a, b):
|
|
c = b
|
|
if bool(a < b):
|
|
c = a
|
|
return c
|
|
def while_test(a, i):
|
|
while bool(i < 3):
|
|
a *= a
|
|
i += 1
|
|
return a
|
|
)JIT";
|
|
|
|
TEST(ControlFlowTest, Basic) {
|
|
auto cu = compile(cf_examples);
|
|
|
|
auto run = [&](const std::string& name, std::vector<IValue> stack) {
|
|
auto graph = toGraphFunction(cu->get_function(name)).graph();
|
|
Code code(graph, "");
|
|
InterpreterState interp(code);
|
|
interp.run(stack);
|
|
return stack;
|
|
};
|
|
|
|
auto L = [](int64_t l) { return IValue(scalar_to_tensor(at::Scalar(l))); };
|
|
auto V = [](IValue t) { return std::move(t).toTensor().item<int64_t>(); };
|
|
auto run_binary = [&](const std::string& name, int64_t a, int64_t b) {
|
|
return V(run(name, {L(a), L(b)})[0]);
|
|
};
|
|
ASSERT_EQ(2, run_binary("if_test", 1, 2));
|
|
ASSERT_EQ(3, run_binary("if_test", 3, 2));
|
|
ASSERT_EQ(2, run_binary("if_one", 2, 3));
|
|
ASSERT_EQ(2, run_binary("if_one", 3, 2));
|
|
ASSERT_EQ(256, run_binary("while_test", 2, 0));
|
|
}
|
|
|
|
#if !(C10_ASAN_ENABLED || C10_UBSAN_ENABLED)
|
|
// This test fails vptr UBSAN checks
|
|
|
|
TEST(ProtoTest, Basic) {
|
|
::ONNX_NAMESPACE::ModelProto proto;
|
|
proto.set_producer_name("foo");
|
|
}
|
|
#endif
|
|
|
|
// test a few features that are not directly used in schemas yet
|
|
TEST(SchemaParserTest, NestedArrays) {
|
|
// nested arrays
|
|
auto s = parseSchema("at::what(int[][4] foo) -> ()");
|
|
ASSERT_TRUE(s.arguments().at(0).N() == 4);
|
|
ASSERT_TRUE(IntType::get()->isSubtypeOf(*s.arguments()
|
|
.at(0)
|
|
.type()
|
|
->expectRef<ListType>()
|
|
.getElementType()
|
|
->expectRef<ListType>()
|
|
.getElementType()));
|
|
auto s2 = parseSchema("at::what(int[][] foo) -> ()");
|
|
ASSERT_TRUE(IntType::get()->isSubtypeOf(*s2.arguments()
|
|
.at(0)
|
|
.type()
|
|
->expectRef<ListType>()
|
|
.getElementType()
|
|
->expectRef<ListType>()
|
|
.getElementType()));
|
|
}
|
|
|
|
TEST(SchemaParserTest, OutVariant) {
|
|
auto schema_with_out = parseSchema(
|
|
"at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)");
|
|
ASSERT_TRUE(schema_with_out.arguments().at(1).is_out());
|
|
ASSERT_TRUE(schema_with_out.arguments().at(2).is_out());
|
|
|
|
auto schema_without_out =
|
|
parseSchema("at::foo(Tensor self, *, int scalar) -> (int)");
|
|
|
|
for (const auto& arg : schema_without_out.arguments()) {
|
|
ASSERT_TRUE(!arg.is_out());
|
|
}
|
|
|
|
auto schema_with_is_write = parseSchema(
|
|
"aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))");
|
|
|
|
for (const auto& arg : schema_with_is_write.arguments()) {
|
|
ASSERT_TRUE(!arg.is_out());
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(SchemaParserTest, NamedReturns) {
|
|
// named returns
|
|
parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
|
|
auto s3 =
|
|
parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
|
|
ASSERT_TRUE(s3.returns().at(0).name() == "the_return");
|
|
ASSERT_TRUE(s3.returns().at(1).name() == "the_return2");
|
|
}
|
|
|
|
TEST(SchemaParserTest, Futures) {
|
|
// futures
|
|
auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
|
|
ASSERT_TRUE(IntType::get()->isSubtypeOf(
|
|
*s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
|
|
}
|
|
|
|
TEST(SchemaParserTest, AnnotatedAliasSets) {
|
|
// test tensor with annotated alias sets
|
|
parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))");
|
|
}
|
|
|
|
TEST(SchemaParserTest, TensorListAnnotatedAliasSets) {
|
|
const auto s = parseSchema(
|
|
"at::foo(Tensor(a!) self, Tensor(b!)[] out)"
|
|
" -> ()");
|
|
const AliasInfo* selfAliasInfo = s.arguments().at(0).alias_info();
|
|
const AliasInfo* outAliasInfo = s.arguments().at(1).alias_info();
|
|
ASSERT_TRUE(
|
|
selfAliasInfo->beforeSets() ==
|
|
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
|
|
ASSERT_TRUE(selfAliasInfo->isWrite());
|
|
|
|
ASSERT_TRUE(outAliasInfo->isWrite());
|
|
ASSERT_TRUE(outAliasInfo->beforeSets().empty());
|
|
ASSERT_EQ(outAliasInfo->containedTypes().size(), 1);
|
|
|
|
auto containedType = outAliasInfo->containedTypes()[0];
|
|
|
|
ASSERT_TRUE(containedType.isWrite());
|
|
ASSERT_TRUE(
|
|
containedType.beforeSets() ==
|
|
std::unordered_set<Symbol>{Symbol::fromQualString("alias::b")});
|
|
}
|
|
|
|
TEST(SchemaParserTest, AnnotatedAliasWithoutBeforeSet) {
|
|
EXPECT_THAT(
|
|
[]() { parseSchema("at::foo(Tensor(!) self) -> Tensor"); },
|
|
::testing::Throws<std::runtime_error>(::testing::Property(
|
|
&std::runtime_error::what,
|
|
::testing::HasSubstr("expected ident but found '!' here"))));
|
|
}
|
|
|
|
TEST(SchemaParserTest, BeforeAfterSets) {
|
|
const auto s = parseSchema(
|
|
"at::what(Tensor(b|c)[](a!) list, Tensor(c) element)"
|
|
" -> (Tensor(b|c)[](a!))");
|
|
|
|
// The list itself is annotated with `a`
|
|
const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
|
|
ASSERT_NE(aliasInfo, nullptr);
|
|
ASSERT_TRUE(
|
|
aliasInfo->beforeSets() ==
|
|
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
|
|
ASSERT_TRUE(aliasInfo->isWrite());
|
|
|
|
// Check the contained types
|
|
ASSERT_TRUE(!aliasInfo->containedTypes().empty());
|
|
const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
|
|
const auto expected = std::unordered_set<Symbol>{
|
|
Symbol::fromQualString("alias::b"),
|
|
Symbol::fromQualString("alias::c"),
|
|
};
|
|
ASSERT_TRUE(containedAliasInfo.beforeSets() == expected);
|
|
ASSERT_TRUE(containedAliasInfo.afterSets() == expected);
|
|
ASSERT_FALSE(containedAliasInfo.isWrite());
|
|
}
|
|
|
|
TEST(SchemaParserTest, BeforeAfterSets2) {
|
|
const auto s = parseSchema(
|
|
"at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)"
|
|
" -> (Tensor(b|c)[](a!))");
|
|
|
|
// The list itself is annotated with `a`
|
|
const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
|
|
ASSERT_NE(aliasInfo, nullptr);
|
|
ASSERT_EQ(
|
|
aliasInfo->beforeSets(),
|
|
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
|
|
ASSERT_EQ(
|
|
aliasInfo->afterSets(),
|
|
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
|
|
ASSERT_TRUE(aliasInfo->isWrite());
|
|
ASSERT_EQ(aliasInfo->containedTypes().size(), 1);
|
|
|
|
// Check the contained types
|
|
ASSERT_TRUE(!aliasInfo->containedTypes().empty());
|
|
const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
|
|
const auto expectedBefore = std::unordered_set<Symbol>{
|
|
Symbol::fromQualString("alias::b"),
|
|
};
|
|
const auto expectedAfter = std::unordered_set<Symbol>{
|
|
Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")};
|
|
ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
|
|
ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
|
|
ASSERT_FALSE(containedAliasInfo.isWrite());
|
|
}
|
|
|
|
TEST(TopologicalIndexTest, Basic) {
|
|
Graph graph;
|
|
auto node1 = graph.create(prim::AutogradZero);
|
|
auto node2 = graph.create(prim::AutogradZero);
|
|
auto node3 = graph.create(prim::AutogradZero);
|
|
auto node4 = graph.create(prim::AutogradZero);
|
|
|
|
graph.appendNode(node4);
|
|
graph.prependNode(node1);
|
|
node2->insertAfter(node1);
|
|
node3->insertBefore(node4);
|
|
|
|
// nodes should be in numerical order
|
|
ASSERT_TRUE(node1->isBefore(node2));
|
|
ASSERT_TRUE(node1->isBefore(node3));
|
|
ASSERT_TRUE(node1->isBefore(node4));
|
|
ASSERT_TRUE(node2->isAfter(node1));
|
|
ASSERT_TRUE(node2->isBefore(node3));
|
|
ASSERT_TRUE(node2->isBefore(node4));
|
|
ASSERT_FALSE(node3->isBefore(node1));
|
|
ASSERT_FALSE(node3->isBefore(node2));
|
|
ASSERT_FALSE(node3->isAfter(node4));
|
|
|
|
// Built up a block structure
|
|
// node3
|
|
// /\ ...
|
|
// A B block1
|
|
// \ ...
|
|
// C block2
|
|
auto block1 = node3->addBlock();
|
|
auto A = graph.create(prim::AutogradZero);
|
|
block1->appendNode(A);
|
|
auto B = graph.create(prim::AutogradZero);
|
|
block1->appendNode(B);
|
|
auto block2 = B->addBlock();
|
|
auto C = graph.create(prim::AutogradZero);
|
|
block2->appendNode(C);
|
|
|
|
// Check isAfter on different block levels
|
|
ASSERT_TRUE(node1->isBefore(A));
|
|
ASSERT_TRUE(A->isBefore(B));
|
|
ASSERT_TRUE(A->isBefore(C));
|
|
|
|
// make sure things don't blow up on deletions
|
|
node2->destroy();
|
|
auto node2p = graph.create(prim::AutogradZero);
|
|
node2p->insertAfter(node1);
|
|
ASSERT_TRUE(node1->isBefore(node2p));
|
|
ASSERT_TRUE(node2p->isBefore(node3));
|
|
}
|
|
|
|
TEST(TopologicalIndexTest, Reindex) {
|
|
// Induce reindexing to test that path
|
|
Graph graph;
|
|
std::map<size_t, Node*> nodes;
|
|
|
|
auto anchor = graph.create(prim::AutogradZero);
|
|
graph.appendNode(anchor);
|
|
// Inserting to the same place a lot will trigger reindexing
|
|
for (auto i = 0; i < 100; ++i) {
|
|
auto n = graph.create(prim::AutogradZero);
|
|
n->insertAfter(anchor);
|
|
nodes[i] = n;
|
|
}
|
|
|
|
// Nodes should be in reverse order
|
|
for (auto i = 0; i < 100; ++i) {
|
|
for (auto j = i + 1; j < 100; ++j) {
|
|
ASSERT_TRUE(nodes[i]->isAfter(nodes[j]));
|
|
}
|
|
}
|
|
}
|
|
|
|
at::Tensor invokeTestRecordFunction(at::Tensor& t) {
|
|
RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
|
|
|
|
auto t2 = t.pow(2);
|
|
return t2;
|
|
}
|
|
|
|
static const auto invokeTestRecordFunction_JIT = R"JIT(
|
|
def foo(self, t):
|
|
t2 = t.pow(2)
|
|
return t2
|
|
|
|
def forward(self, t):
|
|
return self.foo(t)
|
|
)JIT";
|
|
|
|
at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) {
|
|
RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
|
|
|
|
auto module = std::make_shared<script::Module>(
|
|
"RecordFunctionTestModule", std::make_shared<script::CompilationUnit>());
|
|
module->define(invokeTestRecordFunction_JIT);
|
|
return module->forward({t}).toTensor();
|
|
}
|
|
|
|
using TracedTestValues =
|
|
std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>;
|
|
|
|
void checkTracedInputs(const TracedTestValues& inputs) {
|
|
bool found_test = false;
|
|
bool found_pow = false;
|
|
bool found_mul = false;
|
|
for (const auto& input : inputs) {
|
|
const auto& fn = std::get<0>(input);
|
|
const auto& sizes = std::get<1>(input);
|
|
|
|
if (fn == "test") {
|
|
found_test = true;
|
|
TORCH_CHECK(sizes.size() == 1);
|
|
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
|
} else if (fn == "aten::pow") {
|
|
found_pow = true;
|
|
TORCH_CHECK(sizes.size() == 2);
|
|
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
|
TORCH_CHECK(sizes[1].empty());
|
|
} else if (fn == "aten::mul") {
|
|
found_mul = true;
|
|
TORCH_CHECK(sizes.size() > 1);
|
|
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
|
}
|
|
}
|
|
TORCH_CHECK(found_test);
|
|
TORCH_CHECK(found_pow);
|
|
TORCH_CHECK(found_mul);
|
|
}
|
|
|
|
void checkTracedOutputs(const TracedTestValues& outputs) {
|
|
bool found_test = false;
|
|
bool found_pow = false;
|
|
bool found_mul = false;
|
|
for (const auto& output : outputs) {
|
|
const auto& fn = std::get<0>(output);
|
|
const auto& sizes = std::get<1>(output);
|
|
|
|
if (fn == "test") {
|
|
found_test = true;
|
|
TORCH_CHECK(sizes.empty());
|
|
} else if (fn == "aten::pow") {
|
|
found_pow = true;
|
|
TORCH_CHECK(sizes.size() == 1);
|
|
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
|
} else if (fn == "aten::mul") {
|
|
found_mul = true;
|
|
TORCH_CHECK(sizes.size() == 1);
|
|
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
|
}
|
|
}
|
|
TORCH_CHECK(found_test);
|
|
TORCH_CHECK(found_pow);
|
|
TORCH_CHECK(found_mul);
|
|
}
|
|
|
|
static bool bad_scope = false;
|
|
template <RecordScope scope, size_t* cnt>
|
|
std::unique_ptr<at::ObserverContext> checkScopeCallback(
|
|
const at::RecordFunction& fn) {
|
|
if (fn.scope() == scope) {
|
|
++(*cnt);
|
|
} else {
|
|
bad_scope = true;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
template <RecordScope scope, size_t* cnt>
|
|
void pushScopedCallback() {
|
|
at::addGlobalCallback(
|
|
at::RecordFunctionCallback(checkScopeCallback<scope, cnt>)
|
|
.scopes({scope}));
|
|
}
|
|
|
|
// These cannot be function-local because that would prohibit them
|
|
// from being used as template arguments prior to C++17.
|
|
static size_t fun_cnt;
|
|
static size_t ts_fun_cnt;
|
|
static size_t user_scope_cnt;
|
|
|
|
void checkScopeCallbacks() {
|
|
static bool found_function_scope;
|
|
static bool found_method_scope;
|
|
static bool found_user_scope;
|
|
found_function_scope = false;
|
|
found_method_scope = false;
|
|
found_user_scope = false;
|
|
at::addGlobalCallback(at::RecordFunctionCallback(
|
|
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
|
if (fn.scope() == at::RecordScope::FUNCTION &&
|
|
std::string(fn.name()) == "test_function") {
|
|
found_function_scope = true;
|
|
}
|
|
if (fn.scope() == at::RecordScope::TORCHSCRIPT_FUNCTION &&
|
|
std::string(fn.name()) == "test_method") {
|
|
found_method_scope = true;
|
|
}
|
|
if (fn.scope() == at::RecordScope::USER_SCOPE &&
|
|
std::string(fn.name()) == "test_user_scope") {
|
|
found_user_scope = true;
|
|
}
|
|
return nullptr;
|
|
}));
|
|
|
|
bad_scope = false;
|
|
fun_cnt = 0;
|
|
pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>();
|
|
ts_fun_cnt = 0;
|
|
pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>();
|
|
user_scope_cnt = 0;
|
|
pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>();
|
|
|
|
TORCH_CHECK(at::hasCallbacks());
|
|
|
|
{
|
|
RECORD_TORCHSCRIPT_FUNCTION("test_method", {});
|
|
{
|
|
RECORD_FUNCTION("test_function", {});
|
|
}
|
|
{
|
|
RECORD_USER_SCOPE("test_user_scope");
|
|
}
|
|
}
|
|
|
|
TORCH_CHECK(!bad_scope);
|
|
TORCH_CHECK(fun_cnt == 1);
|
|
TORCH_CHECK(ts_fun_cnt == 1);
|
|
TORCH_CHECK(user_scope_cnt == 1);
|
|
|
|
TORCH_CHECK(found_function_scope);
|
|
TORCH_CHECK(found_method_scope);
|
|
TORCH_CHECK(found_user_scope);
|
|
}
|
|
|
|
static TracedTestValues traced_inputs;
|
|
static TracedTestValues traced_outputs;
|
|
static std::unordered_set<std::string> ts_input_names;
|
|
static std::unordered_set<std::string> ts_output_names;
|
|
|
|
std::unique_ptr<at::ObserverContext> tracedInputsCallback(
|
|
const RecordFunction& fn) {
|
|
if (fn.scope() == RecordScope::FUNCTION) {
|
|
auto inputs = fn.inputs();
|
|
std::vector<std::vector<int64_t>> sizes;
|
|
for (const auto& input : inputs) {
|
|
if (input.isTensor()) {
|
|
sizes.push_back(input.toTensor().sizes().vec());
|
|
} else if (input.isScalar()) {
|
|
// NOLINTNEXTLINE(modernize-use-emplace)
|
|
sizes.push_back(std::vector<int64_t>());
|
|
}
|
|
}
|
|
traced_inputs.push_back(std::make_tuple(fn.name(), sizes));
|
|
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
|
|
ts_input_names.insert(fn.name());
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
void tracedOutputsCallback(const RecordFunction& fn, ObserverContext* ctx_ptr) {
|
|
if (fn.scope() == RecordScope::FUNCTION) {
|
|
auto outputs = fn.outputs();
|
|
std::vector<std::vector<int64_t>> sizes;
|
|
for (const auto& output : outputs) {
|
|
if (output.isTensor()) {
|
|
sizes.push_back(output.toTensor().sizes().vec());
|
|
} else if (output.isScalar()) {
|
|
sizes.emplace_back();
|
|
}
|
|
}
|
|
traced_outputs.push_back(std::make_tuple(fn.name(), sizes));
|
|
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
|
|
ts_output_names.insert(fn.name());
|
|
}
|
|
}
|
|
|
|
TEST(RecordFunctionTest, TracedTestInputsOutputs) {
|
|
// disabling the inlining of method calls
|
|
GraphOptimizerEnabledGuard opt_guard(false);
|
|
|
|
// [(fn, [[sizes], [sizes], ...]), ...]
|
|
addGlobalCallback(
|
|
RecordFunctionCallback(tracedInputsCallback, tracedOutputsCallback)
|
|
.needsInputs(true)
|
|
.needsOutputs(true));
|
|
|
|
TracedTestValues eager_inputs, eager_outputs, jit_inputs, jit_outputs;
|
|
{
|
|
auto t = torch::randn({1, 2, 3}, at::kCPU);
|
|
t.set_requires_grad(true);
|
|
auto t2 = invokeTestRecordFunction(t);
|
|
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
|
|
eager_inputs = traced_inputs;
|
|
eager_outputs = traced_outputs;
|
|
traced_inputs.clear();
|
|
traced_outputs.clear();
|
|
|
|
TORCH_CHECK(ts_input_names.empty());
|
|
TORCH_CHECK(ts_output_names.empty());
|
|
|
|
t = torch::randn({1, 2, 3}, at::kCPU);
|
|
t.set_requires_grad(true);
|
|
t2 = invokeTestRecordFunctionJIT(t);
|
|
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
|
|
jit_inputs = traced_inputs;
|
|
jit_outputs = traced_outputs;
|
|
traced_inputs.clear();
|
|
traced_outputs.clear();
|
|
}
|
|
|
|
TORCH_CHECK(ts_input_names.find("forward") != ts_input_names.end());
|
|
TORCH_CHECK(ts_input_names.find("foo") != ts_input_names.end());
|
|
TORCH_CHECK(ts_output_names.find("forward") != ts_output_names.end());
|
|
TORCH_CHECK(ts_output_names.find("foo") != ts_output_names.end());
|
|
|
|
checkTracedInputs(eager_inputs);
|
|
checkTracedOutputs(eager_outputs);
|
|
checkTracedInputs(jit_inputs);
|
|
checkTracedOutputs(jit_outputs);
|
|
at::clearCallbacks();
|
|
}
|
|
|
|
static int sampled_cb_ctr = 0;
|
|
std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) {
|
|
if (std::string(fn.name()) == "test") {
|
|
++sampled_cb_ctr;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
static int non_sampled_cb_ctr = 0;
|
|
std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) {
|
|
if (std::string(fn.name()) == "test") {
|
|
++non_sampled_cb_ctr;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
TEST(RecordFunctionTest, SampledCallbacks) {
|
|
// disabling the inlining of method calls
|
|
GraphOptimizerEnabledGuard opt_guard(false);
|
|
|
|
// test sampled callbacks
|
|
sampled_cb_ctr = 0;
|
|
auto setup_sampled_callback = [](double sampling_prob) {
|
|
return addGlobalCallback(
|
|
RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob));
|
|
};
|
|
|
|
addGlobalCallback(RecordFunctionCallback(nonSampledCallback));
|
|
|
|
auto handle = setup_sampled_callback(0.5);
|
|
|
|
auto run_test_function = []() {
|
|
auto t = torch::randn({1, 2, 3}, at::kCPU);
|
|
for (auto k = 0; k < 1000; k++) {
|
|
invokeTestRecordFunction(t);
|
|
}
|
|
};
|
|
|
|
run_test_function();
|
|
TORCH_CHECK(non_sampled_cb_ctr == 1000);
|
|
TORCH_CHECK(sampled_cb_ctr > 0 && sampled_cb_ctr < 1000);
|
|
|
|
sampled_cb_ctr = 0;
|
|
removeCallback(handle);
|
|
handle = setup_sampled_callback(0.0);
|
|
run_test_function();
|
|
|
|
TORCH_CHECK(non_sampled_cb_ctr == 2000);
|
|
TORCH_CHECK(sampled_cb_ctr == 0);
|
|
|
|
sampled_cb_ctr = 0;
|
|
removeCallback(handle);
|
|
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
|
handle = setup_sampled_callback(1.0);
|
|
run_test_function();
|
|
|
|
TORCH_CHECK(non_sampled_cb_ctr == 3000);
|
|
TORCH_CHECK(sampled_cb_ctr == 1000);
|
|
clearCallbacks();
|
|
|
|
// test the scope of the callbacks
|
|
checkScopeCallbacks();
|
|
clearCallbacks();
|
|
}
|
|
|
|
TEST(RecordFunctionTest, RecordFunctionGuard) {
|
|
// disabling the inlining of method calls
|
|
GraphOptimizerEnabledGuard opt_guard(false);
|
|
|
|
static std::vector<std::string> fn_names;
|
|
static std::mutex guard_mtx;
|
|
|
|
// check record function guard
|
|
addGlobalCallback(RecordFunctionCallback(
|
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
|
std::lock_guard<std::mutex> lock(guard_mtx);
|
|
// NOLINTNEXTLINE(modernize-use-emplace)
|
|
fn_names.push_back(fn.name());
|
|
return nullptr;
|
|
}));
|
|
{
|
|
RecordFunctionGuard g1(false);
|
|
{
|
|
RECORD_USER_SCOPE("A");
|
|
{
|
|
RecordFunctionGuard g2(true);
|
|
RECORD_USER_SCOPE("B");
|
|
{
|
|
DisableRecordFunctionGuard g3;
|
|
RECORD_USER_SCOPE("C");
|
|
}
|
|
}
|
|
{
|
|
RECORD_USER_SCOPE("D");
|
|
}
|
|
}
|
|
}
|
|
TORCH_CHECK(fn_names.size() == 1);
|
|
TORCH_CHECK(fn_names[0] == "B");
|
|
clearCallbacks();
|
|
}
|
|
|
|
static std::vector<size_t> ids;
|
|
|
|
template <size_t id>
|
|
auto add_remove_test_add_cb() {
|
|
return addGlobalCallback(RecordFunctionCallback(
|
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
|
ids.push_back(id);
|
|
return nullptr;
|
|
}));
|
|
}
|
|
|
|
TEST(RecordFunctionTest, Callbacks) {
|
|
// disabling the inlining of method calls
|
|
GraphOptimizerEnabledGuard opt_guard(false);
|
|
|
|
auto h1 = add_remove_test_add_cb<1>();
|
|
add_remove_test_add_cb<2>();
|
|
auto h3 = add_remove_test_add_cb<3>();
|
|
|
|
{
|
|
RECORD_USER_SCOPE("test");
|
|
}
|
|
|
|
TORCH_CHECK(ids.size() == 3);
|
|
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
|
|
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
|
|
TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
|
|
|
|
ids.clear();
|
|
removeCallback(h1);
|
|
|
|
{
|
|
RECORD_USER_SCOPE("test");
|
|
}
|
|
|
|
TORCH_CHECK(ids.size() == 2);
|
|
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
|
|
TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
|
|
|
|
ids.clear();
|
|
removeCallback(h3);
|
|
|
|
{
|
|
RECORD_USER_SCOPE("test");
|
|
}
|
|
|
|
TORCH_CHECK(ids.size() == 1);
|
|
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
|
|
|
|
clearCallbacks();
|
|
|
|
// thread local / global callbacks
|
|
|
|
ids.clear();
|
|
add_remove_test_add_cb<1>();
|
|
|
|
{
|
|
RECORD_USER_SCOPE("test");
|
|
}
|
|
|
|
TORCH_CHECK(ids.size() == 1);
|
|
TORCH_CHECK(ids[0] == 1);
|
|
ids.clear();
|
|
|
|
auto th = std::thread([]() {
|
|
addThreadLocalCallback(RecordFunctionCallback(
|
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
|
ids.push_back(2);
|
|
return nullptr;
|
|
}));
|
|
|
|
{
|
|
RECORD_USER_SCOPE("test_thread");
|
|
}
|
|
});
|
|
th.join();
|
|
TORCH_CHECK(ids.size() == 2);
|
|
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
|
|
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
|
|
ids.clear();
|
|
|
|
{
|
|
RECORD_USER_SCOPE("test");
|
|
}
|
|
|
|
TORCH_CHECK(ids.size() == 1);
|
|
TORCH_CHECK(ids[0] == 1);
|
|
ids.clear();
|
|
|
|
clearCallbacks();
|
|
|
|
// START: thread local / global context check callbacks
|
|
struct TestContext : public ObserverContext {
|
|
int a{0};
|
|
std::string b;
|
|
};
|
|
ids.clear();
|
|
{ // START: global test
|
|
addGlobalCallback(RecordFunctionCallback(
|
|
[](const RecordFunction&
|
|
/* unused */) -> std::unique_ptr<at::ObserverContext> {
|
|
auto ctx = std::make_unique<TestContext>();
|
|
ctx->a = 123;
|
|
ctx->b = "test_str";
|
|
ids.push_back(1);
|
|
return ctx;
|
|
},
|
|
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
|
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
|
|
TORCH_CHECK(ctx != nullptr);
|
|
TORCH_CHECK(ctx->a == 123);
|
|
TORCH_CHECK(ctx->b == "test_str");
|
|
}));
|
|
|
|
{
|
|
RECORD_USER_SCOPE("test");
|
|
}
|
|
|
|
TORCH_CHECK(ids.size() == 1);
|
|
TORCH_CHECK(ids[0] == 1);
|
|
ids.clear();
|
|
} // END: global test
|
|
{ // START: thread local test
|
|
auto ctx_th = std::thread([]() {
|
|
const std::string test_str = "test thread str";
|
|
addThreadLocalCallback(RecordFunctionCallback(
|
|
[](const RecordFunction&
|
|
/* unused */) -> std::unique_ptr<at::ObserverContext> {
|
|
auto ctx = std::make_unique<TestContext>();
|
|
ctx->a = 234;
|
|
ctx->b = "test_thread_str";
|
|
ids.push_back(2);
|
|
return ctx;
|
|
},
|
|
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
|
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
|
|
TORCH_CHECK(ctx_ptr != nullptr);
|
|
TORCH_CHECK(ctx->a == 234);
|
|
TORCH_CHECK(ctx->b == "test_thread_str");
|
|
}));
|
|
|
|
// Will call both global and thread local callbacks.
|
|
{
|
|
RECORD_USER_SCOPE("test_thread");
|
|
}
|
|
});
|
|
ctx_th.join();
|
|
TORCH_CHECK(ids.size() == 2);
|
|
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
|
|
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
|
|
ids.clear();
|
|
} // END: thread local test
|
|
|
|
clearCallbacks();
|
|
}
|
|
|
|
TEST(RecordFunctionTest, ShouldRun) {
|
|
// disabling the inlining of method calls
|
|
GraphOptimizerEnabledGuard opt_guard(false);
|
|
|
|
static bool ran = false;
|
|
auto handle = addGlobalCallback(RecordFunctionCallback(
|
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
|
ran = true;
|
|
return nullptr;
|
|
}));
|
|
|
|
{
|
|
RECORD_USER_SCOPE("test");
|
|
}
|
|
|
|
EXPECT_TRUE(ran) << "first run didn't happen";
|
|
ran = false;
|
|
|
|
disableCallback(handle);
|
|
|
|
{
|
|
RECORD_USER_SCOPE("test");
|
|
}
|
|
|
|
EXPECT_FALSE(ran) << "second run happened but shouldn't have";
|
|
ran = false;
|
|
|
|
reenableCallback(handle);
|
|
|
|
{
|
|
RECORD_USER_SCOPE("test");
|
|
}
|
|
|
|
EXPECT_TRUE(ran) << "run after re-enable didn't happen";
|
|
ran = false;
|
|
|
|
clearCallbacks();
|
|
}
|
|
|
|
TEST(RecordFunctionTest, Basic) {
|
|
// disabling the inlining of method calls
|
|
GraphOptimizerEnabledGuard opt_guard(false);
|
|
|
|
static std::string recorded_op;
|
|
static bool has_ids = false;
|
|
|
|
// test propagation of TLS callbacks
|
|
std::thread t([]() {
|
|
RecordFunctionGuard enable_rec_fn;
|
|
auto handle = addThreadLocalCallback(RecordFunctionCallback(
|
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
|
recorded_op = fn.name();
|
|
return nullptr;
|
|
}));
|
|
ThreadLocalState state;
|
|
std::thread t_child([state]() {
|
|
ThreadLocalStateGuard g_tls(state);
|
|
RECORD_USER_SCOPE("test_in_thread");
|
|
});
|
|
t_child.join();
|
|
EXPECT_EQ(recorded_op, "test_in_thread");
|
|
removeCallback(handle);
|
|
});
|
|
t.join();
|
|
clearCallbacks();
|
|
|
|
// test set ids
|
|
addGlobalCallback(
|
|
RecordFunctionCallback(
|
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
|
has_ids = fn.handle() > 0;
|
|
return nullptr;
|
|
})
|
|
.needsIds(true));
|
|
{
|
|
RECORD_USER_SCOPE("test");
|
|
}
|
|
TORCH_CHECK(has_ids);
|
|
clearCallbacks();
|
|
has_ids = false;
|
|
addGlobalCallback(RecordFunctionCallback(
|
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
|
has_ids = fn.handle() > 0;
|
|
return nullptr;
|
|
}));
|
|
{
|
|
RECORD_USER_SCOPE("test");
|
|
}
|
|
TORCH_CHECK(!has_ids);
|
|
clearCallbacks();
|
|
}
|
|
|
|
TEST(RecordFunctionTest, OperatorNameOverload) {
|
|
static std::set<std::string> operator_names;
|
|
at::addGlobalCallback(at::RecordFunctionCallback(
|
|
[](const at::RecordFunction& fn)
|
|
-> std::unique_ptr<at::ObserverContext> {
|
|
std::optional<c10::OperatorName> op_name =
|
|
fn.operator_name();
|
|
if (op_name.has_value()) {
|
|
operator_names.insert(c10::toString(*op_name));
|
|
} else {
|
|
operator_names.insert("No Operator Name");
|
|
}
|
|
return nullptr;
|
|
})
|
|
.scopes({at::RecordScope::FUNCTION}));
|
|
auto t = torch::randn({1, 2, 3}, at::kCPU);
|
|
t.set_requires_grad(false);
|
|
auto t2 = t.pow(2);
|
|
|
|
at::clearCallbacks();
|
|
EXPECT_TRUE(operator_names.count("No Operator Name") == 0)
|
|
<< "Expected that all traced operators had an associated OperatorName object";
|
|
EXPECT_TRUE(operator_names.count("aten::randn") == 1)
|
|
<< "Expected aten::randn to have been called and recorded, but it was not";
|
|
EXPECT_TRUE(operator_names.count("aten::pow.Tensor_Scalar") == 1)
|
|
<< "Expected aten::pow.Tensor_Scalar to have been called and recorded, but it was not";
|
|
}
|
|
|
|
class TestThreadLocalDebugInfo : public c10::DebugInfoBase {
|
|
public:
|
|
int getModelId() const {
|
|
return model_id_;
|
|
}
|
|
|
|
void setModelId(int model_id) {
|
|
model_id_ = model_id;
|
|
}
|
|
|
|
// NOLINTNEXTLINE(modernize-use-equals-default)
|
|
virtual ~TestThreadLocalDebugInfo() override {}
|
|
|
|
private:
|
|
int model_id_ = 0;
|
|
};
|
|
|
|
void checkDebugInfo(c10::DebugInfoKind kind, int model_id) {
|
|
auto* debug_info = c10::ThreadLocalDebugInfo::get(kind);
|
|
TORCH_CHECK(debug_info != nullptr);
|
|
auto* test_debug_info = dynamic_cast<TestThreadLocalDebugInfo*>(debug_info);
|
|
TORCH_CHECK(test_debug_info != nullptr);
|
|
TORCH_CHECK(test_debug_info->getModelId() == model_id);
|
|
}
|
|
|
|
TEST(ThreadLocalDebugInfoTest, Basic) {
|
|
static std::atomic<bool> done{false};
|
|
|
|
TORCH_CHECK(
|
|
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
|
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
|
|
debug_info->setModelId(42);
|
|
{
|
|
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
|
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
|
}
|
|
|
|
// check that thread local debug info is propagated through fork calls
|
|
TORCH_CHECK(
|
|
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
|
{
|
|
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
|
|
at::launch([]() {
|
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
|
done = true;
|
|
});
|
|
}
|
|
while (!done) {
|
|
}
|
|
|
|
// check that thread local debug info is propagated through backward pass
|
|
TORCH_CHECK(
|
|
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
|
done = false;
|
|
auto handle = addGlobalCallback(RecordFunctionCallback(
|
|
[](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
|
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
|
done = true;
|
|
return nullptr;
|
|
}));
|
|
{
|
|
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
|
|
auto t = torch::randn({1, 2, 3}, at::kCPU);
|
|
t.set_requires_grad(true);
|
|
auto t2 = t.pow(2);
|
|
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
|
|
}
|
|
removeCallback(handle);
|
|
TORCH_CHECK(done);
|
|
|
|
// check nested debug info
|
|
TORCH_CHECK(
|
|
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
|
{
|
|
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
|
|
{
|
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
|
{
|
|
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
|
|
debug_info->setModelId(314);
|
|
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO_2, debug_info);
|
|
{
|
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
|
|
done = false;
|
|
at::launch([]() {
|
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
|
|
done = true;
|
|
});
|
|
while (!done) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(TestSymIntArrayRef, BasicConversion) {
|
|
const size_t X = 2, Y = 4, Z = 5;
|
|
std::vector<int64_t> tgt_size_v{2, 4, 5};
|
|
std::vector<c10::SymInt> tgt_size({SymInt(X), SymInt(Y), SymInt(Z)});
|
|
auto a = at::randn({1, 4, 1}, at::kCPU);
|
|
auto b = a.expand_symint(tgt_size);
|
|
auto c = a.expand(tgt_size_v);
|
|
ASSERT_TRUE(torch::allclose(b, c));
|
|
}
|
|
|
|
TEST(TestSymInt, NarrowCopyWithSymbolicInt) {
|
|
static const size_t LENGTH = 5;
|
|
auto a = at::randn({10}, at::kCPU);
|
|
c10::SymInt si(LENGTH);
|
|
auto b = a.narrow_copy_symint(0, 0, si);
|
|
auto c = a.narrow(0, 0, LENGTH);
|
|
ASSERT_TRUE(torch::allclose(b, c));
|
|
}
|
|
|
|
TEST(TestSymInt, NarrowCopy) {
|
|
static const size_t LENGTH = 5;
|
|
auto a = at::randn({10}, at::kCPU);
|
|
auto b = a.narrow_copy(0, 0, LENGTH);
|
|
auto c = a.narrow(0, 0, LENGTH);
|
|
ASSERT_TRUE(torch::allclose(b, c));
|
|
}
|
|
|
|
TEST(TestSymInt, AddSymbolicInt) {
|
|
c10::SymInt a(5);
|
|
c10::SymInt b(3);
|
|
ASSERT_TRUE((a + b).expect_int() == 8);
|
|
}
|
|
|
|
TEST(FallbackGraphsTest, Basic) {
|
|
auto x = at::randn({1}, at::kCPU);
|
|
auto y = at::randn({1}, at::kCPU);
|
|
auto stack = createStack({x.clone(), y.clone()});
|
|
|
|
auto graph_string = R"IR(
|
|
graph(%0 : Float(1),
|
|
%1 : Float(1)):
|
|
%2 : Tensor = aten::mul(%0, %1)
|
|
%3 : Tensor = aten::mul(%2, %0)
|
|
return (%3))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
torch::jit::parseIR(graph_string, graph.get());
|
|
|
|
{
|
|
Code code(graph, "");
|
|
InterpreterState interpreter{code};
|
|
interpreter.run(stack);
|
|
}
|
|
at::Tensor et;
|
|
pop(stack, et);
|
|
float ef = et.item<float>();
|
|
{
|
|
EnableProfilingGuard epg;
|
|
GraphFunction f("fallbackGraphs", graph, nullptr);
|
|
for (size_t i = 0; i < getNumProfiledRuns() + 1; i++) {
|
|
stack.emplace_back(x.clone());
|
|
stack.emplace_back(y.clone());
|
|
if (i == getNumProfiledRuns()) {
|
|
// we will be modifying a profiled graph
|
|
// before ProfilingGraphExecutor
|
|
// will optimize it in the next iteration
|
|
auto opt_graph = lastExecutedOptimizedGraph();
|
|
// this is safe to do since we are done profiling
|
|
ProfilingRecord::removeProfileCounter(opt_graph->block());
|
|
replaceBlockWithFallbackGraph(opt_graph->block(), opt_graph->inputs());
|
|
auto it = opt_graph->block()->nodes().begin();
|
|
ASSERT_EQ(it->kind(), prim::FallbackGraph);
|
|
auto fallback = *it++;
|
|
ASSERT_EQ(it, opt_graph->block()->nodes().end());
|
|
ASSERT_TRUE(fallback->hasAttribute(attr::Subgraph));
|
|
testing::FileCheck()
|
|
.check("Tensor = aten::mul")
|
|
->check("Tensor = aten::mul")
|
|
->run(*fallback->g(attr::Subgraph));
|
|
}
|
|
f.run(stack);
|
|
at::Tensor at;
|
|
pop(stack, at);
|
|
float af = at.item<float>();
|
|
ASSERT_EQ(af, ef);
|
|
}
|
|
|
|
auto opt_graph = lastExecutedOptimizedGraph();
|
|
testing::FileCheck()
|
|
.check("(Tensor) = prim::CallFunction")
|
|
->run(*opt_graph);
|
|
}
|
|
}
|
|
|
|
// TODO this test wasn't running and is broken.
|
|
// TEST(AutogradProfilerTest, Basic) {
|
|
// constexpr int batch_size = 4;
|
|
// constexpr int input_size = 256;
|
|
// constexpr int seq_len = 32;
|
|
|
|
// int hidden_size = 2 * input_size;
|
|
// auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU);
|
|
// auto hx = torch::randn({batch_size, hidden_size}, at::kCPU);
|
|
// auto cx = torch::randn({batch_size, hidden_size}, at::kCPU);
|
|
// auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU));
|
|
// auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU));
|
|
|
|
// std::stringstream ss;
|
|
// {
|
|
// RecordProfile guard(ss);
|
|
// for (size_t i = 0; i < 100; ++i) {
|
|
// std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
|
|
// }
|
|
// }
|
|
|
|
// std::string result = ss.str();
|
|
// size_t count = 0;
|
|
// for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos;
|
|
// count++, pos++) {
|
|
// }
|
|
// ASSERT_EQ((count, 200);
|
|
// }
|
|
|
|
TEST(NoneSchemaMatchTest, Basic) {
|
|
RegisterOperators reg({
|
|
Operator(
|
|
"prim::test_none() -> int?",
|
|
[](Stack& stack) { push(stack, IValue()); },
|
|
aliasAnalysisFromSchema()),
|
|
Operator(
|
|
"prim::is_none(int? a) -> bool",
|
|
[](Stack& stack) {
|
|
IValue a = pop(stack);
|
|
if (a.isNone()) {
|
|
push(stack, true);
|
|
} else {
|
|
push(stack, false);
|
|
}
|
|
},
|
|
aliasAnalysisFromSchema()),
|
|
});
|
|
|
|
// Constant propagation will run test_none and produce a None,
|
|
// testing that its type is set appropriately and schema matching doesn't
|
|
// fail when running is_none
|
|
|
|
auto r = std::make_shared<Graph>();
|
|
auto& g = *r;
|
|
auto opt_int = g.insert(Symbol::fromQualString("prim::test_none"), {});
|
|
auto out_bool = g.insert(Symbol::fromQualString("prim::is_none"), {opt_int});
|
|
g.registerOutput(out_bool);
|
|
ConstantPropagation(r);
|
|
|
|
auto nodes = r->block()->nodes();
|
|
// checking that constant propagation ran wo/failure
|
|
AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
|
|
}
|
|
|
|
static int testPassValue = 0;
|
|
void fakePass(std::shared_ptr<Graph>& g) {
|
|
testPassValue++;
|
|
return;
|
|
}
|
|
|
|
RegisterPass p(fakePass);
|
|
|
|
TEST(PassManagementTest, Basic) {
|
|
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph(%a):
|
|
return (%a))IR",
|
|
&*graph);
|
|
|
|
std::vector<IValue> stack = {IValue(torch::randn({22}, at::kCPU))};
|
|
auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) {
|
|
GraphExecutor executor(graph, "");
|
|
executor.run(stack);
|
|
return stack;
|
|
};
|
|
run(graph, stack);
|
|
// we will not run fusion in simple mode
|
|
if (!getExecutorMode()) {
|
|
AT_ASSERT(testPassValue);
|
|
}
|
|
}
|
|
|
|
static void checkShape(TypePtr typ, std::vector<int64_t> expected) {
|
|
auto ptp = typ->expect<TensorType>();
|
|
ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected);
|
|
}
|
|
|
|
static void checkShape(
|
|
Node* n,
|
|
std::vector<int64_t> expected,
|
|
bool prev = true) {
|
|
auto profile = (prev) ? n->inputs().at(0)->node() : n;
|
|
checkShape(profile->output()->type(), expected);
|
|
}
|
|
|
|
void count_(
|
|
Block* block,
|
|
const std::function<bool(Node* n)>& pred,
|
|
size_t& count) {
|
|
for (Node* n : block->nodes()) {
|
|
if (pred(n)) {
|
|
count++;
|
|
}
|
|
|
|
for (Block* ib : n->blocks()) {
|
|
count_(ib, pred, count);
|
|
}
|
|
}
|
|
}
|
|
|
|
size_t countNodes(
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::function<bool(Node* n)>& pred) {
|
|
size_t count = 0;
|
|
count_(graph->block(), pred, count);
|
|
return count;
|
|
}
|
|
|
|
bool true_pred(Node* n) {
|
|
return true;
|
|
};
|
|
|
|
bool is_loop(Node* n) {
|
|
return n->kind() == prim::Loop;
|
|
};
|
|
|
|
TEST(LoopPeelerTest, NoInductionVariableUse) {
|
|
// do not use an induction variable explicitly
|
|
static const auto str_func_def = R"JIT(
|
|
def test_peel_n_times():
|
|
sum = 0
|
|
for i in range(10):
|
|
sum += 2
|
|
return sum
|
|
)JIT";
|
|
|
|
auto cu = compile(str_func_def);
|
|
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
|
|
auto stack = createStack({});
|
|
// peeling loop once
|
|
{
|
|
LoopsPeeler peeler(true_pred, 1);
|
|
auto copy = f.graph()->copy();
|
|
peeler.run(copy);
|
|
int num_loops =
|
|
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
|
|
ASSERT_EQ(num_loops, 2);
|
|
Code code(copy, "");
|
|
InterpreterState interpreter{code};
|
|
interpreter.run(stack);
|
|
ASSERT_EQ(stack.back().toInt(), 20);
|
|
}
|
|
|
|
// test peeling more than one iteration
|
|
{
|
|
LoopsPeeler peeler(true_pred, 3);
|
|
auto copy = f.graph()->copy();
|
|
peeler.run(copy);
|
|
int num_loops =
|
|
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
|
|
ASSERT_EQ(num_loops, 2);
|
|
Code code(copy, "");
|
|
InterpreterState interpreter{code};
|
|
interpreter.run(stack);
|
|
ASSERT_EQ(stack.back().toInt(), 20);
|
|
}
|
|
}
|
|
|
|
TEST(LoopPeelerTest, YesInductionVariableUse) {
|
|
// uses the induction variable
|
|
static const auto str_func_def = R"JIT(
|
|
def test_peel_n_times():
|
|
sum = 0
|
|
for i in range(10):
|
|
sum += i
|
|
return sum
|
|
)JIT";
|
|
|
|
auto cu = compile(str_func_def);
|
|
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
|
|
auto stack = createStack({});
|
|
// peeling loop once
|
|
{
|
|
LoopsPeeler peeler(true_pred, 1);
|
|
auto copy = f.graph()->copy();
|
|
peeler.run(copy);
|
|
int num_loops =
|
|
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
|
|
ASSERT_EQ(num_loops, 2);
|
|
Code code(copy, "");
|
|
InterpreterState interpreter{code};
|
|
interpreter.run(stack);
|
|
ASSERT_EQ(stack.back().toInt(), 45);
|
|
}
|
|
|
|
// test peeling more than one iteration
|
|
{
|
|
LoopsPeeler peeler(true_pred, 3);
|
|
auto copy = f.graph()->copy();
|
|
peeler.run(copy);
|
|
int num_loops =
|
|
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
|
|
ASSERT_EQ(num_loops, 2);
|
|
Code code(copy, "");
|
|
InterpreterState interpreter{code};
|
|
interpreter.run(stack);
|
|
ASSERT_EQ(stack.back().toInt(), 45);
|
|
}
|
|
}
|
|
|
|
TEST(LoopPeelerTest, LoopWithTerminationCondition) {
|
|
// tests with explicit termination conditions
|
|
static const auto str_func_def = R"JIT(
|
|
def test_with_cond_times():
|
|
sum = 0
|
|
i = 0
|
|
while (sum < 2):
|
|
sum += i
|
|
i += 1
|
|
return sum
|
|
)JIT";
|
|
|
|
// the peel changes the termination condition to false
|
|
// so the original loop doesn't run
|
|
auto cu = compile(str_func_def);
|
|
auto& f = toGraphFunction(cu->get_function("test_with_cond_times"));
|
|
auto stack = createStack({});
|
|
// peeling 5 iterations should update the termination
|
|
// condition to false
|
|
{
|
|
LoopsPeeler peeler(true_pred, 5);
|
|
auto copy = f.graph()->copy();
|
|
peeler.run(copy);
|
|
int num_loops =
|
|
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
|
|
ASSERT_EQ(num_loops, 2);
|
|
Code code(copy, "");
|
|
InterpreterState interpreter{code};
|
|
interpreter.run(stack);
|
|
ASSERT_EQ(stack.back().toInt(), 3);
|
|
}
|
|
|
|
// the termination condition remains true
|
|
{
|
|
LoopsPeeler peeler(true_pred, 1);
|
|
auto copy = f.graph()->copy();
|
|
peeler.run(copy);
|
|
int num_loops =
|
|
std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
|
|
ASSERT_EQ(num_loops, 2);
|
|
Code code(copy, "");
|
|
InterpreterState interpreter{code};
|
|
interpreter.run(stack);
|
|
ASSERT_EQ(stack.back().toInt(), 3);
|
|
}
|
|
}
|
|
|
|
// tests simple nested loops
|
|
TEST(LoopPeelerTest, SimpleNestedLoops) {
|
|
static const auto str_func_def = R"JIT(
|
|
def test_nested_loops():
|
|
sum = 0
|
|
i = 0
|
|
for i in range(10):
|
|
for j in range(10):
|
|
sum += i + j
|
|
return sum
|
|
)JIT";
|
|
|
|
auto cu = compile(str_func_def);
|
|
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
|
|
auto stack = createStack({});
|
|
|
|
{
|
|
LoopsPeeler peeler(true_pred, 1);
|
|
auto copy = f.graph()->copy();
|
|
peeler.run(copy);
|
|
ASSERT_EQ(countNodes(copy, is_loop), 5);
|
|
Code code(copy, "");
|
|
InterpreterState interpreter{code};
|
|
interpreter.run(stack);
|
|
ASSERT_EQ(stack.back().toInt(), 900);
|
|
}
|
|
|
|
{
|
|
LoopsPeeler peeler(true_pred, 5);
|
|
auto copy = f.graph()->copy();
|
|
peeler.run(copy);
|
|
ASSERT_EQ(countNodes(copy, is_loop), 5);
|
|
Code code(copy, "");
|
|
InterpreterState interpreter{code};
|
|
interpreter.run(stack);
|
|
ASSERT_EQ(stack.back().toInt(), 900);
|
|
}
|
|
}
|
|
|
|
TEST(LoopPeelerTest, SimpleNestedLoops2) {
|
|
static const auto str_func_def = R"JIT(
|
|
def test_nested_loops():
|
|
sum = 0
|
|
i = 0
|
|
for i in range(10):
|
|
j = 0
|
|
while sum < 2:
|
|
sum += i + j
|
|
j += 1
|
|
return sum
|
|
)JIT";
|
|
|
|
auto cu = compile(str_func_def);
|
|
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
|
|
auto stack = createStack({});
|
|
{
|
|
LoopsPeeler peeler(true_pred, 1);
|
|
auto copy = f.graph()->copy();
|
|
peeler.run(copy);
|
|
ASSERT_EQ(countNodes(copy, is_loop), 5);
|
|
Code code(copy, "");
|
|
InterpreterState interpreter{code};
|
|
interpreter.run(stack);
|
|
ASSERT_EQ(stack.back().toInt(), 3);
|
|
}
|
|
|
|
{
|
|
LoopsPeeler peeler(true_pred, 5);
|
|
auto copy = f.graph()->copy();
|
|
peeler.run(copy);
|
|
ASSERT_EQ(countNodes(copy, is_loop), 5);
|
|
Code code(copy, "");
|
|
InterpreterState interpreter{code};
|
|
interpreter.run(stack);
|
|
ASSERT_EQ(stack.back().toInt(), 3);
|
|
}
|
|
}
|
|
|
|
TEST(JitTracing, Basic) {
|
|
constexpr int batch_size = 4;
|
|
constexpr int input_size = 256;
|
|
|
|
int hidden_size = 2 * input_size;
|
|
|
|
auto input = at::randn({batch_size, input_size}, at::kCPU);
|
|
auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
|
|
auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
|
|
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
|
|
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
|
|
|
|
auto graph = build_lstm();
|
|
auto stack = createStack({input, hx, cx, w_ih, w_hh});
|
|
auto traced = TraceGraph(graph, stack);
|
|
|
|
// Check that the inputs of traced graph have the same type as the inputs
|
|
// specified here.
|
|
ASSERT_EQ(*traced->inputs().at(0)->type(), *TensorType::create(input));
|
|
ASSERT_EQ(*traced->inputs().at(1)->type(), *TensorType::create(hx));
|
|
ASSERT_EQ(*traced->inputs().at(2)->type(), *TensorType::create(cx));
|
|
ASSERT_EQ(*traced->inputs().at(3)->type(), *TensorType::create(w_ih));
|
|
ASSERT_EQ(*traced->inputs().at(4)->type(), *TensorType::create(w_hh));
|
|
|
|
Tensor prof_out;
|
|
pop(stack, prof_out);
|
|
|
|
{
|
|
stack = createStack({input, hx, cx, w_ih, w_hh});
|
|
Code cd(traced, "traced");
|
|
InterpreterState is{cd};
|
|
is.run(stack);
|
|
Tensor traced_out;
|
|
pop(stack, traced_out);
|
|
torch::allclose(prof_out, traced_out);
|
|
}
|
|
|
|
{
|
|
stack = createStack({input, hx, cx, w_ih, w_hh});
|
|
Code cd(graph, "graph");
|
|
InterpreterState is{cd};
|
|
is.run(stack);
|
|
Tensor scripted_out;
|
|
pop(stack, scripted_out);
|
|
torch::allclose(prof_out, scripted_out);
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
|
|
static const auto basic_example = R"JIT(
|
|
def basic(x, y):
|
|
a = x + y
|
|
b = x * y
|
|
c = x + 1
|
|
d = a - c
|
|
e = b - c
|
|
return d + e
|
|
)JIT";
|
|
|
|
auto cu = compile(basic_example);
|
|
auto& fun = toGraphFunction(cu->get_function("basic"));
|
|
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
|
|
auto x = at::randn({2, 3}, at::kCPU);
|
|
auto y = at::randn({2, 3}, at::kCPU);
|
|
auto stack = createStack({x, y});
|
|
// introduce some profiling information
|
|
Code cd(pr->profiled_graph_, "");
|
|
InterpreterState is{cd};
|
|
is.run(stack);
|
|
auto copy = pr->profiled_graph_->copy();
|
|
ProfilingRecord::removeProfileCounter(copy->block());
|
|
InsertGuards(copy);
|
|
auto nodes = copy->block()->nodes();
|
|
auto guard = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
|
|
return n->kind() == prim::Guard;
|
|
});
|
|
ASSERT_NE(guard, nodes.end());
|
|
ASSERT_EQ(
|
|
guard->input()->type()->expectRef<TensorType>().sizes().size(),
|
|
std::nullopt);
|
|
checkShape(*guard, {2, 3}, false);
|
|
auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
|
|
int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
|
|
ASSERT_EQ(num_guards, 12);
|
|
// now eliminate as many guards as possible
|
|
// we should be left with two guards on x and y's defs
|
|
EliminateRedundantGuards(copy);
|
|
num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
|
|
ASSERT_EQ(num_guards, 2);
|
|
}
|
|
|
|
TEST(InsertBailOutsTest, Basic) {
|
|
static const auto basic_example = R"JIT(
|
|
def basic_loop(x, y):
|
|
|
|
a = x + 1
|
|
b = y + 2
|
|
c = x + y + 3
|
|
|
|
for i in range(10):
|
|
a = a + b
|
|
# invariant
|
|
d = b * c
|
|
#
|
|
a = a - d
|
|
|
|
e = a + 4
|
|
return e
|
|
)JIT";
|
|
|
|
auto cu = compile(basic_example);
|
|
auto& fun = toGraphFunction(cu->get_function("basic_loop"));
|
|
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
|
|
auto x = at::randn({2, 3}, at::kCPU);
|
|
auto y = at::randn({2, 3}, at::kCPU);
|
|
auto stack = createStack({x, y});
|
|
// introduce some profiling information
|
|
Code cd(pr->profiled_graph_, "");
|
|
InterpreterState is{cd};
|
|
is.run(stack);
|
|
auto copy = pr->profiled_graph_->copy();
|
|
ProfilingRecord::removeProfileCounter(copy->block());
|
|
InsertGuards(copy);
|
|
EliminateRedundantGuards(copy);
|
|
auto nodes = copy->block()->nodes();
|
|
auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
|
|
auto num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
|
|
ASSERT_EQ(num_guards, 3);
|
|
InsertBailOuts(copy);
|
|
auto is_bailout = [](Node* n) { return n->kind() == prim::BailOut; };
|
|
auto num_bailouts = std::count_if(nodes.begin(), nodes.end(), is_bailout);
|
|
ASSERT_EQ(num_guards, num_bailouts);
|
|
std::vector<Node*> bailouts(num_bailouts);
|
|
std::copy_if(nodes.begin(), nodes.end(), bailouts.begin(), is_bailout);
|
|
|
|
for (auto blo : bailouts) {
|
|
ASSERT_EQ(blo->inputs().at(0)->node()->kind(), prim::BailoutTemplate);
|
|
}
|
|
}
|
|
|
|
TEST(ProfilerTest, Basic) {
|
|
constexpr int batch_size = 4;
|
|
constexpr int input_size = 256;
|
|
|
|
int hidden_size = 2 * input_size;
|
|
|
|
auto input = at::randn({batch_size, input_size}, at::kCPU);
|
|
auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
|
|
auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
|
|
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
|
|
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
|
|
|
|
auto g = build_lstm();
|
|
auto stack = createStack({input, hx, cx, w_ih, w_hh});
|
|
|
|
auto& opt_graph = *g.get();
|
|
ArgumentSpecCreator arg_spec_creator(opt_graph);
|
|
ArgumentSpec spec =
|
|
arg_spec_creator.create(autograd::GradMode::is_enabled(), stack);
|
|
arg_spec_creator.specializeTypes(opt_graph, spec);
|
|
auto pr = ProfilingRecord::instrumentGraph(g);
|
|
Code cd(pr->profiled_graph_, "");
|
|
InterpreterState is{cd};
|
|
is.run(stack);
|
|
|
|
// profiled types are stored as attributes and show up in the dump, e.g.
|
|
// Tensor = prim::profile[profiled_type=Double(4, 256, strides=[256, 1],
|
|
// requires_grad=0, device=cpu)
|
|
testing::FileCheck()
|
|
.check("Tensor = prim::profile[profiled_type")
|
|
->check_same("256")
|
|
->run(*pr->profiled_graph_);
|
|
|
|
auto begin = pr->profiled_graph_->block()->nodes().begin();
|
|
auto end = pr->profiled_graph_->block()->nodes().end();
|
|
auto mm =
|
|
std::find_if(begin, end, [](Node* n) { return n->kind() == aten::add; });
|
|
ASSERT_NE(mm, end);
|
|
std::vector<int64_t> mm_expected{4, 2048};
|
|
std::vector<int64_t> eltwise{4, 512};
|
|
checkShape(mm->inputs().at(0)->node()->ty(attr::profiled_type), mm_expected);
|
|
auto mul_n =
|
|
std::find_if(begin, end, [](Node* n) { return n->kind() == aten::mul; });
|
|
ASSERT_NE(mul_n, end);
|
|
checkShape(mul_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
|
|
auto tanh_n =
|
|
std::find_if(begin, end, [](Node* n) { return n->kind() == aten::tanh; });
|
|
checkShape(tanh_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
|
|
}
|
|
|
|
TEST(ProfilerTest, OptionalProfiling) {
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%inp : Tensor,
|
|
%weight : Tensor,
|
|
%bias : Tensor?):
|
|
%1 : Tensor = aten::linear(%inp, %weight, %bias)
|
|
return (%1))IR",
|
|
&*graph,
|
|
vmap);
|
|
|
|
auto pr = ProfilingRecord::instrumentGraph(graph);
|
|
pr->profiling_count_ = 2;
|
|
|
|
auto input = torch::randn({1, 2});
|
|
auto weight = torch::randn({2, 2});
|
|
auto bias = torch::randn({1, 2});
|
|
|
|
auto stack = createStack({input, weight, bias});
|
|
Code cd(pr->profiled_graph_, "");
|
|
InterpreterState is{cd};
|
|
is.run(stack);
|
|
|
|
testing::FileCheck()
|
|
.check_count("Tensor? = prim::profile[profiled_type", 1, true)
|
|
->run(*pr->profiled_graph_);
|
|
|
|
// make sure we recorded the shape
|
|
auto begin = pr->profiled_graph_->block()->nodes().begin();
|
|
auto end = pr->profiled_graph_->block()->nodes().end();
|
|
auto linear = std::find_if(
|
|
begin, end, [](Node* n) { return n->kind() == aten::linear; });
|
|
ASSERT_NE(linear, end);
|
|
std::vector<int64_t> bias_expected_shape = {1, 2};
|
|
auto profiled_bias = linear->namedInput("bias")->node();
|
|
checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
|
|
ASSERT_EQ(0, profiled_bias->i(attr::seen_none));
|
|
|
|
auto none_bias = c10::IValue();
|
|
|
|
stack.clear();
|
|
stack.emplace_back(input);
|
|
stack.emplace_back(weight);
|
|
stack.emplace_back(none_bias);
|
|
is = InterpreterState{cd};
|
|
is.run(stack);
|
|
|
|
// make sure we recorded that "None" was seen.
|
|
begin = pr->profiled_graph_->block()->nodes().begin();
|
|
end = pr->profiled_graph_->block()->nodes().end();
|
|
linear = std::find_if(
|
|
begin, end, [](Node* n) { return n->kind() == aten::linear; });
|
|
ASSERT_NE(linear, end);
|
|
profiled_bias = linear->namedInput("bias")->node();
|
|
checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
|
|
ASSERT_EQ(1, profiled_bias->i(attr::seen_none));
|
|
}
|
|
|
|
TEST(CallStackTest, Basic) {
|
|
const auto text = R"(
|
|
def ham(x):
|
|
return x/7
|
|
|
|
def bar(x):
|
|
return x*3
|
|
|
|
def baz(x):
|
|
return ham(x)*x
|
|
|
|
def foo(x):
|
|
return bar(x)*baz(x)*11
|
|
)";
|
|
auto cu = compile(text);
|
|
const auto& foo = toGraphFunction(cu->get_function("foo"));
|
|
for (Node* n : foo.optimized_graph()->nodes()) {
|
|
if (n->kind() == prim::Constant) {
|
|
if (!n->hasAttribute(attr::value) ||
|
|
n->kindOf(attr::value) != AttributeKind::i) {
|
|
continue;
|
|
}
|
|
int v = n->i(attr::value);
|
|
switch (v) {
|
|
case 3: {
|
|
// Const 3 comes from function 'bar', which gets inlined to 'foo'.
|
|
// The callstack for the corresponding node should contain only the
|
|
// function 'bar'.
|
|
ASSERT_TRUE(n->callstack());
|
|
auto callstack_vector = (*n->callstack())->vec();
|
|
ASSERT_EQ(callstack_vector.size(), 1);
|
|
ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("bar"));
|
|
break;
|
|
}
|
|
case 7: {
|
|
// Const 7 comes from function 'ham', which gets inlined to 'baz',
|
|
// which is then inlined to 'foo'. The callstack for the corresponding
|
|
// node should contain these two functions.
|
|
ASSERT_TRUE(n->callstack());
|
|
auto callstack_vector = (*n->callstack())->vec();
|
|
ASSERT_EQ(callstack_vector.size(), 2);
|
|
ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("baz"));
|
|
ASSERT_EQ(std::get<0>(callstack_vector[1]), &cu->get_function("ham"));
|
|
break;
|
|
}
|
|
case 11: {
|
|
// Const 11 comes from function 'foo', which is not inlined anywhere
|
|
// and thus it should not have a callstack.
|
|
ASSERT_FALSE(n->callstack());
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check that inlining doesn't corrupt callstack of the callee's nodes.
|
|
const auto& baz = toGraphFunction(cu->get_function("baz"));
|
|
for (Node* n : baz.optimized_graph()->nodes()) {
|
|
if (n->kind() == prim::Constant) {
|
|
if (!n->hasAttribute(attr::value) ||
|
|
n->kindOf(attr::value) != AttributeKind::i) {
|
|
continue;
|
|
}
|
|
int v = n->i(attr::value);
|
|
ASSERT_TRUE(v == 7);
|
|
// Const 7 comes from function 'ham', which gets inlined to 'baz'. 'baz'
|
|
// was also inlined into 'foo', but when looking at the graph of 'baz' we
|
|
// should only see a callstack of depth 1 (containing only 'ham').
|
|
ASSERT_TRUE(n->callstack());
|
|
auto callstack_vector = (*n->callstack())->vec();
|
|
ASSERT_EQ(callstack_vector.size(), 1);
|
|
ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("ham"));
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(CallStackTest, Caching) {
|
|
const auto text = R"(
|
|
|
|
def a(x):
|
|
print("a1")
|
|
print("a2")
|
|
return x
|
|
|
|
def b(x):
|
|
print("b1")
|
|
print("b2")
|
|
a(x)
|
|
return x
|
|
|
|
def c(x):
|
|
print("c1")
|
|
print("c2")
|
|
b(x)
|
|
return x
|
|
)";
|
|
auto cu = compile(text);
|
|
const auto& baz = toGraphFunction(cu->get_function("c"));
|
|
std::unordered_map<std::string, InlinedCallStack*> callstack_objects;
|
|
for (Node* n : baz.optimized_graph()->nodes()) {
|
|
if (n->kind() == prim::Constant) {
|
|
if (!n->hasAttribute(attr::value) ||
|
|
n->kindOf(attr::value) != AttributeKind::s) {
|
|
continue;
|
|
}
|
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
|
std::string v = n->s(attr::value);
|
|
if (n->callstack()) {
|
|
callstack_objects[v] = n->callstack()->get();
|
|
}
|
|
}
|
|
}
|
|
// We expect to see nodes prim::Constant[value="a1"] and
|
|
// prim::Constant[value="a2"] inlined to function 'c'. Their callstacks are
|
|
// the same (a->b->c), so we want to make sure we're not creating different
|
|
// callstack entries for them.
|
|
ASSERT_TRUE(callstack_objects.count("a1") && callstack_objects.count("a2"));
|
|
ASSERT_TRUE(callstack_objects.at("a1") == callstack_objects.at("a2"));
|
|
}
|
|
|
|
TEST(InlinedCallStackTest, BlockAnnotation) {
|
|
Module a("A");
|
|
a.define(R"(
|
|
def forward(self, x, y, z: int):
|
|
if (z == 1):
|
|
return x + y
|
|
else:
|
|
return x * y
|
|
)");
|
|
Module b("B");
|
|
b.define(R"(
|
|
def forward(self, x):
|
|
return x + 2
|
|
)");
|
|
Module c("C");
|
|
c.register_module("A0", a);
|
|
c.register_module("B0", b);
|
|
c.define(R"(
|
|
def forward(self, x, y, z: int):
|
|
return self.A0.forward(x, y, z) + self.B0.forward(x)
|
|
)");
|
|
|
|
auto graph =
|
|
toGraphFunction(c.get_method("forward").function()).optimized_graph();
|
|
std::stringstream add_ss, mul_ss;
|
|
for (Node* n : graph->nodes()) {
|
|
if (n->kind() == prim::If) {
|
|
for (Block* block : n->blocks()) {
|
|
for (Node* if_node : block->nodes()) {
|
|
if (if_node->kind() == aten::add) {
|
|
for (const auto& e : if_node->callstack().value()->vec()) {
|
|
add_ss << std::get<1>(e);
|
|
}
|
|
add_ss << if_node->sourceRange();
|
|
}
|
|
if (if_node->kind() == aten::mul) {
|
|
for (const auto& e : if_node->callstack().value()->vec()) {
|
|
mul_ss << std::get<1>(e);
|
|
}
|
|
mul_ss << if_node->sourceRange();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
ASSERT_NE(add_ss.str().find("line 3"), std::string::npos);
|
|
ASSERT_NE(add_ss.str().find("line 4"), std::string::npos);
|
|
ASSERT_NE(
|
|
add_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
|
|
ASSERT_NE(add_ss.str().find("return x + y"), std::string::npos);
|
|
ASSERT_NE(mul_ss.str().find("line 3"), std::string::npos);
|
|
ASSERT_NE(mul_ss.str().find("line 6"), std::string::npos);
|
|
ASSERT_NE(
|
|
mul_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
|
|
ASSERT_NE(mul_ss.str().find("return x * y"), std::string::npos);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(InlinedCallStackTest, SelfCallMethods) {
|
|
Module a("A");
|
|
a.define(R"(
|
|
def my_new_method(self, x):
|
|
return x * 3
|
|
def forward_impl_(self, x, y):
|
|
return self.my_new_method(x) + y
|
|
def forward(self, x, y):
|
|
y = y + 2
|
|
return self.forward_impl_(x, y)
|
|
)");
|
|
Module b("B");
|
|
b.define(R"(
|
|
def forward(self, x):
|
|
return x + 2
|
|
)");
|
|
Module c("C");
|
|
c.register_module("A0", a);
|
|
c.register_module("B0", b);
|
|
c.define(R"(
|
|
def call_b(self, x):
|
|
return self.B0.forward(x)
|
|
def forward(self, x, y):
|
|
return self.A0.forward(x, y) + self.call_b(x)
|
|
)");
|
|
|
|
auto graph =
|
|
toGraphFunction(c.get_method("forward").function()).optimized_graph();
|
|
std::unordered_map<std::string, size_t> module_hierarchies;
|
|
for (Node* n : graph->nodes()) {
|
|
auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n);
|
|
if (module_hierarchies.count(hierarchy) == 0) {
|
|
module_hierarchies[hierarchy] = 0;
|
|
}
|
|
module_hierarchies[hierarchy] += 1;
|
|
}
|
|
ASSERT_EQ(module_hierarchies["A0(A)"], 2);
|
|
ASSERT_EQ(module_hierarchies["A0(A).SELF(A).SELF(A)"], 2);
|
|
ASSERT_EQ(module_hierarchies["A0(A).SELF(A)"], 1);
|
|
ASSERT_EQ(module_hierarchies["SELF(C)"], 1);
|
|
ASSERT_EQ(module_hierarchies["SELF(C).B0(B)"], 1);
|
|
}
|
|
|
|
TEST(AutogradSymbolsTest, Basic) {
|
|
Symbol sym = Symbol::fromQualString("aten::test_symbol");
|
|
Graph graph;
|
|
auto node = graph.create(sym);
|
|
TORCH_CHECK(canRunWithAutograd(node));
|
|
|
|
sym = Symbol::fromQualString("prim::test_symbol");
|
|
node = graph.create(sym);
|
|
TORCH_CHECK(canRunWithAutograd(node));
|
|
|
|
sym = Symbol::fromQualString("prim::FusionGroup");
|
|
node = graph.create(sym);
|
|
TORCH_CHECK(!canRunWithAutograd(node));
|
|
|
|
sym = Symbol::fromQualString("custom::test_symbol");
|
|
node = graph.create(sym);
|
|
TORCH_CHECK(!canRunWithAutograd(node));
|
|
}
|
|
|
|
TEST(DefaultArgTypeHintingTest, Basic) {
|
|
const auto text_non_hinted = R"(
|
|
|
|
def a(x, y=1):
|
|
print("a1")
|
|
print("a2")
|
|
return x
|
|
)";
|
|
|
|
const auto text_hinted = R"(
|
|
|
|
def a(x, y:int=1):
|
|
print("a1")
|
|
print("a2")
|
|
return x
|
|
)";
|
|
|
|
try {
|
|
compile(text_non_hinted);
|
|
ASSERT_TRUE(0);
|
|
} catch (const std::exception& c) {
|
|
}
|
|
|
|
auto cu = compile(text_hinted);
|
|
}
|
|
|
|
// Basic set case.
|
|
TEST(FuturesTest, Basic) {
|
|
auto f1 = c10::make_intrusive<Future>(IntType::get());
|
|
ASSERT_FALSE(f1->completed());
|
|
ASSERT_FALSE(f1->hasValue());
|
|
int32_t sat1 = 0;
|
|
int32_t sat2 = 0;
|
|
f1->addCallback([&](Future& /* unused */) { ++sat1; });
|
|
f1->markCompleted(43);
|
|
ASSERT_TRUE(f1->completed());
|
|
ASSERT_TRUE(f1->hasValue());
|
|
ASSERT_FALSE(f1->hasError());
|
|
ASSERT_EQ(sat1, 1);
|
|
ASSERT_EQ(f1->constValue().toInt(), 43);
|
|
ASSERT_EQ(f1->value().toInt(), 43);
|
|
f1->addCallback([&](Future& /* unused */) { ++sat2; });
|
|
ASSERT_EQ(sat1, 1);
|
|
ASSERT_EQ(sat2, 1);
|
|
}
|
|
|
|
// Sparse CUDA tensor test
|
|
TEST(FutureTest, SparseTensor) {
|
|
// Skip test if CUDA is not available.
|
|
bool has_cuda = at::globalContext().hasCUDA();
|
|
if (!has_cuda) {
|
|
LOG(INFO) << "CUDA not available, skipping test";
|
|
}
|
|
for (int i = 0; i < 2; ++i) {
|
|
auto f = c10::make_intrusive<Future>(TensorType::get());
|
|
at::TensorOptions opts = at::TensorOptions().device(at::DeviceType::CUDA);
|
|
auto sparse_tensor = i == 0 ? at::ones(10).to_sparse()
|
|
: at::sparse_coo_tensor(
|
|
at::arange(10).unsqueeze(0).to(at::kLong),
|
|
at::ones({10, 10}),
|
|
opts);
|
|
// Runs storage extraction for sparse CUDA tensors
|
|
f->markCompleted(sparse_tensor);
|
|
ASSERT_TRUE(f->completed());
|
|
ASSERT_FALSE(f->hasError());
|
|
}
|
|
}
|
|
|
|
// Basic error cases.
|
|
TEST(FuturesTest, Error) {
|
|
auto f1 = c10::make_intrusive<Future>(IntType::get());
|
|
int sat1 = 0;
|
|
int sat2 = 0;
|
|
f1->addCallback([&](Future& /* unused */) { ++sat1; });
|
|
f1->setError(
|
|
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
|
|
ASSERT_EQ(sat1, 1);
|
|
ASSERT_TRUE(f1->completed());
|
|
ASSERT_TRUE(f1->hasError());
|
|
ASSERT_FALSE(f1->hasValue());
|
|
try {
|
|
(void)f1->value();
|
|
ASSERT_TRUE(false); // Supposed to throw.
|
|
} catch (const std::exception& e) {
|
|
ASSERT_TRUE(strcmp(e.what(), "Failed") == 0);
|
|
}
|
|
f1->addCallback([&](Future& /* unused */) { ++sat2; });
|
|
ASSERT_EQ(sat1, 1);
|
|
ASSERT_EQ(sat2, 1);
|
|
f1->setErrorIfNeeded(
|
|
std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup")));
|
|
ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0);
|
|
ASSERT_EQ(sat1, 1);
|
|
ASSERT_EQ(sat2, 1);
|
|
try {
|
|
(void)f1->constValue();
|
|
ASSERT_TRUE(false); // Supposed to throw.
|
|
} catch (const std::exception& e) {
|
|
// Original error should be logged.
|
|
ASSERT_TRUE(std::string(e.what()).find("Failed") != std::string::npos);
|
|
}
|
|
}
|
|
|
|
// then
|
|
TEST(FuturesTest, Then) {
|
|
auto f1 = c10::make_intrusive<Future>(IntType::get());
|
|
auto f2 = f1->then(
|
|
[](Future& f1) -> IValue { return f1.constValue().toInt() + 1; },
|
|
IntType::get());
|
|
auto f3 = f2->then(
|
|
[](Future& f2) -> IValue { return f2.constValue().toInt() * 3; },
|
|
IntType::get());
|
|
bool done = false;
|
|
f3->addCallback([&done](Future& f3) {
|
|
ASSERT_EQ(f3.constValue().toInt(), (42 + 1) * 3);
|
|
done = true;
|
|
});
|
|
ASSERT_FALSE(done);
|
|
f1->markCompleted(42);
|
|
ASSERT_TRUE(done);
|
|
}
|
|
|
|
// collectAll()
|
|
TEST(FuturesTest, CollectAll) {
|
|
auto s1 = c10::make_intrusive<Future>(IntType::get());
|
|
auto s2 = c10::make_intrusive<Future>(IntType::get());
|
|
auto s3 = c10::make_intrusive<Future>(IntType::get());
|
|
|
|
// Empty case
|
|
c10::List<intrusive_ptr<ivalue::Future>> futures(
|
|
FutureType::create(IntType::get()));
|
|
auto c1 = collectAll(futures);
|
|
ASSERT_TRUE(c1->completed());
|
|
ASSERT_EQ(c1->value().toList().size(), 0);
|
|
ASSERT_TRUE(
|
|
*(c1->value().toList().elementType()) ==
|
|
*FutureType::create(IntType::get()));
|
|
|
|
// 1-element, initially not completed.
|
|
futures.push_back(s1);
|
|
auto c2 = collectAll(futures);
|
|
ASSERT_FALSE(c2->completed());
|
|
s1->markCompleted(5);
|
|
ASSERT_TRUE(c2->completed());
|
|
ASSERT_EQ(c2->value().toList().size(), 1);
|
|
ASSERT_TRUE(
|
|
*(c2->value().toList().elementType()) ==
|
|
*FutureType::create(IntType::get()));
|
|
ASSERT_EQ(c2->value().toList().get(0).toFuture()->value().toInt(), 5);
|
|
|
|
// 1-element, already completed
|
|
auto c3 = collectAll(futures);
|
|
ASSERT_TRUE(c3->completed());
|
|
ASSERT_EQ(c3->value().toList().size(), 1);
|
|
ASSERT_EQ(c3->value().toList().get(0).toFuture()->value().toInt(), 5);
|
|
|
|
// 3 elements.
|
|
futures.push_back(s2);
|
|
futures.push_back(s3);
|
|
auto c4 = collectAll(futures);
|
|
ASSERT_FALSE(c4->completed());
|
|
s3->markCompleted(7);
|
|
ASSERT_FALSE(c4->completed());
|
|
s2->markCompleted(6);
|
|
ASSERT_TRUE(c4->completed());
|
|
ASSERT_EQ(c4->value().toList().size(), 3);
|
|
ASSERT_EQ(c4->value().toList().get(0).toFuture()->value().toInt(), 5);
|
|
ASSERT_EQ(c4->value().toList().get(1).toFuture()->value().toInt(), 6);
|
|
ASSERT_EQ(c4->value().toList().get(2).toFuture()->value().toInt(), 7);
|
|
ASSERT_TRUE(
|
|
*(c4->value().toList().elementType()) ==
|
|
*FutureType::create(IntType::get()));
|
|
|
|
// Handle exception in the list.
|
|
auto s4 = c10::make_intrusive<Future>(IntType::get());
|
|
futures.push_back(s4);
|
|
auto c5 = collectAll(futures);
|
|
ASSERT_FALSE(c5->completed());
|
|
s4->setError(
|
|
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
|
|
ASSERT_TRUE(c5->completed());
|
|
try {
|
|
c5->value();
|
|
ASSERT_TRUE(false); // supposed to throw
|
|
} catch (const std::exception& e) {
|
|
ASSERT_EQ(std::string(e.what()), "Failed");
|
|
}
|
|
}
|
|
|
|
// collectAny()
|
|
TEST(FuturesTest, CollectAny) {
|
|
auto s1 = c10::make_intrusive<Future>(IntType::get());
|
|
|
|
// Empty case
|
|
c10::List<intrusive_ptr<ivalue::Future>> futures(
|
|
FutureType::create(IntType::get()));
|
|
auto c1 = collectAny(futures);
|
|
ASSERT_TRUE(c1->completed());
|
|
|
|
// 1 element, not yet satisfied
|
|
futures.push_back(s1);
|
|
auto c2 = collectAny(futures);
|
|
ASSERT_FALSE(c2->completed());
|
|
s1->markCompleted(5);
|
|
ASSERT_TRUE(c2->completed());
|
|
ASSERT_TRUE(c2->value().isInt());
|
|
ASSERT_EQ(c2->value().toInt(), 5);
|
|
|
|
// 1 element already satisfied.
|
|
auto c3 = collectAny(futures);
|
|
ASSERT_TRUE(c3->completed());
|
|
ASSERT_TRUE(c3->value().isInt());
|
|
ASSERT_EQ(c3->value().toInt(), 5);
|
|
|
|
// 2 elements
|
|
futures.clear();
|
|
auto s2 = c10::make_intrusive<Future>(IntType::get());
|
|
auto s3 = c10::make_intrusive<Future>(IntType::get());
|
|
futures.push_back(s2);
|
|
futures.push_back(s3);
|
|
auto c4 = collectAny(futures);
|
|
ASSERT_FALSE(c4->completed());
|
|
s3->markCompleted(7);
|
|
ASSERT_TRUE(c4->completed());
|
|
ASSERT_EQ(c4->value().toInt(), 7);
|
|
s2->markCompleted(1);
|
|
ASSERT_EQ(c4->value().toInt(), 7);
|
|
}
|
|
|
|
TEST(TLSFutureCallbacksTest, Basic) {
|
|
// cb that verifies the profiler is enabled
|
|
auto profilerEnabledCb = [](Future& /* unused */) {
|
|
ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
|
|
};
|
|
// test running callbacks with propagation of TLS state.
|
|
{
|
|
// Enable the profiler in this thread
|
|
torch::autograd::profiler::enableProfilerLegacy(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::CPU, false, false));
|
|
auto s1 = c10::make_intrusive<Future>(IntType::get());
|
|
s1->addCallback(wrapPropagateTLSState(profilerEnabledCb));
|
|
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
|
|
// Since we join here, we can ensure that all callbacks corresponding to
|
|
// markCompleted() have finished.
|
|
t.join();
|
|
torch::autograd::profiler::disableProfilerLegacy();
|
|
}
|
|
// then() with TLS State
|
|
{
|
|
// Enable the profiler in this thread
|
|
torch::autograd::profiler::enableProfilerLegacy(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::CPU, false, false));
|
|
auto s1 = c10::make_intrusive<Future>(IntType::get());
|
|
auto s2 = s1->then(
|
|
wrapPropagateTLSState([&profilerEnabledCb](Future& s1) {
|
|
profilerEnabledCb(s1);
|
|
return at::IValue(1);
|
|
}),
|
|
IntType::get());
|
|
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
|
|
t.join();
|
|
s2->wait();
|
|
torch::autograd::profiler::disableProfilerLegacy();
|
|
}
|
|
}
|
|
|
|
TEST(ProfilerDisableInCallbackTest, Basic) {
|
|
// cb that verifies the profiler is enabled
|
|
auto profilerEnabledCb = []() {
|
|
ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
|
|
};
|
|
torch::autograd::profiler::enableProfilerLegacy(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::CPU, false, false));
|
|
auto s1 = c10::make_intrusive<Future>(IntType::get());
|
|
auto verifyProfilerCb =
|
|
wrapPropagateTLSState([&profilerEnabledCb](Future& /* unused */) {
|
|
// Ensure the profiler is still enabled in this thread.
|
|
profilerEnabledCb();
|
|
auto t1 = torch::ones({2, 2});
|
|
auto t2 = torch::ones({2, 2});
|
|
torch::add(t1, t2);
|
|
// Don't cleanup TLSState, and just consolidate.
|
|
auto opts =
|
|
torch::autograd::profiler::ProfilerDisableOptions(false, true);
|
|
auto thread_event_lists =
|
|
// NOLINTNEXTLINE(performance-move-const-arg)
|
|
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
|
|
// Ensure that the events from this thread are still profiled and we
|
|
// obtain the expected in events in our consolidated list when calling
|
|
// disableProfilerLegacy().
|
|
bool found_ones = false;
|
|
bool found_add = false;
|
|
for (const auto& li : thread_event_lists) {
|
|
for (const auto& evt : li) {
|
|
if (strcmp(evt.name(), "aten::add") == 0) {
|
|
found_add = true;
|
|
} else if (strcmp(evt.name(), "aten::ones") == 0) {
|
|
found_ones = true;
|
|
}
|
|
}
|
|
if (found_add && found_ones) {
|
|
break;
|
|
}
|
|
}
|
|
ASSERT_TRUE(found_ones);
|
|
ASSERT_TRUE(found_add);
|
|
});
|
|
|
|
s1->addCallback(verifyProfilerCb);
|
|
// Disable the profiler, but do not consolidate results in the main thread.
|
|
auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
|
|
// NOLINTNEXTLINE(performance-move-const-arg)
|
|
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
|
|
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); });
|
|
t.join();
|
|
|
|
// Similar to above test, but verifies correctness in the case where
|
|
// continuation runs on the main thread.
|
|
torch::autograd::profiler::enableProfilerLegacy(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::CPU, false, false));
|
|
s1 = c10::make_intrusive<Future>(IntType::get());
|
|
s1->addCallback(verifyProfilerCb);
|
|
// Runs callback inline
|
|
s1->markCompleted(at::IValue(1));
|
|
opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
|
|
// NOLINTNEXTLINE(performance-move-const-arg)
|
|
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
|
|
}
|
|
|
|
TEST(RecordDebugHandles, Basic) {
|
|
GTEST_SKIP() << "Test is flaky and sometimes hangs on CI. ";
|
|
// Enable the profiler in this thread
|
|
const std::set<torch::autograd::profiler::ActivityType> activities(
|
|
{torch::autograd::profiler::ActivityType::CPU});
|
|
torch::autograd::profiler::prepareProfiler(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::KINETO, false, false),
|
|
activities);
|
|
torch::autograd::profiler::enableProfiler(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::KINETO, false, false),
|
|
activities);
|
|
{
|
|
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
|
|
float x{5.9999}, y{2.1212};
|
|
float z = x / y;
|
|
(void)z;
|
|
}
|
|
{
|
|
RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
|
|
float x{5.9999}, y{2.1212};
|
|
float z = x / y;
|
|
(void)z;
|
|
}
|
|
auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
|
|
const auto& kineto_events = profiler_results_ptr->events();
|
|
size_t my_events{0};
|
|
for (const auto& e : kineto_events) {
|
|
if (e.name() == "my_function") {
|
|
ASSERT_EQ(e.debugHandle(), 42);
|
|
my_events++;
|
|
} else if (e.name() == "not_my_function") {
|
|
ASSERT_EQ(e.debugHandle(), -1);
|
|
my_events++;
|
|
}
|
|
}
|
|
ASSERT_EQ(my_events, 2);
|
|
}
|
|
|
|
TEST(RecordDebugHandles, ScopedCallbacks) {
|
|
// Enable the profiler in this thread
|
|
torch::autograd::profiler::prepareProfiler(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::KINETO, false, false),
|
|
{torch::autograd::profiler::ActivityType::CPU});
|
|
torch::autograd::profiler::enableProfiler(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::KINETO, false, false),
|
|
{torch::autograd::profiler::ActivityType::CPU});
|
|
|
|
{
|
|
auto a = torch::rand({128, 128});
|
|
auto b = torch::rand({128, 128});
|
|
auto c = a + b;
|
|
}
|
|
auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
|
|
ASSERT_TRUE(profiler_results_ptr->events().size() > 0);
|
|
|
|
// Enable the profiler in this thread
|
|
torch::autograd::profiler::prepareProfiler(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::KINETO, false, false),
|
|
{torch::autograd::profiler::ActivityType::CPU});
|
|
torch::autograd::profiler::enableProfiler(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::KINETO, false, false),
|
|
{torch::autograd::profiler::ActivityType::CPU},
|
|
{at::RecordScope::LITE_INTERPRETER});
|
|
{
|
|
auto a = torch::rand({128, 128});
|
|
auto b = torch::rand({128, 128});
|
|
auto c = a + b;
|
|
}
|
|
profiler_results_ptr = torch::autograd::profiler::disableProfiler();
|
|
ASSERT_TRUE(profiler_results_ptr->events().size() == 0);
|
|
|
|
torch::autograd::profiler::prepareProfiler(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::KINETO, false, false),
|
|
{torch::autograd::profiler::ActivityType::CPU});
|
|
torch::autograd::profiler::enableProfiler(
|
|
torch::autograd::profiler::ProfilerConfig(
|
|
torch::autograd::profiler::ProfilerState::KINETO, false, false),
|
|
{torch::autograd::profiler::ActivityType::CPU},
|
|
{at::RecordScope::LITE_INTERPRETER});
|
|
{
|
|
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
|
|
auto a = torch::rand({128, 128});
|
|
auto b = torch::rand({128, 128});
|
|
auto c = a + b;
|
|
}
|
|
{
|
|
RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
|
|
auto a = torch::rand({128, 128});
|
|
auto b = torch::rand({128, 128});
|
|
auto c = a + b;
|
|
}
|
|
profiler_results_ptr = torch::autograd::profiler::disableProfiler();
|
|
const auto& kineto_events = profiler_results_ptr->events();
|
|
for (const auto& e : kineto_events) {
|
|
if (e.name() == "my_function") {
|
|
ASSERT_EQ(e.debugHandle(), 42);
|
|
}
|
|
}
|
|
ASSERT_TRUE(profiler_results_ptr->events().size() == 1);
|
|
}
|
|
|
|
TEST(IValueKWargsTest, Basic) {
|
|
const auto text = R"(
|
|
def foo(a : int, b : int, c : int = 4):
|
|
return a + 2*b + 3*c
|
|
)";
|
|
auto cu = compile(text);
|
|
auto result = cu->get_function("foo")({1}, {{"b", 3}});
|
|
ASSERT_EQ(result.toInt(), 19);
|
|
}
|
|
|
|
TEST(ComputeFlopsTest, Basic) {
|
|
uint64_t flops = 0;
|
|
|
|
// Test unknown operator
|
|
std::unordered_map<std::string, c10::IValue> extra_args;
|
|
flops = torch::profiler::impl::computeFlops(
|
|
std::string("aten::unknown"), extra_args);
|
|
ASSERT_EQ(flops, 0);
|
|
|
|
// Test aten::conv2d
|
|
extra_args.clear();
|
|
std::vector<int64_t> input_size = {4, 5, 6, 7};
|
|
std::vector<int64_t> weight_size = {3, 5, 2, 1};
|
|
std::vector<int64_t> padding = {1, 0};
|
|
std::vector<int64_t> stride = {1, 1};
|
|
std::vector<int64_t> dilation = {0, 0};
|
|
extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
|
|
extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
|
|
extra_args["groups"] = 1;
|
|
extra_args["padding"] = at::IValue(at::IntArrayRef(padding));
|
|
extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
|
|
extra_args["dilation"] = at::IValue(at::IntArrayRef(dilation));
|
|
flops = torch::profiler::impl::computeFlops(
|
|
std::string("aten::conv2d"), extra_args);
|
|
ASSERT_EQ(flops, 13440);
|
|
|
|
// Test aten::conv2d fail
|
|
input_size = {4, 5, 6, 7};
|
|
weight_size = {4, 5, 6};
|
|
extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
|
|
extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
|
|
flops = torch::profiler::impl::computeFlops(
|
|
std::string("aten::conv2d"), extra_args);
|
|
ASSERT_EQ(flops, 0);
|
|
|
|
// Test aten::conv2d fail 2
|
|
weight_size = {3, 5, 2, 1};
|
|
stride = {0, 0};
|
|
extra_args["weight_size"] = at::IValue(at::IntArrayRef(input_size));
|
|
extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
|
|
flops = torch::profiler::impl::computeFlops(
|
|
std::string("aten::conv2d"), extra_args);
|
|
ASSERT_EQ(flops, 0);
|
|
|
|
// Test aten::conv2d fail 3
|
|
extra_args.clear();
|
|
input_size = {4, 5, 6, 7};
|
|
extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
|
|
flops = torch::profiler::impl::computeFlops(
|
|
std::string("aten::conv2d"), extra_args);
|
|
ASSERT_EQ(flops, 0);
|
|
|
|
// Test aten::mm
|
|
extra_args.clear();
|
|
std::vector<int64_t> mat1_sizes = {3, 4, 5, 6};
|
|
std::vector<int64_t> mat2_sizes = {6, 5, 4, 3};
|
|
extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
|
|
extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
|
|
flops =
|
|
torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
|
|
ASSERT_EQ(flops, 43200);
|
|
|
|
// Test aten::addmm
|
|
flops = torch::profiler::impl::computeFlops(
|
|
std::string("aten::addmm"), extra_args);
|
|
ASSERT_EQ(flops, 43200);
|
|
|
|
// Test aten::bmm
|
|
extra_args.clear();
|
|
mat1_sizes = {7, 5, 6};
|
|
mat2_sizes = {7, 6, 3};
|
|
extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
|
|
extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
|
|
flops =
|
|
torch::profiler::impl::computeFlops(std::string("aten::bmm"), extra_args);
|
|
ASSERT_EQ(flops, 1260);
|
|
|
|
// Test aten::baddbmm
|
|
flops = torch::profiler::impl::computeFlops(
|
|
std::string("aten::baddbmm"), extra_args);
|
|
ASSERT_EQ(flops, 1260);
|
|
|
|
// Test mm out of range
|
|
extra_args.clear();
|
|
flops =
|
|
torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
|
|
ASSERT_EQ(flops, 0);
|
|
|
|
// Test aten::add.Tensor
|
|
extra_args.clear();
|
|
std::vector<int64_t> mat_sizes = {3, 4, 5, 6};
|
|
extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
|
|
flops =
|
|
torch::profiler::impl::computeFlops(std::string("aten::add"), extra_args);
|
|
ASSERT_EQ(flops, 360);
|
|
|
|
// Test aten::mul.Tensor
|
|
extra_args.clear();
|
|
mat_sizes = {3, 4, 5, 6};
|
|
extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
|
|
flops =
|
|
torch::profiler::impl::computeFlops(std::string("aten::mul"), extra_args);
|
|
ASSERT_EQ(flops, 360);
|
|
}
|
|
|
|
TEST(TestConstant, TensorGrad) {
|
|
auto graph = std::make_shared<Graph>();
|
|
IValue ten = torch::randn({3, 5}).requires_grad_(true);
|
|
auto con = tryInsertConstant(*graph, ten);
|
|
ASSERT_TRUE(con == std::nullopt);
|
|
}
|
|
|
|
TEST(TestMutation, Basic) {
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x.1 : Tensor):
|
|
%2 : int = prim::Constant[value=1]()
|
|
%9 : int = prim::Constant[value=4]()
|
|
%x.3 : Tensor = aten::add(%x.1, %2, %2)
|
|
%7 : Tensor = aten::add_(%x.3, %2, %2)
|
|
%y.1 : Tensor = aten::add(%x.3, %9, %2)
|
|
return (%y.1))IR",
|
|
&*graph,
|
|
vmap);
|
|
RemoveTensorMutation(graph, [](Node*) { return false; });
|
|
testing::FileCheck().check("aten::add_")->run(*graph);
|
|
RemoveTensorMutation(graph, [](Node*) { return true; });
|
|
testing::FileCheck().check_not("aten::add_")->run(*graph);
|
|
}
|
|
|
|
TEST(TestInplaceToFunctionalActivation, Basic) {
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x.1 : Tensor):
|
|
%2 : int = prim::Constant[value=1]()
|
|
%x.3 : Tensor = aten::add(%x.1, %2, %2)
|
|
%y : Tensor = aten::relu_(%x.3)
|
|
return (%y))IR",
|
|
&*graph,
|
|
vmap);
|
|
InplaceToFunctionalActivation(graph);
|
|
testing::FileCheck().check("aten::relu")->run(*graph);
|
|
testing::FileCheck().check_not("aten::relu_")->run(*graph);
|
|
}
|
|
|
|
TEST(TestRegisterShapeOp, Basic) {
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph():
|
|
%2 : int = prim::Constant[value=5]()
|
|
%3: int[] = prim::ListConstruct(%2, %2)
|
|
return (%3))IR",
|
|
&*graph,
|
|
vmap);
|
|
|
|
auto g2 = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph():
|
|
%2 : Tensor = prim::MakeTestTensor()
|
|
return (%2))IR",
|
|
&*g2,
|
|
vmap);
|
|
|
|
const FunctionSchema& schema = g2->nodes().begin()->schema();
|
|
torch::jit::RegisterShapeComputeGraphForSchema(schema, graph);
|
|
PropagateShapesOnGraph(g2);
|
|
testing::FileCheck().check("5, 5")->run(*g2);
|
|
}
|
|
|
|
TEST(TestFunctionalToInplaceActivation, Basic) {
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x.1 : Tensor):
|
|
%2 : int = prim::Constant[value=1]()
|
|
%x.3 : Tensor = aten::add(%x.1, %2, %2)
|
|
%y : Tensor = aten::relu(%x.3)
|
|
return (%y))IR",
|
|
&*graph,
|
|
vmap);
|
|
FunctionalToInplaceActivation(graph);
|
|
testing::FileCheck().check("aten::relu_")->run(*graph);
|
|
testing::FileCheck().check_not("aten::relu(")->run(*graph);
|
|
}
|
|
|
|
TEST(TestFunctionExecutor, SimpleExecutorTest) {
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x.1 : Tensor):
|
|
%2 : int = prim::Constant[value=1]()
|
|
%x.3 : Tensor = aten::add(%x.1, %2, %2)
|
|
%y : Tensor = aten::relu(%x.3)
|
|
return (%y))IR",
|
|
&*graph);
|
|
{
|
|
auto func = std::make_unique<GraphFunction>(
|
|
"name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::PROFILING);
|
|
auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
Stack stack = {a};
|
|
func->run(stack);
|
|
auto g = lastExecutedOptimizedGraph();
|
|
testing::FileCheck()
|
|
.check("prim::profile")
|
|
->check("aten::add")
|
|
->check("aten::relu")
|
|
->run(*g);
|
|
}
|
|
{
|
|
auto func = std::make_unique<GraphFunction>(
|
|
"name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::SIMPLE);
|
|
auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
Stack stack = {a};
|
|
func->run(stack);
|
|
auto g = func->getDebugState().graph;
|
|
testing::FileCheck()
|
|
.check_not("prim::profile")
|
|
->check("aten::add")
|
|
->check("aten::relu")
|
|
->run(*g);
|
|
}
|
|
}
|
|
|
|
TEST(TestFunctionExecutor, RunDecompositionTest) {
|
|
static auto* func = torch::jit::GetDecompositionExecutor(
|
|
"aten::var(Tensor self, bool unbiased=True) -> Tensor");
|
|
for (bool unbiased : {true, false}) {
|
|
auto input = at::rand({4, 4});
|
|
Stack stack = {input, unbiased};
|
|
func->run(stack);
|
|
at::Tensor out = pop(stack).toTensor();
|
|
ASSERT_TRUE(at::allclose(out, input.var(unbiased)));
|
|
}
|
|
}
|
|
|
|
TEST(TestShapeGraphLinting, Basic) {
|
|
auto schemas = RegisteredShapeComputeSchemas();
|
|
for (const auto& schema : schemas) {
|
|
// arange does not actually support complex, leave as
|
|
// union[int, float] for now
|
|
if (schema->name() == "aten::arange") {
|
|
continue;
|
|
}
|
|
auto g = shapeComputeGraphForSchema(*schema);
|
|
TORCH_INTERNAL_ASSERT(g);
|
|
LintShapeComputeGraph(schema, *g);
|
|
}
|
|
}
|
|
|
|
// TODO: move to test_kernel when global settings are explicit
|
|
// fusion parameters
|
|
class Composed : public ::testing::Test {
|
|
public:
|
|
void SetUp() override {
|
|
torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
|
|
}
|
|
};
|
|
|
|
TEST_F(Composed, ComposedOp) {
|
|
struct WithCPUFuser {
|
|
WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
|
|
overrideCanFuseOnCPU(val);
|
|
}
|
|
|
|
~WithCPUFuser() {
|
|
overrideCanFuseOnCPU(cpuFuserEnabled);
|
|
}
|
|
|
|
bool cpuFuserEnabled;
|
|
};
|
|
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
const auto graph_string = R"IR(
|
|
graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
|
|
%1 : Float(5, 3, strides=[1, 5], device=cpu)):
|
|
%2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
|
|
%3 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %2)
|
|
%4 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %3)
|
|
return (%3, %4))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
// wrong input sizes so we hit the fallback path
|
|
auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto b = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat))
|
|
.transpose(0, 1);
|
|
auto ref1 = a * (a * b);
|
|
auto ref2 = a * ref1;
|
|
WithCPUFuser g(true);
|
|
bool fusable_on_device = torch::jit::tensorexpr::getTEMustUseLLVMOnCPU();
|
|
torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
|
|
FuseTensorExprs(
|
|
graph,
|
|
/*min_group_size*/ 2,
|
|
/*add_composed_op*/ true,
|
|
/*fuse_to_dynamic_shapes*/ true);
|
|
Code code(graph, "");
|
|
InterpreterState interpreter{code};
|
|
std::vector<IValue> stack = {a, b};
|
|
interpreter.run(stack);
|
|
at::Tensor out2 = pop(stack).toTensor();
|
|
at::Tensor out1 = pop(stack).toTensor();
|
|
ASSERT_TRUE(at::allclose(ref1, out1));
|
|
ASSERT_TRUE(at::allclose(ref2, out2));
|
|
|
|
auto inp_1 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto inp_2 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
stack = {inp_1, inp_2, a, b};
|
|
InterpreterState interpreter2{code};
|
|
interpreter2.run(stack);
|
|
out2 = pop(stack).toTensor();
|
|
out1 = pop(stack).toTensor();
|
|
ASSERT_TRUE(at::allclose(ref1, out1));
|
|
ASSERT_TRUE(at::allclose(ref2, out2));
|
|
// inp_1 is on the bottom of the stack, and corresponds
|
|
// to the second output. inp_2 is on the top corresponds to first output
|
|
ASSERT_TRUE(at::allclose(inp_1, ref2));
|
|
ASSERT_TRUE(at::allclose(inp_2, ref1));
|
|
torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = fusable_on_device;
|
|
#endif
|
|
}
|
|
|
|
TEST(ConstantPropagation, CustomClassesCanBePropagated) {
|
|
#ifdef USE_PYTORCH_QNNPACK
|
|
const auto src = R"IR(
|
|
graph():
|
|
%none: NoneType = prim::Constant()
|
|
%dim: int = prim::Constant[value=3]()
|
|
%shape: int[] = prim::ListConstruct(%dim, %dim)
|
|
%weight: Tensor = aten::ones(%shape, %none, %none, %none, %none)
|
|
%scale: float = prim::Constant[value=1.]()
|
|
%zero_point: int = prim::Constant[value=0]()
|
|
%dtype: int = prim::Constant[value=12]()
|
|
%weight_q: Tensor = aten::quantize_per_tensor(%weight, %scale, %zero_point, %dtype)
|
|
%params: __torch__.torch.classes.quantized.LinearPackedParamsBase = quantized::linear_prepack(%weight_q, %none)
|
|
return (%params)
|
|
)IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(src, graph.get(), vmap);
|
|
|
|
ConstantPropagation(graph);
|
|
|
|
testing::FileCheck().check_not("quantized::linear_prepack")->run(*graph);
|
|
#endif
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|