mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Almost there! Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913 Approved by: https://github.com/oulgen
		
			
				
	
	
		
			446 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			446 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # Owner(s): ["oncall: mobile"]
 | |
| # mypy: allow-untyped-defs
 | |
| 
 | |
| import io
 | |
| import textwrap
 | |
| from typing import Optional
 | |
| 
 | |
| import torch
 | |
| import torch.utils.bundled_inputs
 | |
| from torch.testing._internal.common_utils import run_tests, TestCase
 | |
| 
 | |
| 
 | |
| def model_size(sm):
 | |
|     buffer = io.BytesIO()
 | |
|     torch.jit.save(sm, buffer)
 | |
|     return len(buffer.getvalue())
 | |
| 
 | |
| 
 | |
| def save_and_load(sm):
 | |
|     buffer = io.BytesIO()
 | |
|     torch.jit.save(sm, buffer)
 | |
|     buffer.seek(0)
 | |
|     return torch.jit.load(buffer)
 | |
| 
 | |
| 
 | |
| class TestBundledInputs(TestCase):
 | |
|     def test_single_tensors(self):
 | |
|         class SingleTensorModel(torch.nn.Module):
 | |
|             def forward(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|         sm = torch.jit.script(SingleTensorModel())
 | |
|         original_size = model_size(sm)
 | |
|         get_expr: list[str] = []
 | |
|         samples = [
 | |
|             # Tensor with small numel and small storage.
 | |
|             (torch.tensor([1]),),
 | |
|             # Tensor with large numel and small storage.
 | |
|             (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
 | |
|             # Tensor with small numel and large storage.
 | |
|             (torch.tensor(range(1 << 16))[-8:],),
 | |
|             # Large zero tensor.
 | |
|             (torch.zeros(1 << 16),),
 | |
|             # Large channels-last ones tensor.
 | |
|             (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
 | |
|             # Special encoding of random tensor.
 | |
|             (torch.utils.bundled_inputs.bundle_randn(1 << 16),),
 | |
|             # Quantized uniform tensor.
 | |
|             (torch.quantize_per_tensor(torch.zeros(4, 8, 32, 32), 1, 0, torch.qint8),),
 | |
|         ]
 | |
|         torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
 | |
|             sm, samples, get_expr
 | |
|         )
 | |
|         # print(get_expr[0])
 | |
|         # print(sm._generate_bundled_inputs.code)
 | |
| 
 | |
|         # Make sure the model only grew a little bit,
 | |
|         # despite having nominally large bundled inputs.
 | |
|         augmented_size = model_size(sm)
 | |
| 
 | |
|         self.assertLess(augmented_size, original_size + (1 << 12))
 | |
| 
 | |
|         loaded = save_and_load(sm)
 | |
|         inflated = loaded.get_all_bundled_inputs()
 | |
|         self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
 | |
|         self.assertEqual(len(inflated), len(samples))
 | |
| 
 | |
|         self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
 | |
| 
 | |
|         for idx, inp in enumerate(inflated):
 | |
|             self.assertIsInstance(inp, tuple)
 | |
|             self.assertEqual(len(inp), 1)
 | |
| 
 | |
|             self.assertIsInstance(inp[0], torch.Tensor)
 | |
|             if idx != 5:
 | |
|                 # Strides might be important for benchmarking.
 | |
|                 self.assertEqual(inp[0].stride(), samples[idx][0].stride())
 | |
|                 self.assertEqual(inp[0], samples[idx][0], exact_dtype=True)
 | |
| 
 | |
|         # This tensor is random, but with 100,000 trials,
 | |
|         # mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105).
 | |
|         self.assertEqual(inflated[5][0].shape, (1 << 16,))
 | |
|         self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0)
 | |
|         self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0)
 | |
| 
 | |
|     def test_large_tensor_with_inflation(self):
 | |
|         class SingleTensorModel(torch.nn.Module):
 | |
|             def forward(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|         sm = torch.jit.script(SingleTensorModel())
 | |
|         sample_tensor = torch.randn(1 << 16)
 | |
|         # We can store tensors with custom inflation functions regardless
 | |
|         # of size, even if inflation is just the identity.
 | |
|         sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor)
 | |
|         torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, [(sample,)])
 | |
| 
 | |
|         loaded = save_and_load(sm)
 | |
|         inflated = loaded.get_all_bundled_inputs()
 | |
|         self.assertEqual(len(inflated), 1)
 | |
| 
 | |
|         self.assertEqual(inflated[0][0], sample_tensor)
 | |
| 
 | |
|     def test_rejected_tensors(self):
 | |
|         def check_tensor(sample):
 | |
|             # Need to define the class in this scope to get a fresh type for each run.
 | |
|             class SingleTensorModel(torch.nn.Module):
 | |
|                 def forward(self, arg):
 | |
|                     return arg
 | |
| 
 | |
|             sm = torch.jit.script(SingleTensorModel())
 | |
|             with self.assertRaisesRegex(Exception, "Bundled input argument"):
 | |
|                 torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
 | |
|                     sm, [(sample,)]
 | |
|                 )
 | |
| 
 | |
|         # Plain old big tensor.
 | |
|         check_tensor(torch.randn(1 << 16))
 | |
|         # This tensor has two elements, but they're far apart in memory.
 | |
|         # We currently cannot represent this compactly while preserving
 | |
|         # the strides.
 | |
|         small_sparse = torch.randn(2, 1 << 16)[:, 0:1]
 | |
|         self.assertEqual(small_sparse.numel(), 2)
 | |
|         check_tensor(small_sparse)
 | |
| 
 | |
|     def test_non_tensors(self):
 | |
|         class StringAndIntModel(torch.nn.Module):
 | |
|             def forward(self, fmt: str, num: int):
 | |
|                 return fmt.format(num)
 | |
| 
 | |
|         sm = torch.jit.script(StringAndIntModel())
 | |
|         samples = [
 | |
|             ("first {}", 1),
 | |
|             ("second {}", 2),
 | |
|         ]
 | |
|         torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, samples)
 | |
| 
 | |
|         loaded = save_and_load(sm)
 | |
|         inflated = loaded.get_all_bundled_inputs()
 | |
|         self.assertEqual(inflated, samples)
 | |
| 
 | |
|         self.assertTrue(loaded(*inflated[0]) == "first 1")
 | |
| 
 | |
|     def test_multiple_methods_with_inputs(self):
 | |
|         class MultipleMethodModel(torch.nn.Module):
 | |
|             def forward(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def foo(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|         mm = torch.jit.script(MultipleMethodModel())
 | |
|         samples = [
 | |
|             # Tensor with small numel and small storage.
 | |
|             (torch.tensor([1]),),
 | |
|             # Tensor with large numel and small storage.
 | |
|             (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
 | |
|             # Tensor with small numel and large storage.
 | |
|             (torch.tensor(range(1 << 16))[-8:],),
 | |
|             # Large zero tensor.
 | |
|             (torch.zeros(1 << 16),),
 | |
|             # Large channels-last ones tensor.
 | |
|             (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
 | |
|         ]
 | |
|         info = [
 | |
|             "Tensor with small numel and small storage.",
 | |
|             "Tensor with large numel and small storage.",
 | |
|             "Tensor with small numel and large storage.",
 | |
|             "Large zero tensor.",
 | |
|             "Large channels-last ones tensor.",
 | |
|             "Special encoding of random tensor.",
 | |
|         ]
 | |
|         torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
 | |
|             mm,
 | |
|             inputs={mm.forward: samples, mm.foo: samples},
 | |
|             info={mm.forward: info, mm.foo: info},
 | |
|         )
 | |
|         loaded = save_and_load(mm)
 | |
|         inflated = loaded.get_all_bundled_inputs()
 | |
| 
 | |
|         # Make sure these functions are all consistent.
 | |
|         self.assertEqual(inflated, samples)
 | |
|         self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_forward())
 | |
|         self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
 | |
| 
 | |
|         # Check running and size helpers
 | |
| 
 | |
|         self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
 | |
|         self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
 | |
| 
 | |
|         # Check helper that work on all functions
 | |
|         all_info = loaded.get_bundled_inputs_functions_and_info()
 | |
|         self.assertEqual(set(all_info.keys()), {"forward", "foo"})
 | |
|         self.assertEqual(
 | |
|             all_info["forward"]["get_inputs_function_name"],
 | |
|             ["get_all_bundled_inputs_for_forward"],
 | |
|         )
 | |
|         self.assertEqual(
 | |
|             all_info["foo"]["get_inputs_function_name"],
 | |
|             ["get_all_bundled_inputs_for_foo"],
 | |
|         )
 | |
|         self.assertEqual(all_info["forward"]["info"], info)
 | |
|         self.assertEqual(all_info["foo"]["info"], info)
 | |
| 
 | |
|         # example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs
 | |
|         for func_name in all_info.keys():
 | |
|             input_func_name = all_info[func_name]["get_inputs_function_name"][0]
 | |
|             func_to_run = getattr(loaded, input_func_name)
 | |
|             self.assertEqual(func_to_run(), samples)
 | |
| 
 | |
|     def test_multiple_methods_with_inputs_both_defined_failure(self):
 | |
|         class MultipleMethodModel(torch.nn.Module):
 | |
|             def forward(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def foo(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|         samples = [(torch.tensor([1]),)]
 | |
| 
 | |
|         # inputs defined 2 ways so should fail
 | |
|         with self.assertRaises(Exception):
 | |
|             mm = torch.jit.script(MultipleMethodModel())
 | |
|             definition = textwrap.dedent(
 | |
|                 """
 | |
|                 def _generate_bundled_inputs_for_forward(self):
 | |
|                     return []
 | |
|                 """
 | |
|             )
 | |
|             mm.define(definition)
 | |
|             torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
 | |
|                 mm,
 | |
|                 inputs={
 | |
|                     mm.forward: samples,
 | |
|                     mm.foo: samples,
 | |
|                 },
 | |
|             )
 | |
| 
 | |
|     def test_multiple_methods_with_inputs_neither_defined_failure(self):
 | |
|         class MultipleMethodModel(torch.nn.Module):
 | |
|             def forward(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|             @torch.jit.export
 | |
|             def foo(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|         samples = [(torch.tensor([1]),)]
 | |
| 
 | |
|         # inputs not defined so should fail
 | |
|         with self.assertRaises(Exception):
 | |
|             mm = torch.jit.script(MultipleMethodModel())
 | |
|             mm._generate_bundled_inputs_for_forward()
 | |
|             torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
 | |
|                 mm,
 | |
|                 inputs={
 | |
|                     mm.forward: None,
 | |
|                     mm.foo: samples,
 | |
|                 },
 | |
|             )
 | |
| 
 | |
|     def test_bad_inputs(self):
 | |
|         class SingleTensorModel(torch.nn.Module):
 | |
|             def forward(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|         # Non list for input list
 | |
|         with self.assertRaises(TypeError):
 | |
|             m = torch.jit.script(SingleTensorModel())
 | |
|             torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
 | |
|                 m,
 | |
|                 inputs="foo",  # type: ignore[arg-type]
 | |
|             )
 | |
| 
 | |
|         # List of non tuples. Most common error using the api.
 | |
|         with self.assertRaises(TypeError):
 | |
|             m = torch.jit.script(SingleTensorModel())
 | |
|             torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
 | |
|                 m,
 | |
|                 inputs=[torch.ones(1, 2)],  # type: ignore[list-item]
 | |
|             )
 | |
| 
 | |
|     def test_double_augment_fail(self):
 | |
|         class SingleTensorModel(torch.nn.Module):
 | |
|             def forward(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|         m = torch.jit.script(SingleTensorModel())
 | |
|         torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
 | |
|             m, inputs=[(torch.ones(1),)]
 | |
|         )
 | |
|         with self.assertRaisesRegex(
 | |
|             Exception, "Models can only be augmented with bundled inputs once."
 | |
|         ):
 | |
|             torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
 | |
|                 m, inputs=[(torch.ones(1),)]
 | |
|             )
 | |
| 
 | |
|     def test_double_augment_non_mutator(self):
 | |
|         class SingleTensorModel(torch.nn.Module):
 | |
|             def forward(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|         m = torch.jit.script(SingleTensorModel())
 | |
|         bundled_model = torch.utils.bundled_inputs.bundle_inputs(
 | |
|             m, inputs=[(torch.ones(1),)]
 | |
|         )
 | |
|         with self.assertRaises(AttributeError):
 | |
|             m.get_all_bundled_inputs()
 | |
|         self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
 | |
|         self.assertEqual(bundled_model.forward(torch.ones(1)), torch.ones(1))
 | |
| 
 | |
|     def test_double_augment_success(self):
 | |
|         class SingleTensorModel(torch.nn.Module):
 | |
|             def forward(self, arg):
 | |
|                 return arg
 | |
| 
 | |
|         m = torch.jit.script(SingleTensorModel())
 | |
|         bundled_model = torch.utils.bundled_inputs.bundle_inputs(
 | |
|             m, inputs={m.forward: [(torch.ones(1),)]}
 | |
|         )
 | |
|         self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
 | |
| 
 | |
|         bundled_model2 = torch.utils.bundled_inputs.bundle_inputs(
 | |
|             bundled_model, inputs=[(torch.ones(2),)]
 | |
|         )
 | |
|         self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)])
 | |
| 
 | |
|     def test_dict_args(self):
 | |
|         class MyModel(torch.nn.Module):
 | |
|             def forward(
 | |
|                 self,
 | |
|                 arg1: Optional[dict[str, torch.Tensor]],
 | |
|                 arg2: Optional[list[torch.Tensor]],
 | |
|                 arg3: torch.Tensor,
 | |
|             ):
 | |
|                 if arg1 is None:
 | |
|                     return arg3
 | |
|                 elif arg2 is None:
 | |
|                     return arg1["a"] + arg1["b"]
 | |
|                 else:
 | |
|                     return arg1["a"] + arg1["b"] + arg2[0]
 | |
| 
 | |
|         small_sample = dict(
 | |
|             a=torch.zeros([10, 20]),
 | |
|             b=torch.zeros([1, 1]),
 | |
|             c=torch.zeros([10, 20]),
 | |
|         )
 | |
|         small_list = [torch.zeros([10, 20])]
 | |
| 
 | |
|         big_sample = dict(
 | |
|             a=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
 | |
|             b=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
 | |
|             c=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
 | |
|         )
 | |
|         big_list = [torch.zeros([1 << 5, 1 << 8, 1 << 10])]
 | |
| 
 | |
|         def condensed(t):
 | |
|             ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape)
 | |
|             assert ret.storage().size() == 1
 | |
|             # ret.storage()[0] = 0
 | |
|             return ret
 | |
| 
 | |
|         def bundle_optional_dict_of_randn(template):
 | |
|             return torch.utils.bundled_inputs.InflatableArg(
 | |
|                 value=(
 | |
|                     None
 | |
|                     if template is None
 | |
|                     else {k: condensed(v) for (k, v) in template.items()}
 | |
|                 ),
 | |
|                 fmt="{}",
 | |
|                 fmt_fn="""
 | |
|                 def {}(self, value: Optional[Dict[str, Tensor]]):
 | |
|                     if value is None:
 | |
|                         return None
 | |
|                     output = {{}}
 | |
|                     for k, v in value.items():
 | |
|                         output[k] = torch.randn_like(v)
 | |
|                     return output
 | |
|                 """,
 | |
|             )
 | |
| 
 | |
|         def bundle_optional_list_of_randn(template):
 | |
|             return torch.utils.bundled_inputs.InflatableArg(
 | |
|                 value=(None if template is None else [condensed(v) for v in template]),
 | |
|                 fmt="{}",
 | |
|                 fmt_fn="""
 | |
|                 def {}(self, value: Optional[List[Tensor]]):
 | |
|                     if value is None:
 | |
|                         return None
 | |
|                     output = []
 | |
|                     for v in value:
 | |
|                         output.append(torch.randn_like(v))
 | |
|                     return output
 | |
|                 """,
 | |
|             )
 | |
| 
 | |
|         out: list[str] = []
 | |
|         sm = torch.jit.script(MyModel())
 | |
|         original_size = model_size(sm)
 | |
|         small_inputs = (
 | |
|             bundle_optional_dict_of_randn(small_sample),
 | |
|             bundle_optional_list_of_randn(small_list),
 | |
|             torch.zeros([3, 4]),
 | |
|         )
 | |
|         big_inputs = (
 | |
|             bundle_optional_dict_of_randn(big_sample),
 | |
|             bundle_optional_list_of_randn(big_list),
 | |
|             torch.zeros([1 << 5, 1 << 8, 1 << 10]),
 | |
|         )
 | |
| 
 | |
|         torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
 | |
|             sm,
 | |
|             [big_inputs, small_inputs],
 | |
|             _receive_inflate_expr=out,
 | |
|         )
 | |
|         augmented_size = model_size(sm)
 | |
|         # assert the size has not increased more than 8KB
 | |
| 
 | |
|         self.assertLess(augmented_size, original_size + (1 << 13))
 | |
| 
 | |
|         loaded = save_and_load(sm)
 | |
|         inflated = loaded.get_all_bundled_inputs()
 | |
|         self.assertEqual(len(inflated[0]), len(small_inputs))
 | |
| 
 | |
|         methods, _ = (
 | |
|             torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods(
 | |
|                 loaded
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         # One Function (forward)
 | |
|         # two bundled inputs (big_inputs and small_inputs)
 | |
|         # two args which have InflatableArg with fmt_fn
 | |
|         # 1 * 2 * 2 = 4
 | |
|         self.assertEqual(
 | |
|             sum(method.startswith("_inflate_helper") for method in methods), 4
 | |
|         )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     run_tests()
 |