[Static Runtime] Move PrepackWeights to internal-only graph passes (#87799)

Summary:
The pass introduces an `fb::` operator and thus cannot be used in OSS.

The test failure was not exposed because the Static Runtime tests have been disabled in OSS for a while. The Dev Infra folks encountered this failure when re-enabling the tests.

Test Plan: Existing tests

Differential Revision: D40724547

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87799
Approved by: https://github.com/huydhn
This commit is contained in:
Mike Iovine
2022-10-28 01:28:34 +00:00
committed by PyTorch MergeBot
parent ce7fcab9bd
commit 81c4049f4d
2 changed files with 1 additions and 36 deletions

View File

@ -3676,41 +3676,6 @@ TEST(StaticRuntime, ClampNaNToNum) {
testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true);
}
TEST(StaticRuntime, PrepackWeights) {
const std::string src = R"IR(
graph(%input: Tensor, %weight: Tensor, %bias: Tensor?, %scale: Tensor, %zero_point: Tensor):
%none: NoneType = prim::Constant()
%result: Tensor = fb::quantized_linear_unpacked_weight_v2(%input, %weight, %bias, %scale, %zero_point)
%dequantized: Tensor = aten::dequantize(%result)
return (%dequantized)
)IR";
auto graph = getGraphFromIR(src);
PrepackWeights(graph);
ASSERT_TRUE(graphHasOp(graph, "quantized::linear"));
ASSERT_TRUE(graphHasOp(graph, "quantized::linear_prepack"));
ASSERT_FALSE(graphHasOp(graph, "fb::quantized_linear_unpacked_weight_v2"));
auto scale = at::tensor({2}, at::kFloat);
auto zero_point = at::tensor({3}, at::kLong);
auto weight =
at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8);
auto input =
at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8);
auto args1 = std::vector<IValue>{input, weight, c10::nullopt, scale, zero_point};
auto weight_2 =
at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8);
auto input_2 =
at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8);
auto bias_2 = torch::randn({3}, torch::kFloat);
auto args2 = std::vector<IValue>{input, weight, bias_2, scale, zero_point};
testStaticRuntime(src, args1);
testStaticRuntime(src, args2);
}
TEST(StaticRuntime, IfReturningTuple) {
const auto src = R"JIT(
def forward(self, x, y, cond: bool, idx: int):

View File

@ -172,7 +172,6 @@ void OptimizeGraph(
UseVariadicStack(graph);
EliminateTrivialEquallySplit(graph);
EliminateExtraPermuteOps(graph);
PrepackWeights(graph);
if (opts.enable_out_variant) {
UseVariadicOp(
@ -199,6 +198,7 @@ void OptimizeGraph(
}
FuseListUnpack(graph);
RemoveUnnecessaryOutputs(graph);
PrepackWeights(graph);
#endif
}