[SR] Remove linear/relu fusion

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77620

Apparently, this is not implemented in fbgemm, so it's strictly worse than using NNC.

Differential Revision: [D36431811](https://our.internmc.facebook.com/intern/diff/D36431811/)

Approved by: https://github.com/hlu1
This commit is contained in:
mikeiovine
2022-05-23 07:51:21 -07:00
committed by PyTorch MergeBot
parent bb4653e736
commit 2ae3c59e4b
4 changed files with 0 additions and 42 deletions

View File

@ -3323,30 +3323,6 @@ TEST(StaticRuntime, NestedBlockIfReturnList) {
testStaticRuntime(src, args1, args2);
}
TEST(StaticRuntime, QuantizedLinearDynamicFp16ReluFusion) {
const auto src = R"IR(
graph(%input: Tensor, %weights: Tensor):
%bias: None = prim::Constant()
%packed_params = quantized::linear_prepack_fp16(%weights, %bias)
%x = quantized::linear_dynamic_fp16(%input, %packed_params)
%y = aten::relu(%x)
%ret = aten::clone(%y, %bias)
return (%ret)
)IR";
at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
at::Tensor input = torch::randn({3, 2}, torch::kFloat);
at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
testStaticRuntime(src, {input, weight}, {input_2, weight_2});
auto graph = getGraphFromIR(src);
QuantizedLinearReluFusion(graph);
EXPECT_FALSE(hasNodeWithKind(graph, "quantized::linear_dynamic_fp16"));
EXPECT_TRUE(hasNodeWithKind(graph, "quantized::linear_relu_dynamic_fp16"));
}
TEST(StaticRuntime, ClampNaNToNum) {
const auto src1 = R"JIT(
def forward(self, a):

View File

@ -180,7 +180,6 @@ void OptimizeGraph(
graph, /* custom_ops */ {fromQualString("fb::scale_gradient")});
AddIfThenElseOp(graph);
UseSplitAndSqueeze(graph);
QuantizedLinearReluFusion(graph);
GRAPH_DUMP("Final graph after optimizations: ", graph);
}

View File

@ -1349,21 +1349,6 @@ void EliminateNoOpSlice(std::shared_ptr<Graph>& graph) {
}
}
void QuantizedLinearReluFusion(std::shared_ptr<Graph>& graph) {
std::string pattern = R"IR(
graph(%input, %packed_params):
%x : Tensor = quantized::linear_dynamic_fp16(%input, %packed_params)
%y : Tensor = aten::relu(%x)
return (%y))IR";
std::string fused_pattern = R"IR(
graph(%input, %packed_params):
%x : Tensor = quantized::linear_relu_dynamic_fp16(%input, %packed_params)
return (%x))IR";
SubgraphRewriter fuse;
fuse.RegisterRewritePattern(pattern, fused_pattern);
fuse.runOnGraph(graph);
}
void FuseClampNaNToNum(std::shared_ptr<Graph>& graph) {
#ifdef FBCODE_CAFFE2
std::string pattern = R"IR(

View File

@ -80,8 +80,6 @@ TORCH_API void RemoveUnnecessaryOutputs(std::shared_ptr<Graph>& graph);
TORCH_API void RemoveUnnecessaryEmbeddingBagOutputs(
std::shared_ptr<Graph>& graph);
TORCH_API void QuantizedLinearReluFusion(std::shared_ptr<Graph>& graph);
TORCH_API void FuseClampNaNToNum(std::shared_ptr<Graph>& graph);
} // namespace jit