mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[export] Remove Proxy from exported programs and modules (#132956)
Summary: Remove Proxy from exported programs and modules because they cannot be deepcopied or pickeled. Test Plan: CI ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r qat_conv2d buck2 run 'fbcode//mode/dev-nosan' fbcode//modai/test:test_modai -- -r test_qat_stinson_htp_export buck2 run 'fbcode//mode/dev-nosan' fbcode//vizard_projects/ml_depth/tests:test_model -- -r test_qat_model_et buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=False,use_3d_input=False buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=True,use_3d_input=False buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_fold_bn_erases_bn_node ``` Differential Revision: D60940832 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132956 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
e2b94923ba
commit
3c5b246d3c
@ -560,9 +560,6 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
||||
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
torch._export.utils.remove_proxy_from_state_dict(
|
||||
m.__dict__["_buffers"], in_place=True
|
||||
)
|
||||
prepare_model = copy.deepcopy(m)
|
||||
m = convert_pt2e(m)
|
||||
convert_model = copy.deepcopy(m)
|
||||
|
||||
@ -208,6 +208,12 @@ def capture_pre_autograd_graph(
|
||||
|
||||
module.train = types.MethodType(_train, module) # type: ignore[method-assign]
|
||||
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
|
||||
|
||||
# Remove Proxy because they cannot be deepcopied or pickled.
|
||||
if hasattr(module, "_buffers"):
|
||||
torch._export.utils.remove_proxy_from_state_dict(
|
||||
module._buffers, in_place=True
|
||||
)
|
||||
return module
|
||||
|
||||
|
||||
|
||||
@ -616,19 +616,19 @@ def placeholder_naming_pass(
|
||||
|
||||
def remove_proxy_from_state_dict(state_dict: Dict, in_place: bool) -> Dict:
|
||||
"""
|
||||
If `in_place` is false, remove a new copy of `state_dict` with "proxy" removed from `v.__dict__`.
|
||||
If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`.
|
||||
`v` is the values in the dictionary.
|
||||
If `in_place` is true, modify `state_dict` in place.
|
||||
"""
|
||||
if in_place:
|
||||
for k, v in state_dict.items():
|
||||
if "proxy" in v.__dict__:
|
||||
state_dict[k] = v.clone().detach()
|
||||
if hasattr(v, "proxy"):
|
||||
delattr(state_dict[k], "proxy")
|
||||
return state_dict
|
||||
else:
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if "proxy" in v.__dict__:
|
||||
if hasattr(v, "proxy"):
|
||||
new_state_dict[k] = v.clone().detach()
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
|
||||
@ -2071,6 +2071,9 @@ def _export(
|
||||
if not _is_torch_jit_trace:
|
||||
_verify_placeholder_names(gm, export_graph_signature)
|
||||
|
||||
# Remove Proxy because they cannot be deepcopied or pickled.
|
||||
torch._export.utils.remove_proxy_from_state_dict(original_state_dict, in_place=True)
|
||||
|
||||
from torch._export.verifier import Verifier
|
||||
|
||||
exported_program = ExportedProgram(
|
||||
|
||||
@ -408,20 +408,20 @@ class Proxy:
|
||||
def __getstate__(self) -> Dict:
|
||||
raise NotImplementedError(
|
||||
"""__getstate__ not implemented for Proxy. """
|
||||
"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
|
||||
f"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
|
||||
)
|
||||
|
||||
def __deepcopy__(self, memo) -> Dict:
|
||||
raise NotImplementedError(
|
||||
"""__deepcopy__ not implemented for Proxy. """
|
||||
"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
|
||||
f"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
|
||||
)
|
||||
|
||||
def __setstate__(self, d):
|
||||
# This is called when being unpickled/loaded.
|
||||
raise NotImplementedError(
|
||||
"""__setstate__ not implemented for Proxy. """
|
||||
"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
|
||||
f"""Proxy is created for {self.node.name}, {self.node.target}. Please remove "proxy" from __dict__."""
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> 'Proxy':
|
||||
|
||||
Reference in New Issue
Block a user