Compare commits

...

15 Commits

Author SHA1 Message Date
896fa02589 Update on "address DDE in matmul decomp"
Address https://github.com/pytorch/pytorch/issues/165081

[ghstack-poisoned]
2025-10-29 15:32:55 -07:00
6321680a77 Update base for Update on "address DDE in matmul decomp"
Address https://github.com/pytorch/pytorch/issues/165081

[ghstack-poisoned]
2025-10-29 15:32:55 -07:00
2763e7216b Update on "address DDE in matmul decomp"
Address https://github.com/pytorch/pytorch/issues/165081

[ghstack-poisoned]
2025-10-29 12:16:41 -07:00
7c4d3e9a8c Update base for Update on "address DDE in matmul decomp"
Address https://github.com/pytorch/pytorch/issues/165081

[ghstack-poisoned]
2025-10-29 12:16:41 -07:00
78f0718a5a Update on "address DDE in matmul decomp"
Address https://github.com/pytorch/pytorch/issues/165081

[ghstack-poisoned]
2025-10-29 10:42:41 -07:00
b07f7cdd8c address DDE in matmul decomp
[ghstack-poisoned]
2025-10-29 09:44:23 -07:00
0c31b1d46c Update on "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. "
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-29 09:18:34 -07:00
0069cf49a0 Update on "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. "
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 11:23:13 -07:00
121b28ee26 Update on "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. "
[ghstack-poisoned]
2025-10-27 23:17:44 -07:00
404bae6094 Update on "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. "
[ghstack-poisoned]
2025-10-27 17:47:37 -07:00
2d5bf015b6 Update on "WIP: fix comparing inductor strides vs bw graph strides for bw compile"
[ghstack-poisoned]
2025-10-27 13:33:48 -07:00
20233b08ce Update on "WIP: fix comparing inductor strides vs bw graph strides for bw compile"
[ghstack-poisoned]
2025-10-27 09:16:43 -07:00
4fc5aed281 Update on "WIP: fix comparing inductor strides vs bw graph strides for bw compile"
[ghstack-poisoned]
2025-10-26 18:21:39 -07:00
af34c67082 Update on "WIP: fix comparing inductor strides vs bw graph strides for bw compile"
[ghstack-poisoned]
2025-10-26 18:16:51 -07:00
6c1a16f03c WIP: fix comparing inductor strides vs bw graph strides for bw compile
[ghstack-poisoned]
2025-10-26 15:28:50 -07:00
2 changed files with 356 additions and 3 deletions

View File

@ -1395,6 +1395,357 @@ class HasDecompTest(TestCase):
check_case(groups=1, C_in=8, C_out=12) # groups=1 bigger
check_case(groups=2, C_in=8, C_out=12) # grouped conv
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
def test_mm_decompose_mm_dde(self):
def fuzzed_program(
arg_0,
arg_1,
arg_2,
arg_3,
arg_4,
arg_5,
arg_6,
arg_7,
arg_8,
arg_9,
arg_10,
arg_11,
arg_12,
arg_13,
arg_14,
arg_15,
arg_16,
arg_17,
arg_18,
sentinel,
):
var_node_6 = (
arg_0 # size=(9, 9, 9), stride=(81, 9, 1), dtype=float64, device=cuda
)
var_node_7 = (
arg_1 # size=(9, 9, 11), stride=(99, 11, 1), dtype=float64, device=cuda
)
var_node_5 = torch.matmul(
var_node_6.to(torch.float64), var_node_7.to(torch.float64)
) # size=(9, 9, 11), stride=(99, 11, 1), dtype=float64, device=cuda
var_node_9 = torch.full(
(9, 11, 12), 1.5758497316910556, dtype=torch.float64
) # size=(9, 11, 12), stride=(132, 12, 1), dtype=float64, device=cuda
var_node_10 = (
arg_2 # size=(9, 12, 8), stride=(96, 8, 1), dtype=float64, device=cuda
)
var_node_8 = torch.matmul(
var_node_9.to(torch.float64), var_node_10.to(torch.float64)
) # size=(9, 11, 8), stride=(88, 8, 1), dtype=float64, device=cuda
var_node_4 = torch.matmul(
var_node_5.to(torch.float64), var_node_8.to(torch.float64)
) # size=(9, 9, 8), stride=(72, 8, 1), dtype=float64, device=cuda
var_node_13 = arg_3 # size=(9, 8, 13), stride=(104, 13, 1), dtype=float64, device=cuda
var_node_14 = (
arg_4 # size=(9, 13, 7), stride=(91, 7, 1), dtype=float64, device=cuda
)
var_node_12 = torch.matmul(
var_node_13.to(torch.float64), var_node_14.to(torch.float64)
) # size=(9, 8, 7), stride=(56, 7, 1), dtype=float64, device=cuda
var_node_15 = arg_5 # size=(9, 7, 16), stride=(112, 16, 1), dtype=float64, device=cuda
var_node_11 = torch.matmul(
var_node_12.to(torch.float64), var_node_15.to(torch.float64)
) # size=(9, 8, 16), stride=(128, 16, 1), dtype=float64, device=cuda
var_node_3 = torch.matmul(
var_node_4.to(torch.float64), var_node_11.to(torch.float64)
) # size=(9, 9, 16), stride=(144, 16, 1), dtype=float64, device=cuda
var_node_17 = arg_6 # size=(9, 16, 12), stride=(192, 12, 1), dtype=float64, device=cuda
var_node_18 = arg_7 # size=(9, 12, 11), stride=(132, 11, 1), dtype=float64, device=cuda
var_node_16 = torch.matmul(
var_node_17.to(torch.float64), var_node_18.to(torch.float64)
) # size=(9, 16, 11), stride=(176, 11, 1), dtype=float64, device=cuda
var_node_2 = torch.matmul(
var_node_3.to(torch.float64), var_node_16.to(torch.float64)
) # size=(9, 9, 11), stride=(99, 11, 1), dtype=float64, device=cuda
var_node_23 = torch.full(
(156, 8), -0.5249394453404403, dtype=torch.float64
) # size=(156, 8), stride=(8, 1), dtype=float64, device=cuda
var_node_24 = torch.full(
(8, 9), 0.9331226188585692, dtype=torch.float64
) # size=(8, 9), stride=(9, 1), dtype=float64, device=cuda
var_node_22 = torch.matmul(
var_node_23.to(torch.float64), var_node_24.to(torch.float64)
) # size=(156, 9), stride=(9, 1), dtype=float64, device=cuda
var_node_26 = torch.full(
(9, 13), -0.9276381954691514, dtype=torch.float64
) # size=(9, 13), stride=(13, 1), dtype=float64, device=cuda
var_node_27 = torch.full(
(13, 16), 0.024752238943232543, dtype=torch.float64
) # size=(13, 16), stride=(16, 1), dtype=float64, device=cuda
var_node_25 = torch.matmul(
var_node_26.to(torch.float64), var_node_27.to(torch.float64)
) # size=(9, 16), stride=(16, 1), dtype=float64, device=cuda
var_node_21 = torch.matmul(
var_node_22.to(torch.float64), var_node_25.to(torch.float64)
) # size=(156, 16), stride=(16, 1), dtype=float64, device=cuda
var_node_29 = arg_8
_x_nz = torch.zeros(
(9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
dtype=torch.bool,
device=var_node_29.device,
)
_x_nz_flat = _x_nz.reshape(-1)
_x_nz_flat[:9] = True
var_node_28 = torch.nonzero(
_x_nz
) # size=(9, 11), stride=(11, 1), dtype=int64, device=cuda
var_node_20 = torch.nn.functional.embedding(
torch.clamp(var_node_28.to(torch.int64), 0, var_node_21.size(0) - 1),
var_node_21,
) # size=(9, 11, 16), stride=(176, 16, 1), dtype=float64, device=cuda
var_node_33 = torch.full(
(9, 16, 5), 1.0707914920634904, dtype=torch.float64
) # size=(9, 16, 5), stride=(80, 5, 1), dtype=float64, device=cuda
var_node_34 = torch.full(
(9, 5, 10), -0.44934093079047227, dtype=torch.float64
) # size=(9, 5, 10), stride=(50, 10, 1), dtype=float64, device=cuda
var_node_32 = torch.matmul(
var_node_33.to(torch.float64), var_node_34.to(torch.float64)
) # size=(9, 16, 10), stride=(160, 10, 1), dtype=float64, device=cuda
var_node_36 = (
arg_9 # size=(9, 10, 1), stride=(10, 1, 1), dtype=float64, device=cuda
)
var_node_37 = torch.full(
(9, 1, 11), -1.874293687140311, dtype=torch.float64
) # size=(9, 1, 11), stride=(11, 11, 1), dtype=float64, device=cuda
var_node_35 = torch.matmul(
var_node_36.to(torch.float64), var_node_37.to(torch.float64)
) # size=(9, 10, 11), stride=(110, 11, 1), dtype=float64, device=cuda
var_node_31 = torch.matmul(
var_node_32.to(torch.float64), var_node_35.to(torch.float64)
) # size=(9, 16, 11), stride=(176, 11, 1), dtype=float64, device=cuda
var_node_40 = torch.full(
(990, 2), 0.4084376380351558, dtype=torch.float64
) # size=(990, 2), stride=(2, 1), dtype=float64, device=cuda
var_node_41 = torch.full(
(2,), 0.982671965550022, dtype=torch.float64
) # size=(2,), stride=(1,), dtype=float64, device=cuda
var_node_39 = torch.matmul(
var_node_40.to(torch.float64), var_node_41.to(torch.float64)
) # size=(990,), stride=(1,), dtype=float64, device=cuda
var_node_38 = torch.reshape(
var_node_39, [9, 11, 10]
) # size=(9, 11, 10), stride=(110, 10, 1), dtype=float64, device=cuda
var_node_30 = torch.matmul(
var_node_31.to(torch.float64), var_node_38.to(torch.float64)
) # size=(9, 16, 10), stride=(160, 10, 1), dtype=float64, device=cuda
var_node_19 = torch.matmul(
var_node_20.to(torch.float64), var_node_30.to(torch.float64)
) # size=(9, 11, 10), stride=(110, 10, 1), dtype=float64, device=cuda
var_node_1 = torch.matmul(
var_node_2.to(torch.float64), var_node_19.to(torch.float64)
) # size=(9, 9, 10), stride=(90, 10, 1), dtype=float64, device=cuda
var_node_47 = arg_10 # size=(9, 10, 15), stride=(150, 15, 1), dtype=float64, device=cuda
var_node_48 = torch.full(
(9, 15, 2), -0.3349339402390618, dtype=torch.float64
) # size=(9, 15, 2), stride=(30, 2, 1), dtype=float64, device=cuda
var_node_46 = torch.matmul(
var_node_47.to(torch.float64), var_node_48.to(torch.float64)
) # size=(9, 10, 2), stride=(20, 2, 1), dtype=float64, device=cuda
var_node_50 = (
arg_11 # size=(9, 2, 7), stride=(14, 7, 1), dtype=float64, device=cuda
)
var_node_51 = (
arg_12 # size=(9, 7, 2), stride=(14, 2, 1), dtype=float64, device=cuda
)
var_node_49 = torch.matmul(
var_node_50.to(torch.float64), var_node_51.to(torch.float64)
) # size=(9, 2, 2), stride=(4, 2, 1), dtype=float64, device=cuda
var_node_45 = torch.matmul(
var_node_46.to(torch.float64), var_node_49.to(torch.float64)
) # size=(9, 10, 2), stride=(20, 2, 1), dtype=float64, device=cuda
var_node_52 = torch.full(
(9, 2, 1), -0.4046675639434615, dtype=torch.float64
) # size=(9, 2, 1), stride=(2, 1, 1), dtype=float64, device=cuda
var_node_44 = torch.matmul(
var_node_45.to(torch.float64), var_node_52.to(torch.float64)
) # size=(9, 10, 1), stride=(10, 1, 1), dtype=float64, device=cuda
var_node_56 = (
arg_13 # size=(9, 1, 1), stride=(1, 1, 1), dtype=float64, device=cuda
)
var_node_55 = torch.nn.functional.rms_norm(
var_node_56.to(torch.float64), (1,)
) # size=(9, 1, 1), stride=(1, 1, 1), dtype=float64, device=cuda
var_node_57 = torch.full(
(9, 1, 8), 0.17877664640931384, dtype=torch.float64
) # size=(9, 1, 8), stride=(8, 8, 1), dtype=float64, device=cuda
var_node_54 = torch.matmul(
var_node_55.to(torch.float64), var_node_57.to(torch.float64)
) # size=(9, 1, 8), stride=(8, 8, 1), dtype=float64, device=cuda
var_node_60 = arg_14 # size=(9, 8, 10), stride=(80, 10, 1), dtype=float64, device=cuda
var_node_61 = torch.full(
(9, 10, 6), 0.43614806380221494, dtype=torch.float64
) # size=(9, 10, 6), stride=(60, 6, 1), dtype=float64, device=cuda
var_node_59 = torch.matmul(
var_node_60.to(torch.float64), var_node_61.to(torch.float64)
) # size=(9, 8, 6), stride=(48, 6, 1), dtype=float64, device=cuda
var_node_63 = (
arg_15 # size=(9, 6, 3), stride=(18, 3, 1), dtype=float64, device=cuda
)
var_node_64 = torch.full(
(9, 3, 8), -0.042774422041922854, dtype=torch.float64
) # size=(9, 3, 8), stride=(24, 8, 1), dtype=float64, device=cuda
var_node_62 = torch.matmul(
var_node_63.to(torch.float64), var_node_64.to(torch.float64)
) # size=(9, 6, 8), stride=(48, 8, 1), dtype=float64, device=cuda
var_node_58 = torch.matmul(
var_node_59.to(torch.float64), var_node_62.to(torch.float64)
) # size=(9, 8, 8), stride=(64, 8, 1), dtype=float64, device=cuda
var_node_53 = torch.matmul(
var_node_54.to(torch.float64), var_node_58.to(torch.float64)
) # size=(9, 1, 8), stride=(8, 8, 1), dtype=float64, device=cuda
var_node_43 = torch.matmul(
var_node_44.to(torch.float64), var_node_53.to(torch.float64)
) # size=(9, 10, 8), stride=(80, 8, 1), dtype=float64, device=cuda
var_node_68 = arg_16 # size=(9, 8, 16), stride=(128, 16, 1), dtype=float64, device=cuda
var_node_70 = torch.full(
(9, 16, 15), 0.24947808634496438, dtype=torch.float64
) # size=(9, 16, 15), stride=(240, 15, 1), dtype=float64, device=cuda
var_node_71 = torch.full(
(9, 15, 7), -0.09035245509773453, dtype=torch.float64
) # size=(9, 15, 7), stride=(105, 7, 1), dtype=float64, device=cuda
var_node_69 = torch.matmul(
var_node_70.to(torch.float64), var_node_71.to(torch.float64)
) # size=(9, 16, 7), stride=(112, 7, 1), dtype=float64, device=cuda
var_node_67 = torch.matmul(
var_node_68.to(torch.float64), var_node_69.to(torch.float64)
) # size=(9, 8, 7), stride=(56, 7, 1), dtype=float64, device=cuda
var_node_74 = torch.full(
(9, 7, 1), 0.05671950481832341, dtype=torch.float64
) # size=(9, 7, 1), stride=(7, 1, 1), dtype=float64, device=cuda
var_node_73 = torch.nn.functional.gelu(
var_node_74
) # size=(9, 7, 1), stride=(7, 1, 1), dtype=float64, device=cuda
var_node_76 = torch.full(
(9, 1, 2), -0.019912810353597852, dtype=torch.float64
) # size=(9, 1, 2), stride=(2, 2, 1), dtype=float64, device=cuda
var_node_77 = (
arg_17 # size=(9, 2, 7), stride=(14, 7, 1), dtype=float64, device=cuda
)
var_node_75 = torch.matmul(
var_node_76.to(torch.float64), var_node_77.to(torch.float64)
) # size=(9, 1, 7), stride=(7, 7, 1), dtype=float64, device=cuda
var_node_72 = torch.matmul(
var_node_73.to(torch.float64), var_node_75.to(torch.float64)
) # size=(9, 7, 7), stride=(49, 7, 1), dtype=float64, device=cuda
var_node_66 = torch.matmul(
var_node_67.to(torch.float64), var_node_72.to(torch.float64)
) # size=(9, 8, 7), stride=(56, 7, 1), dtype=float64, device=cuda
var_node_78 = arg_18 # size=(9, 7, 13), stride=(91, 13, 1), dtype=float64, device=cuda
var_node_65 = torch.matmul(
var_node_66.to(torch.float64), var_node_78.to(torch.float64)
) # size=(9, 8, 13), stride=(104, 13, 1), dtype=float64, device=cuda
var_node_42 = torch.matmul(
var_node_43.to(torch.float64), var_node_65.to(torch.float64)
) # size=(9, 10, 13), stride=(130, 13, 1), dtype=float64, device=cuda
var_node_0 = torch.matmul(
var_node_1.to(torch.float64), var_node_42.to(torch.float64)
) # size=(9, 9, 13), stride=(117, 13, 1), dtype=float64, device=cuda
# Ensure gradient computation by multiplying with sentinel and taking real part
result = var_node_0 * sentinel
if result.is_complex():
result = result.real
return result
# Sentinel tensor to ensure gradient computation
sentinel = torch.tensor(1.0, requires_grad=True)
arg_0 = torch.as_strided(
torch.randn(729).to(torch.float64), (9, 9, 9), (81, 9, 1)
)
arg_1 = torch.as_strided(
torch.randn(891).to(torch.float64), (9, 9, 11), (99, 11, 1)
)
arg_2 = torch.as_strided(
torch.randn(864).to(torch.float64), (9, 12, 8), (96, 8, 1)
)
arg_3 = torch.as_strided(
torch.randn(936).to(torch.float64), (9, 8, 13), (104, 13, 1)
)
arg_4 = torch.as_strided(
torch.randn(819).to(torch.float64), (9, 13, 7), (91, 7, 1)
)
arg_5 = torch.as_strided(
torch.randn(1008).to(torch.float64), (9, 7, 16), (112, 16, 1)
)
arg_6 = torch.as_strided(
torch.randn(1728).to(torch.float64), (9, 16, 12), (192, 12, 1)
)
arg_7 = torch.as_strided(
torch.randn(1188).to(torch.float64), (9, 12, 11), (132, 11, 1)
)
arg_8 = torch.as_strided(
torch.randint(0, 2, (1,), dtype=torch.int8).bool(),
(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
)
arg_9 = torch.as_strided(
torch.randn(90).to(torch.float64), (9, 10, 1), (10, 1, 1)
)
arg_10 = torch.as_strided(
torch.randn(1350).to(torch.float64), (9, 10, 15), (150, 15, 1)
)
arg_11 = torch.as_strided(
torch.randn(126).to(torch.float64), (9, 2, 7), (14, 7, 1)
)
arg_12 = torch.as_strided(
torch.randn(126).to(torch.float64), (9, 7, 2), (14, 2, 1)
)
arg_13 = torch.as_strided(
torch.randn(9).to(torch.float64), (9, 1, 1), (1, 1, 1)
)
arg_14 = torch.as_strided(
torch.randn(720).to(torch.float64), (9, 8, 10), (80, 10, 1)
)
arg_15 = torch.as_strided(
torch.randn(162).to(torch.float64), (9, 6, 3), (18, 3, 1)
)
arg_16 = torch.as_strided(
torch.randn(1152).to(torch.float64), (9, 8, 16), (128, 16, 1)
)
arg_17 = torch.as_strided(
torch.randn(126).to(torch.float64), (9, 2, 7), (14, 7, 1)
)
arg_18 = torch.as_strided(
torch.randn(819).to(torch.float64), (9, 7, 13), (91, 13, 1)
)
args = (
arg_0,
arg_1,
arg_2,
arg_3,
arg_4,
arg_5,
arg_6,
arg_7,
arg_8,
arg_9,
arg_10,
arg_11,
arg_12,
arg_13,
arg_14,
arg_15,
arg_16,
arg_17,
arg_18,
) + (sentinel,)
result_original = fuzzed_program(*args)
compiled_program = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)
result_compiled = compiled_program(*args)
# Both should succeed without NameError
self.assertTrue(
torch.allclose(result_original, result_compiled, rtol=1e-5, atol=1e-5)
)
if __name__ == "__main__":
run_tests()

View File

@ -4507,6 +4507,8 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
@aten.matmul.out.py_impl(DispatchKey.CompositeImplicitAutograd)
@out_wrapper(pass_is_out=True)
def matmul(tensor1, tensor2, *, is_out=False):
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
dim_tensor1 = tensor1.dim()
dim_tensor2 = tensor2.dim()
assert dim_tensor1 != 0 and dim_tensor2 != 0
@ -4575,11 +4577,11 @@ def matmul(tensor1, tensor2, *, is_out=False):
if (
dim_tensor1 == 3
and dim_tensor2 == 3
and batch_tensor1[0] != batch_tensor2[0]
and guard_or_true(batch_tensor1[0] != batch_tensor2[0])
):
if batch_tensor1[0] == 1 and tensor1.requires_grad:
if guard_or_false(batch_tensor1[0] == 1) and tensor1.requires_grad:
return matmul(tensor1.squeeze(0), tensor2)
if batch_tensor2[0] == 1 and tensor2.requires_grad:
if guard_or_false(batch_tensor2[0] == 1) and tensor2.requires_grad:
return matmul(tensor1, tensor2.squeeze(0))
# expand the batch portion (i.e. cut off matrix dimensions and expand rest)