[Codemod][AddExplicitStrictExportForTrainingInferenceArg] caffe2/ (#149595)

internal diff: D71497480

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149595
Approved by: https://github.com/Skylion007
This commit is contained in:
Yanan Cao (PyTorch)
2025-04-03 23:50:13 +00:00
committed by PyTorch MergeBot
parent 8878289f89
commit 1ab6c4ff04
14 changed files with 792 additions and 479 deletions

View File

@ -173,7 +173,7 @@ class TestMatcher(JitTestCase):
torch.randn(3, 3, 3, 3),
)
pattern_gm = export_for_training(
WrapperModule(pattern), example_inputs
WrapperModule(pattern), example_inputs, strict=True
).module()
before_split_res = pattern_gm(*example_inputs)
pattern_gm, _ = _split_to_graph_and_name_node_map(pattern_gm)
@ -204,11 +204,11 @@ class TestMatcher(JitTestCase):
torch.randn(3, 3, 3, 3),
)
pattern_gm = export_for_training(
WrapperModule(pattern), example_inputs
WrapperModule(pattern), example_inputs, strict=True
).module()
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
target_gm = export_for_training(
WrapperModule(target_graph), example_inputs
WrapperModule(target_graph), example_inputs, strict=True
).module()
internal_matches = matcher.match(target_gm.graph)
for internal_match in internal_matches:
@ -248,9 +248,11 @@ class TestMatcher(JitTestCase):
return linear, {"linear": linear, "x": x}
example_inputs = (torch.randn(3, 5),)
pattern_gm = export_for_training(Pattern(), example_inputs).module()
pattern_gm = export_for_training(
Pattern(), example_inputs, strict=True
).module()
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
target_gm = export_for_training(M(), example_inputs).module()
target_gm = export_for_training(M(), example_inputs, strict=True).module()
internal_matches = matcher.match(target_gm.graph)
for internal_match in internal_matches:
name_node_map = internal_match.name_node_map

View File

@ -1837,7 +1837,9 @@ class AOTInductorTestsTemplate:
with config.patch(
{"freezing": True, "aot_inductor.force_mmap_weights": True}
), torch.no_grad():
exported_model = export_for_training(model, example_inputs).module()
exported_model = export_for_training(
model, example_inputs, strict=True
).module()
quantizer = X86InductorQuantizer()
quantizer.set_global(
xiq.get_default_x86_inductor_quantization_config(reduce_range=True)

View File

@ -101,10 +101,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
# program capture
m = copy.deepcopy(m_eager)
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
# Calibrate

View File

@ -98,10 +98,7 @@ class TestMetaDataPorting(QuantizationTestCase):
# program capture
m = copy.deepcopy(m_eager)
m = torch.export.export_for_training(
m,
example_inputs,
).module()
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
# Calibrate

View File

@ -81,7 +81,7 @@ class TestNumericDebugger(TestCase):
def test_simple(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
self._assert_each_node_has_debug_handle(ep)
debug_handle_map = self._extract_debug_handles(ep)
@ -91,7 +91,7 @@ class TestNumericDebugger(TestCase):
def test_control_flow(self):
m = TestHelperModules.ControlFlow()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
self._assert_each_node_has_debug_handle(ep)
@ -102,7 +102,7 @@ class TestNumericDebugger(TestCase):
def test_quantize_pt2e_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
@ -162,14 +162,14 @@ class TestNumericDebugger(TestCase):
def test_re_export_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
self._assert_each_node_has_debug_handle(ep)
debug_handle_map_ref = self._extract_debug_handles(ep)
ep_reexport = export_for_training(m, example_inputs)
ep_reexport = export_for_training(m, example_inputs, strict=True)
self._assert_each_node_has_debug_handle(ep_reexport)
debug_handle_map = self._extract_debug_handles(ep_reexport)
@ -179,7 +179,7 @@ class TestNumericDebugger(TestCase):
def test_run_decompositions_same_handle_id(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
self._assert_each_node_has_debug_handle(ep)
@ -204,7 +204,7 @@ class TestNumericDebugger(TestCase):
for m in test_models:
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
self._assert_each_node_has_debug_handle(ep)
@ -227,7 +227,7 @@ class TestNumericDebugger(TestCase):
def test_prepare_for_propagation_comparison(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_logger = prepare_for_propagation_comparison(m)
@ -244,7 +244,7 @@ class TestNumericDebugger(TestCase):
def test_extract_results_from_loggers(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m)
@ -269,7 +269,7 @@ class TestNumericDebugger(TestCase):
def test_extract_results_from_loggers_list_output(self):
m = TestHelperModules.Conv2dWithSplit()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m)
@ -299,7 +299,7 @@ class TestNumericDebugger(TestCase):
def test_added_node_gets_unique_id(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
ref_handles = self._extract_debug_handles(ep)
ref_counter = Counter(ref_handles.values())

View File

@ -767,10 +767,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5))
# program capture
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, BackendAQuantizer())
# make sure the two observers for input are shared
conv_output_obs = []
@ -830,10 +827,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
)
# program capture
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
# make sure the two input observers and output are shared
@ -1152,10 +1146,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
)
# program capture
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()
quantizer = BackendAQuantizer()
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
@ -1305,7 +1296,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().eval()
example_inputs = torch.randn(1, 2, 3, 3)
m = export_for_training(m, (example_inputs,)).module()
m = export_for_training(m, (example_inputs,), strict=True).module()
with self.assertRaises(Exception):
m = prepare_pt2e(m, BackendAQuantizer())
@ -1428,10 +1419,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()
weight_meta = None
for n in m.graph.nodes:
if (
@ -1518,7 +1506,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().eval()
quantizer = TestQuantizer()
example_inputs = (torch.randn(1, 2, 3, 3),)
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
node_occurrence = {
@ -1569,7 +1557,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
torch.randn(1, 2, 3, 3),
torch.randn(1, 2, 3, 3),
)
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
node_occurrence = {
@ -1824,7 +1812,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1),)
m = M().train()
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
if inplace:
target = torch.ops.aten.dropout_.default
else:
@ -1889,7 +1877,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().train()
example_inputs = (torch.randn(1, 3, 3, 3),)
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
# Assert that batch norm op exists and is in train mode
bn_node = self._get_node(m, bn_train_op)
@ -1920,7 +1908,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m.train()
# After export: this is not OK
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
@ -1961,7 +1949,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().train()
example_inputs = (torch.randn(1, 3, 3, 3),)
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
targets = [n.target for n in m.graph.nodes]
@ -2027,7 +2015,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().train()
example_inputs = (torch.randn(1, 3, 3, 3),)
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
torch.ao.quantization.allow_exported_model_train_eval(m)
# Mock m.recompile() to count how many times it's been called
@ -2059,7 +2047,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def test_model_is_exported(self):
m = TestHelperModules.ConvWithBNRelu(relu=True)
example_inputs = (torch.rand(3, 3, 5, 5),)
exported_gm = export_for_training(m, example_inputs).module()
exported_gm = export_for_training(m, example_inputs, strict=True).module()
fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs)
self.assertTrue(
torch.ao.quantization.pt2e.export_utils.model_is_exported(exported_gm)
@ -2077,7 +2065,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=True, is_qat=True)
)
m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module()
m.conv_bn_relu = export_for_training(
m.conv_bn_relu, example_inputs, strict=True
).module()
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer)
m(*example_inputs)
m.conv_bn_relu = convert_pt2e(m.conv_bn_relu)
@ -2085,7 +2075,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
quantizer = XNNPACKQuantizer().set_module_type(
torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False)
)
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m)
@ -2257,7 +2247,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def dynamic_quantize_pt2e(model, example_inputs):
torch._dynamo.reset()
model = export_for_training(model, example_inputs).module()
model = export_for_training(model, example_inputs, strict=True).module()
# Per channel quantization for weight
# Dynamic quantization for activation
# Please read a detail: https://fburl.com/code/30zds51q
@ -2360,7 +2350,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1, 3, 5, 5),)
m = M()
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(),
)
@ -2442,7 +2432,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
edge_or_node_to_obs_or_fq[x] = new_observer
example_inputs = (torch.rand(1, 32, 16, 16),)
gm = export_for_training(Model().eval(), example_inputs).module()
gm = export_for_training(Model().eval(), example_inputs, strict=True).module()
gm = prepare_pt2e(gm, BackendAQuantizer())
gm = convert_pt2e(gm)
for n in gm.graph.nodes:
@ -2469,7 +2459,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
"ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1]
)
m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module()
m.conv_bn_relu = export_for_training(
m.conv_bn_relu, example_inputs, strict=True
).module()
for node in m.conv_bn_relu.graph.nodes:
if node.op not in ["placeholder", "output", "get_attr"]:
check_nn_module(node)

View File

@ -140,8 +140,7 @@ class PT2EQATTestCase(QuantizationTestCase):
)
)
model_pt2e = export_for_training(
model_pt2e,
example_inputs,
model_pt2e, example_inputs, strict=True
).module()
model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer)
torch.manual_seed(MANUAL_SEED)
@ -229,10 +228,7 @@ class PT2EQATTestCase(QuantizationTestCase):
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel, is_qat=True)
)
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -621,7 +617,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
m = M(self.conv_class, self.bn_class, backbone)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
@ -679,7 +675,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
def test_qat_conv_bn_bias_derived_qspec(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
quantizer = ConvBnDerivedBiasQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -726,7 +722,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
def test_qat_per_channel_weight_custom_dtype(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
quantizer = ConvBnInt32WeightQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -780,7 +776,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
def test_qat_conv_bn_per_channel_weight_bias(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True)
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -837,7 +833,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
it into conv in `convert_pt2e` even in train mode.
"""
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
m = export_for_training(m, self.example_inputs).module()
m = export_for_training(m, self.example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
@ -1085,7 +1081,9 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
in_channels = child.linear1.weight.size(1)
example_input = (torch.rand((1, in_channels)),)
traced_child = export_for_training(child, example_input).module()
traced_child = export_for_training(
child, example_input, strict=True
).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(
is_per_channel=True, is_qat=True
@ -1116,10 +1114,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
self._convert_qat_linears(model)
model(*example_inputs)
model_pt2e = export_for_training(
model,
example_inputs,
).module()
model_pt2e = export_for_training(model, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer()
quantizer.set_module_type(torch.nn.Linear, None)

View File

@ -33,10 +33,7 @@ class TestPT2ERepresentation(QuantizationTestCase):
) -> torch.nn.Module:
# resetting dynamo cache
torch._dynamo.reset()
model = export_for_training(
model,
example_inputs,
).module()
model = export_for_training(model, example_inputs, strict=True).module()
model_copy = copy.deepcopy(model)
model = prepare_pt2e(model, quantizer)

View File

@ -665,10 +665,7 @@ class X86InductorQuantTestCase(QuantizationTestCase):
# program capture
m = copy.deepcopy(m_eager)
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()
# QAT Model failed to deepcopy
export_model = m if is_qat else copy.deepcopy(m)
@ -2344,7 +2341,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
# Use a linear count instead of names because the names might change, but
# the order should be the same.

View File

@ -361,7 +361,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
# Use a linear count instead of names because the names might change, but
# the order should be the same.
@ -497,10 +497,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
@ -766,8 +763,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
with torchdynamo.config.patch(allow_rnn=True):
model_graph = export_for_training(
model_graph,
example_inputs,
model_graph, example_inputs, strict=True
).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(
@ -829,8 +825,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
with torchdynamo.config.patch(allow_rnn=True):
model_graph = export_for_training(
model_graph,
example_inputs,
model_graph, example_inputs, strict=True
).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(
@ -1039,10 +1034,7 @@ class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase):
m = torchvision.models.resnet18().eval()
m_copy = copy.deepcopy(m)
# program capture
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)

View File

@ -27,7 +27,9 @@ class TestQuantizePT2EModels(TestCase):
m = m.eval()
input_shape = (1, 3, 224, 224)
example_inputs = (torch.randn(input_shape),)
m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
m = torch.export.export_for_training(
m, copy.deepcopy(example_inputs), strict=True
).module()
m(*example_inputs)
m = export.export(m, copy.deepcopy(example_inputs))
ops = _get_ops_list(m.graph_module)

View File

@ -355,6 +355,7 @@ def _get_aten_graph_module_for_pattern(
pattern, # type: ignore[arg-type]
example_inputs,
kwargs,
strict=True,
).module()
aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]

View File

@ -1003,9 +1003,7 @@ class Pipe(torch.nn.Module):
logger.info("Tracing model ...")
try:
ep = torch.export.export_for_training(
mod,
example_args,
example_kwargs,
mod, example_args, example_kwargs, strict=True
)
except Exception as e:
raise RuntimeError(

File diff suppressed because it is too large Load Diff