Fix derived dim bugs in ep.run_decomp (#123326)

Differential Revision: [D55730289](https://our.internmc.facebook.com/intern/diff/D55730289)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123326
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2024-04-09 23:38:17 +00:00
committed by PyTorch MergeBot
parent cd3c1132a9
commit 4322874282
2 changed files with 31 additions and 14 deletions

View File

@ -465,6 +465,16 @@ class ExportedProgram:
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
return module
def _num_lifted_params_buffers(self):
return next(
(
i
for i, s in enumerate(self._graph_signature.input_specs)
if s.kind == InputKind.USER_INPUT
),
len(self._graph_signature.input_specs),
)
@_disable_prexisiting_fake_mode
def run_decompositions(
self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None
@ -594,7 +604,11 @@ class ExportedProgram:
# (The node-level meta is addressed above.)
gm.meta.update(self.graph_module.meta)
new_range_constraints = _get_updated_range_constraints(gm)
new_range_constraints = _get_updated_range_constraints(
gm,
self._num_lifted_params_buffers(),
pytree.tree_leaves(self.example_inputs),
)
constants = lift_constants_pass(gm, new_graph_signature, ConstantAttrMap())
for k, v in constants.items():
@ -613,7 +627,6 @@ class ExportedProgram:
verifier=self.verifier,
constants=self.constants,
)
if len(new_range_constraints) > 0:
exported_program = exported_program._transform_do_not_use(
_AddRuntimeAssertionsForInlineConstraintsPass(new_range_constraints)
@ -700,7 +713,11 @@ class ExportedProgram:
self.graph_signature, transformed_gm
),
state_dict=self.state_dict,
range_constraints=_get_updated_range_constraints(transformed_gm),
range_constraints=_get_updated_range_constraints(
transformed_gm,
self._num_lifted_params_buffers(),
pytree.tree_leaves(self.example_inputs),
),
module_call_graph=copy.deepcopy(self._module_call_graph),
example_inputs=self.example_inputs,
verifier=self.verifier,
@ -744,7 +761,7 @@ class ExportedProgram:
def _get_updated_range_constraints(
gm: torch.fx.GraphModule,
gm: torch.fx.GraphModule, num_lifted: int, example_inputs: List[Any]
) -> "Dict[sympy.Symbol, Any]":
def get_shape_env(gm):
vals = [
@ -756,18 +773,21 @@ def _get_updated_range_constraints(
fake_mode = detect_fake_mode(vals)
if fake_mode is not None:
return fake_mode.shape_env
return fake_mode.shape_env, fake_mode
for v in vals:
if isinstance(v, torch.SymInt):
return v.node.shape_env
return v.node.shape_env, fake_mode
shape_env = get_shape_env(gm)
shape_env, fake_mode = get_shape_env(gm)
if shape_env is None:
return {}
from torch.export.dynamic_shapes import _process_constraints
range_constraints = _process_constraints(fake_mode, gm, num_lifted, example_inputs)
range_constraints = {
k: v
for k, v in shape_env.var_to_range.items()
if k not in shape_env.replacements
k: v for k, v in range_constraints.items() if k not in shape_env.replacements
}
# Only when we have an unbacked symint, and it's used as constructor inputs,
# runtime_var_to_range will make a difference compated to var_to_range.