mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Static Runtime] Move PrepackWeights to internal-only graph passes (#87799)
Summary: The pass introduces an `fb::` operator and thus cannot be used in OSS. The test failure was not exposed because the Static Runtime tests have been disabled in OSS for a while. The Dev Infra folks encountered this failure when re-enabling the tests. Test Plan: Existing tests Differential Revision: D40724547 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87799 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
ce7fcab9bd
commit
81c4049f4d
@ -3676,41 +3676,6 @@ TEST(StaticRuntime, ClampNaNToNum) {
|
||||
testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, PrepackWeights) {
|
||||
const std::string src = R"IR(
|
||||
graph(%input: Tensor, %weight: Tensor, %bias: Tensor?, %scale: Tensor, %zero_point: Tensor):
|
||||
%none: NoneType = prim::Constant()
|
||||
%result: Tensor = fb::quantized_linear_unpacked_weight_v2(%input, %weight, %bias, %scale, %zero_point)
|
||||
%dequantized: Tensor = aten::dequantize(%result)
|
||||
return (%dequantized)
|
||||
)IR";
|
||||
|
||||
auto graph = getGraphFromIR(src);
|
||||
PrepackWeights(graph);
|
||||
ASSERT_TRUE(graphHasOp(graph, "quantized::linear"));
|
||||
ASSERT_TRUE(graphHasOp(graph, "quantized::linear_prepack"));
|
||||
ASSERT_FALSE(graphHasOp(graph, "fb::quantized_linear_unpacked_weight_v2"));
|
||||
|
||||
auto scale = at::tensor({2}, at::kFloat);
|
||||
auto zero_point = at::tensor({3}, at::kLong);
|
||||
|
||||
auto weight =
|
||||
at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8);
|
||||
auto input =
|
||||
at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8);
|
||||
auto args1 = std::vector<IValue>{input, weight, c10::nullopt, scale, zero_point};
|
||||
|
||||
auto weight_2 =
|
||||
at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8);
|
||||
auto input_2 =
|
||||
at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8);
|
||||
auto bias_2 = torch::randn({3}, torch::kFloat);
|
||||
auto args2 = std::vector<IValue>{input, weight, bias_2, scale, zero_point};
|
||||
|
||||
testStaticRuntime(src, args1);
|
||||
testStaticRuntime(src, args2);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, IfReturningTuple) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, x, y, cond: bool, idx: int):
|
||||
|
@ -172,7 +172,6 @@ void OptimizeGraph(
|
||||
UseVariadicStack(graph);
|
||||
EliminateTrivialEquallySplit(graph);
|
||||
EliminateExtraPermuteOps(graph);
|
||||
PrepackWeights(graph);
|
||||
|
||||
if (opts.enable_out_variant) {
|
||||
UseVariadicOp(
|
||||
@ -199,6 +198,7 @@ void OptimizeGraph(
|
||||
}
|
||||
FuseListUnpack(graph);
|
||||
RemoveUnnecessaryOutputs(graph);
|
||||
PrepackWeights(graph);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user