Compare commits

..

6 Commits

Author SHA1 Message Date
a206dcc79e fb_memcache: Move to fbcode from thirdparty (#128174)
Summary: The fb_memcache injections location and path is changing.

Test Plan: Existing tests should pass.

Reviewed By: bertmaher, oulgen

Differential Revision: D57973772

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128174
Approved by: https://github.com/oulgen
2024-06-11 07:46:12 +00:00
f2d7f235a6 [dynamo][yolov3] Track UnspecializedNNModuleVariable for mutation (#128269)
Fixes https://github.com/pytorch/pytorch/issues/101168

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128269
Approved by: https://github.com/jansel
ghstack dependencies: #128295, #126578, #128268, #128254
2024-06-11 07:09:04 +00:00
402b289f3b Properly register parameter for binary folding test (#128356)
This PR properly registers the tensor used in the module compute as a parameter. This bug was hidden previously because all tensors on the nn modules would be considered constant by dynamo, with inlining NN modules, this is no longer the case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128356
Approved by: https://github.com/anijain2305
ghstack dependencies: #128355
2024-06-11 06:48:26 +00:00
a32157c67c Mark params static if inlining modules and freezing (#128355)
Today inlining builtin nn modules is not compatible with parameter freezing. Freezing parameters and then constant folding them through the graph relies on the assumption that they will not be inputs and will be static across calls to the same graph. When inlining builtin nn modules this assumption is broken and we reuse the same graph for different instances of the same nn module. There are three options 1) abandon constant folding, 2) create a dispatcher layer (like cudagraphs) which will dispatch to the correct constant-folded graph for each distinct set of parameters or 3) recompile

This PR implements 3 by introducing guards on the parameter pointers. This was due to freezing being relatively rare and performance sensistive. 2 Had many more unknowns and 1 is not a viable option due to the drop in performance.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128355
Approved by: https://github.com/anijain2305
2024-06-11 06:48:26 +00:00
24e7f29099 Lowering for avg_pool_3d_backward (Fixes:#127101) (#127722)
We implemented a lowering for the avg_pool3d_backward operation and created tests for it.
We ran some benchmarks and achieved the following results:

```
[-------------- avgpool_3d_backwards --------------]
                             |  Decomposed  |  Eager
16 threads: ----------------------------------------
      (3, 5, 400, 200, 200)  |     6061     |  11160
      (3, 5, 300, 200, 200)  |     4547     |   8372
      (3, 5, 200, 200, 200)  |     3032     |   5585
      (3, 5, 300, 300, 300)  |    10100     |  18840
      (3, 5, 100, 100, 100)  |      381     |    703
      (3, 5, 100, 300, 200)  |     2270     |   4190
      (8, 8, 128, 128, 128)  |     3397     |   6253
      (2, 3, 150, 150, 150)  |      520     |    947
      (1, 3, 128, 128, 128)  |      161     |    299
      (8, 16, 64, 64, 64)    |      851     |   1569
      (1, 1, 50, 50, 50)     |       17     |     11
      (3, 5, 20, 40, 40)     |       17     |     30
      (3, 5, 10, 20, 20)     |       17     |     11
      (1, 1, 10, 10, 10)     |       16     |     11
      (3, 5, 5, 10, 10)      |       17     |     11
      (3, 5, 2, 5, 5)        |       17     |     11
```
These were run on an RTX 3050, so we were not able to allocate larger tensors due to memory limitations.
We believe it would be beneficial to benchmark this on more recent hardware, just to check if the performance holds up with larger sizes.

Furthermore, we also refactored code from adaptive_avg_pool2d and adaptive_max_pool2d, to reduce code duplication.
We diffed the kernels and they are identical.

Fixes #127101

Co-authored-by: Martim Mendes <martimccmendes@tecnico.ulisboa.pt>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127722
Approved by: https://github.com/jansel
2024-06-11 06:39:04 +00:00
5b5d269d34 Speed up fx graph iteration by implementing it in C++ (#128288)
Before this change
```
python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py
iterating over 100000000 FX nodes took 19.5s (5132266 nodes/s)
```

After this change
```
python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py
iterating over 100000000 FX nodes took 3.4s (29114001 nodes/s)
```

5.7x improvement

Differential Revision: [D58343997](https://our.internmc.facebook.com/intern/diff/D58343997)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128288
Approved by: https://github.com/jansel, https://github.com/albanD
2024-06-11 05:48:31 +00:00
33 changed files with 769 additions and 142 deletions

View File

@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
378
379
380
381

View File

@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
yolov3,pass,9
yolov3,fail_accuracy,8

1 name accuracy graph_breaks
286
287
288
289

View File

@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
yolov3,fail_to_run,0
yolov3,pass,0

1 name accuracy graph_breaks
350
351
352
353

View File

@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
338
339
340
341

View File

@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
338
339
340
341

View File

@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
yolov3,fail_to_run,0
yolov3,pass,0

1 name accuracy graph_breaks
350
351
352
353

View File

@ -374,4 +374,4 @@ vision_maskrcnn,pass,17
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
374
375
376
377

View File

@ -282,4 +282,4 @@ vision_maskrcnn,pass,34
yolov3,pass,9
yolov3,fail_accuracy,8

1 name accuracy graph_breaks
282
283
284
285

View File

@ -298,4 +298,4 @@ vision_maskrcnn,pass,28
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
298
299
300
301

View File

@ -374,4 +374,4 @@ vision_maskrcnn,pass,17
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
374
375
376
377

View File

@ -282,4 +282,4 @@ vision_maskrcnn,pass,34
yolov3,pass,9
yolov3,pass,8

1 name accuracy graph_breaks
282
283
284
285

View File

@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
378
379
380
381

View File

@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
yolov3,pass,9
yolov3,pass,8

1 name accuracy graph_breaks
286
287
288
289

View File

@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
yolov3,pass,2
yolov3,pass,0

1 name accuracy graph_breaks
378
379
380
381

View File

@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
yolov3,pass,9
yolov3,pass,8

1 name accuracy graph_breaks
286
287
288
289

View File

@ -827,6 +827,7 @@ libtorch_python_core_sources = [
"torch/csrc/dynamo/guards.cpp",
"torch/csrc/dynamo/init.cpp",
"torch/csrc/functorch/init.cpp",
"torch/csrc/fx/node.cpp",
"torch/csrc/mps/Module.cpp",
"torch/csrc/mtia/Module.cpp",
"torch/csrc/inductor/aoti_runner/pybind.cpp",

View File

@ -56,7 +56,7 @@ class BinaryFoldingTemplate(TestCase):
self.use_scalar = scalar
tensor_size = [1 for _ in range(self.conv.weight.ndim)]
tensor_size[1] = self.conv.weight.size(0)
self.tensor = (
self.tensor = torch.nn.Parameter(
add_tensor
if add_tensor is not None
else torch.rand(tensor_size).to(device)
@ -136,7 +136,11 @@ class BinaryFoldingTemplate(TestCase):
nn.Conv2d,
pytorch_op,
False,
add_tensor=torch.rand(32, 1, 32).to(self.device),
add_tensor=torch.rand(
32,
1,
32,
).to(self.device),
expect_success=False,
)
@ -156,7 +160,7 @@ class BinaryFoldingTemplate(TestCase):
nn.Conv2d,
pytorch_op,
False,
add_tensor=torch.tensor([2]).to(torch.int).to(self.device),
add_tensor=torch.tensor([2]).to(torch.float64).to(self.device),
expect_success=False,
)

View File

@ -195,7 +195,7 @@ class TestFxGraphCache(TestCase):
num_put += 1
cache_module = (
"triton.runtime.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend"
"triton.fb.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend"
if config.is_fbcode()
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
)

View File

@ -267,7 +267,7 @@ class TestMaxAutotune(TestCase):
num_put += 1
cache_module = (
"triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend"
"triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend"
if config.is_fbcode()
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
)

View File

@ -7851,6 +7851,95 @@ class CommonTemplate:
)
assertGeneratedKernelCountEqual(self, 0)
def test_avg_pool3d_backward(self):
def fn(a, b):
return aten.avg_pool3d_backward(
a,
b,
[2, 2, 2],
[2, 2, 2],
[0, 0, 0],
True,
False,
None,
)
self.common(
fn,
[
torch.randn([2, 4, 7, 7, 7]),
torch.randn([2, 4, 14, 14, 14]),
],
)
def test_avg_pool3d_backward2(self):
def fn(a, b):
return aten.avg_pool3d_backward(
a,
b,
[3, 3, 3],
[1, 1, 1],
[1, 1, 1],
True,
False,
None,
)
self.common(
fn,
[
torch.randn([1, 1, 20, 20, 15]),
torch.randn([1, 1, 20, 20, 15]),
],
)
def test_avg_pool3d_backward3(self):
def fn(a, b):
return aten.avg_pool3d_backward(
a,
b,
[1, 1, 1],
[2, 2, 2],
[0, 0, 0],
False,
False,
None,
)
torch._inductor.metrics.generated_kernel_count = 0
self.common(
fn,
[
torch.randn([1, 2016, 11, 11, 11]),
torch.randn([1, 2016, 21, 21, 21]),
],
)
assertGeneratedKernelCountEqual(self, 1)
def test_avg_pool3d_backward4(self):
def fn(a, b):
return aten.avg_pool3d_backward(
a,
b,
[13, 13, 13],
[1, 1, 1],
[0, 0, 0],
True,
False,
None,
)
torch._inductor.metrics.generated_kernel_count = 0
self.common(
fn,
[
torch.randn([1, 16, 12, 12, 12]),
torch.randn([1, 16, 24, 24, 24]),
],
check_lowp=False,
)
assertGeneratedKernelCountEqual(self, 0)
@config.patch(search_autotune_cache=False)
def test_mm_views(self):
def fn(a, b):

View File

@ -146,6 +146,7 @@ test_failures = {
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_avg_pool3d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_baddbmm_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_bmm2_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_both_scalars_dynamic_shapes": TestFailure(("cpu", "cuda")),

View File

@ -2333,3 +2333,14 @@ def _save_pickle(obj: Any) -> bytes: ...
# Defined in torch/csrc/jit/runtime/static/init.cpp
def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ...
def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ...
# Defined in torch/csrc/fx/node.cpp
class _NodeBase:
_erased: _bool
_prev: "_NodeBase"
_next: "_NodeBase"
class _NodeIter(Iterator):
def __init__(self, root: _NodeBase, reversed: _bool) -> None: ...
def __iter__(self) -> Iterator[_NodeBase]: ...
def __next__(self) -> _NodeBase: ...

View File

@ -752,7 +752,13 @@ class OutputGraph:
**options,
):
if is_dynamic_nn_module(target, self.root_tx.export):
return variables.UnspecializedNNModuleVariable(target, **options)
result = variables.UnspecializedNNModuleVariable(target, **options)
if not SideEffects.cls_supports_mutation_side_effects(type(target)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
return self.root_tx.output.side_effects.track_object_existing(
target, result
)
options = dict(options)
assert "source" in options

View File

@ -1128,6 +1128,19 @@ class VariableBuilder:
if mutation_guard.is_dynamic_nn_module(value, self.tx.export):
# created dynamically, don't specialize on it
self.install_guards(GuardBuilder.TYPE_MATCH)
if (
torch._dynamo.config.inline_inbuilt_nn_modules
and torch._inductor.config.freezing
and not torch.is_grad_enabled()
):
from ..decorators import mark_static_address
for p in value.parameters():
mark_static_address(p)
for b in value.buffers():
mark_static_address(b)
result = UnspecializedNNModuleVariable(value, source=self.source)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__

View File

@ -1021,7 +1021,7 @@ class FxGraphCache:
cache_id = "fx-graph-v1"
try:
if config.is_fbcode():
from triton.runtime.fb_memcache import (
from triton.fb.fb_memcache import (
FbMemcacheRemoteFxGraphCacheBackend,
)

View File

@ -400,7 +400,7 @@ def should_use_remote_fx_graph_cache():
return False
try:
from triton.runtime.fb_memcache import MEMCACHE_VERSION
from triton.fb.fb_memcache import MEMCACHE_VERSION
except ModuleNotFoundError:
return False

View File

@ -2155,7 +2155,6 @@ make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp
make_fallback(aten.avg_pool3d_backward)
make_fallback(aten.max_pool3d_with_indices_backward)
make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
make_fallback(aten._adaptive_avg_pool3d_backward)
@ -4034,11 +4033,32 @@ def pad_adaptive_loader(x, pad_val=0.0):
return load
def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
h_start_index_fn, w_start_index_fn = start_index_fns
h_end_index_fn, w_end_index_fn = end_index_fns
def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out):
h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
def fn_sum(idx, loader):
w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
return h_start_index, h_end_index, w_start_index, w_end_index
def _adaptive_pooling_fn(
start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
):
h_in, w_in = in_sizes
h_out, w_out = out_sizes
(
h_start_index_fn,
h_end_index_fn,
w_start_index_fn,
w_end_index_fn,
) = compute_indices_adaptive_pooling(
start_index, end_index, h_in, w_in, h_out, w_out
)
def fn(idx, loader):
*prefix, bh, bw = idx
h_start_index = h_start_index_fn(bh)
@ -4047,7 +4067,7 @@ def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
w_start_index = w_start_index_fn(bw)
w_end_index = w_end_index_fn(bw)
total = None
result = None
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
val = loader(
prefix,
@ -4055,13 +4075,66 @@ def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
[h_start_index, w_start_index],
[h_end_index, w_end_index],
)
if total is None:
total = val
if result is None:
result = val
else:
total = ops.add(val, total)
return total
result = pooling_fn(val, result)
return result
return fn_sum
return fn
def _adaptive_pooling_fn_with_idx(
start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
):
h_in, w_in = in_sizes
h_out, w_out = out_sizes
(
h_start_index_fn,
h_end_index_fn,
w_start_index_fn,
w_end_index_fn,
) = compute_indices_adaptive_pooling(
start_index, end_index, h_in, w_in, h_out, w_out
)
def fn(idx, loader):
*prefix, bh, bw = idx
h_start_index = h_start_index_fn(bh)
h_end_index = h_end_index_fn(bh)
w_start_index = w_start_index_fn(bw)
w_end_index = w_end_index_fn(bw)
maxval = None
maxindex = None
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
val = loader(
prefix,
[ih, iw],
[h_start_index, w_start_index],
[h_end_index, w_end_index],
)
index = ops.index_expr(
(h_start_index + ih) * w_in + w_start_index + iw, torch.int64
)
if maxindex is None:
maxindex = index
else:
maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
if maxval is None:
maxval = val
else:
maxval = pooling_fn(val, maxval)
return maxindex
return fn
fallback_adaptive_avg_pool2d = fallback_handler(
@ -4099,27 +4172,24 @@ def _adaptive_avg_pool2d(x, output_size):
new_size = list(batch) + [h_out, w_out]
dtype = x.get_dtype()
window_size = h_kernel_max * w_kernel_max
if window_size > 25:
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
return fallback_adaptive_avg_pool2d(x, output_size)
def start_index(index, out_dim, inp_dim):
return FloorDiv((index * inp_dim), out_dim)
def end_index(index, out_dim, inp_dim):
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
window_size = h_kernel_max * w_kernel_max
if window_size > 25:
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
return fallback_adaptive_avg_pool2d(x, output_size)
fn_sum = _adaptive_pooling_idx_sum(
[h_kernel_max, w_kernel_max],
[h_start_index, w_start_index],
[h_end_index, w_end_index],
fn_sum = _adaptive_pooling_fn(
start_index=start_index,
end_index=end_index,
kernel_maxes=[h_kernel_max, w_kernel_max],
in_sizes=[h_in, w_in],
out_sizes=[h_out, w_out],
pooling_fn=ops.add,
)
ones_loader = pad_adaptive_loader(ones_like(x))
@ -4139,60 +4209,6 @@ def _adaptive_avg_pool2d(x, output_size):
return rv
def _adaptive_pooling_idx_max(kernel_maxes, in_sizes, out_sizes, return_index, loader):
# NOTE: There is some duplication between this and addaptive_avg_pool2d and max_pool2d
# Look into refactoring/deduplication after #116418 is merged.
h_in, w_in = in_sizes
h_out, w_out = out_sizes
def start_index(index, out_dim, inp_dim):
return FloorDiv((index * inp_dim), out_dim)
def end_index(index, out_dim, inp_dim):
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
h_start_index_fn = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
h_end_index_fn = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
w_start_index_fn = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
w_end_index_fn = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
def fn_max(idx):
*prefix, bh, bw = idx
h_start_index = h_start_index_fn(bh)
h_end_index = h_end_index_fn(bh)
w_start_index = w_start_index_fn(bw)
w_end_index = w_end_index_fn(bw)
maxval = None
maxindex = None
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
val = loader(
prefix,
[ih, iw],
[h_start_index, w_start_index],
[h_end_index, w_end_index],
)
index = ops.index_expr(
(h_start_index + ih) * w_in + w_start_index + iw, torch.int64
)
if return_index:
if maxindex is None:
maxindex = index
else:
maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
if maxval is None:
maxval = val
else:
maxval = ops.maximum(val, maxval)
if return_index:
return maxindex
else:
return maxval
return fn_max
fallback_adaptive_max_pool2d = fallback_handler(
aten.adaptive_max_pool2d.default, add_to_fallback_set=False
)
@ -4245,32 +4261,46 @@ def adaptive_max_pool2d(x, output_size):
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
return fallback_adaptive_max_pool2d(x, output_size)
inner_func_max_val = _adaptive_pooling_idx_max(
def start_index(index, out_dim, inp_dim):
return FloorDiv((index * inp_dim), out_dim)
def end_index(index, out_dim, inp_dim):
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
inner_func_max_val = _adaptive_pooling_fn(
start_index=start_index,
end_index=end_index,
kernel_maxes=[h_kernel_max, w_kernel_max],
in_sizes=[h_in, w_in],
out_sizes=[h_out, w_out],
return_index=False,
loader=pad_adaptive_loader(x, float("-inf")),
pooling_fn=ops.maximum,
)
inner_func_max_idx = _adaptive_pooling_idx_max(
inner_func_max_idx = _adaptive_pooling_fn_with_idx(
start_index=start_index,
end_index=end_index,
kernel_maxes=[h_kernel_max, w_kernel_max],
in_sizes=[h_in, w_in],
out_sizes=[h_out, w_out],
return_index=True,
loader=pad_adaptive_loader(x, float("-inf")),
pooling_fn=ops.maximum,
)
def inner_fn_max_val(idx):
return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf")))
def inner_fn_max_idx(idx):
return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf")))
rv = Pointwise.create(
device=x.get_device(),
dtype=dtype,
inner_fn=inner_func_max_val,
inner_fn=inner_fn_max_val,
ranges=new_size,
)
ri = Pointwise.create(
device=x.get_device(),
dtype=torch.int64,
inner_fn=inner_func_max_idx,
inner_fn=inner_fn_max_idx,
ranges=new_size,
)
return rv, ri
@ -4400,16 +4430,13 @@ def upsample_nearest2d_backward(
def end_index(index, out_dim, inp_dim):
return start_index((index + 1), out_dim, inp_dim)
h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h)
h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h)
w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w)
w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w)
fn_sum = _adaptive_pooling_idx_sum(
[h_kernel_max, w_kernel_max],
[h_start_index, w_start_index],
[h_end_index, w_end_index],
fn_sum = _adaptive_pooling_fn(
start_index=start_index,
end_index=end_index,
kernel_maxes=[h_kernel_max, w_kernel_max],
in_sizes=[inp_h, inp_w],
out_sizes=[out_h, out_w],
pooling_fn=ops.add,
)
def fn(idx):
@ -4761,6 +4788,207 @@ def avg_pool2d_backward(
return rv
fallback_avg_pool3d_backward = fallback_handler(
aten.avg_pool3d_backward.default, add_to_fallback_set=False
)
@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None)
def avg_pool3d_backward(
grad_output,
x,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override=None,
):
assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
if not stride:
stride = kernel_size
if not padding:
padding = [0, 0, 0]
assert isinstance(grad_output, TensorBox)
assert isinstance(x, TensorBox)
assert len(kernel_size) == 3
assert len(stride) == 3
assert len(padding) == 3
assert len(x.get_size()) in (4, 5)
grad_output.realize_hint()
*batch, depth, height, width = x.get_size()
d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode)
h_out, ceil_mode_h = pooling_size(
height, 1, kernel_size, stride, padding, ceil_mode
)
w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode)
grad_loader = grad_output.make_loader()
had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w
*_, pooled_depth, pooled_height, pooled_width = grad_output.get_size()
new_size = list(x.get_size())
dtype = x.get_dtype()
d_window_size, h_window_size, w_window_size = (
max(
max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1)
for d in range(kernel_size[i] * 2)
)
for i in range(3)
)
window_size = d_window_size * h_window_size * w_window_size
if window_size > 125:
# Kernel size too big. Results in hard-to-optimize Triton code.
return fallback_avg_pool3d_backward(
grad_output,
x,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
)
def compute_pool_size_without_padding(pd, ph, pw):
stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride)
pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding)
kernel_d, kernel_h, kernel_w = (
ops.constant(k, torch.int32) for k in kernel_size
)
dstart, hstart, wstart = (
ops.sub(ops.mul(p, s), pad)
for p, s, pad in zip(
[pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w]
)
)
dend, hend, wend = (
ops.minimum(
ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad)
)
for start, k, dim, pad in zip(
[dstart, hstart, wstart],
[kernel_d, kernel_h, kernel_w],
[depth, height, width],
[pad_d, pad_h, pad_w],
)
)
dstart, hstart, wstart = (
ops.maximum(start, ops.constant(0, torch.int32))
for start in [dstart, hstart, wstart]
)
dend, hend, wend = (
ops.minimum(end, ops.index_expr(dim, torch.int32))
for end, dim in zip([dend, hend, wend], [depth, height, width])
)
divide_factor = ops.mul(
ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart)
)
return divide_factor
def fn(idx):
*prefix, d, h, w = idx
d, h, w = (v + pad for v, pad in zip([d, h, w], padding))
pdstart, phstart, pwstart = (
ops.index_expr(FloorDiv(v - k + s, s), torch.int32)
for v, k, s in zip([d, h, w], kernel_size, stride)
)
pdend, phend, pwend = (
ops.index_expr(FloorDiv(v, s) + 1, torch.int32)
for v, s in zip([d, h, w], stride)
)
pdstart, phstart, pwstart = (
ops.maximum(pstart, ops.constant(0, torch.int32))
for pstart in [pdstart, phstart, pwstart]
)
pdend, phend, pwend = (
ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32))
for pend, pooled_dim in zip(
[pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width]
)
)
gradient = None
# Iterate over the 3D region to accumulate gradients
for pd_ in range(d_window_size):
for ph_ in range(h_window_size):
for pw_ in range(w_window_size):
pd, ph, pw = (
ops.add(pstart, ops.constant(p_, torch.int32))
for pstart, p_ in zip(
[pdstart, phstart, pwstart], [pd_, ph_, pw_]
)
)
if divisor_override is not None:
scale = divisor_override
elif count_include_pad or not had_padding:
scale = kernel_size[0] * kernel_size[1] * kernel_size[2]
else:
scale = compute_pool_size_without_padding(pd, ph, pw)
part = ops.truediv(
grad_loader(
[
*prefix,
ops.indirect_indexing(
ops.minimum(
pd, ops.sub(pdend, ops.constant(1, torch.int32))
),
pooled_depth,
check=False,
),
ops.indirect_indexing(
ops.minimum(
ph, ops.sub(phend, ops.constant(1, torch.int32))
),
pooled_height,
check=False,
),
ops.indirect_indexing(
ops.minimum(
pw, ops.sub(pwend, ops.constant(1, torch.int32))
),
pooled_width,
check=False,
),
]
),
scale,
)
mask = ops.and_(
ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)),
ops.lt(pw, pwend),
)
if gradient is None:
gradient = ops.where(
mask, part, ops.constant(0.0, torch.float32)
)
else:
gradient = ops.where(mask, ops.add(gradient, part), gradient)
assert gradient is not None
return gradient
rv = Pointwise.create(
device=grad_output.get_device(),
dtype=dtype,
inner_fn=fn,
ranges=new_size,
)
return rv
def _validate_reduction_axis(x, axis):
size = x.get_size()
if isinstance(axis, int):

View File

@ -1031,7 +1031,7 @@ def should_use_remote_autotune_cache(inductor_meta):
if inductor_meta.get("is_hip"):
return False
from triton.runtime.fb_memcache import MEMCACHE_VERSION
from triton.fb.fb_memcache import MEMCACHE_VERSION
return MEMCACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
"pytorch/remote_cache:autotune_memcache_version"
@ -1075,8 +1075,12 @@ def cached_autotune(
try:
if inductor_meta.get("is_fbcode"):
remote_cache = triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend(
key
import triton.fb.fb_memcache
remote_cache = (
triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend(
key
)
)
else:
from torch._inductor.remote_cache import RedisRemoteCacheBackend

View File

@ -67,6 +67,7 @@
#include <torch/csrc/cpu/Module.h>
#include <torch/csrc/dynamo/init.h>
#include <torch/csrc/functorch/init.h>
#include <torch/csrc/fx/node.h>
#include <torch/csrc/inductor/aoti_runner/pybind.h>
#include <torch/csrc/jit/python/init.h>
#include <torch/csrc/jit/python/python_ir.h>
@ -1602,6 +1603,8 @@ PyObject* initModule() {
THPDevice_init(module);
THPStream_init(module);
THPEvent_init(module);
NodeBase_init(module);
NodeIter_init(module);
ASSERT_TRUE(THPVariable_initModule(module));
ASSERT_TRUE(THPFunction_initModule(module));
ASSERT_TRUE(THPEngine_initModule(module));

257
torch/csrc/fx/node.cpp Normal file
View File

@ -0,0 +1,257 @@
#include <torch/csrc/fx/node.h>
#include <structmember.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
////////////////////////////////
// NodeBase
///////////////////////////////
struct NodeBase {
PyObject_HEAD bool _erased;
NodeBase* _prev;
NodeBase* _next;
};
static PyObject* NodeBase_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwds) {
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
return self;
}
static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
self->_erased = false;
Py_INCREF(self);
self->_prev = self;
Py_INCREF(self);
self->_next = self;
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static struct PyMemberDef NodeBase_members[] = {
{"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
{"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
{"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
{nullptr} /* Sentinel */
};
static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->_prev);
Py_VISIT(self->_next);
return 0;
}
static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->_prev);
Py_CLEAR(self->_next);
return 0;
}
static void NodeBase_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
(void)NodeBase_clear((NodeBase*)self);
Py_TYPE(self)->tp_free(self);
}
static PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */
sizeof(NodeBase), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)NodeBase_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
Py_TPFLAGS_HAVE_GC, /* tp_flags */
nullptr, /* tp_doc */
(traverseproc)NodeBase_traverse, /* tp_traverse */
(inquiry)NodeBase_clear, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
NodeBase_members, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)NodeBase_init_fn, /* tp_init */
nullptr, /* tp_alloc */
NodeBase_new, /* tp_new */
};
bool NodeBase_init(PyObject* module) {
if (PyModule_AddType(module, &NodeBaseType) < 0) {
return false;
}
return true;
}
////////////////////////////////
// NodeIter
////////////////////////////////
struct NodeIter {
PyObject_HEAD bool _reversed;
NodeBase* _root;
NodeBase* _cur;
};
static PyObject* NodeIter_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwds) {
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
return self;
}
static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) {
NodeBase* root = nullptr;
bool reversed = false;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr const char* keywords[] = {"root", "reversed", nullptr};
if (!PyArg_ParseTupleAndKeywords(
args,
kwargs,
"Ob|",
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<char**>(keywords),
&root,
&reversed)) {
return -1;
}
self->_reversed = reversed;
Py_INCREF(root);
self->_root = root;
Py_INCREF(root);
self->_cur = root;
return 0;
}
template <bool reversed>
PyObject* NodeIter_iternext_helper(NodeIter* self) {
// It should be possible to relax the ref counting here
// but in practice, we do not have that many _erased Nodes,
// so probably not worth it.
if constexpr (reversed) {
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
Py_CLEAR(self->_cur);
self->_cur = prev;
} else {
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
Py_CLEAR(self->_cur);
self->_cur = next;
}
while (self->_cur != self->_root) {
if (!self->_cur->_erased) {
Py_INCREF(self->_cur);
return (PyObject*)self->_cur;
}
if constexpr (reversed) {
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
Py_CLEAR(self->_cur);
self->_cur = prev;
} else {
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
Py_CLEAR(self->_cur);
self->_cur = next;
}
}
PyErr_SetNone(PyExc_StopIteration);
return nullptr;
}
PyObject* NodeIter_iternext(PyObject* _self) {
NodeIter* self = (NodeIter*)_self;
if (self->_reversed) {
return NodeIter_iternext_helper<true>(self);
} else {
return NodeIter_iternext_helper<false>(self);
}
}
static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) {
Py_VISIT(self->_root);
Py_VISIT(self->_cur);
return 0;
}
static int NodeIter_clear(NodeIter* self) {
Py_CLEAR(self->_root);
Py_CLEAR(self->_cur);
return 0;
}
static void NodeIter_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
(void)NodeIter_clear((NodeIter*)self);
Py_TYPE(self)->tp_free(self);
}
static PyTypeObject NodeIterType = {
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */
sizeof(NodeIter), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)NodeIter_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
nullptr, /* tp_doc */
(traverseproc)NodeIter_traverse, /* tp_traverse */
(inquiry)NodeIter_clear, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
NodeIter_iternext, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)NodeIter_init_fn, /* tp_init */
nullptr, /* tp_alloc */
NodeIter_new, /* tp_new */
};
bool NodeIter_init(PyObject* module) {
if (PyModule_AddType(module, &NodeIterType) < 0) {
return false;
}
return true;
}

6
torch/csrc/fx/node.h Normal file
View File

@ -0,0 +1,6 @@
#pragma once
#include <torch/csrc/python_headers.h>
bool NodeBase_init(PyObject* module);
bool NodeIter_init(PyObject* module);

View File

@ -4,6 +4,7 @@ from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_na
import torch.utils._pytree as pytree
from . import _pytree as fx_pytree
from ._compatibility import compatibility
from torch._C import _NodeIter
import os
import contextlib
@ -271,20 +272,8 @@ class _node_list:
return self.graph._len
def __iter__(self):
root = self.graph._root
if self.direction == "_next":
cur = root._next
while cur is not root:
if not cur._erased:
yield cur
cur = cur._next
else:
assert self.direction == "_prev"
cur = root._prev
while cur is not root:
if not cur._erased:
yield cur
cur = cur._prev
assert self.direction == "_prev" or self.direction == "_next"
yield from _NodeIter(self.graph._root, self.direction == "_prev")
def __reversed__(self):
return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')

View File

@ -11,6 +11,7 @@ import inspect
import warnings
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
from .._ops import ops as _ops
from torch._C import _NodeBase
if TYPE_CHECKING:
from .graph import Graph
@ -139,7 +140,7 @@ def _format_arg(arg, max_list_len=float('inf')) -> str:
return str(arg)
@compatibility(is_backward_compatible=True)
class Node:
class Node(_NodeBase):
"""
``Node`` is the data structure that represents individual operations within
a ``Graph``. For the most part, Nodes represent callsites to various entities,
@ -197,6 +198,7 @@ class Node:
annotation of values in the generated code or for other types
of analyses.
"""
super().__init__()
self.graph = graph
self.name = name # unique name of value being created
assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']
@ -235,9 +237,6 @@ class Node:
# does not produce a value, it's more of a notation. Thus, this value
# describes the type of args[0] in the ``return`` node.
self.type : Optional[Any] = return_type
self._prev = self
self._next = self
self._erased = False
self._sort_key: Any = ()
# If set, use this fn to print this node
@ -247,6 +246,22 @@ class Node:
# transformations. This metadata is preserved across node copies
self.meta : Dict[str, Any] = {}
def __getstate__(self):
state = self.__dict__.copy()
state["_erased"] = self._erased
state["_prev"] = self._prev
state["_next"] = self._next
return state
def __setstate__(self, state):
_erased = state.pop("_erased")
_prev = state.pop("_prev")
_next = state.pop("_next")
self.__dict__.update(state)
self._erased = _erased
self._prev = _prev
self._next = _next
@property
def next(self) -> 'Node':
"""