mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/157636 Approved by: https://github.com/yewentao256, https://github.com/mlazos ghstack dependencies: #156311, #156609
		
			
				
	
	
		
			59 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			59 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["module: unknown"]
 | 
						|
 | 
						|
import glob
 | 
						|
import io
 | 
						|
import os
 | 
						|
import unittest
 | 
						|
 | 
						|
import torch
 | 
						|
from torch.testing._internal.common_utils import run_tests, TestCase
 | 
						|
 | 
						|
 | 
						|
try:
 | 
						|
    from third_party.build_bundled import create_bundled
 | 
						|
except ImportError:
 | 
						|
    create_bundled = None
 | 
						|
 | 
						|
license_file = "third_party/LICENSES_BUNDLED.txt"
 | 
						|
starting_txt = "The PyTorch repository and source distributions bundle"
 | 
						|
site_packages = os.path.dirname(os.path.dirname(torch.__file__))
 | 
						|
distinfo = glob.glob(os.path.join(site_packages, "torch-*dist-info"))
 | 
						|
 | 
						|
 | 
						|
class TestLicense(TestCase):
 | 
						|
    @unittest.skipIf(not create_bundled, "can only be run in a source tree")
 | 
						|
    def test_license_for_wheel(self):
 | 
						|
        current = io.StringIO()
 | 
						|
        create_bundled("third_party", current)
 | 
						|
        with open(license_file) as fid:
 | 
						|
            src_tree = fid.read()
 | 
						|
        if not src_tree == current.getvalue():
 | 
						|
            raise AssertionError(
 | 
						|
                f'the contents of "{license_file}" do not '
 | 
						|
                "match the current state of the third_party files. Use "
 | 
						|
                '"python third_party/build_bundled.py" to regenerate it'
 | 
						|
            )
 | 
						|
 | 
						|
    @unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test")
 | 
						|
    def test_distinfo_license(self):
 | 
						|
        """If run when pytorch is installed via a wheel, the license will be in
 | 
						|
        site-package/torch-*dist-info/LICENSE. Make sure it contains the third
 | 
						|
        party bundle of licenses"""
 | 
						|
 | 
						|
        if len(distinfo) > 1:
 | 
						|
            raise AssertionError(
 | 
						|
                'Found too many "torch-*dist-info" directories '
 | 
						|
                f'in "{site_packages}, expected only one'
 | 
						|
            )
 | 
						|
        # setuptools renamed *dist-info/LICENSE to *dist-info/licenses/LICENSE since 77.0
 | 
						|
        license_file = os.path.join(distinfo[0], "licenses", "LICENSE")
 | 
						|
        if not os.path.exists(license_file):
 | 
						|
            license_file = os.path.join(distinfo[0], "LICENSE")
 | 
						|
        with open(license_file) as fid:
 | 
						|
            txt = fid.read()
 | 
						|
            self.assertTrue(starting_txt in txt)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    run_tests()
 |