mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Codemod][AddExplicitStrictExportForTrainingInferenceArg] caffe2/ (#149595)
internal diff: D71497480 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149595 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
8878289f89
commit
1ab6c4ff04
@ -173,7 +173,7 @@ class TestMatcher(JitTestCase):
|
||||
torch.randn(3, 3, 3, 3),
|
||||
)
|
||||
pattern_gm = export_for_training(
|
||||
WrapperModule(pattern), example_inputs
|
||||
WrapperModule(pattern), example_inputs, strict=True
|
||||
).module()
|
||||
before_split_res = pattern_gm(*example_inputs)
|
||||
pattern_gm, _ = _split_to_graph_and_name_node_map(pattern_gm)
|
||||
@ -204,11 +204,11 @@ class TestMatcher(JitTestCase):
|
||||
torch.randn(3, 3, 3, 3),
|
||||
)
|
||||
pattern_gm = export_for_training(
|
||||
WrapperModule(pattern), example_inputs
|
||||
WrapperModule(pattern), example_inputs, strict=True
|
||||
).module()
|
||||
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
|
||||
target_gm = export_for_training(
|
||||
WrapperModule(target_graph), example_inputs
|
||||
WrapperModule(target_graph), example_inputs, strict=True
|
||||
).module()
|
||||
internal_matches = matcher.match(target_gm.graph)
|
||||
for internal_match in internal_matches:
|
||||
@ -248,9 +248,11 @@ class TestMatcher(JitTestCase):
|
||||
return linear, {"linear": linear, "x": x}
|
||||
|
||||
example_inputs = (torch.randn(3, 5),)
|
||||
pattern_gm = export_for_training(Pattern(), example_inputs).module()
|
||||
pattern_gm = export_for_training(
|
||||
Pattern(), example_inputs, strict=True
|
||||
).module()
|
||||
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
|
||||
target_gm = export_for_training(M(), example_inputs).module()
|
||||
target_gm = export_for_training(M(), example_inputs, strict=True).module()
|
||||
internal_matches = matcher.match(target_gm.graph)
|
||||
for internal_match in internal_matches:
|
||||
name_node_map = internal_match.name_node_map
|
||||
|
@ -1837,7 +1837,9 @@ class AOTInductorTestsTemplate:
|
||||
with config.patch(
|
||||
{"freezing": True, "aot_inductor.force_mmap_weights": True}
|
||||
), torch.no_grad():
|
||||
exported_model = export_for_training(model, example_inputs).module()
|
||||
exported_model = export_for_training(
|
||||
model, example_inputs, strict=True
|
||||
).module()
|
||||
quantizer = X86InductorQuantizer()
|
||||
quantizer.set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config(reduce_range=True)
|
||||
|
@ -101,10 +101,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
|
||||
|
||||
# program capture
|
||||
m = copy.deepcopy(m_eager)
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
|
@ -98,10 +98,7 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
|
||||
# program capture
|
||||
m = copy.deepcopy(m_eager)
|
||||
m = torch.export.export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
).module()
|
||||
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
|
||||
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
|
@ -81,7 +81,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_simple(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
self._assert_each_node_has_debug_handle(ep)
|
||||
debug_handle_map = self._extract_debug_handles(ep)
|
||||
@ -91,7 +91,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_control_flow(self):
|
||||
m = TestHelperModules.ControlFlow()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
|
||||
self._assert_each_node_has_debug_handle(ep)
|
||||
@ -102,7 +102,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_quantize_pt2e_preserve_handle(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
|
||||
@ -162,14 +162,14 @@ class TestNumericDebugger(TestCase):
|
||||
def test_re_export_preserve_handle(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
|
||||
self._assert_each_node_has_debug_handle(ep)
|
||||
debug_handle_map_ref = self._extract_debug_handles(ep)
|
||||
|
||||
ep_reexport = export_for_training(m, example_inputs)
|
||||
ep_reexport = export_for_training(m, example_inputs, strict=True)
|
||||
|
||||
self._assert_each_node_has_debug_handle(ep_reexport)
|
||||
debug_handle_map = self._extract_debug_handles(ep_reexport)
|
||||
@ -179,7 +179,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_run_decompositions_same_handle_id(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
|
||||
self._assert_each_node_has_debug_handle(ep)
|
||||
@ -204,7 +204,7 @@ class TestNumericDebugger(TestCase):
|
||||
|
||||
for m in test_models:
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
|
||||
self._assert_each_node_has_debug_handle(ep)
|
||||
@ -227,7 +227,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_prepare_for_propagation_comparison(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
m_logger = prepare_for_propagation_comparison(m)
|
||||
@ -244,7 +244,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_extract_results_from_loggers(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
m_ref_logger = prepare_for_propagation_comparison(m)
|
||||
@ -269,7 +269,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_extract_results_from_loggers_list_output(self):
|
||||
m = TestHelperModules.Conv2dWithSplit()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
m_ref_logger = prepare_for_propagation_comparison(m)
|
||||
@ -299,7 +299,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_added_node_gets_unique_id(self) -> None:
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
ref_handles = self._extract_debug_handles(ep)
|
||||
ref_counter = Counter(ref_handles.values())
|
||||
|
@ -767,10 +767,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5))
|
||||
|
||||
# program capture
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, BackendAQuantizer())
|
||||
# make sure the two observers for input are shared
|
||||
conv_output_obs = []
|
||||
@ -830,10 +827,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
)
|
||||
|
||||
# program capture
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
# make sure the two input observers and output are shared
|
||||
@ -1152,10 +1146,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
)
|
||||
|
||||
# program capture
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
quantizer = BackendAQuantizer()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -1305,7 +1296,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
m = M().eval()
|
||||
example_inputs = torch.randn(1, 2, 3, 3)
|
||||
m = export_for_training(m, (example_inputs,)).module()
|
||||
m = export_for_training(m, (example_inputs,), strict=True).module()
|
||||
with self.assertRaises(Exception):
|
||||
m = prepare_pt2e(m, BackendAQuantizer())
|
||||
|
||||
@ -1428,10 +1419,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(2, 2),)
|
||||
m = M().eval()
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
weight_meta = None
|
||||
for n in m.graph.nodes:
|
||||
if (
|
||||
@ -1518,7 +1506,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = M().eval()
|
||||
quantizer = TestQuantizer()
|
||||
example_inputs = (torch.randn(1, 2, 3, 3),)
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
node_occurrence = {
|
||||
@ -1569,7 +1557,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
torch.randn(1, 2, 3, 3),
|
||||
torch.randn(1, 2, 3, 3),
|
||||
)
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
node_occurrence = {
|
||||
@ -1824,7 +1812,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
example_inputs = (torch.randn(1),)
|
||||
m = M().train()
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
if inplace:
|
||||
target = torch.ops.aten.dropout_.default
|
||||
else:
|
||||
@ -1889,7 +1877,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = M().train()
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
|
||||
# Assert that batch norm op exists and is in train mode
|
||||
bn_node = self._get_node(m, bn_train_op)
|
||||
@ -1920,7 +1908,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m.train()
|
||||
|
||||
# After export: this is not OK
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
m.eval()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
@ -1961,7 +1949,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = M().train()
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
|
||||
def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
|
||||
targets = [n.target for n in m.graph.nodes]
|
||||
@ -2027,7 +2015,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
m = M().train()
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
torch.ao.quantization.allow_exported_model_train_eval(m)
|
||||
|
||||
# Mock m.recompile() to count how many times it's been called
|
||||
@ -2059,7 +2047,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
def test_model_is_exported(self):
|
||||
m = TestHelperModules.ConvWithBNRelu(relu=True)
|
||||
example_inputs = (torch.rand(3, 3, 5, 5),)
|
||||
exported_gm = export_for_training(m, example_inputs).module()
|
||||
exported_gm = export_for_training(m, example_inputs, strict=True).module()
|
||||
fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs)
|
||||
self.assertTrue(
|
||||
torch.ao.quantization.pt2e.export_utils.model_is_exported(exported_gm)
|
||||
@ -2077,7 +2065,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
quantizer = XNNPACKQuantizer().set_global(
|
||||
get_symmetric_quantization_config(is_per_channel=True, is_qat=True)
|
||||
)
|
||||
m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module()
|
||||
m.conv_bn_relu = export_for_training(
|
||||
m.conv_bn_relu, example_inputs, strict=True
|
||||
).module()
|
||||
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer)
|
||||
m(*example_inputs)
|
||||
m.conv_bn_relu = convert_pt2e(m.conv_bn_relu)
|
||||
@ -2085,7 +2075,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
quantizer = XNNPACKQuantizer().set_module_type(
|
||||
torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False)
|
||||
)
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m = convert_pt2e(m)
|
||||
|
||||
@ -2257,7 +2247,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
def dynamic_quantize_pt2e(model, example_inputs):
|
||||
torch._dynamo.reset()
|
||||
model = export_for_training(model, example_inputs).module()
|
||||
model = export_for_training(model, example_inputs, strict=True).module()
|
||||
# Per channel quantization for weight
|
||||
# Dynamic quantization for activation
|
||||
# Please read a detail: https://fburl.com/code/30zds51q
|
||||
@ -2360,7 +2350,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
m = M()
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
quantizer = XNNPACKQuantizer().set_global(
|
||||
get_symmetric_quantization_config(),
|
||||
)
|
||||
@ -2442,7 +2432,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
edge_or_node_to_obs_or_fq[x] = new_observer
|
||||
|
||||
example_inputs = (torch.rand(1, 32, 16, 16),)
|
||||
gm = export_for_training(Model().eval(), example_inputs).module()
|
||||
gm = export_for_training(Model().eval(), example_inputs, strict=True).module()
|
||||
gm = prepare_pt2e(gm, BackendAQuantizer())
|
||||
gm = convert_pt2e(gm)
|
||||
for n in gm.graph.nodes:
|
||||
@ -2469,7 +2459,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
"ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1]
|
||||
)
|
||||
|
||||
m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module()
|
||||
m.conv_bn_relu = export_for_training(
|
||||
m.conv_bn_relu, example_inputs, strict=True
|
||||
).module()
|
||||
for node in m.conv_bn_relu.graph.nodes:
|
||||
if node.op not in ["placeholder", "output", "get_attr"]:
|
||||
check_nn_module(node)
|
||||
|
@ -140,8 +140,7 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
)
|
||||
)
|
||||
model_pt2e = export_for_training(
|
||||
model_pt2e,
|
||||
example_inputs,
|
||||
model_pt2e, example_inputs, strict=True
|
||||
).module()
|
||||
model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
@ -229,10 +228,7 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
quantizer.set_global(
|
||||
get_symmetric_quantization_config(is_per_channel, is_qat=True)
|
||||
)
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
|
||||
@ -621,7 +617,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
m = M(self.conv_class, self.bn_class, backbone)
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
@ -679,7 +675,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
def test_qat_conv_bn_bias_derived_qspec(self):
|
||||
m = self._get_conv_bn_model()
|
||||
example_inputs = self.example_inputs
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
quantizer = ConvBnDerivedBiasQuantizer()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -726,7 +722,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
def test_qat_per_channel_weight_custom_dtype(self):
|
||||
m = self._get_conv_bn_model()
|
||||
example_inputs = self.example_inputs
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
quantizer = ConvBnInt32WeightQuantizer()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -780,7 +776,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
def test_qat_conv_bn_per_channel_weight_bias(self):
|
||||
m = self._get_conv_bn_model()
|
||||
example_inputs = self.example_inputs
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True)
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -837,7 +833,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
it into conv in `convert_pt2e` even in train mode.
|
||||
"""
|
||||
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
|
||||
m = export_for_training(m, self.example_inputs).module()
|
||||
m = export_for_training(m, self.example_inputs, strict=True).module()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_global(
|
||||
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
|
||||
@ -1085,7 +1081,9 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
|
||||
in_channels = child.linear1.weight.size(1)
|
||||
|
||||
example_input = (torch.rand((1, in_channels)),)
|
||||
traced_child = export_for_training(child, example_input).module()
|
||||
traced_child = export_for_training(
|
||||
child, example_input, strict=True
|
||||
).module()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=True, is_qat=True
|
||||
@ -1116,10 +1114,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
|
||||
self._convert_qat_linears(model)
|
||||
model(*example_inputs)
|
||||
|
||||
model_pt2e = export_for_training(
|
||||
model,
|
||||
example_inputs,
|
||||
).module()
|
||||
model_pt2e = export_for_training(model, example_inputs, strict=True).module()
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_module_type(torch.nn.Linear, None)
|
||||
|
@ -33,10 +33,7 @@ class TestPT2ERepresentation(QuantizationTestCase):
|
||||
) -> torch.nn.Module:
|
||||
# resetting dynamo cache
|
||||
torch._dynamo.reset()
|
||||
model = export_for_training(
|
||||
model,
|
||||
example_inputs,
|
||||
).module()
|
||||
model = export_for_training(model, example_inputs, strict=True).module()
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
||||
model = prepare_pt2e(model, quantizer)
|
||||
|
@ -665,10 +665,7 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
||||
|
||||
# program capture
|
||||
m = copy.deepcopy(m_eager)
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
|
||||
# QAT Model failed to deepcopy
|
||||
export_model = m if is_qat else copy.deepcopy(m)
|
||||
@ -2344,7 +2341,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
)
|
||||
example_inputs = (torch.randn(2, 2),)
|
||||
m = M().eval()
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Use a linear count instead of names because the names might change, but
|
||||
# the order should be the same.
|
||||
|
@ -361,7 +361,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
|
||||
)
|
||||
example_inputs = (torch.randn(2, 2),)
|
||||
m = M().eval()
|
||||
m = export_for_training(m, example_inputs).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Use a linear count instead of names because the names might change, but
|
||||
# the order should be the same.
|
||||
@ -497,10 +497,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
|
||||
# program capture
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -766,8 +763,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
|
||||
|
||||
with torchdynamo.config.patch(allow_rnn=True):
|
||||
model_graph = export_for_training(
|
||||
model_graph,
|
||||
example_inputs,
|
||||
model_graph, example_inputs, strict=True
|
||||
).module()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
@ -829,8 +825,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
|
||||
|
||||
with torchdynamo.config.patch(allow_rnn=True):
|
||||
model_graph = export_for_training(
|
||||
model_graph,
|
||||
example_inputs,
|
||||
model_graph, example_inputs, strict=True
|
||||
).module()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
@ -1039,10 +1034,7 @@ class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase):
|
||||
m = torchvision.models.resnet18().eval()
|
||||
m_copy = copy.deepcopy(m)
|
||||
# program capture
|
||||
m = export_for_training(
|
||||
m,
|
||||
example_inputs,
|
||||
).module()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
|
@ -27,7 +27,9 @@ class TestQuantizePT2EModels(TestCase):
|
||||
m = m.eval()
|
||||
input_shape = (1, 3, 224, 224)
|
||||
example_inputs = (torch.randn(input_shape),)
|
||||
m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
|
||||
m = torch.export.export_for_training(
|
||||
m, copy.deepcopy(example_inputs), strict=True
|
||||
).module()
|
||||
m(*example_inputs)
|
||||
m = export.export(m, copy.deepcopy(example_inputs))
|
||||
ops = _get_ops_list(m.graph_module)
|
||||
|
@ -355,6 +355,7 @@ def _get_aten_graph_module_for_pattern(
|
||||
pattern, # type: ignore[arg-type]
|
||||
example_inputs,
|
||||
kwargs,
|
||||
strict=True,
|
||||
).module()
|
||||
|
||||
aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
|
||||
|
@ -1003,9 +1003,7 @@ class Pipe(torch.nn.Module):
|
||||
logger.info("Tracing model ...")
|
||||
try:
|
||||
ep = torch.export.export_for_training(
|
||||
mod,
|
||||
example_args,
|
||||
example_kwargs,
|
||||
mod, example_args, example_kwargs, strict=True
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user