mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix XPU CI UT test_circular_dependencies (#158189)
# Motivation fix https://github.com/pytorch/pytorch/issues/110040 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158189 Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
committed by
PyTorch MergeBot
parent
5aee022d8b
commit
c68af9af1b
@ -2369,6 +2369,7 @@ class TestImports(TestCase):
|
||||
"torch.distributed.benchmarks", # depends on RPC and DDP Optim
|
||||
"torch.distributed.examples", # requires CUDA and torchvision
|
||||
"torch.distributed.tensor.examples", # example scripts
|
||||
"torch.distributed._tools.sac_ilp", # depends on pulp
|
||||
"torch.csrc", # files here are devtools, not part of torch
|
||||
"torch.include", # torch include files after install
|
||||
]
|
||||
|
@ -6,7 +6,6 @@ import sympy
|
||||
import torch
|
||||
|
||||
from .ir import Pointwise, ShapeAsConstantBuffer, TensorBox
|
||||
from .lowering import fallback_handler, is_integer_type, register_lowering
|
||||
from .virtualized import ops
|
||||
|
||||
|
||||
@ -109,6 +108,9 @@ def jagged_idx_to_dense_idx(
|
||||
|
||||
|
||||
def register_jagged_ops():
|
||||
# Avoid circular import by importing here
|
||||
from .lowering import fallback_handler, is_integer_type, register_lowering
|
||||
|
||||
# pyre-ignore[56]
|
||||
@register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default)
|
||||
def _jagged_to_padded_dense_forward(
|
||||
|
@ -13,6 +13,7 @@ from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from . import ir
|
||||
from .exc import SubgraphLoweringException
|
||||
from .graph import GraphLowering
|
||||
from .ops_handler import SimpleCSEHandler
|
||||
from .virtualized import ops, V, WrapperHandler
|
||||
|
||||
@ -32,7 +33,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
|
||||
"""
|
||||
|
||||
graph_outputs: Optional[list[ir.IRNode]]
|
||||
root_graph: torch._inductor.graph.GraphLowering
|
||||
root_graph: GraphLowering
|
||||
_current_op: Optional[TargetType]
|
||||
# For backwards of buffer_grads with scatters we allow mutations
|
||||
allowed_mutations: Optional[OrderedSet[OpOverload]]
|
||||
@ -43,7 +44,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
|
||||
def __init__(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
root_graph_lowering: torch._inductor.graph.GraphLowering,
|
||||
root_graph_lowering: GraphLowering,
|
||||
allowed_mutations: Optional[OrderedSet[OpOverload]] = None,
|
||||
additional_lowerings: Optional[LoweringDict] = None,
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user