mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Sets `prefer_deferred_runtime_asserts_over_guards=True` for export, so any guards emitted from `SymNode.expect_true` (for example, guards that are implicitly required to be true for an op to succeed) won't lead to constraint violations. Instead these should appear in the graph as runtime asserts, or potentially as replacement expressions for placeholder shapes. For example, this reshape op should emit s0 * s1 = s2, deferred as a runtime assert. ``` x = torch.randn(4, 8) # [s0, s1] y = torch.randn(32) # [s2] out = x.reshape(-1) + y # this emits Eq(s0 * s1, s2), and we represent y's shape as [s0*s1] in the graph. ``` However, other complex guards can still cause export to fail, for instance guards emitted from `SymNode.guard_bool/guard_size_oblivious` (e.g. explicit if-else conditions in user code or lower-level op implementations hit during tracing) can still raise constraint violations. These can be deferred with `allow_complex_guards_as_runtime_asserts=True`. We don't yet make this default, because while this makes export more likely to succeed, it results in non-trivial asserts being emitted that often represent specialization to a variant of the op, or checks related to 0/1 specialization. We also remove forced specializations for export and kill the `_disable_forced_specializations` flag - now any guard we can't express with Dims/DerivedDims either are handled with Hybrid SymInts, or should be resolved with rewriting or deferring. Follow up: Currently, `ShapeEnv._set_replacement()` is called for complex equality expressions (e.g. s2 -> s0*s1 in the example above), and the ExportedProgram stores `s0*s1` in the input placeholder. This isn't checked for validity when the program is run, so an option is to avoid replacement and/or runtime assert on equality. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130775 Approved by: https://github.com/avikchaudhuri