[export] Remove Proxy from exported programs and modules (#132956)

Summary: Remove Proxy from exported programs and modules because they cannot be deepcopied or pickeled.

Test Plan:
CI

```
buck2 run 'fbcode//mode/dev-nosan'  fbcode//caffe2/test/quantization:test_quantization -- -r  qat_conv2d
buck2 run 'fbcode//mode/dev-nosan' fbcode//modai/test:test_modai -- -r test_qat_stinson_htp_export
buck2 run 'fbcode//mode/dev-nosan' fbcode//vizard_projects/ml_depth/tests:test_model -- -r test_qat_model_et
buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=False,use_3d_input=False
buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=True,use_3d_input=False
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r  test_fold_bn_erases_bn_node
```

Differential Revision: D60940832

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132956
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu
2024-08-09 00:00:20 +00:00
committed by PyTorch MergeBot
parent e2b94923ba
commit 3c5b246d3c
5 changed files with 16 additions and 10 deletions

View File

@ -560,9 +560,6 @@ class X86InductorQuantTestCase(QuantizationTestCase):
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
torch._export.utils.remove_proxy_from_state_dict(
m.__dict__["_buffers"], in_place=True
)
prepare_model = copy.deepcopy(m)
m = convert_pt2e(m)
convert_model = copy.deepcopy(m)

View File

@ -208,6 +208,12 @@ def capture_pre_autograd_graph(
module.train = types.MethodType(_train, module) # type: ignore[method-assign]
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
# Remove Proxy because they cannot be deepcopied or pickled.
if hasattr(module, "_buffers"):
torch._export.utils.remove_proxy_from_state_dict(
module._buffers, in_place=True
)
return module

View File

@ -616,19 +616,19 @@ def placeholder_naming_pass(
def remove_proxy_from_state_dict(state_dict: Dict, in_place: bool) -> Dict:
"""
If `in_place` is false, remove a new copy of `state_dict` with "proxy" removed from `v.__dict__`.
If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`.
`v` is the values in the dictionary.
If `in_place` is true, modify `state_dict` in place.
"""
if in_place:
for k, v in state_dict.items():
if "proxy" in v.__dict__:
state_dict[k] = v.clone().detach()
if hasattr(v, "proxy"):
delattr(state_dict[k], "proxy")
return state_dict
else:
new_state_dict = {}
for k, v in state_dict.items():
if "proxy" in v.__dict__:
if hasattr(v, "proxy"):
new_state_dict[k] = v.clone().detach()
else:
new_state_dict[k] = v

View File

@ -2071,6 +2071,9 @@ def _export(
if not _is_torch_jit_trace:
_verify_placeholder_names(gm, export_graph_signature)
# Remove Proxy because they cannot be deepcopied or pickled.
torch._export.utils.remove_proxy_from_state_dict(original_state_dict, in_place=True)
from torch._export.verifier import Verifier
exported_program = ExportedProgram(

View File

@ -408,20 +408,20 @@ class Proxy:
def __getstate__(self) -> Dict:
raise NotImplementedError(
"""__getstate__ not implemented for Proxy. """
"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
f"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
)
def __deepcopy__(self, memo) -> Dict:
raise NotImplementedError(
"""__deepcopy__ not implemented for Proxy. """
"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
f"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
)
def __setstate__(self, d):
# This is called when being unpickled/loaded.
raise NotImplementedError(
"""__setstate__ not implemented for Proxy. """
"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
f"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
)
def __call__(self, *args, **kwargs) -> 'Proxy':