mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Since caffe2 and torch have been consolidated, CAFFE2_API should be merged with TORCH_API. Addresses a TODO. Manually edited some references of the removed `CAFFE2_API`: * `CONTRIBUTING.md` * `caffe2/proto/CMakeLists.txt` * `cmake/ProtoBuf.cmake` * `c10/macros/Export.h` * `torch/csrc/WindowsTorchApiMacro.h` Pull Request resolved: https://github.com/pytorch/pytorch/pull/49496 Reviewed By: malfet, samestep Differential Revision: D25600726 Pulled By: janeyx99 fbshipit-source-id: 7e068d959e397ac183c097d7e9a9afeca5ddd782
53 lines
1.5 KiB
C++
53 lines
1.5 KiB
C++
|
|
#pragma once
|
|
|
|
#include "caffe2/core/common.h"
|
|
#include "caffe2/core/transform.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
#include "caffe2/utils/proto_utils.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
/**
|
|
* Common Subexpression Elimination
|
|
*
|
|
* This transforms looks for specific operators (denoted by allowed_ops_),
|
|
* and removes unnecessary repetition of that operator.
|
|
*
|
|
* Consider some operator of X, that reads from blob b_ written to by W.
|
|
* X_a and X_b read the output of X. However, another operator Y, is the same
|
|
* type as X, has the same arguments as X, and reads from the same input b_,
|
|
* written to by W. It's output is the same as X. Y_a, Y_b, and Y_c read from Y.
|
|
*
|
|
* Then, we can eliminate the common subexpressions X and Y, and merge them to
|
|
* Z, where X_a, X_b, Y_a, Y_b, and Y_c all read from Z.
|
|
*
|
|
*
|
|
* TODO(benz): Fix the error to not match nodes that write to external output.
|
|
*/
|
|
class TORCH_API CommonSubexpressionEliminationTransform : public Transform {
|
|
public:
|
|
CommonSubexpressionEliminationTransform() {
|
|
SetPatternMatchType(SORTED_WRT_EXECUTION_ORDER);
|
|
}
|
|
|
|
protected:
|
|
bool PatternRule(
|
|
const transform::Graph& g,
|
|
const std::vector<int>& subgraph,
|
|
int idx) override;
|
|
bool ValidatorRule(
|
|
const transform::Graph& g,
|
|
const std::vector<int>& subgraph) override;
|
|
bool ReplaceRule(const std::vector<int>& subgraph, transform::Graph* g_ptr)
|
|
override;
|
|
|
|
private:
|
|
bool IsAllowed(string op_type) {
|
|
return allowed_ops_.count(op_type);
|
|
}
|
|
std::set<string> allowed_ops_ = {"LearningRate", "FC"};
|
|
};
|
|
|
|
} // namespace caffe2
|