From a4ea16dbc623864ee8a3a9d24e647c707bb41bd8 Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Wed, 25 Mar 2020 14:06:49 -0700 Subject: [PATCH] Put prim ops used in full jit only in a separate file (#35232) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35232 Some prim operators, like profile and fusion, are not used in mobile (at least in short term). They are coupled with JIT code. Put them in a separate file (register_prim_ops_fulljit.cpp). ghstack-source-id: 100807055 Test Plan: buck build //xplat/caffe2:torch Reviewed By: dreiss Differential Revision: D20408827 fbshipit-source-id: 9013093357cf75723ef00c34bbfdb6b7ea40a4cf --- caffe2/CMakeLists.txt | 1 + tools/build_variables.bzl | 1 + torch/csrc/jit/runtime/register_prim_ops.cpp | 238 +------------- .../jit/runtime/register_prim_ops_fulljit.cpp | 308 ++++++++++++++++++ 4 files changed, 311 insertions(+), 237 deletions(-) create mode 100644 torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index b4af7fba6e73..c105f4abdc30 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -441,6 +441,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/runtime/print_handler.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/interface.cpp ${TORCH_SRC_DIR}/csrc/jit/runtime/register_prim_ops.cpp + ${TORCH_SRC_DIR}/csrc/jit/runtime/register_prim_ops_fulljit.cpp ${TORCH_SRC_DIR}/csrc/jit/runtime/register_prim_ops_c10.cpp ${TORCH_SRC_DIR}/csrc/jit/runtime/register_string_ops.cpp ${TORCH_SRC_DIR}/csrc/jit/runtime/register_special_ops.cpp diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 9ce1fb09816f..e4774d2c1eaa 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -154,6 +154,7 @@ libtorch_sources = [ "torch/csrc/jit/passes/freeze_module.cpp", "torch/csrc/jit/runtime/print_handler.cpp", "torch/csrc/jit/runtime/register_prim_ops.cpp", + "torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp", "torch/csrc/jit/runtime/register_prim_ops_c10.cpp", "torch/csrc/jit/runtime/register_string_ops.cpp", "torch/csrc/jit/runtime/register_special_ops.cpp", diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index b3ead317d14f..57023050ebf2 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -283,67 +283,7 @@ int64_t normalizeIndex(int64_t idx, int64_t list_size) { return idx; } -RegisterOperators reg( - {Operator( - prim::profile, - [](const Node* node) -> Operation { - auto callback = node->cast()->getCallback(); - return [callback](Stack& stack) { - callback(stack); - return 0; - }; - }, - aliasAnalysisSpecialCase()), - Operator( - prim::CudaFusionGroup, - [](const Node* node) -> Operation { - const auto key = registerFusion(node); - return [key](Stack& stack) { - RECORD_FUNCTION("CudaFusionGroup", std::vector()); - runFusion(key, stack); - return 0; - }; - }, - aliasAnalysisSpecialCase()), - Operator( - prim::FusionGroup, - [](const Node* node) -> Operation { - const auto key = registerFusion(node); - return [key](Stack& stack) { - RECORD_FUNCTION("FusionGroup", std::vector()); - runFusion(key, stack); - return 0; - }; - }, - aliasAnalysisSpecialCase()), - Operator( - "prim::Guard(Tensor(a) t) -> Tensor(a)", - [](Stack& stack) { - AT_ERROR("Should be replaced by prim::BailOut"); - return 0; - }, - aliasAnalysisFromSchema()), - Operator( - "prim::BailOut(...) -> Tensor(a)", - [](Stack& /* stack */) { - AT_ERROR("prim::BailOut not yet implemented"); // NOLINT - return 0; - }, - aliasAnalysisFromSchema()), - Operator( - "prim::BailoutTemplate() -> int", - [](Stack& stack) { - // TODO: today, we put a single bailout template at the front to - // carry the un-optimized graph for bailout nodes to use. Ideally - // this should never run, but we haven't written the code to remove - // it yet. - // TORCH_INTERNAL_ASSERT(false); - - // Returns an int so that we have an easy way to do graph traversal - push(stack, 1); - return 0; - }, - aliasAnalysisFromSchema()), +RegisterOperators reg({ Operator( "prim::rangelist(int n) -> int[]", [](Stack& stack) { @@ -724,84 +664,6 @@ RegisterOperators reg( return 0; }, aliasAnalysisFromSchema()), - Operator( - "aten::grad(Tensor[] outputs, Tensor[] inputs, Tensor?[]? grad_outputs=None, bool? retain_graph=None, bool create_graph=False, bool allow_unused=False) -> Tensor?[]", - [](Stack& stack) { - bool allow_unused = pop(stack).toBool(); - bool create_graph = pop(stack).toBool(); - auto retain_graph = pop(stack).toOptional(); - auto grad_outputs = pop(stack); - auto inputs = pop(stack).toTensorList(); - auto outputs = pop(stack).toTensorList(); - std::vector input_vars( - inputs.begin(), inputs.end()); - std::vector output_vars( - outputs.begin(), outputs.end()); - std::vector gradients; - - if (!grad_outputs.isNone()) { - for (const IValue& v : grad_outputs.toListRef()) { - gradients.emplace_back(v.isNone() ? at::Tensor() : v.toTensor()); - } - } - - auto res = torch::autograd::grad( - output_vars, - input_vars, - gradients, - retain_graph, - create_graph, - allow_unused); - - c10::impl::GenericList res_list{OptionalType::ofTensor()}; - for (const at::Tensor& t : res) { - res_list.emplace_back(t.defined() ? t : IValue()); - } - push(stack, res_list); - return 0; - }, - aliasAnalysisFromSchema()), - // NB: backward op might write to every input tensors in the graph and it's - // much more expensive to analayze the leaves and sometimes it might retain - // the whole gradients in every tensor of the Autograd graph with - // create_graph=True so we use aliasAnalysisConservative for these two OPs - Operator( - "aten::backward.list(Tensor[](a!) tensors, Tensor?[]? grad_tensors=None, bool? retain_graph=None, bool create_graph=False) -> ()", - [](Stack& stack) { - bool create_graph = pop(stack).toBool(); - auto retain_graph = pop(stack).toOptional(); - auto grad_tensors = pop(stack); - auto outputs = pop(stack).toTensorList(); - std::vector output_vars( - outputs.begin(), outputs.end()); - std::vector gradients; - - if (!grad_tensors.isNone()) { - for (const IValue& v : grad_tensors.toListRef()) { - gradients.emplace_back(v.isNone() ? at::Tensor() : v.toTensor()); - } - } - - torch::autograd::backward( - output_vars, gradients, retain_graph, create_graph); - return 0; - }, - aliasAnalysisConservative()), - Operator( - "aten::backward(Tensor(a!) self, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()", - [](Stack& stack) { - bool create_graph = pop(stack).toBool(); - auto retain_graph = pop(stack).toOptional(); - IValue gradient_ivalue = pop(stack); - at::Tensor gradient = gradient_ivalue.isNone() - ? at::Tensor() - : gradient_ivalue.toTensor(); - at::Tensor self = pop(stack).toTensor(); - bool keep_graph = retain_graph ? retain_graph.value() : create_graph; - self.backward(gradient, keep_graph, create_graph); - return 0; - }, - aliasAnalysisConservative()), Operator( "aten::requires_grad_(Tensor(a!) self, bool _requires_grad=True) -> Tensor(a!)", [](Stack& stack) { @@ -818,41 +680,6 @@ RegisterOperators reg( return 0; }, aliasAnalysisSpecialCase()), - Operator( - "aten::save(t item, str filename) -> ()", - [](Stack& stack) { - auto filename = pop(stack).toStringRef(); - auto ivalue = pop(stack); - - // Pickle the tensor - auto data = jit::pickle_save(ivalue); - - // Write file - std::fstream output(filename, std::ios::out | std::ios::binary); - output.write(data.data(), data.size()); - return 0; - }, - aliasAnalysisFromSchema()), - Operator( - "prim::Print(...) -> ()", - [](Stack& stack) { - auto num_inputs = pop(stack).toInt(); - std::stringstream ss; - bool first = true; - for (const IValue& i : last(stack, num_inputs)) { - if (!first) - ss << " "; - first = false; - ss << i; - } - drop(stack, num_inputs); - ss << std::endl; - auto* handler = getPrintHandler(); - TORCH_INTERNAL_ASSERT(handler); - handler(ss.str()); - return 0; - }, - aliasAnalysisSpecialCase()), Operator( "prim::BroadcastSizes(...) -> int[]", [](Stack& stack) { @@ -906,26 +733,6 @@ RegisterOperators reg( return 0; }, aliasAnalysisFromSchema()), - Operator( - "prim::RaiseException(str msg) -> ()", - [](Stack& stack) { - throw JITException(pop(stack).toStringRef()); - return 0; - }, - aliasAnalysisFromSchema()), - - Operator( - "prim::IgnoredPythonOp(...) -> None", - [](Stack& stack) { - throw JITException( - "This Python function is annotated to be ignored" - " and cannot be and has not been included in the exported" - " binary, meaning that it cannot be executed now." - " Make sure that ignored operations are never executed after" - " import"); - return 0; - }, - aliasAnalysisFromSchema()), Operator( "onnx::Reshape(Tensor input, Tensor shape) -> Tensor", @@ -1205,50 +1012,7 @@ RegisterOperators reg( }, aliasAnalysisSpecialCase())}); -RegisterOperators logging_operators( - {Operator( - "prim::AddStatValue(str key, int val) -> ()", - [](Stack& stack) { - auto val = pop(stack).toInt(); - auto key = pop(stack).toString(); - auto schema = - parseSchema("prim::AddStatValue(str key, int val) -> ()"); - // TODO: remove this custom tracing code once the custom op bugfix - // lands - if (jit::tracer::isTracing()) { - const auto& graph = tracer::getTracingState()->graph; - Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0); - tracer::recordSourceLocation(node); - node->addInput(insertConstant(*graph, key)); - tracer::addInputs(node, "val", val); - graph->insertNode(node); - } - torch::jit::logging::getLogger()->addStatValue(*key, val); - return 0; - }, - aliasAnalysisFromSchema()), - Operator( - "prim::TimePoint() -> int", - [](Stack& stack) { - auto schema = parseSchema("prim::TimePoint() -> int"); - Node* node = nullptr; - // TODO: remove this custom tracing code once the custom op bugfix - // lands - if (jit::tracer::isTracing()) { - const auto& graph = tracer::getTracingState()->graph; - Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0); - tracer::recordSourceLocation(node); - graph->insertNode(node); - } - auto output = autograd::profiler::getTime(); - push(stack, output); - if (jit::tracer::isTracing()) { - jit::tracer::addOutput(node, output); - } - return 0; - }, - aliasAnalysisFromSchema())}); // define implementations for primitive number ops #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \ diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp new file mode 100644 index 000000000000..cb3bdf60560b --- /dev/null +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -0,0 +1,308 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +namespace { + +c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; +} + +c10::AliasAnalysisKind aliasAnalysisConservative() { + return c10::AliasAnalysisKind::CONSERVATIVE; +} + +c10::AliasAnalysisKind aliasAnalysisSpecialCase() { + return c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE; +} + +RegisterOperators reg( + {Operator( + prim::profile, + [](const Node* node) -> Operation { + auto callback = node->cast()->getCallback(); + return [callback](Stack& stack) { + callback(stack); + return 0; + }; + }, + aliasAnalysisSpecialCase()), + Operator( + prim::CudaFusionGroup, + [](const Node* node) -> Operation { + const auto key = registerFusion(node); + return [key](Stack& stack) { + RECORD_FUNCTION("CudaFusionGroup", std::vector()); + runFusion(key, stack); + return 0; + }; + }, + aliasAnalysisSpecialCase()), + Operator( + prim::FusionGroup, + [](const Node* node) -> Operation { + const auto key = registerFusion(node); + return [key](Stack& stack) { + RECORD_FUNCTION("FusionGroup", std::vector()); + runFusion(key, stack); + return 0; + }; + }, + aliasAnalysisSpecialCase()), + Operator( + "prim::Guard(Tensor(a) t) -> Tensor(a)", + [](Stack& stack) { + AT_ERROR("Should be replaced by prim::BailOut"); + return 0; + }, + aliasAnalysisFromSchema()), + Operator( + "prim::BailOut(...) -> Tensor(a)", + [](Stack& /* stack */) { + AT_ERROR("prim::BailOut not yet implemented"); // NOLINT + return 0; + }, + aliasAnalysisFromSchema()), + Operator( + "prim::BailoutTemplate() -> int", + [](Stack& stack) { + // TODO: today, we put a single bailout template at the front to + // carry the un-optimized graph for bailout nodes to use. Ideally + // this should never run, but we haven't written the code to remove + // it yet. + // TORCH_INTERNAL_ASSERT(false); + + // Returns an int so that we have an easy way to do graph traversal + push(stack, 1); + return 0; + }, + aliasAnalysisFromSchema()), + Operator( + "aten::grad(Tensor[] outputs, Tensor[] inputs, Tensor?[]? grad_outputs=None, bool? retain_graph=None, bool create_graph=False, bool allow_unused=False) -> Tensor?[]", + [](Stack& stack) { + bool allow_unused = pop(stack).toBool(); + bool create_graph = pop(stack).toBool(); + auto retain_graph = pop(stack).toOptional(); + auto grad_outputs = pop(stack); + auto inputs = pop(stack).toTensorList(); + auto outputs = pop(stack).toTensorList(); + std::vector input_vars( + inputs.begin(), inputs.end()); + std::vector output_vars(outputs.begin(), outputs.end()); + std::vector gradients; + + if (!grad_outputs.isNone()) { + for (const IValue& v : grad_outputs.toListRef()) { + gradients.emplace_back(v.isNone() ? at::Tensor() : v.toTensor()); + } + } + + auto res = torch::autograd::grad( + output_vars, + input_vars, + gradients, + retain_graph, + create_graph, + allow_unused); + + c10::impl::GenericList res_list{OptionalType::ofTensor()}; + for (const at::Tensor& t : res) { + res_list.emplace_back(t.defined() ? t : IValue()); + } + push(stack, res_list); + return 0; + }, + aliasAnalysisFromSchema()), + // NB: backward op might write to every input tensors in the graph and it's + // much more expensive to analayze the leaves and sometimes it might retain + // the whole gradients in every tensor of the Autograd graph with + // create_graph=True so we use aliasAnalysisConservative for these two OPs + Operator( + "aten::backward(Tensor[](a!) tensors, Tensor?[]? grad_tensors=None, bool? retain_graph=None, bool create_graph=False) -> ()", + [](Stack& stack) { + bool create_graph = pop(stack).toBool(); + auto retain_graph = pop(stack).toOptional(); + auto grad_tensors = pop(stack); + auto outputs = pop(stack).toTensorList(); + std::vector output_vars( + outputs.begin(), outputs.end()); + std::vector gradients; + + if (!grad_tensors.isNone()) { + for (const IValue& v : grad_tensors.toListRef()) { + gradients.emplace_back(v.isNone() ? at::Tensor() : v.toTensor()); + } + } + + torch::autograd::backward( + output_vars, gradients, retain_graph, create_graph); + return 0; + }, + aliasAnalysisConservative()), + Operator( + "aten::backward(Tensor(a!) self, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()", + [](Stack& stack) { + bool create_graph = pop(stack).toBool(); + auto retain_graph = pop(stack).toOptional(); + IValue gradient_ivalue = pop(stack); + at::Tensor gradient = gradient_ivalue.isNone() + ? at::Tensor() + : gradient_ivalue.toTensor(); + at::Tensor self = pop(stack).toTensor(); + bool keep_graph = retain_graph ? retain_graph.value() : create_graph; + self.backward(gradient, keep_graph, create_graph); + return 0; + }, + aliasAnalysisConservative()), + Operator( + "aten::save(t item, str filename) -> ()", + [](Stack& stack) { + auto filename = pop(stack).toStringRef(); + auto ivalue = pop(stack); + + // Pickle the tensor + auto data = jit::pickle_save(ivalue); + + // Write file + std::fstream output(filename, std::ios::out | std::ios::binary); + output.write(data.data(), data.size()); + return 0; + }, + aliasAnalysisFromSchema()), + Operator( + "prim::Print(...) -> ()", + [](Stack& stack) { + auto num_inputs = pop(stack).toInt(); + std::stringstream ss; + bool first = true; + for (const IValue& i : last(stack, num_inputs)) { + if (!first) + ss << " "; + first = false; + ss << i; + } + drop(stack, num_inputs); + ss << std::endl; + auto* handler = getPrintHandler(); + TORCH_INTERNAL_ASSERT(handler); + handler(ss.str()); + return 0; + }, + aliasAnalysisSpecialCase()), + Operator( + "prim::RaiseException(str msg) -> ()", + [](Stack& stack) { + throw JITException(pop(stack).toStringRef()); + return 0; + }, + aliasAnalysisFromSchema()), + Operator( + "prim::IgnoredPythonOp(...) -> None", + [](Stack& stack) { + throw JITException( + "This Python function is annotated to be ignored" + " and cannot be and has not been included in the exported" + " binary, meaning that it cannot be executed now." + " Make sure that ignored operations are never executed after" + " import"); + return 0; + }, + aliasAnalysisFromSchema()), +}); + +RegisterOperators logging_operators( + {Operator( + "prim::AddStatValue(str key, int val) -> ()", + [](Stack& stack) { + auto val = pop(stack).toInt(); + auto key = pop(stack).toString(); + + auto schema = + parseSchema("prim::AddStatValue(str key, int val) -> ()"); + // TODO: remove this custom tracing code once the custom op bugfix + // lands + if (jit::tracer::isTracing()) { + const auto& graph = tracer::getTracingState()->graph; + Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0); + tracer::recordSourceLocation(node); + node->addInput(insertConstant(*graph, key)); + tracer::addInputs(node, "val", val); + graph->insertNode(node); + } + torch::jit::logging::getLogger()->addStatValue(*key, val); + return 0; + }, + aliasAnalysisFromSchema()), + Operator( + "prim::TimePoint() -> int", + [](Stack& stack) { + auto schema = parseSchema("prim::TimePoint() -> int"); + Node* node = nullptr; + // TODO: remove this custom tracing code once the custom op bugfix + // lands + if (jit::tracer::isTracing()) { + const auto& graph = tracer::getTracingState()->graph; + Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0); + tracer::recordSourceLocation(node); + graph->insertNode(node); + } + auto output = autograd::profiler::getTime(); + push(stack, output); + if (jit::tracer::isTracing()) { + jit::tracer::addOutput(node, output); + } + return 0; + }, + aliasAnalysisFromSchema())}); + +} // namespace +} // namespace jit +} // namespace torch