Fix XPU CI UT test_circular_dependencies (#158189)

# Motivation
fix https://github.com/pytorch/pytorch/issues/110040

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158189
Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
Yu, Guangye
2025-07-13 00:09:24 +00:00
committed by PyTorch MergeBot
parent 5aee022d8b
commit c68af9af1b
3 changed files with 7 additions and 3 deletions

View File

@ -2369,6 +2369,7 @@ class TestImports(TestCase):
"torch.distributed.benchmarks", # depends on RPC and DDP Optim "torch.distributed.benchmarks", # depends on RPC and DDP Optim
"torch.distributed.examples", # requires CUDA and torchvision "torch.distributed.examples", # requires CUDA and torchvision
"torch.distributed.tensor.examples", # example scripts "torch.distributed.tensor.examples", # example scripts
"torch.distributed._tools.sac_ilp", # depends on pulp
"torch.csrc", # files here are devtools, not part of torch "torch.csrc", # files here are devtools, not part of torch
"torch.include", # torch include files after install "torch.include", # torch include files after install
] ]

View File

@ -6,7 +6,6 @@ import sympy
import torch import torch
from .ir import Pointwise, ShapeAsConstantBuffer, TensorBox from .ir import Pointwise, ShapeAsConstantBuffer, TensorBox
from .lowering import fallback_handler, is_integer_type, register_lowering
from .virtualized import ops from .virtualized import ops
@ -109,6 +108,9 @@ def jagged_idx_to_dense_idx(
def register_jagged_ops(): def register_jagged_ops():
# Avoid circular import by importing here
from .lowering import fallback_handler, is_integer_type, register_lowering
# pyre-ignore[56] # pyre-ignore[56]
@register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default) @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default)
def _jagged_to_padded_dense_forward( def _jagged_to_padded_dense_forward(

View File

@ -13,6 +13,7 @@ from torch.utils._ordered_set import OrderedSet
from . import ir from . import ir
from .exc import SubgraphLoweringException from .exc import SubgraphLoweringException
from .graph import GraphLowering
from .ops_handler import SimpleCSEHandler from .ops_handler import SimpleCSEHandler
from .virtualized import ops, V, WrapperHandler from .virtualized import ops, V, WrapperHandler
@ -32,7 +33,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
""" """
graph_outputs: Optional[list[ir.IRNode]] graph_outputs: Optional[list[ir.IRNode]]
root_graph: torch._inductor.graph.GraphLowering root_graph: GraphLowering
_current_op: Optional[TargetType] _current_op: Optional[TargetType]
# For backwards of buffer_grads with scatters we allow mutations # For backwards of buffer_grads with scatters we allow mutations
allowed_mutations: Optional[OrderedSet[OpOverload]] allowed_mutations: Optional[OrderedSet[OpOverload]]
@ -43,7 +44,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
def __init__( def __init__(
self, self,
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
root_graph_lowering: torch._inductor.graph.GraphLowering, root_graph_lowering: GraphLowering,
allowed_mutations: Optional[OrderedSet[OpOverload]] = None, allowed_mutations: Optional[OrderedSet[OpOverload]] = None,
additional_lowerings: Optional[LoweringDict] = None, additional_lowerings: Optional[LoweringDict] = None,
) -> None: ) -> None: