Allow symbols to reach conv_layout stride argument (#125829)

#Fix https://github.com/pytorch/pytorch/issues/125638

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125829
Approved by: https://github.com/anijain2305
This commit is contained in:
laithsakka
2024-05-09 09:40:02 -07:00
committed by PyTorch MergeBot
parent fcbf2b61e6
commit 013722bcb8

View File

@ -5,6 +5,7 @@ import logging
from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict
import torch
from .. import config, ir
from ..lowering import (
@ -245,11 +246,11 @@ def conv_layout(
ir.ir_node_to_tensor(x, guard_shape=True),
ir.ir_node_to_tensor(weight, guard_shape=True),
ir.ir_node_to_tensor(bias, guard_shape=True),
stride,
tuple(V.graph.sizevars.size_hint(p) for p in padding), # type: ignore[arg-type]
V.graph.sizevars.size_hints(stride), # type: ignore[arg-type]
V.graph.sizevars.size_hints(padding), # type: ignore[arg-type]
dilation,
transposed,
tuple(V.graph.sizevars.size_hint(p) for p in output_padding), # type: ignore[arg-type]
V.graph.sizevars.size_hints(output_padding), # type: ignore[arg-type]
groups,
)
sizes = ir.convert_shape_to_inductor(output.size())