[FX] Ensure BC coverage for all of torch.fx.passes (#65081)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65081

Test Plan: Imported from OSS

Reviewed By: jbschlosser, khabinov

Differential Revision: D30967428

Pulled By: jamesr66a

fbshipit-source-id: 2ff83da728dc469f086cf504e71b43396db612d8
This commit is contained in:
James Reed
2021-09-17 09:26:37 -07:00
committed by Facebook GitHub Bot
parent cf7409e184
commit 0559cb37cd
10 changed files with 50 additions and 22 deletions

View File

@ -5,8 +5,10 @@ import torch.fx
import torch.nn as nn
from torch.fx.graph import map_arg
from .tools_common import NodeList, NodeSet
from torch.fx._compatibility import compatibility
@compatibility(is_backward_compatible=False)
@dataclass
class Component:
"""
@ -32,6 +34,7 @@ class Component:
gm: Optional[torch.fx.GraphModule] = None
@compatibility(is_backward_compatible=False)
class HolderModule(nn.Module):
"""
HolderModule is used to copy all the attributes from original module to submodules
@ -44,6 +47,7 @@ class HolderModule(nn.Module):
self.add_module(k, v)
@compatibility(is_backward_compatible=False)
def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphModule:
"""
Splits a GraphModule using tags on its graph nodes. We honor the order of