mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
detect mocked module on saving pass (#70641)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70641 Raises a not implemented error if we attempt to pickle an object which uses a mocked module. Now we no longer have to load the object to get this check, and instead happens right on the saving path. Review History is on https://github.com/pytorch/pytorch/pull/69793 PR was moved to a different branch due to original branch getting corrupted. Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D33414365 Pulled By: PaliC fbshipit-source-id: 6d72ddb05c47a3d060e9622ec0b6e5cd6c6c71c8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c4400fc431
commit
118bd82dde
@ -182,16 +182,24 @@ class TestDependencyAPI(PackageTestCase):
|
|||||||
obj2 = package_a.PackageAObject(obj)
|
obj2 = package_a.PackageAObject(obj)
|
||||||
|
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
with PackageExporter(buffer) as he:
|
|
||||||
he.mock(include="package_a.subpackage")
|
|
||||||
he.intern("**")
|
|
||||||
he.save_pickle("obj", "obj.pkl", obj2)
|
|
||||||
|
|
||||||
buffer.seek(0)
|
|
||||||
|
|
||||||
hi = PackageImporter(buffer)
|
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
hi.load_pickle("obj", "obj.pkl")
|
with PackageExporter(buffer) as he:
|
||||||
|
he.mock(include="package_a.subpackage")
|
||||||
|
he.intern("**")
|
||||||
|
he.save_pickle("obj", "obj.pkl", obj2)
|
||||||
|
|
||||||
|
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
||||||
|
def test_pickle_mocked_all(self):
|
||||||
|
import package_a.subpackage
|
||||||
|
|
||||||
|
obj = package_a.subpackage.PackageASubpackageObject()
|
||||||
|
obj2 = package_a.PackageAObject(obj)
|
||||||
|
|
||||||
|
buffer = BytesIO()
|
||||||
|
with PackageExporter(buffer) as he:
|
||||||
|
he.intern(include="package_a.**")
|
||||||
|
he.mock("**")
|
||||||
|
he.save_pickle("obj", "obj.pkl", obj2)
|
||||||
|
|
||||||
def test_allow_empty_with_error(self):
|
def test_allow_empty_with_error(self):
|
||||||
"""If an error occurs during packaging, it should not be shadowed by the allow_empty error."""
|
"""If an error occurs during packaging, it should not be shadowed by the allow_empty error."""
|
||||||
|
@ -592,6 +592,17 @@ class PackageExporter:
|
|||||||
module, field = arg.split(" ")
|
module, field = arg.split(" ")
|
||||||
if module not in all_dependencies:
|
if module not in all_dependencies:
|
||||||
all_dependencies.append(module)
|
all_dependencies.append(module)
|
||||||
|
for pattern, pattern_info in self.patterns.items():
|
||||||
|
if pattern.matches(module):
|
||||||
|
if pattern_info.action == _ModuleProviderAction.MOCK:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Object '{field}' from module {module} was mocked out during packaging "
|
||||||
|
f"but is being used in resource - {resource} in package {package}. "
|
||||||
|
"If this error is happening during 'save_pickle', please ensure that your "
|
||||||
|
"pickled object doesn't contain any mocked objects."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
for module_name in all_dependencies:
|
for module_name in all_dependencies:
|
||||||
self.dependency_graph.add_edge(name_in_dependency_graph, module_name)
|
self.dependency_graph.add_edge(name_in_dependency_graph, module_name)
|
||||||
|
Reference in New Issue
Block a user