mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-26 16:44:54 +08:00 
			
		
		
		
	Compare commits
	
		
			13 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 207c2248a8 | |||
| a206dcc79e | |||
| f2d7f235a6 | |||
| 402b289f3b | |||
| a32157c67c | |||
| 24e7f29099 | |||
| 5b5d269d34 | |||
| fa88f390a0 | |||
| fe39c07826 | |||
| cba195c8ed | |||
| 16e67be7f1 | |||
| 7afffdf48b | |||
| ca45649eb5 | 
| @ -378,4 +378,4 @@ vision_maskrcnn,pass,17 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,2 | ||||
| yolov3,pass,0 | ||||
|  | ||||
| 
 | 
| @ -286,4 +286,4 @@ vision_maskrcnn,pass,34 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,9 | ||||
| yolov3,fail_accuracy,8 | ||||
|  | ||||
| 
 | 
| @ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| pyhpc_isoneutral_mixing,fail_to_run,0 | ||||
| pyhpc_isoneutral_mixing,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,fail_to_run,0 | ||||
| yolov3,pass,0 | ||||
|  | ||||
| 
 | 
| @ -338,4 +338,4 @@ vision_maskrcnn,pass,28 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,2 | ||||
| yolov3,pass,0 | ||||
|  | ||||
| 
 | 
| @ -338,4 +338,4 @@ vision_maskrcnn,pass,28 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,2 | ||||
| yolov3,pass,0 | ||||
|  | ||||
| 
 | 
| @ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| pyhpc_isoneutral_mixing,fail_to_run,0 | ||||
| pyhpc_isoneutral_mixing,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,fail_to_run,0 | ||||
| yolov3,pass,0 | ||||
|  | ||||
| 
 | 
| @ -374,4 +374,4 @@ vision_maskrcnn,pass,17 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,2 | ||||
| yolov3,pass,0 | ||||
|  | ||||
| 
 | 
| @ -282,4 +282,4 @@ vision_maskrcnn,pass,34 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,9 | ||||
| yolov3,fail_accuracy,8 | ||||
|  | ||||
| 
 | 
| @ -298,4 +298,4 @@ vision_maskrcnn,pass,28 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,2 | ||||
| yolov3,pass,0 | ||||
|  | ||||
| 
 | 
| @ -374,4 +374,4 @@ vision_maskrcnn,pass,17 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,2 | ||||
| yolov3,pass,0 | ||||
|  | ||||
| 
 | 
| @ -282,4 +282,4 @@ vision_maskrcnn,pass,34 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,9 | ||||
| yolov3,pass,8 | ||||
|  | ||||
| 
 | 
| @ -378,4 +378,4 @@ vision_maskrcnn,pass,17 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,2 | ||||
| yolov3,pass,0 | ||||
|  | ||||
| 
 | 
| @ -286,4 +286,4 @@ vision_maskrcnn,pass,34 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,9 | ||||
| yolov3,pass,8 | ||||
|  | ||||
| 
 | 
| @ -378,4 +378,4 @@ vision_maskrcnn,pass,17 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,2 | ||||
| yolov3,pass,0 | ||||
|  | ||||
| 
 | 
| @ -286,4 +286,4 @@ vision_maskrcnn,pass,34 | ||||
|  | ||||
|  | ||||
|  | ||||
| yolov3,pass,9 | ||||
| yolov3,pass,8 | ||||
|  | ||||
| 
 | 
| @ -4,12 +4,11 @@ phlippe_densenet,float32,static,default,1.3988316 | ||||
| basic_gnn_gcn,float32,dynamic,default,1.074576405 | ||||
| llama_v2_7b_16h,float32,dynamic,default,1.211740245 | ||||
| resnet50,float32,dynamic,default,1.65984261 | ||||
| timm_efficientnet,float32,static,cpp,2.271561735 | ||||
| #timm_efficientnet,float32,static,cpp,2.1938112 | ||||
| mobilenet_v3_large,float32,static,cpp,2.63375628 | ||||
| timm_resnest,float32,dynamic,cpp,1.67998548 | ||||
| pyhpc_turbulent_kinetic_energy,float32,dynamic,cpp,1.59968463 | ||||
| #hf_GPT2,float32,dynamic,cpp, | ||||
| hf_GPT2,float32,dynamic,cpp,1.379885175 | ||||
| #hf_GPT2,float32,dynamic,cpp,1.292704418 | ||||
| resnext50_32x4d,amp,static,default,1.461687045 | ||||
| vgg16,amp,static,default,1.267194285 | ||||
| hf_Longformer,amp,dynamic,default,0.997006035 | ||||
| @ -17,6 +16,6 @@ hf_Bert_large,amp,dynamic,default,0.99391146 | ||||
| llama,amp,static,default,1.32950568 | ||||
| timm_regnet,amp,static,cpp,1.157188305 | ||||
| lennard_jones,amp,static,cpp,2.240104485 | ||||
| hf_T5_generate,amp,dynamic,cpp,1.447656135 | ||||
| #hf_T5_generate,amp,dynamic,cpp,1.29339502 | ||||
| timm_vovnet,amp,dynamic,cpp,1.07856471 | ||||
| mobilenet_v2,amp,dynamic,cpp,2.27774577 | ||||
|  | ||||
| 
 | 
| @ -25,10 +25,6 @@ from torch._dynamo.utils import clone_inputs | ||||
| # We are primarily interested in tf32 datatype | ||||
| torch.backends.cuda.matmul.allow_tf32 = True | ||||
|  | ||||
| # Enable FX graph caching | ||||
| if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: | ||||
|     torch._inductor.config.fx_graph_cache = True | ||||
|  | ||||
|  | ||||
| def _reassign_parameters(model): | ||||
|     # torch_geometric models register parameter as tensors due to | ||||
|  | ||||
| @ -827,6 +827,7 @@ libtorch_python_core_sources = [ | ||||
|     "torch/csrc/dynamo/guards.cpp", | ||||
|     "torch/csrc/dynamo/init.cpp", | ||||
|     "torch/csrc/functorch/init.cpp", | ||||
|     "torch/csrc/fx/node.cpp", | ||||
|     "torch/csrc/mps/Module.cpp", | ||||
|     "torch/csrc/mtia/Module.cpp", | ||||
|     "torch/csrc/inductor/aoti_runner/pybind.cpp", | ||||
|  | ||||
| @ -62,8 +62,8 @@ Overall, the ``pipelining`` package provides the following features: | ||||
|   application on the Llama model. | ||||
|  | ||||
|  | ||||
| Step 1: build ``PipelineStage`` for execution | ||||
| ********************************************* | ||||
| Step 1: build ``PipelineStage`` | ||||
| ******************************* | ||||
|  | ||||
| Before we can use a ``PipelineSchedule``, we need to create ``PipelineStage`` | ||||
| objects that wrap the part of the model running in that stage.  The | ||||
|  | ||||
| @ -304,13 +304,12 @@ class DecoratorTests(torch._dynamo.test_case.TestCase): | ||||
|         self.assertEqual(cnt.frame_count, 0) | ||||
|  | ||||
|     def test_torch_guards_stack_frame_register_inlining_disable(self): | ||||
|         y = torch.nn.Parameter(torch.tensor([0.25, 0.25])) | ||||
|         x = torch.tensor([0.5, 0.5]) | ||||
|  | ||||
|         class encoder(torch.nn.Module): | ||||
|             def __init__(self, y): | ||||
|                 super().__init__() | ||||
|                 self.register_parameter("param", y) | ||||
|                 self.a = y | ||||
|  | ||||
|             @torch._dynamo.disable | ||||
|             def helper(self, x, y): | ||||
| @ -318,9 +317,9 @@ class DecoratorTests(torch._dynamo.test_case.TestCase): | ||||
|  | ||||
|             def forward(self, a, *args): | ||||
|                 x = a + a | ||||
|                 return self.helper(x, self.param) | ||||
|                 return self.helper(x, self.a) | ||||
|  | ||||
|         e = encoder(y) | ||||
|         e = encoder(2.0) | ||||
|  | ||||
|         seen_frames = [] | ||||
|         import contextlib | ||||
|  | ||||
| @ -56,7 +56,7 @@ class BinaryFoldingTemplate(TestCase): | ||||
|                     self.use_scalar = scalar | ||||
|                     tensor_size = [1 for _ in range(self.conv.weight.ndim)] | ||||
|                     tensor_size[1] = self.conv.weight.size(0) | ||||
|                     self.tensor = ( | ||||
|                     self.tensor = torch.nn.Parameter( | ||||
|                         add_tensor | ||||
|                         if add_tensor is not None | ||||
|                         else torch.rand(tensor_size).to(device) | ||||
| @ -136,7 +136,11 @@ class BinaryFoldingTemplate(TestCase): | ||||
|                 nn.Conv2d, | ||||
|                 pytorch_op, | ||||
|                 False, | ||||
|                 add_tensor=torch.rand(32, 1, 32).to(self.device), | ||||
|                 add_tensor=torch.rand( | ||||
|                     32, | ||||
|                     1, | ||||
|                     32, | ||||
|                 ).to(self.device), | ||||
|                 expect_success=False, | ||||
|             ) | ||||
|  | ||||
| @ -156,7 +160,7 @@ class BinaryFoldingTemplate(TestCase): | ||||
|                 nn.Conv2d, | ||||
|                 pytorch_op, | ||||
|                 False, | ||||
|                 add_tensor=torch.tensor([2]).to(torch.int).to(self.device), | ||||
|                 add_tensor=torch.tensor([2]).to(torch.float64).to(self.device), | ||||
|                 expect_success=False, | ||||
|             ) | ||||
|  | ||||
|  | ||||
| @ -195,7 +195,7 @@ class TestFxGraphCache(TestCase): | ||||
|                 num_put += 1 | ||||
|  | ||||
|         cache_module = ( | ||||
|             "triton.runtime.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend" | ||||
|             "triton.fb.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend" | ||||
|             if config.is_fbcode() | ||||
|             else "torch._inductor.remote_cache.RedisRemoteCacheBackend" | ||||
|         ) | ||||
|  | ||||
| @ -267,7 +267,7 @@ class TestMaxAutotune(TestCase): | ||||
|                 num_put += 1 | ||||
|  | ||||
|         cache_module = ( | ||||
|             "triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend" | ||||
|             "triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend" | ||||
|             if config.is_fbcode() | ||||
|             else "torch._inductor.remote_cache.RedisRemoteCacheBackend" | ||||
|         ) | ||||
|  | ||||
| @ -233,6 +233,23 @@ def run_fw_bw_and_get_code(fn): | ||||
|     return run_and_get_code(run_with_backward) | ||||
|  | ||||
|  | ||||
| def register_ops_with_aoti_compile(ns, op_set, dispatch_key, torch_compile_op_lib_impl): | ||||
|     for _op_name in op_set: | ||||
|         qualified_op_name = f"{ns}::{_op_name}" | ||||
|         _, overload_names = torch._C._jit_get_operation(qualified_op_name) | ||||
|         for overload_name in overload_names: | ||||
|             try: | ||||
|                 reg_op_name = qualified_op_name | ||||
|                 schema = torch._C._get_schema(qualified_op_name, overload_name) | ||||
|                 if schema.overload_name: | ||||
|                     reg_op_name = f"{qualified_op_name}.{schema.overload_name}" | ||||
|                 torch_compile_op_lib_impl._impl_with_aoti_compile(  # noqa: F821 | ||||
|                     reg_op_name, dispatch_key | ||||
|                 ) | ||||
|             except Exception as e: | ||||
|                 continue | ||||
|  | ||||
|  | ||||
| class TestCase(InductorTestCase): | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
| @ -751,6 +768,58 @@ class CommonTemplate: | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     @skipCUDAIf(not SM80OrLater, "Requires sm80") | ||||
|     def test_eager_aoti_support_out(self): | ||||
|         ns = "aten" | ||||
|         op_name = "clamp" | ||||
|         dispatch_key = "CPU" | ||||
|         device = "cpu" | ||||
|         if self.device.lower() == "cuda": | ||||
|             dispatch_key = "CUDA" | ||||
|             device = "cuda" | ||||
|  | ||||
|         inp_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(1.0) | ||||
|         min_tensor = inp_tensor - 0.05 | ||||
|         max_tensor = inp_tensor + 0.05 | ||||
|         with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: | ||||
|             ref_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_( | ||||
|                 -1 | ||||
|             ) | ||||
|             ref_tensor = torch.clamp( | ||||
|                 max=max_tensor, min=min_tensor, input=inp_tensor, out=ref_out_tensor | ||||
|             ) | ||||
|  | ||||
|             ref_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_( | ||||
|                 -1 | ||||
|             ) | ||||
|             ref_tensor1 = torch.clamp( | ||||
|                 max=max_tensor, out=ref_out_tensor1, min=min_tensor, input=inp_tensor | ||||
|             ) | ||||
|  | ||||
|             register_ops_with_aoti_compile( | ||||
|                 ns, [op_name], dispatch_key, torch_compile_op_lib_impl | ||||
|             ) | ||||
|  | ||||
|             res_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_( | ||||
|                 -1 | ||||
|             ) | ||||
|             res_tensor = torch.clamp( | ||||
|                 max=max_tensor, min=min_tensor, input=inp_tensor, out=res_out_tensor | ||||
|             ) | ||||
|  | ||||
|             self.assertEqual(ref_tensor, res_tensor) | ||||
|             self.assertEqual(ref_out_tensor, res_out_tensor) | ||||
|  | ||||
|             res_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_( | ||||
|                 -1 | ||||
|             ) | ||||
|             res_tensor1 = torch.clamp( | ||||
|                 max=max_tensor, out=res_out_tensor1, min=min_tensor, input=inp_tensor | ||||
|             ) | ||||
|  | ||||
|             self.assertEqual(ref_tensor1, res_tensor1) | ||||
|             self.assertEqual(ref_out_tensor1, res_out_tensor1) | ||||
|  | ||||
|     @skipCUDAIf(not SM80OrLater, "Requires sm80") | ||||
|     def test_eager_aoti_cache_hit(self): | ||||
|         ns = "aten" | ||||
| @ -779,24 +848,13 @@ class CommonTemplate: | ||||
|         with mock.patch( | ||||
|             "torch._inductor.utils.aoti_compile_with_persistent_cache", None | ||||
|         ): | ||||
|             qualified_op_name = f"{ns}::{op_name}" | ||||
|             _, overload_names = torch._C._jit_get_operation(qualified_op_name) | ||||
|  | ||||
|             with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: | ||||
|                 # Get ref result from eager | ||||
|                 ref_value = getattr(torch.ops.aten, op_name)(input_tensor) | ||||
|  | ||||
|                 for overload_name in overload_names: | ||||
|                     try: | ||||
|                         reg_op_name = qualified_op_name | ||||
|                         schema = torch._C._get_schema(qualified_op_name, overload_name) | ||||
|                         if schema.overload_name: | ||||
|                             reg_op_name = f"{qualified_op_name}.{schema.overload_name}" | ||||
|                         torch_compile_op_lib_impl._impl_with_aoti_compile(  # noqa: F821 | ||||
|                             reg_op_name, dispatch_key | ||||
|                         ) | ||||
|                     except Exception as e: | ||||
|                         continue | ||||
|                 register_ops_with_aoti_compile( | ||||
|                     ns, [op_name], dispatch_key, torch_compile_op_lib_impl | ||||
|                 ) | ||||
|  | ||||
|                 # Invoke the pre-compiled kernel and get result. | ||||
|                 res_value = getattr(torch.ops.aten, op_name)(input_tensor) | ||||
| @ -804,7 +862,7 @@ class CommonTemplate: | ||||
|                 self.assertEqual(ref_value, res_value) | ||||
|  | ||||
|     @skipCUDAIf(not SM80OrLater, "Requires sm80") | ||||
|     def test_aoti_compile_with_persistent_cache(self): | ||||
|     def test_eager_aoti_with_persistent_cache(self): | ||||
|         def fn(a): | ||||
|             return torch.abs(a) | ||||
|  | ||||
| @ -906,19 +964,9 @@ class CommonTemplate: | ||||
|             for scalar_value in scalar_values: | ||||
|                 ref_values.append(torch.add(a, b, alpha=scalar_value)) | ||||
|  | ||||
|             qualified_op_name = f"{namespace_name}::{op_name}" | ||||
|             _, overload_names = torch._C._jit_get_operation(qualified_op_name) | ||||
|             for overload_name in overload_names: | ||||
|                 try: | ||||
|                     reg_op_name = qualified_op_name | ||||
|                     schema = torch._C._get_schema(reg_op_name, overload_name) | ||||
|                     if schema.overload_name: | ||||
|                         reg_op_name = f"{reg_op_name}.{schema.overload_name}" | ||||
|                     torch_compile_op_lib_impl._impl_with_aoti_compile(  # noqa: F821 | ||||
|                         reg_op_name, dispatch_key | ||||
|                     ) | ||||
|                 except Exception as e: | ||||
|                     continue | ||||
|             register_ops_with_aoti_compile( | ||||
|                 namespace_name, [op_name], dispatch_key, torch_compile_op_lib_impl | ||||
|             ) | ||||
|  | ||||
|             res_values = [] | ||||
|             for scalar_value in scalar_values: | ||||
| @ -928,8 +976,7 @@ class CommonTemplate: | ||||
|             self.assertEqual(ref_values, res_values) | ||||
|  | ||||
|     @skipCUDAIf(not SM80OrLater, "Requires sm80") | ||||
|     def test_torch_compile_override_registration(self): | ||||
|         dynamic = False | ||||
|     def test_eager_aoti_override_registration(self): | ||||
|         namespace_name = "aten" | ||||
|         dispatch_key = "CPU" | ||||
|         device = torch.device("cpu") | ||||
| @ -951,24 +998,10 @@ class CommonTemplate: | ||||
|             ref = opt_fn(x) | ||||
|             ref_array.append(ref) | ||||
|  | ||||
|         def register_ops(op_set, dispatch_key, torch_compile_op_lib_impl): | ||||
|             for _op_name in op_set: | ||||
|                 qualified_op_name = f"{namespace_name}::{_op_name}" | ||||
|                 _, overload_names = torch._C._jit_get_operation(qualified_op_name) | ||||
|                 for overload_name in overload_names: | ||||
|                     try: | ||||
|                         reg_op_name = qualified_op_name | ||||
|                         schema = torch._C._get_schema(qualified_op_name, overload_name) | ||||
|                         if schema.overload_name: | ||||
|                             reg_op_name = f"{qualified_op_name}.{schema.overload_name}" | ||||
|                         torch_compile_op_lib_impl._impl_with_aoti_compile(  # noqa: F821 | ||||
|                             reg_op_name, dispatch_key | ||||
|                         ) | ||||
|                     except Exception as e: | ||||
|                         continue | ||||
|  | ||||
|         with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: | ||||
|             register_ops(unary_op_set, dispatch_key, torch_compile_op_lib_impl) | ||||
|             register_ops_with_aoti_compile( | ||||
|                 namespace_name, unary_op_set, dispatch_key, torch_compile_op_lib_impl | ||||
|             ) | ||||
|  | ||||
|             res_array = [] | ||||
|             for unary_op_name in unary_op_set: | ||||
| @ -985,7 +1018,9 @@ class CommonTemplate: | ||||
|         ref_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor) | ||||
|  | ||||
|         with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: | ||||
|             register_ops(["clamp"], dispatch_key, torch_compile_op_lib_impl) | ||||
|             register_ops_with_aoti_compile( | ||||
|                 namespace_name, ["clamp"], dispatch_key, torch_compile_op_lib_impl | ||||
|             ) | ||||
|             res_with_min = torch.ops.aten.clamp(a, min_tensor) | ||||
|             res_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor) | ||||
|             self.assertEqual(ref_with_min, res_with_min) | ||||
| @ -5502,6 +5537,14 @@ class CommonTemplate: | ||||
|         for dtype in all_types(): | ||||
|             self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),)) | ||||
|  | ||||
|     def test_full_boolean(self): | ||||
|         def fn(n): | ||||
|             x = torch.full((1,), n >= 1024, device=self.device) | ||||
|             return x, x + 1 | ||||
|  | ||||
|         self.common(fn, (1024,)) | ||||
|         self.common(fn, (1023,)) | ||||
|  | ||||
|     def test_index1(self): | ||||
|         def fn(a, b, c): | ||||
|             return aten.index(a, [b, c]) | ||||
| @ -7816,6 +7859,95 @@ class CommonTemplate: | ||||
|         ) | ||||
|         assertGeneratedKernelCountEqual(self, 0) | ||||
|  | ||||
|     def test_avg_pool3d_backward(self): | ||||
|         def fn(a, b): | ||||
|             return aten.avg_pool3d_backward( | ||||
|                 a, | ||||
|                 b, | ||||
|                 [2, 2, 2], | ||||
|                 [2, 2, 2], | ||||
|                 [0, 0, 0], | ||||
|                 True, | ||||
|                 False, | ||||
|                 None, | ||||
|             ) | ||||
|  | ||||
|         self.common( | ||||
|             fn, | ||||
|             [ | ||||
|                 torch.randn([2, 4, 7, 7, 7]), | ||||
|                 torch.randn([2, 4, 14, 14, 14]), | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|     def test_avg_pool3d_backward2(self): | ||||
|         def fn(a, b): | ||||
|             return aten.avg_pool3d_backward( | ||||
|                 a, | ||||
|                 b, | ||||
|                 [3, 3, 3], | ||||
|                 [1, 1, 1], | ||||
|                 [1, 1, 1], | ||||
|                 True, | ||||
|                 False, | ||||
|                 None, | ||||
|             ) | ||||
|  | ||||
|         self.common( | ||||
|             fn, | ||||
|             [ | ||||
|                 torch.randn([1, 1, 20, 20, 15]), | ||||
|                 torch.randn([1, 1, 20, 20, 15]), | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|     def test_avg_pool3d_backward3(self): | ||||
|         def fn(a, b): | ||||
|             return aten.avg_pool3d_backward( | ||||
|                 a, | ||||
|                 b, | ||||
|                 [1, 1, 1], | ||||
|                 [2, 2, 2], | ||||
|                 [0, 0, 0], | ||||
|                 False, | ||||
|                 False, | ||||
|                 None, | ||||
|             ) | ||||
|  | ||||
|         torch._inductor.metrics.generated_kernel_count = 0 | ||||
|         self.common( | ||||
|             fn, | ||||
|             [ | ||||
|                 torch.randn([1, 2016, 11, 11, 11]), | ||||
|                 torch.randn([1, 2016, 21, 21, 21]), | ||||
|             ], | ||||
|         ) | ||||
|         assertGeneratedKernelCountEqual(self, 1) | ||||
|  | ||||
|     def test_avg_pool3d_backward4(self): | ||||
|         def fn(a, b): | ||||
|             return aten.avg_pool3d_backward( | ||||
|                 a, | ||||
|                 b, | ||||
|                 [13, 13, 13], | ||||
|                 [1, 1, 1], | ||||
|                 [0, 0, 0], | ||||
|                 True, | ||||
|                 False, | ||||
|                 None, | ||||
|             ) | ||||
|  | ||||
|         torch._inductor.metrics.generated_kernel_count = 0 | ||||
|         self.common( | ||||
|             fn, | ||||
|             [ | ||||
|                 torch.randn([1, 16, 12, 12, 12]), | ||||
|                 torch.randn([1, 16, 24, 24, 24]), | ||||
|             ], | ||||
|             check_lowp=False, | ||||
|         ) | ||||
|         assertGeneratedKernelCountEqual(self, 0) | ||||
|  | ||||
|     @config.patch(search_autotune_cache=False) | ||||
|     def test_mm_views(self): | ||||
|         def fn(a, b): | ||||
|  | ||||
| @ -121,6 +121,7 @@ test_failures = { | ||||
|     "test_conv2d_channels_last_dynamic_shapes": TestFailure(("cpu",)), | ||||
|     "test_conv3d_channels_last_dynamic_shapes": TestFailure(("cpu",)), | ||||
|     "test_expand_dynamic_shapes": TestFailure(("cpu",)), | ||||
|     "test_full_boolean_dynamic_shapes": TestFailure(("cpu",)), | ||||
|     "test_glu_dynamic_shapes": TestFailure(("cpu",)), | ||||
|     "test_isinf2_dynamic_shapes": TestFailure(("cpu",)), | ||||
|     "test_linspace1_dynamic_shapes": TestFailure(("cpu",)), | ||||
| @ -146,6 +147,7 @@ test_failures = { | ||||
|     "test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda")), | ||||
|     "test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda")), | ||||
|     "test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")), | ||||
|     "test_avg_pool3d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")), | ||||
|     "test_baddbmm_dynamic_shapes": TestFailure(("cpu", "cuda")), | ||||
|     "test_bmm2_dynamic_shapes": TestFailure(("cpu", "cuda")), | ||||
|     "test_both_scalars_dynamic_shapes": TestFailure(("cpu", "cuda")), | ||||
|  | ||||
| @ -8,6 +8,7 @@ import os | ||||
| import sys | ||||
| import unittest | ||||
| from functools import partial | ||||
| from typing import List | ||||
|  | ||||
| import torch | ||||
| import torch.library | ||||
| @ -369,6 +370,47 @@ class TestInductorDynamic(TestCase): | ||||
|         arg = torch.tensor(5, device=device) | ||||
|         self.assertEqual(f(arg), cf(arg)) | ||||
|  | ||||
|     @torch._dynamo.config.patch( | ||||
|         capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True | ||||
|     ) | ||||
|     @torch._inductor.config.patch(implicit_fallbacks=True) | ||||
|     def test_unbacked_save_for_backwards(self, device) -> None: | ||||
|         @torch.library.custom_op("_test::_cat", mutates_args=()) | ||||
|         def _cat(t: torch.Tensor, ds: List[int]) -> torch.Tensor: | ||||
|             return t * t.new_ones([sum(ds)]) | ||||
|  | ||||
|         @torch.library.register_fake("_test::_cat") | ||||
|         def _cat_fake(t: torch.Tensor, ds: List[int]) -> torch.Tensor: | ||||
|             [torch._check_is_size(d) for d in ds] | ||||
|             return t.new_empty([sum(ds)]) | ||||
|  | ||||
|         def _cat_setup_context(ctx, inputs, output): | ||||
|             pass | ||||
|  | ||||
|         def _cat_backward(ctx, grad): | ||||
|             return grad.sum(), None | ||||
|  | ||||
|         torch.library.register_autograd( | ||||
|             "_test::_cat", | ||||
|             _cat_backward, | ||||
|             setup_context=_cat_setup_context, | ||||
|         ) | ||||
|  | ||||
|         def fn(t, sizes): | ||||
|             r = torch.ops._test._cat(t, sizes.tolist()) | ||||
|             return r * t | ||||
|  | ||||
|         t = torch.randn((), requires_grad=True, device=device) | ||||
|         sizes = torch.tensor([4, 8], dtype=torch.int64, device="cpu") | ||||
|         out = fn(t, sizes) | ||||
|         out.sum().backward() | ||||
|         expect = t.grad | ||||
|         t.grad = None | ||||
|         torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)( | ||||
|             t, sizes | ||||
|         ).sum().backward() | ||||
|         self.assertEqual(t.grad, expect) | ||||
|  | ||||
|     @torch._dynamo.config.patch(capture_scalar_outputs=True) | ||||
|     def test_unbacked_reduction(self, device): | ||||
|         expect_fail = device == "cpu" and not IS_ARM64 | ||||
|  | ||||
| @ -2333,3 +2333,14 @@ def _save_pickle(obj: Any) -> bytes: ... | ||||
| # Defined in torch/csrc/jit/runtime/static/init.cpp | ||||
| def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ... | ||||
| def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ... | ||||
|  | ||||
| # Defined in torch/csrc/fx/node.cpp | ||||
| class _NodeBase: | ||||
|     _erased: _bool | ||||
|     _prev: "_NodeBase" | ||||
|     _next: "_NodeBase" | ||||
|  | ||||
| class _NodeIter(Iterator): | ||||
|     def __init__(self, root: _NodeBase, reversed: _bool) -> None: ... | ||||
|     def __iter__(self) -> Iterator[_NodeBase]: ... | ||||
|     def __next__(self) -> _NodeBase: ... | ||||
|  | ||||
| @ -752,7 +752,13 @@ class OutputGraph: | ||||
|         **options, | ||||
|     ): | ||||
|         if is_dynamic_nn_module(target, self.root_tx.export): | ||||
|             return variables.UnspecializedNNModuleVariable(target, **options) | ||||
|             result = variables.UnspecializedNNModuleVariable(target, **options) | ||||
|             if not SideEffects.cls_supports_mutation_side_effects(type(target)): | ||||
|                 # don't allow STORE_ATTR mutation with custom __setattr__ | ||||
|                 return result | ||||
|             return self.root_tx.output.side_effects.track_object_existing( | ||||
|                 target, result | ||||
|             ) | ||||
|  | ||||
|         options = dict(options) | ||||
|         assert "source" in options | ||||
|  | ||||
| @ -1128,6 +1128,19 @@ class VariableBuilder: | ||||
|         if mutation_guard.is_dynamic_nn_module(value, self.tx.export): | ||||
|             # created dynamically, don't specialize on it | ||||
|             self.install_guards(GuardBuilder.TYPE_MATCH) | ||||
|             if ( | ||||
|                 torch._dynamo.config.inline_inbuilt_nn_modules | ||||
|                 and torch._inductor.config.freezing | ||||
|                 and not torch.is_grad_enabled() | ||||
|             ): | ||||
|                 from ..decorators import mark_static_address | ||||
|  | ||||
|                 for p in value.parameters(): | ||||
|                     mark_static_address(p) | ||||
|  | ||||
|                 for b in value.buffers(): | ||||
|                     mark_static_address(b) | ||||
|  | ||||
|             result = UnspecializedNNModuleVariable(value, source=self.source) | ||||
|             if not SideEffects.cls_supports_mutation_side_effects(type(value)): | ||||
|                 # don't allow STORE_ATTR mutation with custom __setattr__ | ||||
|  | ||||
| @ -1021,7 +1021,7 @@ class FxGraphCache: | ||||
|                 cache_id = "fx-graph-v1" | ||||
|                 try: | ||||
|                     if config.is_fbcode(): | ||||
|                         from triton.runtime.fb_memcache import ( | ||||
|                         from triton.fb.fb_memcache import ( | ||||
|                             FbMemcacheRemoteFxGraphCacheBackend, | ||||
|                         ) | ||||
|  | ||||
|  | ||||
| @ -201,11 +201,19 @@ def _unlift_graph(mod, gm, graph_signature): | ||||
|  | ||||
|     outputs = list(gm.graph.nodes)[-1].args[0] | ||||
|     mutated_outputs = [] | ||||
|     for out in outputs: | ||||
|         if out.name in graph_signature.buffers_to_mutate: | ||||
|             mutated_outputs.append(graph_signature.buffers_to_mutate[out.name]) | ||||
|         else: | ||||
|             mutated_outputs.append(None) | ||||
|     buffer_mutations = graph_signature.buffers_to_mutate | ||||
|     user_input_mutations = graph_signature.user_inputs_to_mutate | ||||
|     output_tokens = graph_signature.output_tokens | ||||
|     for idx, out in enumerate(outputs): | ||||
|         value = None | ||||
|  | ||||
|         if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): | ||||
|             if out.name in buffer_mutations: | ||||
|                 value = buffer_mutations[out.name] | ||||
|             elif out.name in user_input_mutations: | ||||
|                 value = user_input_mutations[out.name] | ||||
|  | ||||
|         mutated_outputs.append(value) | ||||
|  | ||||
|     unlifted_gm = _unlift( | ||||
|         gm, | ||||
| @ -392,7 +400,7 @@ def should_use_remote_fx_graph_cache(): | ||||
|         return False | ||||
|  | ||||
|     try: | ||||
|         from triton.runtime.fb_memcache import MEMCACHE_VERSION | ||||
|         from triton.fb.fb_memcache import MEMCACHE_VERSION | ||||
|     except ModuleNotFoundError: | ||||
|         return False | ||||
|  | ||||
|  | ||||
| @ -177,7 +177,7 @@ def is_boolean_type(x): | ||||
|  | ||||
| def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND): | ||||
|     def construct_input(inp): | ||||
|         if isinstance(inp, (Number, sympy.Expr)): | ||||
|         if isinstance(inp, (Number, sympy.Basic)): | ||||
|             return inp | ||||
|         else: | ||||
|             assert hasattr(inp, "get_dtype") | ||||
| @ -216,7 +216,7 @@ def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool): | ||||
|             promoting_args = [ | ||||
|                 a | ||||
|                 for a in args | ||||
|                 if isinstance(a, (Number, sympy.Expr)) | ||||
|                 if isinstance(a, (Number, sympy.Basic)) | ||||
|                 or getattr(a, "dtype", None) is not None | ||||
|             ] | ||||
|             dtype = get_promoted_dtype( | ||||
| @ -368,15 +368,15 @@ def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=No | ||||
|     if override_return_dtype is None and type_promotion_kind is None: | ||||
|         type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT | ||||
|  | ||||
|     if not any(isinstance(x, (sympy.Expr, int, float)) for x in inputs): | ||||
|     if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs): | ||||
|         return inputs | ||||
|     if all(isinstance(x, (int, float, sympy.Expr)) for x in inputs): | ||||
|     if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs): | ||||
|         dtype = override_return_dtype or get_promoted_dtype( | ||||
|             *inputs, type_promotion_kind=type_promotion_kind | ||||
|         ) | ||||
|  | ||||
|         def const_func(x): | ||||
|             if isinstance(x, sympy.Expr): | ||||
|             if isinstance(x, sympy.Basic): | ||||
|                 return ir.IndexingConstant(x, dtype, decode_device(None)) | ||||
|             else: | ||||
|                 return ir.Constant(x, dtype, decode_device(None)) | ||||
| @ -391,7 +391,7 @@ def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=No | ||||
|                     ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size()) | ||||
|                 ) | ||||
|             ) | ||||
|         elif isinstance(x, sympy.Expr): | ||||
|         elif isinstance(x, sympy.Basic): | ||||
|             out.append( | ||||
|                 ExpandView.create( | ||||
|                     IndexingConstant(x, ex.get_dtype(), ex.get_device()), | ||||
| @ -2155,7 +2155,6 @@ make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) | ||||
|  | ||||
|  | ||||
| # 4) Backwards (try py_impl'ing them) when fwd is written as a decomp | ||||
| make_fallback(aten.avg_pool3d_backward) | ||||
| make_fallback(aten.max_pool3d_with_indices_backward) | ||||
| make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) | ||||
| make_fallback(aten._adaptive_avg_pool3d_backward) | ||||
| @ -2471,7 +2470,7 @@ def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): | ||||
|  | ||||
|     ranges: List[sympy.Expr] = [] | ||||
|  | ||||
|     if isinstance(data, sympy.Expr): | ||||
|     if isinstance(data, sympy.Basic): | ||||
|  | ||||
|         def inner_fn(index): | ||||
|             return ops.index_expr(data, dtype) | ||||
| @ -2597,7 +2596,7 @@ def _full(fill_value, device, dtype, size): | ||||
|         def inner_fn(index): | ||||
|             return ops.constant(value, dtype) | ||||
|  | ||||
|     elif isinstance(value, sympy.Expr): | ||||
|     elif isinstance(value, sympy.Basic): | ||||
|  | ||||
|         def inner_fn(index): | ||||
|             return ops.index_expr(value, dtype) | ||||
| @ -4034,11 +4033,32 @@ def pad_adaptive_loader(x, pad_val=0.0): | ||||
|     return load | ||||
|  | ||||
|  | ||||
| def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns): | ||||
|     h_start_index_fn, w_start_index_fn = start_index_fns | ||||
|     h_end_index_fn, w_end_index_fn = end_index_fns | ||||
| def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out): | ||||
|     h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) | ||||
|     h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) | ||||
|  | ||||
|     def fn_sum(idx, loader): | ||||
|     w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) | ||||
|     w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) | ||||
|  | ||||
|     return h_start_index, h_end_index, w_start_index, w_end_index | ||||
|  | ||||
|  | ||||
| def _adaptive_pooling_fn( | ||||
|     start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn | ||||
| ): | ||||
|     h_in, w_in = in_sizes | ||||
|     h_out, w_out = out_sizes | ||||
|  | ||||
|     ( | ||||
|         h_start_index_fn, | ||||
|         h_end_index_fn, | ||||
|         w_start_index_fn, | ||||
|         w_end_index_fn, | ||||
|     ) = compute_indices_adaptive_pooling( | ||||
|         start_index, end_index, h_in, w_in, h_out, w_out | ||||
|     ) | ||||
|  | ||||
|     def fn(idx, loader): | ||||
|         *prefix, bh, bw = idx | ||||
|  | ||||
|         h_start_index = h_start_index_fn(bh) | ||||
| @ -4047,7 +4067,7 @@ def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns): | ||||
|         w_start_index = w_start_index_fn(bw) | ||||
|         w_end_index = w_end_index_fn(bw) | ||||
|  | ||||
|         total = None | ||||
|         result = None | ||||
|         for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): | ||||
|             val = loader( | ||||
|                 prefix, | ||||
| @ -4055,13 +4075,66 @@ def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns): | ||||
|                 [h_start_index, w_start_index], | ||||
|                 [h_end_index, w_end_index], | ||||
|             ) | ||||
|             if total is None: | ||||
|                 total = val | ||||
|             if result is None: | ||||
|                 result = val | ||||
|             else: | ||||
|                 total = ops.add(val, total) | ||||
|         return total | ||||
|                 result = pooling_fn(val, result) | ||||
|         return result | ||||
|  | ||||
|     return fn_sum | ||||
|     return fn | ||||
|  | ||||
|  | ||||
| def _adaptive_pooling_fn_with_idx( | ||||
|     start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn | ||||
| ): | ||||
|     h_in, w_in = in_sizes | ||||
|     h_out, w_out = out_sizes | ||||
|  | ||||
|     ( | ||||
|         h_start_index_fn, | ||||
|         h_end_index_fn, | ||||
|         w_start_index_fn, | ||||
|         w_end_index_fn, | ||||
|     ) = compute_indices_adaptive_pooling( | ||||
|         start_index, end_index, h_in, w_in, h_out, w_out | ||||
|     ) | ||||
|  | ||||
|     def fn(idx, loader): | ||||
|         *prefix, bh, bw = idx | ||||
|  | ||||
|         h_start_index = h_start_index_fn(bh) | ||||
|         h_end_index = h_end_index_fn(bh) | ||||
|  | ||||
|         w_start_index = w_start_index_fn(bw) | ||||
|         w_end_index = w_end_index_fn(bw) | ||||
|  | ||||
|         maxval = None | ||||
|         maxindex = None | ||||
|         for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): | ||||
|             val = loader( | ||||
|                 prefix, | ||||
|                 [ih, iw], | ||||
|                 [h_start_index, w_start_index], | ||||
|                 [h_end_index, w_end_index], | ||||
|             ) | ||||
|  | ||||
|             index = ops.index_expr( | ||||
|                 (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 | ||||
|             ) | ||||
|  | ||||
|             if maxindex is None: | ||||
|                 maxindex = index | ||||
|             else: | ||||
|                 maxindex = ops.where(ops.gt(val, maxval), index, maxindex) | ||||
|  | ||||
|             if maxval is None: | ||||
|                 maxval = val | ||||
|             else: | ||||
|                 maxval = pooling_fn(val, maxval) | ||||
|  | ||||
|         return maxindex | ||||
|  | ||||
|     return fn | ||||
|  | ||||
|  | ||||
| fallback_adaptive_avg_pool2d = fallback_handler( | ||||
| @ -4099,27 +4172,24 @@ def _adaptive_avg_pool2d(x, output_size): | ||||
|     new_size = list(batch) + [h_out, w_out] | ||||
|     dtype = x.get_dtype() | ||||
|  | ||||
|     window_size = h_kernel_max * w_kernel_max | ||||
|     if window_size > 25: | ||||
|         # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. | ||||
|         return fallback_adaptive_avg_pool2d(x, output_size) | ||||
|  | ||||
|     def start_index(index, out_dim, inp_dim): | ||||
|         return FloorDiv((index * inp_dim), out_dim) | ||||
|  | ||||
|     def end_index(index, out_dim, inp_dim): | ||||
|         return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) | ||||
|  | ||||
|     h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) | ||||
|     h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) | ||||
|  | ||||
|     w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) | ||||
|     w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) | ||||
|  | ||||
|     window_size = h_kernel_max * w_kernel_max | ||||
|     if window_size > 25: | ||||
|         # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. | ||||
|         return fallback_adaptive_avg_pool2d(x, output_size) | ||||
|  | ||||
|     fn_sum = _adaptive_pooling_idx_sum( | ||||
|         [h_kernel_max, w_kernel_max], | ||||
|         [h_start_index, w_start_index], | ||||
|         [h_end_index, w_end_index], | ||||
|     fn_sum = _adaptive_pooling_fn( | ||||
|         start_index=start_index, | ||||
|         end_index=end_index, | ||||
|         kernel_maxes=[h_kernel_max, w_kernel_max], | ||||
|         in_sizes=[h_in, w_in], | ||||
|         out_sizes=[h_out, w_out], | ||||
|         pooling_fn=ops.add, | ||||
|     ) | ||||
|  | ||||
|     ones_loader = pad_adaptive_loader(ones_like(x)) | ||||
| @ -4139,60 +4209,6 @@ def _adaptive_avg_pool2d(x, output_size): | ||||
|     return rv | ||||
|  | ||||
|  | ||||
| def _adaptive_pooling_idx_max(kernel_maxes, in_sizes, out_sizes, return_index, loader): | ||||
|     # NOTE: There is some duplication between this and addaptive_avg_pool2d and max_pool2d | ||||
|     # Look into refactoring/deduplication after #116418 is merged. | ||||
|     h_in, w_in = in_sizes | ||||
|     h_out, w_out = out_sizes | ||||
|  | ||||
|     def start_index(index, out_dim, inp_dim): | ||||
|         return FloorDiv((index * inp_dim), out_dim) | ||||
|  | ||||
|     def end_index(index, out_dim, inp_dim): | ||||
|         return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) | ||||
|  | ||||
|     h_start_index_fn = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) | ||||
|     h_end_index_fn = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) | ||||
|     w_start_index_fn = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) | ||||
|     w_end_index_fn = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) | ||||
|  | ||||
|     def fn_max(idx): | ||||
|         *prefix, bh, bw = idx | ||||
|  | ||||
|         h_start_index = h_start_index_fn(bh) | ||||
|         h_end_index = h_end_index_fn(bh) | ||||
|  | ||||
|         w_start_index = w_start_index_fn(bw) | ||||
|         w_end_index = w_end_index_fn(bw) | ||||
|         maxval = None | ||||
|         maxindex = None | ||||
|         for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): | ||||
|             val = loader( | ||||
|                 prefix, | ||||
|                 [ih, iw], | ||||
|                 [h_start_index, w_start_index], | ||||
|                 [h_end_index, w_end_index], | ||||
|             ) | ||||
|             index = ops.index_expr( | ||||
|                 (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 | ||||
|             ) | ||||
|             if return_index: | ||||
|                 if maxindex is None: | ||||
|                     maxindex = index | ||||
|                 else: | ||||
|                     maxindex = ops.where(ops.gt(val, maxval), index, maxindex) | ||||
|             if maxval is None: | ||||
|                 maxval = val | ||||
|             else: | ||||
|                 maxval = ops.maximum(val, maxval) | ||||
|         if return_index: | ||||
|             return maxindex | ||||
|         else: | ||||
|             return maxval | ||||
|  | ||||
|     return fn_max | ||||
|  | ||||
|  | ||||
| fallback_adaptive_max_pool2d = fallback_handler( | ||||
|     aten.adaptive_max_pool2d.default, add_to_fallback_set=False | ||||
| ) | ||||
| @ -4245,32 +4261,46 @@ def adaptive_max_pool2d(x, output_size): | ||||
|         # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. | ||||
|         return fallback_adaptive_max_pool2d(x, output_size) | ||||
|  | ||||
|     inner_func_max_val = _adaptive_pooling_idx_max( | ||||
|     def start_index(index, out_dim, inp_dim): | ||||
|         return FloorDiv((index * inp_dim), out_dim) | ||||
|  | ||||
|     def end_index(index, out_dim, inp_dim): | ||||
|         return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) | ||||
|  | ||||
|     inner_func_max_val = _adaptive_pooling_fn( | ||||
|         start_index=start_index, | ||||
|         end_index=end_index, | ||||
|         kernel_maxes=[h_kernel_max, w_kernel_max], | ||||
|         in_sizes=[h_in, w_in], | ||||
|         out_sizes=[h_out, w_out], | ||||
|         return_index=False, | ||||
|         loader=pad_adaptive_loader(x, float("-inf")), | ||||
|         pooling_fn=ops.maximum, | ||||
|     ) | ||||
|  | ||||
|     inner_func_max_idx = _adaptive_pooling_idx_max( | ||||
|     inner_func_max_idx = _adaptive_pooling_fn_with_idx( | ||||
|         start_index=start_index, | ||||
|         end_index=end_index, | ||||
|         kernel_maxes=[h_kernel_max, w_kernel_max], | ||||
|         in_sizes=[h_in, w_in], | ||||
|         out_sizes=[h_out, w_out], | ||||
|         return_index=True, | ||||
|         loader=pad_adaptive_loader(x, float("-inf")), | ||||
|         pooling_fn=ops.maximum, | ||||
|     ) | ||||
|  | ||||
|     def inner_fn_max_val(idx): | ||||
|         return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf"))) | ||||
|  | ||||
|     def inner_fn_max_idx(idx): | ||||
|         return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf"))) | ||||
|  | ||||
|     rv = Pointwise.create( | ||||
|         device=x.get_device(), | ||||
|         dtype=dtype, | ||||
|         inner_fn=inner_func_max_val, | ||||
|         inner_fn=inner_fn_max_val, | ||||
|         ranges=new_size, | ||||
|     ) | ||||
|     ri = Pointwise.create( | ||||
|         device=x.get_device(), | ||||
|         dtype=torch.int64, | ||||
|         inner_fn=inner_func_max_idx, | ||||
|         inner_fn=inner_fn_max_idx, | ||||
|         ranges=new_size, | ||||
|     ) | ||||
|     return rv, ri | ||||
| @ -4400,16 +4430,13 @@ def upsample_nearest2d_backward( | ||||
|     def end_index(index, out_dim, inp_dim): | ||||
|         return start_index((index + 1), out_dim, inp_dim) | ||||
|  | ||||
|     h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h) | ||||
|     h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h) | ||||
|  | ||||
|     w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w) | ||||
|     w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w) | ||||
|  | ||||
|     fn_sum = _adaptive_pooling_idx_sum( | ||||
|         [h_kernel_max, w_kernel_max], | ||||
|         [h_start_index, w_start_index], | ||||
|         [h_end_index, w_end_index], | ||||
|     fn_sum = _adaptive_pooling_fn( | ||||
|         start_index=start_index, | ||||
|         end_index=end_index, | ||||
|         kernel_maxes=[h_kernel_max, w_kernel_max], | ||||
|         in_sizes=[inp_h, inp_w], | ||||
|         out_sizes=[out_h, out_w], | ||||
|         pooling_fn=ops.add, | ||||
|     ) | ||||
|  | ||||
|     def fn(idx): | ||||
| @ -4761,6 +4788,207 @@ def avg_pool2d_backward( | ||||
|     return rv | ||||
|  | ||||
|  | ||||
| fallback_avg_pool3d_backward = fallback_handler( | ||||
|     aten.avg_pool3d_backward.default, add_to_fallback_set=False | ||||
| ) | ||||
|  | ||||
|  | ||||
| @register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None) | ||||
| def avg_pool3d_backward( | ||||
|     grad_output, | ||||
|     x, | ||||
|     kernel_size, | ||||
|     stride, | ||||
|     padding, | ||||
|     ceil_mode, | ||||
|     count_include_pad, | ||||
|     divisor_override=None, | ||||
| ): | ||||
|     assert divisor_override is None or divisor_override != 0, "divisor must be not zero" | ||||
|     if not stride: | ||||
|         stride = kernel_size | ||||
|     if not padding: | ||||
|         padding = [0, 0, 0] | ||||
|  | ||||
|     assert isinstance(grad_output, TensorBox) | ||||
|     assert isinstance(x, TensorBox) | ||||
|     assert len(kernel_size) == 3 | ||||
|     assert len(stride) == 3 | ||||
|     assert len(padding) == 3 | ||||
|     assert len(x.get_size()) in (4, 5) | ||||
|  | ||||
|     grad_output.realize_hint() | ||||
|  | ||||
|     *batch, depth, height, width = x.get_size() | ||||
|  | ||||
|     d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode) | ||||
|     h_out, ceil_mode_h = pooling_size( | ||||
|         height, 1, kernel_size, stride, padding, ceil_mode | ||||
|     ) | ||||
|     w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode) | ||||
|  | ||||
|     grad_loader = grad_output.make_loader() | ||||
|     had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w | ||||
|  | ||||
|     *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size() | ||||
|     new_size = list(x.get_size()) | ||||
|     dtype = x.get_dtype() | ||||
|  | ||||
|     d_window_size, h_window_size, w_window_size = ( | ||||
|         max( | ||||
|             max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1) | ||||
|             for d in range(kernel_size[i] * 2) | ||||
|         ) | ||||
|         for i in range(3) | ||||
|     ) | ||||
|  | ||||
|     window_size = d_window_size * h_window_size * w_window_size | ||||
|     if window_size > 125: | ||||
|         # Kernel size too big. Results in hard-to-optimize Triton code. | ||||
|         return fallback_avg_pool3d_backward( | ||||
|             grad_output, | ||||
|             x, | ||||
|             kernel_size, | ||||
|             stride, | ||||
|             padding, | ||||
|             ceil_mode, | ||||
|             count_include_pad, | ||||
|             divisor_override, | ||||
|         ) | ||||
|  | ||||
|     def compute_pool_size_without_padding(pd, ph, pw): | ||||
|         stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride) | ||||
|         pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding) | ||||
|         kernel_d, kernel_h, kernel_w = ( | ||||
|             ops.constant(k, torch.int32) for k in kernel_size | ||||
|         ) | ||||
|  | ||||
|         dstart, hstart, wstart = ( | ||||
|             ops.sub(ops.mul(p, s), pad) | ||||
|             for p, s, pad in zip( | ||||
|                 [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w] | ||||
|             ) | ||||
|         ) | ||||
|         dend, hend, wend = ( | ||||
|             ops.minimum( | ||||
|                 ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad) | ||||
|             ) | ||||
|             for start, k, dim, pad in zip( | ||||
|                 [dstart, hstart, wstart], | ||||
|                 [kernel_d, kernel_h, kernel_w], | ||||
|                 [depth, height, width], | ||||
|                 [pad_d, pad_h, pad_w], | ||||
|             ) | ||||
|         ) | ||||
|         dstart, hstart, wstart = ( | ||||
|             ops.maximum(start, ops.constant(0, torch.int32)) | ||||
|             for start in [dstart, hstart, wstart] | ||||
|         ) | ||||
|         dend, hend, wend = ( | ||||
|             ops.minimum(end, ops.index_expr(dim, torch.int32)) | ||||
|             for end, dim in zip([dend, hend, wend], [depth, height, width]) | ||||
|         ) | ||||
|         divide_factor = ops.mul( | ||||
|             ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart) | ||||
|         ) | ||||
|         return divide_factor | ||||
|  | ||||
|     def fn(idx): | ||||
|         *prefix, d, h, w = idx | ||||
|         d, h, w = (v + pad for v, pad in zip([d, h, w], padding)) | ||||
|  | ||||
|         pdstart, phstart, pwstart = ( | ||||
|             ops.index_expr(FloorDiv(v - k + s, s), torch.int32) | ||||
|             for v, k, s in zip([d, h, w], kernel_size, stride) | ||||
|         ) | ||||
|  | ||||
|         pdend, phend, pwend = ( | ||||
|             ops.index_expr(FloorDiv(v, s) + 1, torch.int32) | ||||
|             for v, s in zip([d, h, w], stride) | ||||
|         ) | ||||
|  | ||||
|         pdstart, phstart, pwstart = ( | ||||
|             ops.maximum(pstart, ops.constant(0, torch.int32)) | ||||
|             for pstart in [pdstart, phstart, pwstart] | ||||
|         ) | ||||
|         pdend, phend, pwend = ( | ||||
|             ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32)) | ||||
|             for pend, pooled_dim in zip( | ||||
|                 [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width] | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         gradient = None | ||||
|         # Iterate over the 3D region to accumulate gradients | ||||
|         for pd_ in range(d_window_size): | ||||
|             for ph_ in range(h_window_size): | ||||
|                 for pw_ in range(w_window_size): | ||||
|                     pd, ph, pw = ( | ||||
|                         ops.add(pstart, ops.constant(p_, torch.int32)) | ||||
|                         for pstart, p_ in zip( | ||||
|                             [pdstart, phstart, pwstart], [pd_, ph_, pw_] | ||||
|                         ) | ||||
|                     ) | ||||
|  | ||||
|                     if divisor_override is not None: | ||||
|                         scale = divisor_override | ||||
|                     elif count_include_pad or not had_padding: | ||||
|                         scale = kernel_size[0] * kernel_size[1] * kernel_size[2] | ||||
|                     else: | ||||
|                         scale = compute_pool_size_without_padding(pd, ph, pw) | ||||
|  | ||||
|                     part = ops.truediv( | ||||
|                         grad_loader( | ||||
|                             [ | ||||
|                                 *prefix, | ||||
|                                 ops.indirect_indexing( | ||||
|                                     ops.minimum( | ||||
|                                         pd, ops.sub(pdend, ops.constant(1, torch.int32)) | ||||
|                                     ), | ||||
|                                     pooled_depth, | ||||
|                                     check=False, | ||||
|                                 ), | ||||
|                                 ops.indirect_indexing( | ||||
|                                     ops.minimum( | ||||
|                                         ph, ops.sub(phend, ops.constant(1, torch.int32)) | ||||
|                                     ), | ||||
|                                     pooled_height, | ||||
|                                     check=False, | ||||
|                                 ), | ||||
|                                 ops.indirect_indexing( | ||||
|                                     ops.minimum( | ||||
|                                         pw, ops.sub(pwend, ops.constant(1, torch.int32)) | ||||
|                                     ), | ||||
|                                     pooled_width, | ||||
|                                     check=False, | ||||
|                                 ), | ||||
|                             ] | ||||
|                         ), | ||||
|                         scale, | ||||
|                     ) | ||||
|  | ||||
|                     mask = ops.and_( | ||||
|                         ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)), | ||||
|                         ops.lt(pw, pwend), | ||||
|                     ) | ||||
|                     if gradient is None: | ||||
|                         gradient = ops.where( | ||||
|                             mask, part, ops.constant(0.0, torch.float32) | ||||
|                         ) | ||||
|                     else: | ||||
|                         gradient = ops.where(mask, ops.add(gradient, part), gradient) | ||||
|         assert gradient is not None | ||||
|         return gradient | ||||
|  | ||||
|     rv = Pointwise.create( | ||||
|         device=grad_output.get_device(), | ||||
|         dtype=dtype, | ||||
|         inner_fn=fn, | ||||
|         ranges=new_size, | ||||
|     ) | ||||
|     return rv | ||||
|  | ||||
|  | ||||
| def _validate_reduction_axis(x, axis): | ||||
|     size = x.get_size() | ||||
|     if isinstance(axis, int): | ||||
|  | ||||
| @ -1031,7 +1031,7 @@ def should_use_remote_autotune_cache(inductor_meta): | ||||
|     if inductor_meta.get("is_hip"): | ||||
|         return False | ||||
|  | ||||
|     from triton.runtime.fb_memcache import MEMCACHE_VERSION | ||||
|     from triton.fb.fb_memcache import MEMCACHE_VERSION | ||||
|  | ||||
|     return MEMCACHE_VERSION >= torch._utils_internal.justknobs_getval_int( | ||||
|         "pytorch/remote_cache:autotune_memcache_version" | ||||
| @ -1075,8 +1075,12 @@ def cached_autotune( | ||||
|  | ||||
|                 try: | ||||
|                     if inductor_meta.get("is_fbcode"): | ||||
|                         remote_cache = triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend( | ||||
|                             key | ||||
|                         import triton.fb.fb_memcache | ||||
|  | ||||
|                         remote_cache = ( | ||||
|                             triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend( | ||||
|                                 key | ||||
|                             ) | ||||
|                         ) | ||||
|                     else: | ||||
|                         from torch._inductor.remote_cache import RedisRemoteCacheBackend | ||||
|  | ||||
| @ -193,7 +193,16 @@ class SizeVarAllocator: | ||||
|         """ | ||||
|         sizes = list(map(self.simplify, sizes)) | ||||
|  | ||||
|         strides = [self.stride_vars(x, index_vars) for x in index_formulas] | ||||
|         strides = [ | ||||
|             # index_formulas may contain boolean expressions (e.g. s0 < 10), | ||||
|             # for which "strides" don't make sense so we ignore them here. | ||||
|             # NOTE: These expressions may still block merging dims in the sound | ||||
|             # substitution test performed in can_merge_dims. | ||||
|             self.stride_vars(x, index_vars) | ||||
|             if isinstance(x, sympy.Expr) | ||||
|             else [0] * len(index_vars) | ||||
|             for x in index_formulas | ||||
|         ] | ||||
|         assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) | ||||
|  | ||||
|         for i in range(len(sizes)): | ||||
|  | ||||
| @ -1046,17 +1046,17 @@ def type_to_dtype(typ: type) -> torch.dtype: | ||||
|  | ||||
|     assert isinstance(typ, type) | ||||
|  | ||||
|     if typ is bool: | ||||
|     if typ in (bool, torch.SymBool): | ||||
|         return torch.bool | ||||
|     if typ in [int, torch.SymInt]: | ||||
|     if typ in (int, torch.SymInt): | ||||
|         return torch.long | ||||
|     if typ in [float, torch.SymFloat]: | ||||
|     if typ in (float, torch.SymFloat): | ||||
|         return torch.get_default_dtype() | ||||
|     # TODO: sym_complex_float? | ||||
|     if typ is complex: | ||||
|         return corresponding_complex_dtype(torch.get_default_dtype()) | ||||
|  | ||||
|     raise ValueError("Invalid type!") | ||||
|     raise ValueError(f"Invalid type {typ}!") | ||||
|  | ||||
|  | ||||
| def get_dtype(x: Union[torch.Tensor, NumberType]): | ||||
| @ -1363,8 +1363,12 @@ def number_type( | ||||
|         return type(x) | ||||
|  | ||||
|  | ||||
| def expr_type(x: sympy.Expr) -> Type: | ||||
|     if x.is_integer:  # type: ignore[attr-defined] | ||||
| def expr_type(x: sympy.Basic) -> Type: | ||||
|     import sympy | ||||
|  | ||||
|     if x.kind is sympy.core.kind.BooleanKind: | ||||
|         return bool | ||||
|     elif x.is_integer:  # type: ignore[attr-defined] | ||||
|         return int | ||||
|     else: | ||||
|         # NB: Not strictly correct, but we don't support SymPy complex or bool. | ||||
| @ -1471,13 +1475,13 @@ def elementwise_dtypes( | ||||
|     import sympy | ||||
|  | ||||
|     for x in args: | ||||
|         if not isinstance(x, (Number, TensorLike, sympy.Expr)): | ||||
|         if not isinstance(x, (Number, TensorLike, sympy.Basic)): | ||||
|             msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!" | ||||
|             raise ValueError(msg) | ||||
|  | ||||
|         if isinstance(x, Number): | ||||
|             highest_type = get_higher_type(highest_type, number_type(x)) | ||||
|         elif isinstance(x, sympy.Expr): | ||||
|         elif isinstance(x, sympy.Basic): | ||||
|             highest_type = get_higher_type(highest_type, expr_type(x)) | ||||
|         else: | ||||
|             # x is a TensorLike | ||||
|  | ||||
| @ -67,6 +67,7 @@ | ||||
| #include <torch/csrc/cpu/Module.h> | ||||
| #include <torch/csrc/dynamo/init.h> | ||||
| #include <torch/csrc/functorch/init.h> | ||||
| #include <torch/csrc/fx/node.h> | ||||
| #include <torch/csrc/inductor/aoti_runner/pybind.h> | ||||
| #include <torch/csrc/jit/python/init.h> | ||||
| #include <torch/csrc/jit/python/python_ir.h> | ||||
| @ -1602,6 +1603,8 @@ PyObject* initModule() { | ||||
|   THPDevice_init(module); | ||||
|   THPStream_init(module); | ||||
|   THPEvent_init(module); | ||||
|   NodeBase_init(module); | ||||
|   NodeIter_init(module); | ||||
|   ASSERT_TRUE(THPVariable_initModule(module)); | ||||
|   ASSERT_TRUE(THPFunction_initModule(module)); | ||||
|   ASSERT_TRUE(THPEngine_initModule(module)); | ||||
|  | ||||
							
								
								
									
										257
									
								
								torch/csrc/fx/node.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										257
									
								
								torch/csrc/fx/node.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,257 @@ | ||||
| #include <torch/csrc/fx/node.h> | ||||
|  | ||||
| #include <structmember.h> | ||||
| #include <torch/csrc/utils/pythoncapi_compat.h> | ||||
|  | ||||
| //////////////////////////////// | ||||
| // NodeBase | ||||
| /////////////////////////////// | ||||
|  | ||||
| struct NodeBase { | ||||
|   PyObject_HEAD bool _erased; | ||||
|   NodeBase* _prev; | ||||
|   NodeBase* _next; | ||||
| }; | ||||
|  | ||||
| static PyObject* NodeBase_new( | ||||
|     PyTypeObject* type, | ||||
|     PyObject* args, | ||||
|     PyObject* kwds) { | ||||
|   PyObject* self = type->tp_alloc(type, 0); | ||||
|   if (!self) | ||||
|     return nullptr; | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) { | ||||
|   self->_erased = false; | ||||
|   Py_INCREF(self); | ||||
|   self->_prev = self; | ||||
|   Py_INCREF(self); | ||||
|   self->_next = self; | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) | ||||
| static struct PyMemberDef NodeBase_members[] = { | ||||
|     {"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr}, | ||||
|     {"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr}, | ||||
|     {"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr}, | ||||
|     {nullptr} /* Sentinel */ | ||||
| }; | ||||
|  | ||||
| static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) { | ||||
|   Py_VISIT(self->_prev); | ||||
|   Py_VISIT(self->_next); | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static int NodeBase_clear(NodeBase* self) { | ||||
|   Py_CLEAR(self->_prev); | ||||
|   Py_CLEAR(self->_next); | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static void NodeBase_dealloc(PyObject* self) { | ||||
|   PyObject_GC_UnTrack(self); | ||||
|   (void)NodeBase_clear((NodeBase*)self); | ||||
|   Py_TYPE(self)->tp_free(self); | ||||
| } | ||||
|  | ||||
| static PyTypeObject NodeBaseType = { | ||||
|     PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */ | ||||
|     sizeof(NodeBase), /* tp_basicsize */ | ||||
|     0, /* tp_itemsize */ | ||||
|     (destructor)NodeBase_dealloc, /* tp_dealloc */ | ||||
|     0, /* tp_vectorcall_offset */ | ||||
|     nullptr, /* tp_getattr */ | ||||
|     nullptr, /* tp_setattr */ | ||||
|     nullptr, /* tp_reserved */ | ||||
|     nullptr, /* tp_repr */ | ||||
|     nullptr, /* tp_as_number */ | ||||
|     nullptr, /* tp_as_sequence */ | ||||
|     nullptr, /* tp_as_mapping */ | ||||
|     nullptr, /* tp_hash  */ | ||||
|     nullptr, /* tp_call */ | ||||
|     nullptr, /* tp_str */ | ||||
|     nullptr, /* tp_getattro */ | ||||
|     nullptr, /* tp_setattro */ | ||||
|     nullptr, /* tp_as_buffer */ | ||||
|     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | | ||||
|         Py_TPFLAGS_HAVE_GC, /* tp_flags */ | ||||
|     nullptr, /* tp_doc */ | ||||
|     (traverseproc)NodeBase_traverse, /* tp_traverse */ | ||||
|     (inquiry)NodeBase_clear, /* tp_clear */ | ||||
|     nullptr, /* tp_richcompare */ | ||||
|     0, /* tp_weaklistoffset */ | ||||
|     nullptr, /* tp_iter */ | ||||
|     nullptr, /* tp_iternext */ | ||||
|     nullptr, /* tp_methods */ | ||||
|     NodeBase_members, /* tp_members */ | ||||
|     nullptr, /* tp_getset */ | ||||
|     nullptr, /* tp_base */ | ||||
|     nullptr, /* tp_dict */ | ||||
|     nullptr, /* tp_descr_get */ | ||||
|     nullptr, /* tp_descr_set */ | ||||
|     0, /* tp_dictoffset */ | ||||
|     (initproc)NodeBase_init_fn, /* tp_init */ | ||||
|     nullptr, /* tp_alloc */ | ||||
|     NodeBase_new, /* tp_new */ | ||||
| }; | ||||
|  | ||||
| bool NodeBase_init(PyObject* module) { | ||||
|   if (PyModule_AddType(module, &NodeBaseType) < 0) { | ||||
|     return false; | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| //////////////////////////////// | ||||
| // NodeIter | ||||
| //////////////////////////////// | ||||
|  | ||||
| struct NodeIter { | ||||
|   PyObject_HEAD bool _reversed; | ||||
|   NodeBase* _root; | ||||
|   NodeBase* _cur; | ||||
| }; | ||||
|  | ||||
| static PyObject* NodeIter_new( | ||||
|     PyTypeObject* type, | ||||
|     PyObject* args, | ||||
|     PyObject* kwds) { | ||||
|   PyObject* self = type->tp_alloc(type, 0); | ||||
|   if (!self) | ||||
|     return nullptr; | ||||
|   return self; | ||||
| } | ||||
|  | ||||
| static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) { | ||||
|   NodeBase* root = nullptr; | ||||
|   bool reversed = false; | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) | ||||
|   constexpr const char* keywords[] = {"root", "reversed", nullptr}; | ||||
|   if (!PyArg_ParseTupleAndKeywords( | ||||
|           args, | ||||
|           kwargs, | ||||
|           "Ob|", | ||||
|           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) | ||||
|           const_cast<char**>(keywords), | ||||
|           &root, | ||||
|           &reversed)) { | ||||
|     return -1; | ||||
|   } | ||||
|   self->_reversed = reversed; | ||||
|   Py_INCREF(root); | ||||
|   self->_root = root; | ||||
|   Py_INCREF(root); | ||||
|   self->_cur = root; | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| template <bool reversed> | ||||
| PyObject* NodeIter_iternext_helper(NodeIter* self) { | ||||
|   // It should be possible to relax the ref counting here | ||||
|   // but in practice, we do not have that many _erased Nodes, | ||||
|   // so probably not worth it. | ||||
|   if constexpr (reversed) { | ||||
|     NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev); | ||||
|     Py_CLEAR(self->_cur); | ||||
|     self->_cur = prev; | ||||
|   } else { | ||||
|     NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next); | ||||
|     Py_CLEAR(self->_cur); | ||||
|     self->_cur = next; | ||||
|   } | ||||
|   while (self->_cur != self->_root) { | ||||
|     if (!self->_cur->_erased) { | ||||
|       Py_INCREF(self->_cur); | ||||
|       return (PyObject*)self->_cur; | ||||
|     } | ||||
|     if constexpr (reversed) { | ||||
|       NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev); | ||||
|       Py_CLEAR(self->_cur); | ||||
|       self->_cur = prev; | ||||
|     } else { | ||||
|       NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next); | ||||
|       Py_CLEAR(self->_cur); | ||||
|       self->_cur = next; | ||||
|     } | ||||
|   } | ||||
|   PyErr_SetNone(PyExc_StopIteration); | ||||
|   return nullptr; | ||||
| } | ||||
|  | ||||
| PyObject* NodeIter_iternext(PyObject* _self) { | ||||
|   NodeIter* self = (NodeIter*)_self; | ||||
|   if (self->_reversed) { | ||||
|     return NodeIter_iternext_helper<true>(self); | ||||
|   } else { | ||||
|     return NodeIter_iternext_helper<false>(self); | ||||
|   } | ||||
| } | ||||
|  | ||||
| static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) { | ||||
|   Py_VISIT(self->_root); | ||||
|   Py_VISIT(self->_cur); | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static int NodeIter_clear(NodeIter* self) { | ||||
|   Py_CLEAR(self->_root); | ||||
|   Py_CLEAR(self->_cur); | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static void NodeIter_dealloc(PyObject* self) { | ||||
|   PyObject_GC_UnTrack(self); | ||||
|   (void)NodeIter_clear((NodeIter*)self); | ||||
|   Py_TYPE(self)->tp_free(self); | ||||
| } | ||||
|  | ||||
| static PyTypeObject NodeIterType = { | ||||
|     PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */ | ||||
|     sizeof(NodeIter), /* tp_basicsize */ | ||||
|     0, /* tp_itemsize */ | ||||
|     (destructor)NodeIter_dealloc, /* tp_dealloc */ | ||||
|     0, /* tp_vectorcall_offset */ | ||||
|     nullptr, /* tp_getattr */ | ||||
|     nullptr, /* tp_setattr */ | ||||
|     nullptr, /* tp_reserved */ | ||||
|     nullptr, /* tp_repr */ | ||||
|     nullptr, /* tp_as_number */ | ||||
|     nullptr, /* tp_as_sequence */ | ||||
|     nullptr, /* tp_as_mapping */ | ||||
|     nullptr, /* tp_hash  */ | ||||
|     nullptr, /* tp_call */ | ||||
|     nullptr, /* tp_str */ | ||||
|     nullptr, /* tp_getattro */ | ||||
|     nullptr, /* tp_setattro */ | ||||
|     nullptr, /* tp_as_buffer */ | ||||
|     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */ | ||||
|     nullptr, /* tp_doc */ | ||||
|     (traverseproc)NodeIter_traverse, /* tp_traverse */ | ||||
|     (inquiry)NodeIter_clear, /* tp_clear */ | ||||
|     nullptr, /* tp_richcompare */ | ||||
|     0, /* tp_weaklistoffset */ | ||||
|     PyObject_SelfIter, /* tp_iter */ | ||||
|     NodeIter_iternext, /* tp_iternext */ | ||||
|     nullptr, /* tp_methods */ | ||||
|     nullptr, /* tp_members */ | ||||
|     nullptr, /* tp_getset */ | ||||
|     nullptr, /* tp_base */ | ||||
|     nullptr, /* tp_dict */ | ||||
|     nullptr, /* tp_descr_get */ | ||||
|     nullptr, /* tp_descr_set */ | ||||
|     0, /* tp_dictoffset */ | ||||
|     (initproc)NodeIter_init_fn, /* tp_init */ | ||||
|     nullptr, /* tp_alloc */ | ||||
|     NodeIter_new, /* tp_new */ | ||||
| }; | ||||
|  | ||||
| bool NodeIter_init(PyObject* module) { | ||||
|   if (PyModule_AddType(module, &NodeIterType) < 0) { | ||||
|     return false; | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
							
								
								
									
										6
									
								
								torch/csrc/fx/node.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								torch/csrc/fx/node.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <torch/csrc/python_headers.h> | ||||
|  | ||||
| bool NodeBase_init(PyObject* module); | ||||
| bool NodeIter_init(PyObject* module); | ||||
| @ -86,6 +86,7 @@ def _insert_copy_for_mutations( | ||||
|     assert len(outputs) == len(mutated_outputs) | ||||
|  | ||||
|     user_output_nodes = [] | ||||
|     return_nodes_to_copy = {} | ||||
|     for return_node, mutated_node_name in zip(outputs, mutated_outputs): | ||||
|         if mutated_node_name is None: | ||||
|             user_output_nodes.append(return_node) | ||||
| @ -101,13 +102,18 @@ def _insert_copy_for_mutations( | ||||
|             ) | ||||
|  | ||||
|         with gm.graph.inserting_before(output_node): | ||||
|             _ = gm.graph.call_function( | ||||
|             copy_node = gm.graph.call_function( | ||||
|                 torch.ops.aten.copy_.default, (mutated_node, return_node) | ||||
|             ) | ||||
|             return_nodes_to_copy[return_node] = copy_node | ||||
|  | ||||
|     output_args = [ | ||||
|         return_nodes_to_copy[node] if node in return_nodes_to_copy else node | ||||
|         for node in user_output_nodes | ||||
|     ] | ||||
|     with gm.graph.inserting_before(output_node): | ||||
|         # Only return user outputs | ||||
|         new_output = gm.graph.output(tuple(user_output_nodes)) | ||||
|         new_output = gm.graph.output(tuple(output_args)) | ||||
|         output_node.replace_all_uses_with(new_output) | ||||
|         gm.graph.erase_node(output_node) | ||||
|  | ||||
|  | ||||
| @ -453,20 +453,21 @@ def free_unbacked_symbols(x): | ||||
| # setup! | ||||
| def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]: | ||||
|     if ( | ||||
|         node.op == "placeholder" and | ||||
|         "val" in node.meta and | ||||
|         isinstance(node.meta["val"], torch.SymInt) and | ||||
|         isinstance(node.meta["val"].node.expr, sympy.Symbol) | ||||
|         isinstance(node.meta["val"].node.expr, sympy.Symbol) and | ||||
|         (node.op == "placeholder" or free_unbacked_symbols(node.meta["val"].node.expr)) | ||||
|     ): | ||||
|         return node.meta["val"].node.expr | ||||
|     return None | ||||
|  | ||||
| def find_symbol_binding_fx_nodes(graph): | ||||
|     return { | ||||
|         node.meta["val"].node.expr: node | ||||
|         for node in graph.nodes | ||||
|         if is_symbol_binding_fx_node(node) | ||||
|     } | ||||
|     r = {} | ||||
|     # NB: Prefer first occurrence of symbol | ||||
|     for node in graph.nodes: | ||||
|         if is_symbol_binding_fx_node(node) and node.meta["val"].node.expr not in r: | ||||
|             r[node.meta["val"].node.expr] = node | ||||
|     return r | ||||
|  | ||||
|  | ||||
| # Analogous to ConvertIntSource | ||||
|  | ||||
| @ -4,6 +4,7 @@ from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_na | ||||
| import torch.utils._pytree as pytree | ||||
| from . import _pytree as fx_pytree | ||||
| from ._compatibility import compatibility | ||||
| from torch._C import _NodeIter | ||||
|  | ||||
| import os | ||||
| import contextlib | ||||
| @ -271,20 +272,8 @@ class _node_list: | ||||
|         return self.graph._len | ||||
|  | ||||
|     def __iter__(self): | ||||
|         root = self.graph._root | ||||
|         if self.direction == "_next": | ||||
|             cur = root._next | ||||
|             while cur is not root: | ||||
|                 if not cur._erased: | ||||
|                     yield cur | ||||
|                 cur = cur._next | ||||
|         else: | ||||
|             assert self.direction == "_prev" | ||||
|             cur = root._prev | ||||
|             while cur is not root: | ||||
|                 if not cur._erased: | ||||
|                     yield cur | ||||
|                 cur = cur._prev | ||||
|         assert self.direction == "_prev" or self.direction == "_next" | ||||
|         yield from _NodeIter(self.graph._root, self.direction == "_prev") | ||||
|  | ||||
|     def __reversed__(self): | ||||
|         return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') | ||||
|  | ||||
| @ -11,6 +11,7 @@ import inspect | ||||
| import warnings | ||||
| from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair | ||||
| from .._ops import ops as _ops | ||||
| from torch._C import _NodeBase | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from .graph import Graph | ||||
| @ -139,7 +140,7 @@ def _format_arg(arg, max_list_len=float('inf')) -> str: | ||||
|         return str(arg) | ||||
|  | ||||
| @compatibility(is_backward_compatible=True) | ||||
| class Node: | ||||
| class Node(_NodeBase): | ||||
|     """ | ||||
|     ``Node`` is the data structure that represents individual operations within | ||||
|     a ``Graph``. For the most part, Nodes represent callsites to various entities, | ||||
| @ -197,6 +198,7 @@ class Node: | ||||
|                 annotation of values in the generated code or for other types | ||||
|                 of analyses. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.graph = graph | ||||
|         self.name = name  # unique name of value being created | ||||
|         assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'] | ||||
| @ -235,9 +237,6 @@ class Node: | ||||
|         # does not produce a value, it's more of a notation. Thus, this value | ||||
|         # describes the type of args[0] in the ``return`` node. | ||||
|         self.type : Optional[Any] = return_type | ||||
|         self._prev = self | ||||
|         self._next = self | ||||
|         self._erased = False | ||||
|         self._sort_key: Any = () | ||||
|  | ||||
|         # If set, use this fn to print this node | ||||
| @ -247,6 +246,22 @@ class Node: | ||||
|         # transformations. This metadata is preserved across node copies | ||||
|         self.meta : Dict[str, Any] = {} | ||||
|  | ||||
|     def __getstate__(self): | ||||
|         state = self.__dict__.copy() | ||||
|         state["_erased"] = self._erased | ||||
|         state["_prev"] = self._prev | ||||
|         state["_next"] = self._next | ||||
|         return state | ||||
|  | ||||
|     def __setstate__(self, state): | ||||
|         _erased = state.pop("_erased") | ||||
|         _prev = state.pop("_prev") | ||||
|         _next = state.pop("_next") | ||||
|         self.__dict__.update(state) | ||||
|         self._erased = _erased | ||||
|         self._prev = _prev | ||||
|         self._next = _next | ||||
|  | ||||
|     @property | ||||
|     def next(self) -> 'Node': | ||||
|         """ | ||||
|  | ||||
| @ -89,7 +89,10 @@ def sympy_generic_le(lower, upper): | ||||
|         return lower <= upper | ||||
|     else: | ||||
|         # only negative condition is True > False | ||||
|         assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean) | ||||
|         assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean), ( | ||||
|             lower, | ||||
|             upper, | ||||
|         ) | ||||
|         return not (lower and not upper) | ||||
|  | ||||
|  | ||||
| @ -945,6 +948,8 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis): | ||||
|         if dtype == torch.bool: | ||||
|             if x.is_singleton(): | ||||
|                 return ValueRanges.wrap(x.lower != 0) | ||||
|             elif x.is_bool: | ||||
|                 return x | ||||
|             elif 0 not in x: | ||||
|                 return ValueRanges.wrap(sympy.true) | ||||
|             else: | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	