mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[package] implement get_resource_reader
API (#51674)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51674 See https://docs.python.org/3/library/importlib.html#importlib.abc.ResourceReader Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D26237034 Pulled By: suo fbshipit-source-id: 4c19f6172d16b710737528d3de48372873b9368d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bfc80b3566
commit
b4d8f4af82
@ -1,6 +1,7 @@
|
||||
from torch.package.importer import ObjMismatchError
|
||||
from unittest import skipIf
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
|
||||
from tempfile import NamedTemporaryFile
|
||||
from torch.package import (
|
||||
@ -722,6 +723,81 @@ def load():
|
||||
self.assertTrue(packaged_dependency is not package_a.subpackage)
|
||||
|
||||
|
||||
class TestPackageResources(TestCase):
|
||||
def test_resource_reader(self):
|
||||
"""Test compliance with the get_resource_reader importlib API."""
|
||||
buffer = BytesIO()
|
||||
with PackageExporter(buffer, verbose=False) as pe:
|
||||
# Layout looks like:
|
||||
# package
|
||||
# ├── one/
|
||||
# │ ├── a.txt
|
||||
# │ ├── b.txt
|
||||
# │ ├── c.txt
|
||||
# │ └── three/
|
||||
# │ ├── d.txt
|
||||
# │ └── e.txt
|
||||
# └── two/
|
||||
# ├── f.txt
|
||||
# └── g.txt
|
||||
pe.save_text('one', 'a.txt', 'hello, a!')
|
||||
pe.save_text('one', 'b.txt', 'hello, b!')
|
||||
pe.save_text('one', 'c.txt', 'hello, c!')
|
||||
|
||||
pe.save_text('one.three', 'd.txt', 'hello, d!')
|
||||
pe.save_text('one.three', 'e.txt', 'hello, e!')
|
||||
|
||||
pe.save_text('two', 'f.txt', 'hello, f!')
|
||||
pe.save_text('two', 'g.txt', 'hello, g!')
|
||||
|
||||
buffer.seek(0)
|
||||
importer = PackageImporter(buffer)
|
||||
|
||||
reader_one = importer.get_resource_reader('one')
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
reader_one.resource_path('a.txt')
|
||||
|
||||
self.assertTrue(reader_one.is_resource('a.txt'))
|
||||
self.assertEqual(reader_one.open_resource('a.txt').getbuffer(), b'hello, a!')
|
||||
self.assertFalse(reader_one.is_resource('three'))
|
||||
reader_one_contents = list(reader_one.contents())
|
||||
self.assertSequenceEqual(reader_one_contents, ['a.txt', 'b.txt', 'c.txt', 'three'])
|
||||
|
||||
reader_two = importer.get_resource_reader('two')
|
||||
self.assertTrue(reader_two.is_resource('f.txt'))
|
||||
self.assertEqual(reader_two.open_resource('f.txt').getbuffer(), b'hello, f!')
|
||||
reader_two_contents = list(reader_two.contents())
|
||||
self.assertSequenceEqual(reader_two_contents, ['f.txt', 'g.txt'])
|
||||
|
||||
reader_one_three = importer.get_resource_reader('one.three')
|
||||
self.assertTrue(reader_one_three.is_resource('d.txt'))
|
||||
self.assertEqual(reader_one_three.open_resource('d.txt').getbuffer(), b'hello, d!')
|
||||
reader_one_three_contenst = list(reader_one_three.contents())
|
||||
self.assertSequenceEqual(reader_one_three_contenst, ['d.txt', 'e.txt'])
|
||||
|
||||
self.assertIsNone(importer.get_resource_reader('nonexistent_package'))
|
||||
|
||||
def test_package_resource_access(self):
|
||||
"""Packaged modules should be able to use the importlib.resources API to access
|
||||
resources saved in the package.
|
||||
"""
|
||||
mod_src = """\
|
||||
import importlib.resources
|
||||
import my_cool_resources
|
||||
|
||||
def secret_message():
|
||||
return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
|
||||
"""
|
||||
buffer = BytesIO()
|
||||
with PackageExporter(buffer, verbose=False) as pe:
|
||||
pe.save_source_string("foo.bar", dedent(mod_src))
|
||||
pe.save_text('my_cool_resources', 'sekrit.txt', 'my sekrit plays')
|
||||
|
||||
buffer.seek(0)
|
||||
importer = PackageImporter(buffer)
|
||||
self.assertEqual(importer.import_module('foo.bar').secret_message(), 'my sekrit plays')
|
||||
|
||||
|
||||
class ManglingTest(TestCase):
|
||||
def test_unique_manglers(self):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user