#include #include "test/cpp/jit/test_utils.h" #include "torch/csrc/jit/frontend/tracer.h" #include "torch/csrc/jit/passes/common_subexpression_elimination.h" #include "torch/csrc/jit/passes/constant_propagation.h" #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/passes/graph_fuser.h" #include "torch/csrc/jit/passes/lower_grad_of.h" #include "torch/csrc/jit/passes/requires_grad_analysis.h" #include "torch/csrc/jit/passes/shape_analysis.h" #include "torch/csrc/jit/passes/utils/subgraph_utils.h" #include "torch/csrc/jit/runtime/argument_spec.h" #include "torch/csrc/jit/runtime/autodiff.h" #include "torch/csrc/jit/runtime/graph_iterator.h" #include "torch/csrc/jit/runtime/profiling_graph_executor_impl.h" #include "torch/torch.h" #include #include "torch/csrc/autograd/engine.h" #include "torch/csrc/autograd/generated/variable_factories.h" #include "torch/csrc/autograd/variable.h" namespace torch { namespace jit { using namespace torch::autograd; using var_meta_type = std::vector; using var_meta_list = std::vector; using test_fn_type = std::function; struct ADTestSpec { ADTestSpec( const char* name, // NOLINTNEXTLINE(modernize-pass-by-value) var_meta_list input_meta, // NOLINTNEXTLINE(modernize-pass-by-value) test_fn_type test_fn, float clampMax = -1.0f) : name(name), input_meta(input_meta), test_fn(test_fn), clampMax(clampMax) {} variable_list operator()(const variable_list& inputs) const { return test_fn(inputs); }; std::vector make_vars() const { std::vector out; for (const auto& m : input_meta) { if (clampMax > 0.0f) { out.push_back(torch::randn(m, at::requires_grad(true)) .clamp(-clampMax, clampMax)); continue; } out.push_back(torch::randn(m, at::requires_grad(true))); } return out; } const char* name; var_meta_list input_meta; test_fn_type test_fn; float clampMax; }; variable_list get_grad_outputs(const variable_list& vars) { return fmap(vars, [](const Variable& v) -> Variable { return at::randn(v.sizes(), v.options()); }); } variable_list grad( const variable_list& outputs, const variable_list& inputs, const variable_list& grad_outputs) { const auto get_edge = [](const Variable& v) { return torch::autograd::impl::gradient_edge(v); }; auto& engine = torch::autograd::Engine::get_default_engine(); return engine.execute( fmap(outputs, get_edge), grad_outputs, true, false, false, fmap(inputs, get_edge)); } TEST(AutodiffTest, ADFormulas) { const auto cast = [](const Variable& v) { return static_cast(v); }; using VL = variable_list; const var_meta_list binary_pointwise = {{2, 3, 4, 5}, {2, 3, 4, 5}}; const var_meta_list unary_pointwise = {{2, 3, 4, 5}}; const var_meta_list unary_pointwise_2d = {{2, 3}}; const std::vector ad_tests = { {"add", binary_pointwise, [](const VL& v) -> VL { return {v[0] + v[1]}; }}, {"sub", binary_pointwise, [](const VL& v) -> VL { return {v[0] - v[1]}; }}, {"mul", binary_pointwise, [](const VL& v) -> VL { return {v[0] * v[1]}; }}, {"sigmoid", unary_pointwise, [](const VL& v) -> VL { return {v[0].sigmoid()}; }}, // Clamp tanh input tensor values to [-3, 3] // to set a minimum on gradient absolute values {"tanh", unary_pointwise, [](const VL& v) -> VL { return {v[0].tanh()}; }, 3.0f}, {"t", unary_pointwise_2d, [](const VL& v) -> VL { return {v[0].t()}; }}, {"view", unary_pointwise_2d, [](const VL& v) -> VL { return {v[0].view({3, 2})}; }}, {"expand", {{2, 1}}, [](const VL& v) -> VL { return {v[0].expand({2, 3})}; }}, {"mm", {{10, 12}, {12, 15}}, [](const VL& v) -> VL { return {v[0].mm(v[1])}; }}, // TODO: enable once we'll be able to capture lists across // forward-backward //{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return // fmap(v[0].chunk(4, 1)); }}, //{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return // fmap(v[0].chunk(3, 2)); }}, //{"split", {{10, 12, 15}}, [](const VL& v) -> VL { return // fmap(v[0].split(4, 1)); }}, //{"split", {{10, 12, 15}}, [](const VL& v) -> VL { return // fmap(v[0].split(3, 2)); }}, }; for (const auto& test : ad_tests) { // Get reference values form autograd auto vars_in = test.make_vars(); auto vars_out = test(vars_in); auto var_grads_in = get_grad_outputs(vars_out); auto var_grads_out = grad(vars_out, vars_in, var_grads_in); // Trace and differentiate the op auto graph = tracer::trace( fmap(vars_in), [&test](Stack in) -> Stack { auto ivalue_inps = fmap(in, [](const IValue& v) { return Variable(v.toTensor()); }); return fmap(test(ivalue_inps)); }, [](const Variable& var) { return ""; }) .first->graph; EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick ConstantPropagation(graph); auto grad_spec = differentiate(graph); LowerGradOf(*grad_spec.df); // Get outputs from the interpreter auto tensors_in = fmap(vars_in, cast); auto tensor_grads_in = fmap(var_grads_in, cast); auto [tensors_out, tensor_grads_out] = runGradient(grad_spec, tensors_in, tensor_grads_in); // Compare results auto expected_tensors_out = fmap(vars_out, cast); auto expected_tensor_grads_out = fmap(var_grads_out, cast); assertAllClose(tensors_out, expected_tensors_out); assertAllClose(tensor_grads_out, expected_tensor_grads_out); } } TEST(AutodiffTest, Differentiate) { // Note: can't use IRParser for this test due to issue #23989 auto graph = std::make_shared(); std::vector sizes{2, 3, 4}; std::vector strides{12, 4, 1}; const auto type = TensorType::create( at::ScalarType::Float, at::kCPU, c10::VaryingShape{sizes}, c10::VaryingShape{strides}, true); // Builds graph a * b * a + b auto* a = graph->addInput()->setType(type); auto* b = graph->addInput()->setType(type); auto* cOne = graph->insertConstant(1); auto* ab = graph->insertNode(graph->create(aten::mul, /*num_outputs =*/1)); ab->addInput(a); ab->addInput(b); auto* aba = graph->insertNode(graph->create(aten::mul, /*num_outputs =*/1)); aba->addInput(ab->output()); aba->addInput(a); auto* abaplusb = graph->insertNode(graph->create(aten::add, /*num_outputs =*/1)); abaplusb->addInput(aba->output()); abaplusb->addInput(b); abaplusb->addInput(cOne); graph->registerOutput(abaplusb->output()); auto grad_spec = differentiate(graph); std::vector expected_captured_inputs = {0, 1}; std::vector expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7}; std::vector expected_input_vjps = {0, 1}; std::vector expected_output_vjps = {0, 1}; ASSERT_EQ(grad_spec.f_real_outputs, 1); ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs); ASSERT_EQ(grad_spec.df_input_captured_outputs, expected_captured_outputs); ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps); ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps); testing::FileCheck() .check_count("aten::mul", 2) ->check("aten::size") ->check("aten::add") ->run(*grad_spec.f); testing::FileCheck() .check("prim::GradOf[name=\"aten::add\"]") ->check_count("prim::GradOf[name=\"aten::mul\"]", 2) ->check_count("AutogradAdd", 2) ->run(*grad_spec.df); } TEST(AutodiffTest, DifferentiateWithRequiresGrad) { const auto graph_string = R"IR( graph(%0 : Tensor, %1 : Tensor): %2 : int = prim::Constant[value=1]() %3 : Tensor = aten::mul(%1, %1) %4 : Tensor = aten::add(%3, %1, %2) %5 : Tensor = aten::add(%4, %0, %2) %6 : Tensor = aten::mul(%5, %0) %7 : Tensor = aten::add(%6, %1, %2) return (%4, %7))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); auto a_var = autograd::make_variable( at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true); auto b_var = autograd::make_variable( at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false); ArgumentSpecCreator asc(*g); asc.specializeTypes(*g, asc.create(true, {a_var, b_var})); PropagateInputShapes(g); PropagateRequiresGrad(g); auto grad_spec = differentiate(g); std::vector expected_input_vjps = {1, 2}; // for e and %4 = (d + a) std::vector expected_output_vjps = {0}; // only a requires grad ASSERT_EQ(grad_spec.f_real_outputs, 2); ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector({0})); ASSERT_EQ( grad_spec.df_input_captured_outputs, std::vector({2, 3, 4, 5, 6})); ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps); ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps); testing::FileCheck() .check("aten::mul") ->check_count("aten::add", 2) ->check("aten::mul") ->check("aten::size") ->check("aten::add") ->run(*grad_spec.f); testing::FileCheck() .check_count("prim::GradOf[name=\"aten::mul\"]", 1, /*exactly*/ true) ->run(*grad_spec.df); } class AutodiffRemoveUnusedGradientsTest : public ::testing::Test { protected: void SetUp() override { prev_exec = getExecutorMode(); getExecutorMode() = true; prev_inline_autodiff = getAutodiffSubgraphInlining(); debugSetAutodiffSubgraphInlining(false); } void TearDown() override { getExecutorMode() = prev_exec; debugSetAutodiffSubgraphInlining(prev_inline_autodiff); } bool prev_exec; bool prev_profiling; bool prev_inline_autodiff; }; TEST_F(AutodiffRemoveUnusedGradientsTest, Linear) { auto graph = std::make_shared(); const std::string input = R"IR( graph(%inp.1 : Tensor, %weight.1 : Tensor, %bias.1 : Tensor): %6 : Tensor = aten::linear(%inp.1, %weight.1, %bias.1) return (%6))IR"; parseIR(input, graph.get()); auto inp = torch::randn({10, 10}).requires_grad_(false); auto weight = torch::randn({10, 10}).requires_grad_(true); auto bias = torch::randn({1, 10}).requires_grad_(true); auto stack = createStack({inp, weight, bias}); ProfilingGraphExecutorImpl executor(graph, "linear"); // initial run to profile requires_grad information auto plan = executor.getPlanFor(stack, 20); InterpreterState is{plan.code}; is.run(stack); auto optimized_plan = executor.getPlanFor(stack, 20); DepthFirstGraphNodeIterator it(optimized_plan.graph); Node* diff_graph_node = nullptr; while ((diff_graph_node = it.next()) != nullptr) { if (diff_graph_node->kind() == prim::DifferentiableGraph) { break; } } ASSERT_NE(nullptr, diff_graph_node); auto backward_graph = diff_graph_node->g(attr::ReverseSubgraph); // we expect to compute grad_weight (which requires a matmul) but we don't // expect to compute grad_input. So, we expect exactly 1 matmul. // Note: this could change, e.g. if mm is used instead testing::FileCheck().check_count("matmul", 1, true)->run(*backward_graph); } } // namespace jit } // namespace torch