Fix XPU CI UT test_circular_dependencies (#158189)

# Motivation
fix https://github.com/pytorch/pytorch/issues/110040

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158189
Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
Yu, Guangye
2025-07-13 00:09:24 +00:00
committed by PyTorch MergeBot
parent 5aee022d8b
commit c68af9af1b
3 changed files with 7 additions and 3 deletions

View File

@ -2369,6 +2369,7 @@ class TestImports(TestCase):
"torch.distributed.benchmarks", # depends on RPC and DDP Optim
"torch.distributed.examples", # requires CUDA and torchvision
"torch.distributed.tensor.examples", # example scripts
"torch.distributed._tools.sac_ilp", # depends on pulp
"torch.csrc", # files here are devtools, not part of torch
"torch.include", # torch include files after install
]

View File

@ -6,7 +6,6 @@ import sympy
import torch
from .ir import Pointwise, ShapeAsConstantBuffer, TensorBox
from .lowering import fallback_handler, is_integer_type, register_lowering
from .virtualized import ops
@ -109,6 +108,9 @@ def jagged_idx_to_dense_idx(
def register_jagged_ops():
# Avoid circular import by importing here
from .lowering import fallback_handler, is_integer_type, register_lowering
# pyre-ignore[56]
@register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default)
def _jagged_to_padded_dense_forward(

View File

@ -13,6 +13,7 @@ from torch.utils._ordered_set import OrderedSet
from . import ir
from .exc import SubgraphLoweringException
from .graph import GraphLowering
from .ops_handler import SimpleCSEHandler
from .virtualized import ops, V, WrapperHandler
@ -32,7 +33,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
"""
graph_outputs: Optional[list[ir.IRNode]]
root_graph: torch._inductor.graph.GraphLowering
root_graph: GraphLowering
_current_op: Optional[TargetType]
# For backwards of buffer_grads with scatters we allow mutations
allowed_mutations: Optional[OrderedSet[OpOverload]]
@ -43,7 +44,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
def __init__(
self,
gm: torch.fx.GraphModule,
root_graph_lowering: torch._inductor.graph.GraphLowering,
root_graph_lowering: GraphLowering,
allowed_mutations: Optional[OrderedSet[OpOverload]] = None,
additional_lowerings: Optional[LoweringDict] = None,
) -> None: