Files
pytorch/torch/_inductor/codegen
Mwiza Kunda ce97a5dcfa [Inductor] Restrict block analysis to only match integer dims and strides (#149615)
Restrict block analysis to only match dimension sizes and strides that are integers. E.g. `sympy` can match index expressions like  `ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))` with the candidate below that is invalid.
  ```python
match_expr = stride_mod0_*((xindex//(dim_mod1_*dim_mod2_*dim_mod3_*dim_mod4_))) + stride_mod1_*(ModularIndexing(xindex, dim_mod2_*dim_mod3_*dim_mod4_, dim_mod1_)) + stride_mod2_*(ModularIndexing(xindex, dim_mod3_*dim_mod4_, dim_mod2_)) + stride_mod3_*(ModularIndexing(xindex, dim_mod4_, dim_mod3_)) + stride_mod4_*(ModularIndexing(xindex, 1, dim_mod4_))
match={
      dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16,
       dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
     }
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149615
Approved by: https://github.com/blaine-rister
2025-06-24 22:43:12 +00:00
..