mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[SR] Remove linear/relu fusion
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77620 Apparently, this is not implemented in fbgemm, so it's strictly worse than using NNC. Differential Revision: [D36431811](https://our.internmc.facebook.com/intern/diff/D36431811/) Approved by: https://github.com/hlu1
This commit is contained in:
committed by
PyTorch MergeBot
parent
bb4653e736
commit
2ae3c59e4b
@ -3323,30 +3323,6 @@ TEST(StaticRuntime, NestedBlockIfReturnList) {
|
||||
testStaticRuntime(src, args1, args2);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, QuantizedLinearDynamicFp16ReluFusion) {
|
||||
const auto src = R"IR(
|
||||
graph(%input: Tensor, %weights: Tensor):
|
||||
%bias: None = prim::Constant()
|
||||
%packed_params = quantized::linear_prepack_fp16(%weights, %bias)
|
||||
%x = quantized::linear_dynamic_fp16(%input, %packed_params)
|
||||
%y = aten::relu(%x)
|
||||
%ret = aten::clone(%y, %bias)
|
||||
return (%ret)
|
||||
)IR";
|
||||
at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
|
||||
at::Tensor input = torch::randn({3, 2}, torch::kFloat);
|
||||
|
||||
at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
|
||||
at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
|
||||
|
||||
testStaticRuntime(src, {input, weight}, {input_2, weight_2});
|
||||
|
||||
auto graph = getGraphFromIR(src);
|
||||
QuantizedLinearReluFusion(graph);
|
||||
EXPECT_FALSE(hasNodeWithKind(graph, "quantized::linear_dynamic_fp16"));
|
||||
EXPECT_TRUE(hasNodeWithKind(graph, "quantized::linear_relu_dynamic_fp16"));
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, ClampNaNToNum) {
|
||||
const auto src1 = R"JIT(
|
||||
def forward(self, a):
|
||||
|
Reference in New Issue
Block a user