mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145165 Approved by: https://github.com/bobrenjc93
		
			
				
	
	
		
			70 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			70 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import copy
 | |
| import typing
 | |
| 
 | |
| import torch
 | |
| from torch.export.exported_program import _decompose_exported_program
 | |
| 
 | |
| 
 | |
| def _copy_graph_module_and_signature(
 | |
|     ep: torch.fx.GraphModule,
 | |
| ) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]:
 | |
|     # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(),
 | |
|     # and this can break placeholder names in some particular cases.
 | |
|     # For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'.
 | |
|     # So we manually overwrite placeholder names by reading the old graph.
 | |
|     gm = copy.deepcopy(ep.graph_module)
 | |
|     new_graph_signature = copy.deepcopy(ep.graph_signature)
 | |
| 
 | |
|     # iterate over old/new graph modules
 | |
|     for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()):  # type: ignore[union-attr]
 | |
|         old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"]
 | |
|         new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"]
 | |
|         # iterate over placeholders
 | |
|         assert len(old_phs) == len(new_phs)
 | |
|         for old_node, new_node in zip(old_phs, new_phs):
 | |
|             new_node.name = old_node.name
 | |
| 
 | |
|     return gm, new_graph_signature  # type: ignore[return-value]
 | |
| 
 | |
| 
 | |
| def _remove_detach_pass(
 | |
|     gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature
 | |
| ) -> None:
 | |
|     with gm._set_replace_hook(sig.get_replace_hook()):
 | |
|         for node in list(reversed(gm.graph.nodes)):
 | |
|             if node.op != "call_function":
 | |
|                 continue
 | |
|             if (
 | |
|                 node.target == torch.ops.aten.detach.default
 | |
|                 and len(node.users) == 1
 | |
|                 and next(iter(node.users)).target == torch.ops.aten.detach.default
 | |
|             ):
 | |
|                 next(iter(node.users)).replace_all_uses_with(node)
 | |
| 
 | |
|     gm.graph.eliminate_dead_code()
 | |
|     gm.recompile()
 | |
| 
 | |
| 
 | |
| def _export_forward_backward(
 | |
|     ep: torch.export.ExportedProgram, joint_loss_index: int = 0
 | |
| ) -> torch.export.ExportedProgram:
 | |
|     """
 | |
|     WARNING: This API is highly unstable and will be subject to change in the future.
 | |
|     """
 | |
|     from torch._decomp import core_aten_decompositions
 | |
| 
 | |
|     ep = _decompose_exported_program(
 | |
|         ep,
 | |
|         cia_to_decomp={},
 | |
|         python_decomp_table=core_aten_decompositions(),
 | |
|         joint_loss_index=joint_loss_index,
 | |
|         # For serialization purpose, we don't want to decompose custom triton ops.
 | |
|         # If users would like to decompose custom triton ops, they could do it
 | |
|         # with run_decompositions() API.
 | |
|         decompose_custom_triton_ops=False,
 | |
|     )
 | |
|     gm, new_graph_signature = _copy_graph_module_and_signature(ep)
 | |
|     _remove_detach_pass(gm, new_graph_signature)
 | |
| 
 | |
|     return ep._update(gm, new_graph_signature)
 |