Compare commits

...

11 Commits

Author SHA1 Message Date
f2d7f235a6 [dynamo][yolov3] Track UnspecializedNNModuleVariable for mutation (#128269)
Fixes https://github.com/pytorch/pytorch/issues/101168

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128269
Approved by: https://github.com/jansel
ghstack dependencies: #128295, #126578, #128268, #128254
2024-06-11 07:09:04 +00:00
402b289f3b Properly register parameter for binary folding test (#128356)
This PR properly registers the tensor used in the module compute as a parameter. This bug was hidden previously because all tensors on the nn modules would be considered constant by dynamo, with inlining NN modules, this is no longer the case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128356
Approved by: https://github.com/anijain2305
ghstack dependencies: #128355
2024-06-11 06:48:26 +00:00
a32157c67c Mark params static if inlining modules and freezing (#128355)
Today inlining builtin nn modules is not compatible with parameter freezing. Freezing parameters and then constant folding them through the graph relies on the assumption that they will not be inputs and will be static across calls to the same graph. When inlining builtin nn modules this assumption is broken and we reuse the same graph for different instances of the same nn module. There are three options 1) abandon constant folding, 2) create a dispatcher layer (like cudagraphs) which will dispatch to the correct constant-folded graph for each distinct set of parameters or 3) recompile

This PR implements 3 by introducing guards on the parameter pointers. This was due to freezing being relatively rare and performance sensistive. 2 Had many more unknowns and 1 is not a viable option due to the drop in performance.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128355
Approved by: https://github.com/anijain2305
2024-06-11 06:48:26 +00:00
24e7f29099 Lowering for avg_pool_3d_backward (Fixes:#127101) (#127722)
We implemented a lowering for the avg_pool3d_backward operation and created tests for it.
We ran some benchmarks and achieved the following results:

```
[-------------- avgpool_3d_backwards --------------]
                             |  Decomposed  |  Eager
16 threads: ----------------------------------------
      (3, 5, 400, 200, 200)  |     6061     |  11160
      (3, 5, 300, 200, 200)  |     4547     |   8372
      (3, 5, 200, 200, 200)  |     3032     |   5585
      (3, 5, 300, 300, 300)  |    10100     |  18840
      (3, 5, 100, 100, 100)  |      381     |    703
      (3, 5, 100, 300, 200)  |     2270     |   4190
      (8, 8, 128, 128, 128)  |     3397     |   6253
      (2, 3, 150, 150, 150)  |      520     |    947
      (1, 3, 128, 128, 128)  |      161     |    299
      (8, 16, 64, 64, 64)    |      851     |   1569
      (1, 1, 50, 50, 50)     |       17     |     11
      (3, 5, 20, 40, 40)     |       17     |     30
      (3, 5, 10, 20, 20)     |       17     |     11
      (1, 1, 10, 10, 10)     |       16     |     11
      (3, 5, 5, 10, 10)      |       17     |     11
      (3, 5, 2, 5, 5)        |       17     |     11
```
These were run on an RTX 3050, so we were not able to allocate larger tensors due to memory limitations.
We believe it would be beneficial to benchmark this on more recent hardware, just to check if the performance holds up with larger sizes.

Furthermore, we also refactored code from adaptive_avg_pool2d and adaptive_max_pool2d, to reduce code duplication.
We diffed the kernels and they are identical.

Fixes #127101

Co-authored-by: Martim Mendes <martimccmendes@tecnico.ulisboa.pt>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127722
Approved by: https://github.com/jansel
2024-06-11 06:39:04 +00:00
5b5d269d34 Speed up fx graph iteration by implementing it in C++ (#128288)
Before this change
```
python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py
iterating over 100000000 FX nodes took 19.5s (5132266 nodes/s)
```

After this change
```
python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py
iterating over 100000000 FX nodes took 3.4s (29114001 nodes/s)
```

5.7x improvement

Differential Revision: [D58343997](https://our.internmc.facebook.com/intern/diff/D58343997)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128288
Approved by: https://github.com/jansel, https://github.com/albanD
2024-06-11 05:48:31 +00:00
fa88f390a0 Revert "[inductor] enable fx graph cache on torchbench (#128239)"
This reverts commit 734e8f6ad7e7f0fa0341fb658f1f986225173f5f.

Reverted https://github.com/pytorch/pytorch/pull/128239 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to surface a bunch of inductor failures in trunk 734e8f6ad7 ([comment](https://github.com/pytorch/pytorch/pull/128239#issuecomment-2159789242))
2024-06-11 04:53:38 +00:00
fe39c07826 [pipelining][doc] Remove duplicated words (#128368)
"for execution" is used in both step titles

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128368
Approved by: https://github.com/wconstab
ghstack dependencies: #128361
2024-06-11 04:52:57 +00:00
cba195c8ed Support aten operations with out tensor (#124926)
This PR intends to support the aten operations with the `out` tensor.

Currently, the AOT compile always does **NOT** keep input tensor mutations. According to the comments, this is because it has not encountered such a use case.
> For now there's no use case involving keeping input mutations in the graph (which we can only do in the inference case anyway). We can add this later if we need to.

However, for aten operations, it is popular that the `out` tensor is an input parameter and needs to be mutated. This PR intends to support it by adding a `keep_inference_input_mutations` flag to `aot_inductor.keep_inference_input_mutations`. This flag can provide flexibility to the callee in deciding whether the AOT compile needs to keep input tensor mutations in the graph.

Take `clamp` as an example as follows.
```python
out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(-2.0)
inp_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(1.0)
min_tensor = inp_tensor - 0.05
max_tensor = inp_tensor + 0.05
torch.clamp(input=inp_tensor, min=min_tensor, max=max_tensor, out=out_tensor)
```

W/O this PR
```python
def forward(self):
    arg0_1: "f32[128]"; arg1_1: "f32[128]"; arg2_1: "f32[128]"; arg3_1: "f32[128]";

    arg0_1, arg1_1, arg2_1, arg3_1, = fx_pytree.tree_flatten_spec([], self._in_spec)
    clamp_min: "f32[128]" = torch.ops.aten.clamp_min.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    clamp_max: "f32[128]" = torch.ops.aten.clamp_max.Tensor(clamp_min, arg2_1);  clamp_min = arg2_1 = None
    return (clamp_max, clamp_max)
```

W/ this PR
```python
def forward(self):
    arg0_1: "f32[128]"; arg1_1: "f32[128]"; arg2_1: "f32[128]"; arg3_1: "f32[128]";

    arg0_1, arg1_1, arg2_1, arg3_1, = fx_pytree.tree_flatten_spec([], self._in_spec)
    clamp_min: "f32[128]" = torch.ops.aten.clamp_min.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    clamp_max: "f32[128]" = torch.ops.aten.clamp_max.Tensor(clamp_min, arg2_1);  clamp_min = arg2_1 = None
    copy_: "f32[128]" = torch.ops.aten.copy_.default(arg3_1, clamp_max);  arg3_1 = clamp_max = None
    return (copy_,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124926
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/angelayi
2024-06-11 04:35:27 +00:00
16e67be7f1 Also preserve unbacked SymInts when partitioning as backward inputs (#128338)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128338
Approved by: https://github.com/IvanKobzarev
2024-06-11 04:27:09 +00:00
7afffdf48b [CI] Comment hf_T5_generate, hf_GPT2 and timm_efficientnet in inductor cpu smoketest for performance unstable issue (#127588)
Fixes #126993

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127588
Approved by: https://github.com/chuanqi129, https://github.com/jgong5, https://github.com/desertfire
2024-06-11 03:12:11 +00:00
ca45649eb5 [easy][dynamo][inline work] Fix test with inlining inbuilt nn modules (#128254)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128254
Approved by: https://github.com/williamwen42
ghstack dependencies: #128295, #126578, #128268
2024-06-11 03:02:51 +00:00
36 changed files with 922 additions and 213 deletions

View File

@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
378
379
380
381

View File

@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
yolov3,pass,9
yolov3,fail_accuracy,8

1 name accuracy graph_breaks
286
287
288
289

View File

@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0
pyhpc_isoneutral_mixing,fail_to_run,0
pyhpc_isoneutral_mixing,pass,0
@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
yolov3,fail_to_run,0
yolov3,pass,0

1 name accuracy graph_breaks
242
243
244
245
246
247
248
350
351
352
353

View File

@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
338
339
340
341

View File

@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
338
339
340
341

View File

@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0
pyhpc_isoneutral_mixing,fail_to_run,0
pyhpc_isoneutral_mixing,pass,0
@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
yolov3,fail_to_run,0
yolov3,pass,0

1 name accuracy graph_breaks
242
243
244
245
246
247
248
350
351
352
353

View File

@ -374,4 +374,4 @@ vision_maskrcnn,pass,17
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
374
375
376
377

View File

@ -282,4 +282,4 @@ vision_maskrcnn,pass,34
yolov3,pass,9
yolov3,fail_accuracy,8

1 name accuracy graph_breaks
282
283
284
285

View File

@ -298,4 +298,4 @@ vision_maskrcnn,pass,28
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
298
299
300
301

View File

@ -374,4 +374,4 @@ vision_maskrcnn,pass,17
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
374
375
376
377

View File

@ -282,4 +282,4 @@ vision_maskrcnn,pass,34
yolov3,pass,9
yolov3,pass,8

1 name accuracy graph_breaks
282
283
284
285

View File

@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
378
379
380
381

View File

@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
yolov3,pass,9
yolov3,pass,8

1 name accuracy graph_breaks
286
287
288
289

View File

@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
378
379
380
381

View File

@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
yolov3,pass,9
yolov3,pass,8

1 name accuracy graph_breaks
286
287
288
289

View File

@ -4,12 +4,11 @@ phlippe_densenet,float32,static,default,1.3988316
basic_gnn_gcn,float32,dynamic,default,1.074576405
llama_v2_7b_16h,float32,dynamic,default,1.211740245
resnet50,float32,dynamic,default,1.65984261
timm_efficientnet,float32,static,cpp,2.271561735
#timm_efficientnet,float32,static,cpp,2.1938112
mobilenet_v3_large,float32,static,cpp,2.63375628
timm_resnest,float32,dynamic,cpp,1.67998548
pyhpc_turbulent_kinetic_energy,float32,dynamic,cpp,1.59968463
#hf_GPT2,float32,dynamic,cpp,
hf_GPT2,float32,dynamic,cpp,1.379885175
#hf_GPT2,float32,dynamic,cpp,1.292704418
resnext50_32x4d,amp,static,default,1.461687045
vgg16,amp,static,default,1.267194285
hf_Longformer,amp,dynamic,default,0.997006035
@ -17,6 +16,6 @@ hf_Bert_large,amp,dynamic,default,0.99391146
llama,amp,static,default,1.32950568
timm_regnet,amp,static,cpp,1.157188305
lennard_jones,amp,static,cpp,2.240104485
hf_T5_generate,amp,dynamic,cpp,1.447656135
#hf_T5_generate,amp,dynamic,cpp,1.29339502
timm_vovnet,amp,dynamic,cpp,1.07856471
mobilenet_v2,amp,dynamic,cpp,2.27774577

1 #name data_type shape wrapper perf_speedup_target_c7i_metal_24xl
4 basic_gnn_gcn float32 dynamic default 1.074576405
5 llama_v2_7b_16h float32 dynamic default 1.211740245
6 resnet50 float32 dynamic default 1.65984261
7 timm_efficientnet #timm_efficientnet float32 static cpp 2.271561735 2.1938112
8 mobilenet_v3_large float32 static cpp 2.63375628
9 timm_resnest float32 dynamic cpp 1.67998548
10 pyhpc_turbulent_kinetic_energy float32 dynamic cpp 1.59968463
11 #hf_GPT2 float32 dynamic cpp 1.292704418
hf_GPT2 float32 dynamic cpp 1.379885175
12 resnext50_32x4d amp static default 1.461687045
13 vgg16 amp static default 1.267194285
14 hf_Longformer amp dynamic default 0.997006035
16 llama amp static default 1.32950568
17 timm_regnet amp static cpp 1.157188305
18 lennard_jones amp static cpp 2.240104485
19 hf_T5_generate #hf_T5_generate amp dynamic cpp 1.447656135 1.29339502
20 timm_vovnet amp dynamic cpp 1.07856471
21 mobilenet_v2 amp dynamic cpp 2.27774577

View File

@ -25,10 +25,6 @@ from torch._dynamo.utils import clone_inputs
# We are primarily interested in tf32 datatype
torch.backends.cuda.matmul.allow_tf32 = True
# Enable FX graph caching
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
torch._inductor.config.fx_graph_cache = True
def _reassign_parameters(model):
# torch_geometric models register parameter as tensors due to

View File

@ -827,6 +827,7 @@ libtorch_python_core_sources = [
"torch/csrc/dynamo/guards.cpp",
"torch/csrc/dynamo/init.cpp",
"torch/csrc/functorch/init.cpp",
"torch/csrc/fx/node.cpp",
"torch/csrc/mps/Module.cpp",
"torch/csrc/mtia/Module.cpp",
"torch/csrc/inductor/aoti_runner/pybind.cpp",

View File

@ -62,8 +62,8 @@ Overall, the ``pipelining`` package provides the following features:
application on the Llama model.
Step 1: build ``PipelineStage`` for execution
*********************************************
Step 1: build ``PipelineStage``
*******************************
Before we can use a ``PipelineSchedule``, we need to create ``PipelineStage``
objects that wrap the part of the model running in that stage. The

View File

@ -304,13 +304,12 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnt.frame_count, 0)
def test_torch_guards_stack_frame_register_inlining_disable(self):
y = torch.nn.Parameter(torch.tensor([0.25, 0.25]))
x = torch.tensor([0.5, 0.5])
class encoder(torch.nn.Module):
def __init__(self, y):
super().__init__()
self.register_parameter("param", y)
self.a = y
@torch._dynamo.disable
def helper(self, x, y):
@ -318,9 +317,9 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def forward(self, a, *args):
x = a + a
return self.helper(x, self.param)
return self.helper(x, self.a)
e = encoder(y)
e = encoder(2.0)
seen_frames = []
import contextlib

View File

@ -56,7 +56,7 @@ class BinaryFoldingTemplate(TestCase):
self.use_scalar = scalar
tensor_size = [1 for _ in range(self.conv.weight.ndim)]
tensor_size[1] = self.conv.weight.size(0)
self.tensor = (
self.tensor = torch.nn.Parameter(
add_tensor
if add_tensor is not None
else torch.rand(tensor_size).to(device)
@ -136,7 +136,11 @@ class BinaryFoldingTemplate(TestCase):
nn.Conv2d,
pytorch_op,
False,
add_tensor=torch.rand(32, 1, 32).to(self.device),
add_tensor=torch.rand(
32,
1,
32,
).to(self.device),
expect_success=False,
)
@ -156,7 +160,7 @@ class BinaryFoldingTemplate(TestCase):
nn.Conv2d,
pytorch_op,
False,
add_tensor=torch.tensor([2]).to(torch.int).to(self.device),
add_tensor=torch.tensor([2]).to(torch.float64).to(self.device),
expect_success=False,
)

View File

@ -233,6 +233,23 @@ def run_fw_bw_and_get_code(fn):
return run_and_get_code(run_with_backward)
def register_ops_with_aoti_compile(ns, op_set, dispatch_key, torch_compile_op_lib_impl):
for _op_name in op_set:
qualified_op_name = f"{ns}::{_op_name}"
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
for overload_name in overload_names:
try:
reg_op_name = qualified_op_name
schema = torch._C._get_schema(qualified_op_name, overload_name)
if schema.overload_name:
reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
reg_op_name, dispatch_key
)
except Exception as e:
continue
class TestCase(InductorTestCase):
@classmethod
def setUpClass(cls):
@ -751,6 +768,58 @@ class CommonTemplate:
),
)
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_eager_aoti_support_out(self):
ns = "aten"
op_name = "clamp"
dispatch_key = "CPU"
device = "cpu"
if self.device.lower() == "cuda":
dispatch_key = "CUDA"
device = "cuda"
inp_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(1.0)
min_tensor = inp_tensor - 0.05
max_tensor = inp_tensor + 0.05
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
ref_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(
-1
)
ref_tensor = torch.clamp(
max=max_tensor, min=min_tensor, input=inp_tensor, out=ref_out_tensor
)
ref_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_(
-1
)
ref_tensor1 = torch.clamp(
max=max_tensor, out=ref_out_tensor1, min=min_tensor, input=inp_tensor
)
register_ops_with_aoti_compile(
ns, [op_name], dispatch_key, torch_compile_op_lib_impl
)
res_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(
-1
)
res_tensor = torch.clamp(
max=max_tensor, min=min_tensor, input=inp_tensor, out=res_out_tensor
)
self.assertEqual(ref_tensor, res_tensor)
self.assertEqual(ref_out_tensor, res_out_tensor)
res_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_(
-1
)
res_tensor1 = torch.clamp(
max=max_tensor, out=res_out_tensor1, min=min_tensor, input=inp_tensor
)
self.assertEqual(ref_tensor1, res_tensor1)
self.assertEqual(ref_out_tensor1, res_out_tensor1)
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_eager_aoti_cache_hit(self):
ns = "aten"
@ -779,24 +848,13 @@ class CommonTemplate:
with mock.patch(
"torch._inductor.utils.aoti_compile_with_persistent_cache", None
):
qualified_op_name = f"{ns}::{op_name}"
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
# Get ref result from eager
ref_value = getattr(torch.ops.aten, op_name)(input_tensor)
for overload_name in overload_names:
try:
reg_op_name = qualified_op_name
schema = torch._C._get_schema(qualified_op_name, overload_name)
if schema.overload_name:
reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
reg_op_name, dispatch_key
)
except Exception as e:
continue
register_ops_with_aoti_compile(
ns, [op_name], dispatch_key, torch_compile_op_lib_impl
)
# Invoke the pre-compiled kernel and get result.
res_value = getattr(torch.ops.aten, op_name)(input_tensor)
@ -804,7 +862,7 @@ class CommonTemplate:
self.assertEqual(ref_value, res_value)
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_aoti_compile_with_persistent_cache(self):
def test_eager_aoti_with_persistent_cache(self):
def fn(a):
return torch.abs(a)
@ -906,19 +964,9 @@ class CommonTemplate:
for scalar_value in scalar_values:
ref_values.append(torch.add(a, b, alpha=scalar_value))
qualified_op_name = f"{namespace_name}::{op_name}"
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
for overload_name in overload_names:
try:
reg_op_name = qualified_op_name
schema = torch._C._get_schema(reg_op_name, overload_name)
if schema.overload_name:
reg_op_name = f"{reg_op_name}.{schema.overload_name}"
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
reg_op_name, dispatch_key
)
except Exception as e:
continue
register_ops_with_aoti_compile(
namespace_name, [op_name], dispatch_key, torch_compile_op_lib_impl
)
res_values = []
for scalar_value in scalar_values:
@ -928,8 +976,7 @@ class CommonTemplate:
self.assertEqual(ref_values, res_values)
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_torch_compile_override_registration(self):
dynamic = False
def test_eager_aoti_override_registration(self):
namespace_name = "aten"
dispatch_key = "CPU"
device = torch.device("cpu")
@ -951,24 +998,10 @@ class CommonTemplate:
ref = opt_fn(x)
ref_array.append(ref)
def register_ops(op_set, dispatch_key, torch_compile_op_lib_impl):
for _op_name in op_set:
qualified_op_name = f"{namespace_name}::{_op_name}"
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
for overload_name in overload_names:
try:
reg_op_name = qualified_op_name
schema = torch._C._get_schema(qualified_op_name, overload_name)
if schema.overload_name:
reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
reg_op_name, dispatch_key
)
except Exception as e:
continue
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
register_ops(unary_op_set, dispatch_key, torch_compile_op_lib_impl)
register_ops_with_aoti_compile(
namespace_name, unary_op_set, dispatch_key, torch_compile_op_lib_impl
)
res_array = []
for unary_op_name in unary_op_set:
@ -985,7 +1018,9 @@ class CommonTemplate:
ref_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
register_ops(["clamp"], dispatch_key, torch_compile_op_lib_impl)
register_ops_with_aoti_compile(
namespace_name, ["clamp"], dispatch_key, torch_compile_op_lib_impl
)
res_with_min = torch.ops.aten.clamp(a, min_tensor)
res_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
self.assertEqual(ref_with_min, res_with_min)
@ -7816,6 +7851,95 @@ class CommonTemplate:
)
assertGeneratedKernelCountEqual(self, 0)
def test_avg_pool3d_backward(self):
def fn(a, b):
return aten.avg_pool3d_backward(
a,
b,
[2, 2, 2],
[2, 2, 2],
[0, 0, 0],
True,
False,
None,
)
self.common(
fn,
[
torch.randn([2, 4, 7, 7, 7]),
torch.randn([2, 4, 14, 14, 14]),
],
)
def test_avg_pool3d_backward2(self):
def fn(a, b):
return aten.avg_pool3d_backward(
a,
b,
[3, 3, 3],
[1, 1, 1],
[1, 1, 1],
True,
False,
None,
)
self.common(
fn,
[
torch.randn([1, 1, 20, 20, 15]),
torch.randn([1, 1, 20, 20, 15]),
],
)
def test_avg_pool3d_backward3(self):
def fn(a, b):
return aten.avg_pool3d_backward(
a,
b,
[1, 1, 1],
[2, 2, 2],
[0, 0, 0],
False,
False,
None,
)
torch._inductor.metrics.generated_kernel_count = 0
self.common(
fn,
[
torch.randn([1, 2016, 11, 11, 11]),
torch.randn([1, 2016, 21, 21, 21]),
],
)
assertGeneratedKernelCountEqual(self, 1)
def test_avg_pool3d_backward4(self):
def fn(a, b):
return aten.avg_pool3d_backward(
a,
b,
[13, 13, 13],
[1, 1, 1],
[0, 0, 0],
True,
False,
None,
)
torch._inductor.metrics.generated_kernel_count = 0
self.common(
fn,
[
torch.randn([1, 16, 12, 12, 12]),
torch.randn([1, 16, 24, 24, 24]),
],
check_lowp=False,
)
assertGeneratedKernelCountEqual(self, 0)
@config.patch(search_autotune_cache=False)
def test_mm_views(self):
def fn(a, b):

View File

@ -146,6 +146,7 @@ test_failures = {
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_avg_pool3d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_baddbmm_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_bmm2_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_both_scalars_dynamic_shapes": TestFailure(("cpu", "cuda")),

View File

@ -8,6 +8,7 @@ import os
import sys
import unittest
from functools import partial
from typing import List
import torch
import torch.library
@ -369,6 +370,47 @@ class TestInductorDynamic(TestCase):
arg = torch.tensor(5, device=device)
self.assertEqual(f(arg), cf(arg))
@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)
@torch._inductor.config.patch(implicit_fallbacks=True)
def test_unbacked_save_for_backwards(self, device) -> None:
@torch.library.custom_op("_test::_cat", mutates_args=())
def _cat(t: torch.Tensor, ds: List[int]) -> torch.Tensor:
return t * t.new_ones([sum(ds)])
@torch.library.register_fake("_test::_cat")
def _cat_fake(t: torch.Tensor, ds: List[int]) -> torch.Tensor:
[torch._check_is_size(d) for d in ds]
return t.new_empty([sum(ds)])
def _cat_setup_context(ctx, inputs, output):
pass
def _cat_backward(ctx, grad):
return grad.sum(), None
torch.library.register_autograd(
"_test::_cat",
_cat_backward,
setup_context=_cat_setup_context,
)
def fn(t, sizes):
r = torch.ops._test._cat(t, sizes.tolist())
return r * t
t = torch.randn((), requires_grad=True, device=device)
sizes = torch.tensor([4, 8], dtype=torch.int64, device="cpu")
out = fn(t, sizes)
out.sum().backward()
expect = t.grad
t.grad = None
torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)(
t, sizes
).sum().backward()
self.assertEqual(t.grad, expect)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_unbacked_reduction(self, device):
expect_fail = device == "cpu" and not IS_ARM64

View File

@ -2333,3 +2333,14 @@ def _save_pickle(obj: Any) -> bytes: ...
# Defined in torch/csrc/jit/runtime/static/init.cpp
def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ...
def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ...
# Defined in torch/csrc/fx/node.cpp
class _NodeBase:
_erased: _bool
_prev: "_NodeBase"
_next: "_NodeBase"
class _NodeIter(Iterator):
def __init__(self, root: _NodeBase, reversed: _bool) -> None: ...
def __iter__(self) -> Iterator[_NodeBase]: ...
def __next__(self) -> _NodeBase: ...

View File

@ -752,7 +752,13 @@ class OutputGraph:
**options,
):
if is_dynamic_nn_module(target, self.root_tx.export):
return variables.UnspecializedNNModuleVariable(target, **options)
result = variables.UnspecializedNNModuleVariable(target, **options)
if not SideEffects.cls_supports_mutation_side_effects(type(target)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
return self.root_tx.output.side_effects.track_object_existing(
target, result
)
options = dict(options)
assert "source" in options

View File

@ -1128,6 +1128,19 @@ class VariableBuilder:
if mutation_guard.is_dynamic_nn_module(value, self.tx.export):
# created dynamically, don't specialize on it
self.install_guards(GuardBuilder.TYPE_MATCH)
if (
torch._dynamo.config.inline_inbuilt_nn_modules
and torch._inductor.config.freezing
and not torch.is_grad_enabled()
):
from ..decorators import mark_static_address
for p in value.parameters():
mark_static_address(p)
for b in value.buffers():
mark_static_address(b)
result = UnspecializedNNModuleVariable(value, source=self.source)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__

View File

@ -201,11 +201,19 @@ def _unlift_graph(mod, gm, graph_signature):
outputs = list(gm.graph.nodes)[-1].args[0]
mutated_outputs = []
for out in outputs:
if out.name in graph_signature.buffers_to_mutate:
mutated_outputs.append(graph_signature.buffers_to_mutate[out.name])
else:
mutated_outputs.append(None)
buffer_mutations = graph_signature.buffers_to_mutate
user_input_mutations = graph_signature.user_inputs_to_mutate
output_tokens = graph_signature.output_tokens
for idx, out in enumerate(outputs):
value = None
if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
if out.name in buffer_mutations:
value = buffer_mutations[out.name]
elif out.name in user_input_mutations:
value = user_input_mutations[out.name]
mutated_outputs.append(value)
unlifted_gm = _unlift(
gm,

View File

@ -2155,7 +2155,6 @@ make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp
make_fallback(aten.avg_pool3d_backward)
make_fallback(aten.max_pool3d_with_indices_backward)
make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
make_fallback(aten._adaptive_avg_pool3d_backward)
@ -4034,11 +4033,32 @@ def pad_adaptive_loader(x, pad_val=0.0):
return load
def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
h_start_index_fn, w_start_index_fn = start_index_fns
h_end_index_fn, w_end_index_fn = end_index_fns
def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out):
h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
def fn_sum(idx, loader):
w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
return h_start_index, h_end_index, w_start_index, w_end_index
def _adaptive_pooling_fn(
start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
):
h_in, w_in = in_sizes
h_out, w_out = out_sizes
(
h_start_index_fn,
h_end_index_fn,
w_start_index_fn,
w_end_index_fn,
) = compute_indices_adaptive_pooling(
start_index, end_index, h_in, w_in, h_out, w_out
)
def fn(idx, loader):
*prefix, bh, bw = idx
h_start_index = h_start_index_fn(bh)
@ -4047,7 +4067,7 @@ def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
w_start_index = w_start_index_fn(bw)
w_end_index = w_end_index_fn(bw)
total = None
result = None
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
val = loader(
prefix,
@ -4055,13 +4075,66 @@ def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
[h_start_index, w_start_index],
[h_end_index, w_end_index],
)
if total is None:
total = val
if result is None:
result = val
else:
total = ops.add(val, total)
return total
result = pooling_fn(val, result)
return result
return fn_sum
return fn
def _adaptive_pooling_fn_with_idx(
start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
):
h_in, w_in = in_sizes
h_out, w_out = out_sizes
(
h_start_index_fn,
h_end_index_fn,
w_start_index_fn,
w_end_index_fn,
) = compute_indices_adaptive_pooling(
start_index, end_index, h_in, w_in, h_out, w_out
)
def fn(idx, loader):
*prefix, bh, bw = idx
h_start_index = h_start_index_fn(bh)
h_end_index = h_end_index_fn(bh)
w_start_index = w_start_index_fn(bw)
w_end_index = w_end_index_fn(bw)
maxval = None
maxindex = None
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
val = loader(
prefix,
[ih, iw],
[h_start_index, w_start_index],
[h_end_index, w_end_index],
)
index = ops.index_expr(
(h_start_index + ih) * w_in + w_start_index + iw, torch.int64
)
if maxindex is None:
maxindex = index
else:
maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
if maxval is None:
maxval = val
else:
maxval = pooling_fn(val, maxval)
return maxindex
return fn
fallback_adaptive_avg_pool2d = fallback_handler(
@ -4099,27 +4172,24 @@ def _adaptive_avg_pool2d(x, output_size):
new_size = list(batch) + [h_out, w_out]
dtype = x.get_dtype()
window_size = h_kernel_max * w_kernel_max
if window_size > 25:
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
return fallback_adaptive_avg_pool2d(x, output_size)
def start_index(index, out_dim, inp_dim):
return FloorDiv((index * inp_dim), out_dim)
def end_index(index, out_dim, inp_dim):
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
window_size = h_kernel_max * w_kernel_max
if window_size > 25:
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
return fallback_adaptive_avg_pool2d(x, output_size)
fn_sum = _adaptive_pooling_idx_sum(
[h_kernel_max, w_kernel_max],
[h_start_index, w_start_index],
[h_end_index, w_end_index],
fn_sum = _adaptive_pooling_fn(
start_index=start_index,
end_index=end_index,
kernel_maxes=[h_kernel_max, w_kernel_max],
in_sizes=[h_in, w_in],
out_sizes=[h_out, w_out],
pooling_fn=ops.add,
)
ones_loader = pad_adaptive_loader(ones_like(x))
@ -4139,60 +4209,6 @@ def _adaptive_avg_pool2d(x, output_size):
return rv
def _adaptive_pooling_idx_max(kernel_maxes, in_sizes, out_sizes, return_index, loader):
# NOTE: There is some duplication between this and addaptive_avg_pool2d and max_pool2d
# Look into refactoring/deduplication after #116418 is merged.
h_in, w_in = in_sizes
h_out, w_out = out_sizes
def start_index(index, out_dim, inp_dim):
return FloorDiv((index * inp_dim), out_dim)
def end_index(index, out_dim, inp_dim):
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
h_start_index_fn = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
h_end_index_fn = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
w_start_index_fn = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
w_end_index_fn = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
def fn_max(idx):
*prefix, bh, bw = idx
h_start_index = h_start_index_fn(bh)
h_end_index = h_end_index_fn(bh)
w_start_index = w_start_index_fn(bw)
w_end_index = w_end_index_fn(bw)
maxval = None
maxindex = None
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
val = loader(
prefix,
[ih, iw],
[h_start_index, w_start_index],
[h_end_index, w_end_index],
)
index = ops.index_expr(
(h_start_index + ih) * w_in + w_start_index + iw, torch.int64
)
if return_index:
if maxindex is None:
maxindex = index
else:
maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
if maxval is None:
maxval = val
else:
maxval = ops.maximum(val, maxval)
if return_index:
return maxindex
else:
return maxval
return fn_max
fallback_adaptive_max_pool2d = fallback_handler(
aten.adaptive_max_pool2d.default, add_to_fallback_set=False
)
@ -4245,32 +4261,46 @@ def adaptive_max_pool2d(x, output_size):
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
return fallback_adaptive_max_pool2d(x, output_size)
inner_func_max_val = _adaptive_pooling_idx_max(
def start_index(index, out_dim, inp_dim):
return FloorDiv((index * inp_dim), out_dim)
def end_index(index, out_dim, inp_dim):
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
inner_func_max_val = _adaptive_pooling_fn(
start_index=start_index,
end_index=end_index,
kernel_maxes=[h_kernel_max, w_kernel_max],
in_sizes=[h_in, w_in],
out_sizes=[h_out, w_out],
return_index=False,
loader=pad_adaptive_loader(x, float("-inf")),
pooling_fn=ops.maximum,
)
inner_func_max_idx = _adaptive_pooling_idx_max(
inner_func_max_idx = _adaptive_pooling_fn_with_idx(
start_index=start_index,
end_index=end_index,
kernel_maxes=[h_kernel_max, w_kernel_max],
in_sizes=[h_in, w_in],
out_sizes=[h_out, w_out],
return_index=True,
loader=pad_adaptive_loader(x, float("-inf")),
pooling_fn=ops.maximum,
)
def inner_fn_max_val(idx):
return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf")))
def inner_fn_max_idx(idx):
return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf")))
rv = Pointwise.create(
device=x.get_device(),
dtype=dtype,
inner_fn=inner_func_max_val,
inner_fn=inner_fn_max_val,
ranges=new_size,
)
ri = Pointwise.create(
device=x.get_device(),
dtype=torch.int64,
inner_fn=inner_func_max_idx,
inner_fn=inner_fn_max_idx,
ranges=new_size,
)
return rv, ri
@ -4400,16 +4430,13 @@ def upsample_nearest2d_backward(
def end_index(index, out_dim, inp_dim):
return start_index((index + 1), out_dim, inp_dim)
h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h)
h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h)
w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w)
w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w)
fn_sum = _adaptive_pooling_idx_sum(
[h_kernel_max, w_kernel_max],
[h_start_index, w_start_index],
[h_end_index, w_end_index],
fn_sum = _adaptive_pooling_fn(
start_index=start_index,
end_index=end_index,
kernel_maxes=[h_kernel_max, w_kernel_max],
in_sizes=[inp_h, inp_w],
out_sizes=[out_h, out_w],
pooling_fn=ops.add,
)
def fn(idx):
@ -4761,6 +4788,207 @@ def avg_pool2d_backward(
return rv
fallback_avg_pool3d_backward = fallback_handler(
aten.avg_pool3d_backward.default, add_to_fallback_set=False
)
@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None)
def avg_pool3d_backward(
grad_output,
x,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override=None,
):
assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
if not stride:
stride = kernel_size
if not padding:
padding = [0, 0, 0]
assert isinstance(grad_output, TensorBox)
assert isinstance(x, TensorBox)
assert len(kernel_size) == 3
assert len(stride) == 3
assert len(padding) == 3
assert len(x.get_size()) in (4, 5)
grad_output.realize_hint()
*batch, depth, height, width = x.get_size()
d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode)
h_out, ceil_mode_h = pooling_size(
height, 1, kernel_size, stride, padding, ceil_mode
)
w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode)
grad_loader = grad_output.make_loader()
had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w
*_, pooled_depth, pooled_height, pooled_width = grad_output.get_size()
new_size = list(x.get_size())
dtype = x.get_dtype()
d_window_size, h_window_size, w_window_size = (
max(
max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1)
for d in range(kernel_size[i] * 2)
)
for i in range(3)
)
window_size = d_window_size * h_window_size * w_window_size
if window_size > 125:
# Kernel size too big. Results in hard-to-optimize Triton code.
return fallback_avg_pool3d_backward(
grad_output,
x,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
)
def compute_pool_size_without_padding(pd, ph, pw):
stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride)
pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding)
kernel_d, kernel_h, kernel_w = (
ops.constant(k, torch.int32) for k in kernel_size
)
dstart, hstart, wstart = (
ops.sub(ops.mul(p, s), pad)
for p, s, pad in zip(
[pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w]
)
)
dend, hend, wend = (
ops.minimum(
ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad)
)
for start, k, dim, pad in zip(
[dstart, hstart, wstart],
[kernel_d, kernel_h, kernel_w],
[depth, height, width],
[pad_d, pad_h, pad_w],
)
)
dstart, hstart, wstart = (
ops.maximum(start, ops.constant(0, torch.int32))
for start in [dstart, hstart, wstart]
)
dend, hend, wend = (
ops.minimum(end, ops.index_expr(dim, torch.int32))
for end, dim in zip([dend, hend, wend], [depth, height, width])
)
divide_factor = ops.mul(
ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart)
)
return divide_factor
def fn(idx):
*prefix, d, h, w = idx
d, h, w = (v + pad for v, pad in zip([d, h, w], padding))
pdstart, phstart, pwstart = (
ops.index_expr(FloorDiv(v - k + s, s), torch.int32)
for v, k, s in zip([d, h, w], kernel_size, stride)
)
pdend, phend, pwend = (
ops.index_expr(FloorDiv(v, s) + 1, torch.int32)
for v, s in zip([d, h, w], stride)
)
pdstart, phstart, pwstart = (
ops.maximum(pstart, ops.constant(0, torch.int32))
for pstart in [pdstart, phstart, pwstart]
)
pdend, phend, pwend = (
ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32))
for pend, pooled_dim in zip(
[pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width]
)
)
gradient = None
# Iterate over the 3D region to accumulate gradients
for pd_ in range(d_window_size):
for ph_ in range(h_window_size):
for pw_ in range(w_window_size):
pd, ph, pw = (
ops.add(pstart, ops.constant(p_, torch.int32))
for pstart, p_ in zip(
[pdstart, phstart, pwstart], [pd_, ph_, pw_]
)
)
if divisor_override is not None:
scale = divisor_override
elif count_include_pad or not had_padding:
scale = kernel_size[0] * kernel_size[1] * kernel_size[2]
else:
scale = compute_pool_size_without_padding(pd, ph, pw)
part = ops.truediv(
grad_loader(
[
*prefix,
ops.indirect_indexing(
ops.minimum(
pd, ops.sub(pdend, ops.constant(1, torch.int32))
),
pooled_depth,
check=False,
),
ops.indirect_indexing(
ops.minimum(
ph, ops.sub(phend, ops.constant(1, torch.int32))
),
pooled_height,
check=False,
),
ops.indirect_indexing(
ops.minimum(
pw, ops.sub(pwend, ops.constant(1, torch.int32))
),
pooled_width,
check=False,
),
]
),
scale,
)
mask = ops.and_(
ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)),
ops.lt(pw, pwend),
)
if gradient is None:
gradient = ops.where(
mask, part, ops.constant(0.0, torch.float32)
)
else:
gradient = ops.where(mask, ops.add(gradient, part), gradient)
assert gradient is not None
return gradient
rv = Pointwise.create(
device=grad_output.get_device(),
dtype=dtype,
inner_fn=fn,
ranges=new_size,
)
return rv
def _validate_reduction_axis(x, axis):
size = x.get_size()
if isinstance(axis, int):

View File

@ -67,6 +67,7 @@
#include <torch/csrc/cpu/Module.h>
#include <torch/csrc/dynamo/init.h>
#include <torch/csrc/functorch/init.h>
#include <torch/csrc/fx/node.h>
#include <torch/csrc/inductor/aoti_runner/pybind.h>
#include <torch/csrc/jit/python/init.h>
#include <torch/csrc/jit/python/python_ir.h>
@ -1602,6 +1603,8 @@ PyObject* initModule() {
THPDevice_init(module);
THPStream_init(module);
THPEvent_init(module);
NodeBase_init(module);
NodeIter_init(module);
ASSERT_TRUE(THPVariable_initModule(module));
ASSERT_TRUE(THPFunction_initModule(module));
ASSERT_TRUE(THPEngine_initModule(module));

257
torch/csrc/fx/node.cpp Normal file
View File

@ -0,0 +1,257 @@
#include <torch/csrc/fx/node.h>
#include <structmember.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
////////////////////////////////
// NodeBase
///////////////////////////////
struct NodeBase {
PyObject_HEAD bool _erased;
NodeBase* _prev;
NodeBase* _next;
};
static PyObject* NodeBase_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwds) {
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
return self;
}
static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
self->_erased = false;
Py_INCREF(self);
self->_prev = self;
Py_INCREF(self);
self->_next = self;
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static struct PyMemberDef NodeBase_members[] = {
{"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
{"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
{"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
{nullptr} /* Sentinel */
};
static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->_prev);
Py_VISIT(self->_next);
return 0;
}
static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->_prev);
Py_CLEAR(self->_next);
return 0;
}
static void NodeBase_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
(void)NodeBase_clear((NodeBase*)self);
Py_TYPE(self)->tp_free(self);
}
static PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */
sizeof(NodeBase), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)NodeBase_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
Py_TPFLAGS_HAVE_GC, /* tp_flags */
nullptr, /* tp_doc */
(traverseproc)NodeBase_traverse, /* tp_traverse */
(inquiry)NodeBase_clear, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
NodeBase_members, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)NodeBase_init_fn, /* tp_init */
nullptr, /* tp_alloc */
NodeBase_new, /* tp_new */
};
bool NodeBase_init(PyObject* module) {
if (PyModule_AddType(module, &NodeBaseType) < 0) {
return false;
}
return true;
}
////////////////////////////////
// NodeIter
////////////////////////////////
struct NodeIter {
PyObject_HEAD bool _reversed;
NodeBase* _root;
NodeBase* _cur;
};
static PyObject* NodeIter_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwds) {
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
return self;
}
static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) {
NodeBase* root = nullptr;
bool reversed = false;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr const char* keywords[] = {"root", "reversed", nullptr};
if (!PyArg_ParseTupleAndKeywords(
args,
kwargs,
"Ob|",
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<char**>(keywords),
&root,
&reversed)) {
return -1;
}
self->_reversed = reversed;
Py_INCREF(root);
self->_root = root;
Py_INCREF(root);
self->_cur = root;
return 0;
}
template <bool reversed>
PyObject* NodeIter_iternext_helper(NodeIter* self) {
// It should be possible to relax the ref counting here
// but in practice, we do not have that many _erased Nodes,
// so probably not worth it.
if constexpr (reversed) {
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
Py_CLEAR(self->_cur);
self->_cur = prev;
} else {
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
Py_CLEAR(self->_cur);
self->_cur = next;
}
while (self->_cur != self->_root) {
if (!self->_cur->_erased) {
Py_INCREF(self->_cur);
return (PyObject*)self->_cur;
}
if constexpr (reversed) {
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
Py_CLEAR(self->_cur);
self->_cur = prev;
} else {
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
Py_CLEAR(self->_cur);
self->_cur = next;
}
}
PyErr_SetNone(PyExc_StopIteration);
return nullptr;
}
PyObject* NodeIter_iternext(PyObject* _self) {
NodeIter* self = (NodeIter*)_self;
if (self->_reversed) {
return NodeIter_iternext_helper<true>(self);
} else {
return NodeIter_iternext_helper<false>(self);
}
}
static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) {
Py_VISIT(self->_root);
Py_VISIT(self->_cur);
return 0;
}
static int NodeIter_clear(NodeIter* self) {
Py_CLEAR(self->_root);
Py_CLEAR(self->_cur);
return 0;
}
static void NodeIter_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
(void)NodeIter_clear((NodeIter*)self);
Py_TYPE(self)->tp_free(self);
}
static PyTypeObject NodeIterType = {
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */
sizeof(NodeIter), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)NodeIter_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
nullptr, /* tp_doc */
(traverseproc)NodeIter_traverse, /* tp_traverse */
(inquiry)NodeIter_clear, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
NodeIter_iternext, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)NodeIter_init_fn, /* tp_init */
nullptr, /* tp_alloc */
NodeIter_new, /* tp_new */
};
bool NodeIter_init(PyObject* module) {
if (PyModule_AddType(module, &NodeIterType) < 0) {
return false;
}
return true;
}

6
torch/csrc/fx/node.h Normal file
View File

@ -0,0 +1,6 @@
#pragma once
#include <torch/csrc/python_headers.h>
bool NodeBase_init(PyObject* module);
bool NodeIter_init(PyObject* module);

View File

@ -86,6 +86,7 @@ def _insert_copy_for_mutations(
assert len(outputs) == len(mutated_outputs)
user_output_nodes = []
return_nodes_to_copy = {}
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
if mutated_node_name is None:
user_output_nodes.append(return_node)
@ -101,13 +102,18 @@ def _insert_copy_for_mutations(
)
with gm.graph.inserting_before(output_node):
_ = gm.graph.call_function(
copy_node = gm.graph.call_function(
torch.ops.aten.copy_.default, (mutated_node, return_node)
)
return_nodes_to_copy[return_node] = copy_node
output_args = [
return_nodes_to_copy[node] if node in return_nodes_to_copy else node
for node in user_output_nodes
]
with gm.graph.inserting_before(output_node):
# Only return user outputs
new_output = gm.graph.output(tuple(user_output_nodes))
new_output = gm.graph.output(tuple(output_args))
output_node.replace_all_uses_with(new_output)
gm.graph.erase_node(output_node)

View File

@ -453,20 +453,21 @@ def free_unbacked_symbols(x):
# setup!
def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]:
if (
node.op == "placeholder" and
"val" in node.meta and
isinstance(node.meta["val"], torch.SymInt) and
isinstance(node.meta["val"].node.expr, sympy.Symbol)
isinstance(node.meta["val"].node.expr, sympy.Symbol) and
(node.op == "placeholder" or free_unbacked_symbols(node.meta["val"].node.expr))
):
return node.meta["val"].node.expr
return None
def find_symbol_binding_fx_nodes(graph):
return {
node.meta["val"].node.expr: node
for node in graph.nodes
if is_symbol_binding_fx_node(node)
}
r = {}
# NB: Prefer first occurrence of symbol
for node in graph.nodes:
if is_symbol_binding_fx_node(node) and node.meta["val"].node.expr not in r:
r[node.meta["val"].node.expr] = node
return r
# Analogous to ConvertIntSource

View File

@ -4,6 +4,7 @@ from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_na
import torch.utils._pytree as pytree
from . import _pytree as fx_pytree
from ._compatibility import compatibility
from torch._C import _NodeIter
import os
import contextlib
@ -271,20 +272,8 @@ class _node_list:
return self.graph._len
def __iter__(self):
root = self.graph._root
if self.direction == "_next":
cur = root._next
while cur is not root:
if not cur._erased:
yield cur
cur = cur._next
else:
assert self.direction == "_prev"
cur = root._prev
while cur is not root:
if not cur._erased:
yield cur
cur = cur._prev
assert self.direction == "_prev" or self.direction == "_next"
yield from _NodeIter(self.graph._root, self.direction == "_prev")
def __reversed__(self):
return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')

View File

@ -11,6 +11,7 @@ import inspect
import warnings
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
from .._ops import ops as _ops
from torch._C import _NodeBase
if TYPE_CHECKING:
from .graph import Graph
@ -139,7 +140,7 @@ def _format_arg(arg, max_list_len=float('inf')) -> str:
return str(arg)
@compatibility(is_backward_compatible=True)
class Node:
class Node(_NodeBase):
"""
``Node`` is the data structure that represents individual operations within
a ``Graph``. For the most part, Nodes represent callsites to various entities,
@ -197,6 +198,7 @@ class Node:
annotation of values in the generated code or for other types
of analyses.
"""
super().__init__()
self.graph = graph
self.name = name # unique name of value being created
assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']
@ -235,9 +237,6 @@ class Node:
# does not produce a value, it's more of a notation. Thus, this value
# describes the type of args[0] in the ``return`` node.
self.type : Optional[Any] = return_type
self._prev = self
self._next = self
self._erased = False
self._sort_key: Any = ()
# If set, use this fn to print this node
@ -247,6 +246,22 @@ class Node:
# transformations. This metadata is preserved across node copies
self.meta : Dict[str, Any] = {}
def __getstate__(self):
state = self.__dict__.copy()
state["_erased"] = self._erased
state["_prev"] = self._prev
state["_next"] = self._next
return state
def __setstate__(self, state):
_erased = state.pop("_erased")
_prev = state.pop("_prev")
_next = state.pop("_next")
self.__dict__.update(state)
self._erased = _erased
self._prev = _prev
self._next = _next
@property
def next(self) -> 'Node':
"""