Compare commits

...

1 Commits

Author SHA1 Message Date
11b30e177d [Inductor] Fix resnet regression 2025-08-26 15:49:50 -07:00

View File

@ -2874,8 +2874,14 @@ class ExpandView(BaseView):
if new_size[i] == -1:
assert old_size[i] is not None
new_size[i] = old_size[i]
elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(old_size[i], 1), fallback_value=False
elif (
old_size[i] is None
or V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(old_size[i], 1), fallback_value=False
)
or V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(old_size[i], 0), fallback_value=False
)
):
pass
else:
@ -2884,9 +2890,10 @@ class ExpandView(BaseView):
# NB: new_size[i] == old_size[i] is expected to already be
# guarded because the meta formula was expected to have taught
# us this equality.
assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, (
"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
)
print(new_size)
print(old_size)
if not sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0:
breakpoint()
return new_size
@classmethod
@ -7773,7 +7780,10 @@ class FallbackKernel(ExternKernelAlloc):
@classmethod
def create(cls, kernel: _OpOverloads, *args: Any, **kwargs: Any) -> FallbackKernel:
"""Create an instance of FallbackKernel from an _OpOverloads"""
fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
fake_incorrect_kernels = (
aten._fused_moving_avg_obs_fq_helper_functional,
aten._fused_moving_avg_obs_fq_helper_functional.default,
)
if kernel not in fake_incorrect_kernels:
context = cast(AbstractContextManager[None], V.graph.fake_mode)
else: