diff --git a/test/test_testing.py b/test/test_testing.py index d24329c825ba..a69fb8ac9532 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2369,6 +2369,7 @@ class TestImports(TestCase): "torch.distributed.benchmarks", # depends on RPC and DDP Optim "torch.distributed.examples", # requires CUDA and torchvision "torch.distributed.tensor.examples", # example scripts + "torch.distributed._tools.sac_ilp", # depends on pulp "torch.csrc", # files here are devtools, not part of torch "torch.include", # torch include files after install ] diff --git a/torch/_inductor/jagged_lowerings.py b/torch/_inductor/jagged_lowerings.py index 5d4e17ed538a..83848c5a9612 100644 --- a/torch/_inductor/jagged_lowerings.py +++ b/torch/_inductor/jagged_lowerings.py @@ -6,7 +6,6 @@ import sympy import torch from .ir import Pointwise, ShapeAsConstantBuffer, TensorBox -from .lowering import fallback_handler, is_integer_type, register_lowering from .virtualized import ops @@ -109,6 +108,9 @@ def jagged_idx_to_dense_idx( def register_jagged_ops(): + # Avoid circular import by importing here + from .lowering import fallback_handler, is_integer_type, register_lowering + # pyre-ignore[56] @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default) def _jagged_to_padded_dense_forward( diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index d79923857359..3c8116d402c9 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -13,6 +13,7 @@ from torch.utils._ordered_set import OrderedSet from . import ir from .exc import SubgraphLoweringException +from .graph import GraphLowering from .ops_handler import SimpleCSEHandler from .virtualized import ops, V, WrapperHandler @@ -32,7 +33,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter): """ graph_outputs: Optional[list[ir.IRNode]] - root_graph: torch._inductor.graph.GraphLowering + root_graph: GraphLowering _current_op: Optional[TargetType] # For backwards of buffer_grads with scatters we allow mutations allowed_mutations: Optional[OrderedSet[OpOverload]] @@ -43,7 +44,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter): def __init__( self, gm: torch.fx.GraphModule, - root_graph_lowering: torch._inductor.graph.GraphLowering, + root_graph_lowering: GraphLowering, allowed_mutations: Optional[OrderedSet[OpOverload]] = None, additional_lowerings: Optional[LoweringDict] = None, ) -> None: