mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[FX] Ensure BC coverage for all of torch.fx.passes (#65081)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65081 Test Plan: Imported from OSS Reviewed By: jbschlosser, khabinov Differential Revision: D30967428 Pulled By: jamesr66a fbshipit-source-id: 2ff83da728dc469f086cf504e71b43396db612d8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
cf7409e184
commit
0559cb37cd
@ -5,8 +5,10 @@ import torch.fx
|
||||
import torch.nn as nn
|
||||
from torch.fx.graph import map_arg
|
||||
from .tools_common import NodeList, NodeSet
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@dataclass
|
||||
class Component:
|
||||
"""
|
||||
@ -32,6 +34,7 @@ class Component:
|
||||
gm: Optional[torch.fx.GraphModule] = None
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class HolderModule(nn.Module):
|
||||
"""
|
||||
HolderModule is used to copy all the attributes from original module to submodules
|
||||
@ -44,6 +47,7 @@ class HolderModule(nn.Module):
|
||||
self.add_module(k, v)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Splits a GraphModule using tags on its graph nodes. We honor the order of
|
||||
|
Reference in New Issue
Block a user