mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/157636 Approved by: https://github.com/yewentao256, https://github.com/mlazos ghstack dependencies: #156311, #156609
		
			
				
	
	
		
			372 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			372 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: unknown"]
 | |
| 
 | |
| from functools import partial
 | |
| from textwrap import dedent
 | |
| 
 | |
| import torch
 | |
| from torch.testing import FileCheck
 | |
| from torch.testing._internal.common_device_type import (
 | |
|     instantiate_device_type_tests,
 | |
|     OpDTypes,
 | |
|     ops,
 | |
| )
 | |
| from torch.testing._internal.common_jit import (
 | |
|     check_against_reference,
 | |
|     JitCommonTestCase,
 | |
| )
 | |
| from torch.testing._internal.common_methods_invocations import op_db
 | |
| from torch.testing._internal.common_utils import (
 | |
|     clone_input_helper,
 | |
|     first_sample,
 | |
|     IS_SANDCASTLE,
 | |
|     run_tests,
 | |
|     TestCase,
 | |
|     unMarkDynamoStrictTest,
 | |
| )
 | |
| from torch.testing._internal.jit_metaprogramming_utils import (
 | |
|     check_alias_annotation,
 | |
|     create_script_fn,
 | |
|     create_traced_fn,
 | |
| )
 | |
| from torch.testing._internal.jit_utils import (
 | |
|     disable_autodiff_subgraph_inlining,
 | |
|     is_lambda,
 | |
| )
 | |
| 
 | |
| 
 | |
| # variant testing is only done with torch.float and torch.cfloat to avoid
 | |
| #   excessive test times and maximize signal to noise ratio
 | |
| _variant_ops = partial(
 | |
|     ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
 | |
| )
 | |
| 
 | |
| 
 | |
| # Tests operators for consistency between JIT and eager, also checks
 | |
| #   correctness of JIT specific alias schemas and intended
 | |
| #   autodifferentiation behavior.
 | |
| # Inherits from JitCommonTestCase instead of TestCase directly to share
 | |
| #   functionality with original test_jit.py method operator tests
 | |
| @unMarkDynamoStrictTest
 | |
| class TestJit(JitCommonTestCase):
 | |
|     exact_dtype = True
 | |
| 
 | |
|     # Tests that the forward and backward passes of operations produce the
 | |
|     #   same values for the cross-product of op variants (function, method, inplace)
 | |
|     #   and runtimes (eager, traced, scripted).
 | |
|     # TODO WARNING: inplace x {traced, scripted} not currently tested
 | |
|     @_variant_ops(op_db)
 | |
|     def test_variant_consistency_jit(self, device, dtype, op):
 | |
|         _requires_grad = dtype in op.supported_backward_dtypes(
 | |
|             torch.device(device).type
 | |
|         )
 | |
| 
 | |
|         include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
 | |
|         samples = op.sample_inputs(
 | |
|             device,
 | |
|             dtype,
 | |
|             requires_grad=_requires_grad,
 | |
|             include_conjugated_inputs=include_conjugated_inputs,
 | |
|         )
 | |
| 
 | |
|         # Acquires variants to test
 | |
|         func = op.get_op()
 | |
|         method = op.get_method()
 | |
|         variants = {
 | |
|             # TODO: inplace tests currently fail, fix and add inplace variant
 | |
|             "function": func,
 | |
|             "method": method,
 | |
|         }
 | |
| 
 | |
|         # scripting strips the torch.ops prefix from these operators
 | |
|         # incorrectly; don't bother testing this case.  Count this
 | |
|         # as "testing"
 | |
|         if isinstance(func, torch._ops.OpOverload):
 | |
|             self.skipTest("variant consistency doesn't work on torch.ops")
 | |
| 
 | |
|         # TODO: find better way to standardize on op registration itself..
 | |
|         has_fake_function = op.name in ["resize_", "resize_as_"]
 | |
| 
 | |
|         if has_fake_function:
 | |
|             variants = {"method": getattr(torch.Tensor, op.name)}
 | |
|             samples = op.sample_inputs(device, dtype, requires_grad=False)
 | |
| 
 | |
|         tested = False
 | |
|         for sample in samples:
 | |
|             # Test traced and scripted consistency
 | |
|             for func_type, variant in variants.items():
 | |
|                 if variant is None:
 | |
|                     continue
 | |
| 
 | |
|                 # scripting and check_alias_analysis do not work with lambdas
 | |
|                 # lambdas are typically used as a way to simulate methods without
 | |
|                 # functional variants, so rely on the other variant for testing
 | |
|                 # for now
 | |
|                 if is_lambda(variant):
 | |
|                     continue
 | |
| 
 | |
|                 tested = True
 | |
|                 try:
 | |
|                     self.indiv_variant_test_jit(
 | |
|                         device, dtype, op, sample, func_type, variant, has_fake_function
 | |
|                     )
 | |
|                 except Exception as e:
 | |
|                     variant_error_info = dedent(
 | |
|                         f"""
 | |
|                         Error testing {op.name} {func_type} variant
 | |
|                         with dtype: {dtype}
 | |
|                         with inputs {sample}:
 | |
|                     """
 | |
|                     )
 | |
|                     raise Exception(variant_error_info) from e  # noqa: TRY002
 | |
| 
 | |
|         assert tested, "JIT Test does not execute any logic"
 | |
| 
 | |
|     def indiv_variant_test_jit(
 | |
|         self, device, dtype, op, sample, func_type, variant, has_fake_function
 | |
|     ):
 | |
|         _requires_grad = dtype in op.supported_backward_dtypes(
 | |
|             torch.device(device).type
 | |
|         )
 | |
|         support_script = op.supports_scripting
 | |
|         # Create accessor for script function variant
 | |
|         name = op.name + "_" if func_type == "inplace" else op.name
 | |
| 
 | |
|         # run with disable_autodiff_subgraph_inlining(True) to test
 | |
|         #   autodiff support. Context manager forces the graph to contain
 | |
|         #   DifferentiableGraph nodes if they are present
 | |
|         with disable_autodiff_subgraph_inlining():
 | |
|             # Check scripted forward, grad, and grad grad
 | |
|             if support_script:
 | |
|                 script_fn = create_script_fn(self, name, func_type)
 | |
| 
 | |
|             def out_fn(output):
 | |
|                 # Processes the output for autograd
 | |
|                 if sample.output_process_fn_grad is not None:
 | |
|                     return sample.output_process_fn_grad(output)
 | |
|                 return output
 | |
| 
 | |
|             def get_sample():
 | |
|                 return (
 | |
|                     clone_input_helper(sample.input)
 | |
|                     if op.name[-1] == "_"
 | |
|                     else sample.input
 | |
|                 )
 | |
| 
 | |
|             if support_script:
 | |
|                 check_against_reference(
 | |
|                     self,
 | |
|                     script_fn,
 | |
|                     op.get_op(),
 | |
|                     out_fn,
 | |
|                     (get_sample(),) + sample.args,
 | |
|                     sample.kwargs,
 | |
|                     no_grad=not _requires_grad,
 | |
|                     no_gradgrad=not op.supports_gradgrad,
 | |
|                 )
 | |
| 
 | |
|             # Check traced forward, grad, and grad grad
 | |
|             # TODO: fix tracing here
 | |
|             supports_tracing = op.supports_tracing and not has_fake_function
 | |
|             if op.assert_jit_shape_analysis:
 | |
|                 self.assertTrue(supports_tracing)
 | |
| 
 | |
|             if supports_tracing:
 | |
|                 traced_fn = create_traced_fn(self, variant)
 | |
|                 check_against_reference(
 | |
|                     self,
 | |
|                     traced_fn,
 | |
|                     op.get_op(),
 | |
|                     out_fn,
 | |
|                     (get_sample(),) + sample.args,
 | |
|                     sample.kwargs,
 | |
|                     no_grad=not _requires_grad,
 | |
|                     no_gradgrad=not op.supports_gradgrad,
 | |
|                 )
 | |
| 
 | |
|             # Check alias annotation schema for correctness (make
 | |
|             #   sure inputs that aren't supposed to be modified aren't)
 | |
|             # Note: only runs in float32 because schema isn't affected by dtype,
 | |
|             #   so running it on all dtypes is would be excessive
 | |
|             if dtype == torch.float32:
 | |
|                 # TODO: no reason why we can't run this with tracing graph
 | |
|                 if support_script and op.name != "rsub":
 | |
|                     check_alias_annotation(
 | |
|                         name,
 | |
|                         (get_sample(),) + sample.args,
 | |
|                         sample.kwargs,
 | |
|                         func_type=func_type,
 | |
|                         aten_name=op.aten_name,
 | |
|                     )
 | |
| 
 | |
|                 # TODO: use script graph as well
 | |
|                 checked_shape_analysis = False
 | |
|                 if supports_tracing:
 | |
|                     out = variant(get_sample(), *sample.args, **sample.kwargs)
 | |
| 
 | |
|                     # right now, tuple of outputs and tensor output supported
 | |
|                     # TODO: list of tensor outputs
 | |
|                     tuple_of_tensors = isinstance(out, tuple) and all(
 | |
|                         isinstance(elem, torch.Tensor) for elem in out
 | |
|                     )
 | |
| 
 | |
|                     if isinstance(out, torch.Tensor) or tuple_of_tensors:
 | |
|                         if tuple_of_tensors:
 | |
|                             sizes = [elem.size() for elem in out]
 | |
|                         else:
 | |
|                             sizes = out.size()
 | |
|                         self.checkShapeAnalysis(
 | |
|                             sizes, traced_fn.graph, op.assert_jit_shape_analysis
 | |
|                         )
 | |
|                         checked_shape_analysis = True
 | |
|                 if op.assert_jit_shape_analysis:
 | |
|                     self.assertTrue(checked_shape_analysis)
 | |
| 
 | |
|             # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
 | |
|             if dtype is torch.float32:
 | |
|                 # Sandcastle doesn't fuse nodes
 | |
|                 if IS_SANDCASTLE:
 | |
|                     # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
 | |
|                     nonfusible_nodes = (
 | |
|                         op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
 | |
|                     )
 | |
|                     fusible_nodes = []
 | |
|                 else:
 | |
|                     nonfusible_nodes = op.autodiff_nonfusible_nodes
 | |
|                     fusible_nodes = op.autodiff_fusible_nodes
 | |
| 
 | |
|                 if supports_tracing:
 | |
|                     self.assertAutodiffNode(
 | |
|                         traced_fn.last_graph,
 | |
|                         op.assert_autodiffed,
 | |
|                         nonfusible_nodes,
 | |
|                         fusible_nodes,
 | |
|                     )
 | |
|                 if support_script:
 | |
|                     self.assertAutodiffNode(
 | |
|                         script_fn.last_graph,
 | |
|                         op.assert_autodiffed,
 | |
|                         nonfusible_nodes,
 | |
|                         fusible_nodes,
 | |
|                     )
 | |
| 
 | |
|     # alias testing is only done with torch.float for the same reason
 | |
|     _alias_ops = partial(ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float,))
 | |
| 
 | |
|     @_alias_ops(op for op in op_db if op.aliases)
 | |
|     def test_jit_alias_remapping(self, device, dtype, op):
 | |
|         # NOTE: only tests on first sample
 | |
|         samples = op.sample_inputs(device, dtype, requires_grad=True)
 | |
|         sample = first_sample(self, samples)
 | |
| 
 | |
|         # [Scripting Data Preparation]
 | |
|         # Prepare data for test scripting
 | |
|         # Below we prepare strings of args/kwargs with and without type annotations.
 | |
|         # These strings are inserted into function template strings which is then torch scripted.
 | |
|         # - args string is ["t0"] corresponding to the "input" tensor required by the op
 | |
|         # - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example,
 | |
|         # ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0))
 | |
|         args = ["t0"]
 | |
| 
 | |
|         def quote_strs(v):
 | |
|             if isinstance(v, str):
 | |
|                 return f"'{v}'"
 | |
| 
 | |
|             return str(v)
 | |
| 
 | |
|         args_kw = (
 | |
|             args
 | |
|             + [f"{v}" for v in sample.args]
 | |
|             + [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
 | |
|         )
 | |
| 
 | |
|         # Prepare data for test tracing
 | |
|         sample_args_kwargs = ()
 | |
|         if len(sample.args) > 0:
 | |
|             sample_args_kwargs += (sample.args,)
 | |
|         if len(sample.kwargs) > 0:
 | |
|             sample_args_kwargs += (sample.kwargs,)
 | |
| 
 | |
|         original_name = op.aten_name
 | |
|         original_name_inplace = original_name + "_"
 | |
|         expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype
 | |
| 
 | |
|         for a_op in op.aliases:
 | |
|             inplace = a_op.inplace_variant
 | |
|             method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
 | |
|             variants = (
 | |
|                 v
 | |
|                 for v in (a_op.op, a_op.method_variant, a_op.inplace_variant)
 | |
|                 if v is not None
 | |
|             )
 | |
| 
 | |
|             # Test scripting:
 | |
|             for variant in variants:
 | |
|                 variant_name = variant.__name__
 | |
|                 op_name = original_name_inplace if variant is inplace else original_name
 | |
| 
 | |
|                 if variant in method_or_inplace:
 | |
|                     fn_template = """
 | |
|                         def _fn(t0{c}):
 | |
|                             return t0.{alias_name}({args_kw})
 | |
|                     """
 | |
|                     # remove the first input tensor
 | |
|                     script = fn_template.format(
 | |
|                         c=", " if len(args_kw[1:]) > 1 else "",
 | |
|                         args_kw=", ".join(args_kw[1:]),
 | |
|                         alias_name=variant_name,
 | |
|                     )
 | |
|                 else:
 | |
|                     fn_template = """
 | |
|                         def _fn({args}):
 | |
|                             return variant({args_kw})
 | |
|                     """
 | |
|                     script = fn_template.format(
 | |
|                         args=", ".join(args),
 | |
|                         args_kw=", ".join(args_kw),
 | |
|                     )
 | |
| 
 | |
|                 # Required to avoid undefined value: tensor error in JIT
 | |
|                 # compilation of the function template
 | |
|                 script = script.replace("tensor(", "torch.tensor(")
 | |
| 
 | |
|                 scripted = torch.jit.CompilationUnit(script)._fn
 | |
| 
 | |
|                 if variant is inplace and not torch.can_cast(expected_dtype, dtype):
 | |
|                     try:
 | |
|                         inp = clone_input_helper(sample.input)
 | |
|                         scripted(inp)
 | |
|                     except Exception:
 | |
|                         continue
 | |
|                     self.fail(
 | |
|                         "Inplace operation on integer tensor that should be promoted to float didn't fail!"
 | |
|                     )
 | |
| 
 | |
|                 inp = clone_input_helper(sample.input)
 | |
|                 scripted(inp)
 | |
|                 inp = clone_input_helper(sample.input)
 | |
|                 graph = scripted.graph_for(inp)
 | |
|                 FileCheck().check(op.aten_name).check_not(variant_name).run(graph)
 | |
| 
 | |
|             # Test tracing:
 | |
|             for variant in variants:
 | |
|                 variant_name = variant.__name__
 | |
|                 op_name = original_name_inplace if variant is inplace else original_name
 | |
| 
 | |
|                 def _fn(*sample_args, **sample_kwargs):
 | |
|                     return variant(*sample_args, **sample_kwargs)
 | |
| 
 | |
|                 inp = (clone_input_helper(sample.input),) + sample_args_kwargs
 | |
|                 traced = torch.jit.trace(_fn, *inp)
 | |
|                 inp = (clone_input_helper(sample.input),) + sample_args_kwargs
 | |
|                 traced(*inp)
 | |
|                 inp = (clone_input_helper(sample.input),) + sample_args_kwargs
 | |
|                 graph = traced.graph_for(*inp)
 | |
|                 FileCheck().check(op_name).check_not(variant_name).run(graph)
 | |
| 
 | |
| 
 | |
| instantiate_device_type_tests(TestJit, globals())
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     TestCase._default_dtype_check_enabled = True
 | |
|     run_tests()
 |