Compare commits

...

10 Commits

Author SHA1 Message Date
7b5c5f8b36 Update on "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. "
Fix https://github.com/pytorch/pytorch/issues/163894

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-29 12:09:01 -07:00
0c31b1d46c Update on "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. "
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-29 09:18:34 -07:00
0069cf49a0 Update on "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. "
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 11:23:13 -07:00
121b28ee26 Update on "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. "
[ghstack-poisoned]
2025-10-27 23:17:44 -07:00
404bae6094 Update on "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. "
[ghstack-poisoned]
2025-10-27 17:47:37 -07:00
2d5bf015b6 Update on "WIP: fix comparing inductor strides vs bw graph strides for bw compile"
[ghstack-poisoned]
2025-10-27 13:33:48 -07:00
20233b08ce Update on "WIP: fix comparing inductor strides vs bw graph strides for bw compile"
[ghstack-poisoned]
2025-10-27 09:16:43 -07:00
4fc5aed281 Update on "WIP: fix comparing inductor strides vs bw graph strides for bw compile"
[ghstack-poisoned]
2025-10-26 18:21:39 -07:00
af34c67082 Update on "WIP: fix comparing inductor strides vs bw graph strides for bw compile"
[ghstack-poisoned]
2025-10-26 18:16:51 -07:00
6c1a16f03c WIP: fix comparing inductor strides vs bw graph strides for bw compile
[ghstack-poisoned]
2025-10-26 15:28:50 -07:00
2 changed files with 56 additions and 4 deletions

View File

@ -1718,6 +1718,39 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
self.assertEqual(eager_no_sq, comp_ind_no_sq)
self.assertEqual(eager_no_sq.stride(), comp_ind_no_sq.stride())
@torch._dynamo.config.patch(capture_scalar_outputs=True)
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
def test_unbacked_activation_specialized_in_inductor(self):
"""Test compilation with unbacked operations like nonzero."""
torch._dynamo.reset()
def fuzzed_program(arg_0, sentinel):
var_node_1 = arg_0
var_node_5 = torch.full((1, 2), -66, dtype=torch.int32)
var_node_6 = torch.full((1, 2), 77, dtype=torch.int64)
var_node_4 = torch.ops.aten.add(var_node_5, var_node_6)
var_node_7 = torch.full((1, 2), -64, dtype=torch.int32)
var_node_3 = torch.ops.aten.mul(var_node_4, var_node_7)
var_node_9 = torch.full((3, 4), False, dtype=torch.bool)
var_node_8 = torch.nonzero(var_node_9)
var_node_2 = torch.ops.aten.add(var_node_3, var_node_8)
var_node_0 = torch.ops.aten.div(var_node_1, var_node_2)
result = var_node_0 * sentinel
if result.is_complex():
result = result.real
return result
sentinel = torch.tensor(1.0, requires_grad=True)
arg_0 = torch.randint(0, 3, (1, 2), dtype=torch.int64)
args = (arg_0,) + (sentinel,)
result_original = fuzzed_program(*args)
compiled_program = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)
result_compiled = compiled_program(*args)
self.assertTrue(torch.allclose(result_original, result_compiled))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -41,7 +41,7 @@ from torch._subclasses import FakeTensor
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import is_sym_node
from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals
from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals, guard_or_true
from torch.fx.graph_module import GraphModule
from torch.fx.passes._tensorify_python_scalars import tensorify_python_scalars
from torch.multiprocessing.reductions import StorageWeakRef
@ -89,7 +89,6 @@ from .schemas import (
)
from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta
from .utils import (
_get_symint_hints,
contain_metadata_mutation_ops,
get_cuda_generator_meta_val,
make_boxed_func,
@ -1743,8 +1742,27 @@ def _aot_stage2b_bw_compile(
# Comparing ph_arg.stride() with real_stride directly may
# cause dynamic dimensions in ph_arg being specialized to static
# value. Using the hints to avoid that.
if _get_symint_hints(ph_arg.stride()) != real_stride:
# value. Using suppress_guards and guard_or_true to avoid that.
stride_different = False
fake_mode = detect_fake_mode()
suppress_ctx = (
fake_mode.shape_env.suppress_guards()
if fake_mode is not None and fake_mode.shape_env is not None
else nullcontext()
)
# Inductor can choose different strides for activations than
# what backward graph has. if we can't statically tell that
# strides are the same, we assume they are not.
with suppress_ctx:
for k in range(len(ph_arg.stride())):
# real_stride can't be symbolic.
if guard_or_true(ph_arg.stride()[k] != int(real_stride[k])):
stride_different = True
break
if stride_different:
# Note that here we use the stride of the real tensor to
# restride a FakeTensor. This does not cause trouble
# for dynamic shape since this code path only get
@ -2045,6 +2063,7 @@ def _aot_stage2b_compile_forward_or_inference(
- FunctionalizedRngRuntimeWrapper
- FakifiedOutWrapper
"""
# Validation
if not is_inference and num_fw_outs_saved_for_bw is None:
raise ValueError(