mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix derived dim bugs in ep.run_decomp (#123326)
Differential Revision: [D55730289](https://our.internmc.facebook.com/intern/diff/D55730289) Pull Request resolved: https://github.com/pytorch/pytorch/pull/123326 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
cd3c1132a9
commit
4322874282
@ -465,6 +465,16 @@ class ExportedProgram:
|
||||
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
|
||||
return module
|
||||
|
||||
def _num_lifted_params_buffers(self):
|
||||
return next(
|
||||
(
|
||||
i
|
||||
for i, s in enumerate(self._graph_signature.input_specs)
|
||||
if s.kind == InputKind.USER_INPUT
|
||||
),
|
||||
len(self._graph_signature.input_specs),
|
||||
)
|
||||
|
||||
@_disable_prexisiting_fake_mode
|
||||
def run_decompositions(
|
||||
self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None
|
||||
@ -594,7 +604,11 @@ class ExportedProgram:
|
||||
# (The node-level meta is addressed above.)
|
||||
gm.meta.update(self.graph_module.meta)
|
||||
|
||||
new_range_constraints = _get_updated_range_constraints(gm)
|
||||
new_range_constraints = _get_updated_range_constraints(
|
||||
gm,
|
||||
self._num_lifted_params_buffers(),
|
||||
pytree.tree_leaves(self.example_inputs),
|
||||
)
|
||||
|
||||
constants = lift_constants_pass(gm, new_graph_signature, ConstantAttrMap())
|
||||
for k, v in constants.items():
|
||||
@ -613,7 +627,6 @@ class ExportedProgram:
|
||||
verifier=self.verifier,
|
||||
constants=self.constants,
|
||||
)
|
||||
|
||||
if len(new_range_constraints) > 0:
|
||||
exported_program = exported_program._transform_do_not_use(
|
||||
_AddRuntimeAssertionsForInlineConstraintsPass(new_range_constraints)
|
||||
@ -700,7 +713,11 @@ class ExportedProgram:
|
||||
self.graph_signature, transformed_gm
|
||||
),
|
||||
state_dict=self.state_dict,
|
||||
range_constraints=_get_updated_range_constraints(transformed_gm),
|
||||
range_constraints=_get_updated_range_constraints(
|
||||
transformed_gm,
|
||||
self._num_lifted_params_buffers(),
|
||||
pytree.tree_leaves(self.example_inputs),
|
||||
),
|
||||
module_call_graph=copy.deepcopy(self._module_call_graph),
|
||||
example_inputs=self.example_inputs,
|
||||
verifier=self.verifier,
|
||||
@ -744,7 +761,7 @@ class ExportedProgram:
|
||||
|
||||
|
||||
def _get_updated_range_constraints(
|
||||
gm: torch.fx.GraphModule,
|
||||
gm: torch.fx.GraphModule, num_lifted: int, example_inputs: List[Any]
|
||||
) -> "Dict[sympy.Symbol, Any]":
|
||||
def get_shape_env(gm):
|
||||
vals = [
|
||||
@ -756,18 +773,21 @@ def _get_updated_range_constraints(
|
||||
|
||||
fake_mode = detect_fake_mode(vals)
|
||||
if fake_mode is not None:
|
||||
return fake_mode.shape_env
|
||||
return fake_mode.shape_env, fake_mode
|
||||
for v in vals:
|
||||
if isinstance(v, torch.SymInt):
|
||||
return v.node.shape_env
|
||||
return v.node.shape_env, fake_mode
|
||||
|
||||
shape_env = get_shape_env(gm)
|
||||
shape_env, fake_mode = get_shape_env(gm)
|
||||
if shape_env is None:
|
||||
return {}
|
||||
|
||||
from torch.export.dynamic_shapes import _process_constraints
|
||||
|
||||
range_constraints = _process_constraints(fake_mode, gm, num_lifted, example_inputs)
|
||||
|
||||
range_constraints = {
|
||||
k: v
|
||||
for k, v in shape_env.var_to_range.items()
|
||||
if k not in shape_env.replacements
|
||||
k: v for k, v in range_constraints.items() if k not in shape_env.replacements
|
||||
}
|
||||
# Only when we have an unbacked symint, and it's used as constructor inputs,
|
||||
# runtime_var_to_range will make a difference compated to var_to_range.
|
||||
|
Reference in New Issue
Block a user