[Dynamo] Fix arg ordering in tf modes (#159707)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159707
Approved by: https://github.com/zou3519
This commit is contained in:
Michael Lazos
2025-08-04 11:19:12 -07:00
committed by PyTorch MergeBot
parent e273ff028a
commit 9f8cfe7476
6 changed files with 50 additions and 50 deletions

View File

@ -717,10 +717,10 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
self.assertExpectedInline(
munge_shape_guards(record.getMessage()),
"""\
+- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
+- __SHAPE_GUARD__: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
+- __SHAPE_GUARD__: ((2*L['z'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in #
+- __SHAPE_GUARD__: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
+- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
+- __SHAPE_GUARD__: L['z'].size()[0] == L['y'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
+- __SHAPE_GUARD__: ((2*L['y'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in #
+- __SHAPE_GUARD__: 2 <= L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
)
@make_logging_test(guards=True)

View File

@ -319,11 +319,11 @@ class StructuredTraceTest(TestCase):
{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_storage": {"id": 1, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_y_": [1000, 1000], "l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "l_y_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}

View File

@ -8913,7 +8913,7 @@ def forward(self, p_lin_weight, p_lin_bias, x):
self.assertExpectedInline(
str(ep_decompose_linear.graph_module.code).strip(),
"""\
def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_bias, c_linear_weight, x, y):
def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None
permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None

View File

@ -185,8 +185,7 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
# Custom comment for test
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None
foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = foo_default = None
return ()""", # noqa: B950
ignore_comments=True,
)
@ -247,7 +246,7 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = None
foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None
getitem_4: "f32[3][1]cpu" = foo_default[0]
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
return (getitem_4, getitem_5)""", # noqa: B950
@ -334,9 +333,8 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
# Custom comment for test
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \
arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = \
arg3_1 = arg0_1 = arg1_1 = foo_default = None
return ()""",
ignore_comments=True,
)
@ -416,10 +414,10 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "Sym(s72)", arg1_1: "f32[s72][1]cpu", arg2_1: "f32[s72][1]cpu", arg3_1: "f32[s72][1]cpu", arg4_1: "f32[s72][1]cpu", arg5_1: "f32[s72][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None
copy_: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
copy__1: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77][1]cpu", arg3_1: "f32[s77][1]cpu", arg4_1: "f32[s77][1]cpu", arg5_1: "f32[s77][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg1_1, [arg4_1, arg5_1], arg2_1, 2, arg3_1); arg4_1 = arg5_1 = arg3_1 = foo_default = None
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -429,9 +427,9 @@ def forward(self, arg0_1: "Sym(s72)", arg1_1: "f32[s72][1]cpu", arg2_1: "f32[s72
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg3_1 = arg4_1 = arg2_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -523,11 +521,11 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = None
foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg3_1 = arg4_1 = arg2_1 = None
getitem_4: "f32[3][1]cpu" = foo_default[0]
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
return (getitem_4, getitem_5)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -581,13 +579,13 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1])
getitem_1: "f32[s17][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[s17][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg2_1])
getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[s77][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[s77][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
copy__1: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_2); arg2_1 = getitem_2 = copy__1 = None
return (add,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -597,12 +595,12 @@ def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17
graph_aot,
"""\
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1])
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg0_1, arg1_1])
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None
return (add,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -613,11 +611,11 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None
add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg1_1, arg2_1); foo_default = None
add: "f32[s77][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg2_1)
copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return (add,)""",
ignore_comments=True,
ignore_empty_lines=True,
@ -627,8 +625,8 @@ def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17
graph_inductor,
"""\
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1)
foo_default = torch.ops.mylib.foo.default(arg0_1, arg1_1); foo_default = None
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
return (add,)""",
@ -843,11 +841,11 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
graph_aot,
"""\
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1])
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg0_1, arg1_1])
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -859,7 +857,7 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
graph_inductor,
"""\
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None
foo_default = torch.ops.mylib.foo.default(arg0_1, arg1_1); foo_default = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
return ()""", # noqa: B950
@ -979,8 +977,8 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = arg3_1 = arg1_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,

View File

@ -1333,8 +1333,8 @@ class TestMaxAutotune(TestCase):
if not TEST_WITH_ROCM:
expected = """{
'input_nodes':[
"[[s77,s17],[s17,1],torch.float32,device(type='cuda',index=0),0]",
"[[s17,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"],
"[[s77,s27],[s27,1],torch.float32,device(type='cuda',index=0),0]",
"[[s27,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"],
'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94],
'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0,
'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'ALLOW_TF32':True,

View File

@ -389,7 +389,7 @@ def _flatten_vts(vts):
output = []
while vts:
vt = vts.pop()
vt = vts.popleft()
if not vt.is_realized() and vt.peek_type() in (dict, list, tuple):
vt.realize()
@ -397,8 +397,10 @@ def _flatten_vts(vts):
if vt.is_realized():
if isinstance(vt, ListVariable):
vts.extend(vt.items)
continue
elif isinstance(vt, ConstDictVariable):
vts.extend(vt.items.values())
continue
output.append(vt)