Compare commits

..

10 Commits

Author SHA1 Message Date
de4bd2b3a4 Update 2025-11-03 14:47:02 -08:00
0571524a0e Update 2025-11-03 13:04:11 -08:00
55f9503b47 Update 2025-11-03 12:07:47 -08:00
e122994d51 Update 2025-11-03 12:05:34 -08:00
7f9450a68c Update 2025-11-03 10:46:42 -08:00
6e0311b37e Change python doc push script to print the undocumented modules 2025-11-03 09:08:04 -08:00
20f8edab38 Update 2025-11-03 09:08:04 -08:00
3ef57af18f Test 2025-11-03 09:08:03 -08:00
104b868618 Fix build error by checking cuda version in CUDAGreenContext (#166800)
Fixes #166799
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166800
Approved by: https://github.com/mlazos, https://github.com/eqy, https://github.com/malfet
2025-11-03 16:41:38 +00:00
94f2657c4b [Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer (#166165)
Prefer unfused addmm when there is at least a single elemwise/reduction consumer..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166165
Approved by: https://github.com/eellison
2025-11-03 15:50:32 +00:00
15 changed files with 192 additions and 72 deletions

View File

@ -1,15 +1,11 @@
sphinx==5.3.0 sphinx==7.2.6
#Description: This is used to generate PyTorch docs #Description: This is used to generate PyTorch docs
#Pinned versions: 5.3.0 #Pinned versions: 7.2.6
standard-imghdr==3.13.0; python_version >= "3.13" pytorch_sphinx_theme2==0.2.0
#Description: This is needed by Sphinx, so it needs to be added here. #Description: This is needed to generate PyTorch docs
# The reasons are as follows: #Pinned versions: 0.2.0
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably # but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later. # something related to Docker setup. We can investigate this later.
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs #Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0 #Pinned versions: 2.13.0
breathe==4.34.0 breathe==4.36.0
#Description: This is used to generate PyTorch C++ docs #Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.34.0 #Pinned versions: 4.36.0
exhale==0.2.3 exhale==0.3.7
#Description: This is used to generate PyTorch C++ docs #Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.2.3 #Pinned versions: 0.3.7
docutils==0.16 docutils==0.20
#Description: This is used to generate PyTorch C++ docs #Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.16 #Pinned versions: 0.20
bs4==0.0.1 bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs #Description: This is used to generate PyTorch C++ docs
@ -56,13 +52,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs #Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0 #Pinned versions: 8.12.0
myst-nb==0.17.2 myst-nb==1.3.0
#Description: This is used to generate PyTorch functorch and torch.compile docs. #Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 0.17.2 #Pinned versions: 1.3.0
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs # The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5 python-etcd==0.4.5
sphinx-copybutton==0.5.0 sphinx-copybutton==0.5.0
sphinx-design==0.4.0 sphinx-design==0.6.1
sphinxcontrib-mermaid==1.0.0 sphinxcontrib-mermaid==1.0.0
myst-parser==0.18.1 myst-parser==4.0.1

View File

@ -89,20 +89,23 @@ if [ "$is_main_doc" = true ]; then
make coverage make coverage
# Now we have the coverage report, we need to make sure it is empty. # Now we have the coverage report, we need to make sure it is empty.
# Count the number of lines in the file and turn that number into a variable # Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row
# $lines. The `cut -f1 ...` is to only parse the number, not the filename # showing the undocumented count in the third column.
# Skip the report header by subtracting 2: the header will be output even if # Example: | TOTAL | 99.83% | 2 |
# there are no undocumented items.
# #
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should # Also: see docs/source/conf.py for "coverage_ignore*" items, which should
# be documented then removed from there. # be documented then removed from there.
lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
undocumented=$((lines - 2)) # Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table
if [ $undocumented -lt 0 ]; then # The table format is: | Module | Coverage | Undocumented |
# Extract the third column (undocumented count) from the TOTAL row
undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
echo coverage output not found echo coverage output not found
exit 1 exit 1
elif [ $undocumented -gt 0 ]; then elif [ "$undocumented" -gt 0 ]; then
echo undocumented objects found: echo "undocumented objects found:"
cat build/coverage/python.txt cat build/coverage/python.txt
echo "Make sure you've updated relevant .rsts in docs/source!" echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'" echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"

View File

@ -1,6 +1,6 @@
#include <ATen/cuda/CUDAGreenContext.h> #include <ATen/cuda/CUDAGreenContext.h>
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h> #include <c10/cuda/driver_api.h>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>

View File

@ -207,6 +207,56 @@ templates_path = [
] ]
# TODO: document these and remove them from here. # TODO: document these and remove them from here.
# coverage_ignore_modules uses regex patterns to match module names
# These modules will be completely ignored by the coverage checker
coverage_ignore_modules = [
# Internal/private modules (anything starting with _)
r".*\._.*",
# Experimental modules that are not yet stable
r"torch\.fx\.experimental\..*",
r"torch\.distributed\.elastic\..*",
# Test and example modules
r".*\.test_.*",
r".*\.tests\..*",
r".*\.examples\..*",
]
# Fixes the duplicated
autosummary_filename_map = {
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
"torch.mtia.stream": "torch.mtia.stream_function",
"torch.mtia.Stream": "torch.mtia.Stream_class",
"torch.cpu.stream": "torch.cpu.stream_function",
"torch.cpu.Stream": "torch.cpu.Stream_class",
"torch.cuda.stream": "torch.cuda.stream_function",
"torch.cuda.Stream": "torch.cuda.Stream_class",
"torch.xpu.stream": "torch.xpu.stream_function",
"torch.xpu.Stream": "torch.xpu.Stream_class",
}
coverage_ignore_functions = [ coverage_ignore_functions = [
# torch # torch
"typename", "typename",
@ -3195,6 +3245,11 @@ autodoc_type_aliases = {
# Enable overriding of function signatures in the first line of the docstring. # Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True autodoc_docstring_signature = True
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
autodoc_default_options = {
"exclude-members": "from_bytes, to_bytes",
}
# -- katex javascript in header # -- katex javascript in header
# #
# def setup(app): # def setup(app):

View File

@ -253,7 +253,6 @@ regular full-precision tensor.
.. autosummary:: .. autosummary::
:toctree: generated :toctree: generated
:nosignatures: :nosignatures:
:template: classtemplate.rst
view view
as_strided as_strided

View File

@ -4,7 +4,6 @@ import os
import tempfile import tempfile
from threading import Event from threading import Event
import torch._inductor.config as config
from torch._inductor.compile_worker.subproc_pool import ( from torch._inductor.compile_worker.subproc_pool import (
raise_testexc, raise_testexc,
SubprocException, SubprocException,
@ -17,12 +16,9 @@ from torch.testing._internal.inductor_utils import HAS_CPU
class TestCompileWorker(TestCase): class TestCompileWorker(TestCase):
def make_pool(self, size):
return SubprocPool(size)
@skipIfWindows(msg="pass_fds not supported on Windows.") @skipIfWindows(msg="pass_fds not supported on Windows.")
def test_basic_jobs(self): def test_basic_jobs(self):
pool = self.make_pool(2) pool = SubprocPool(2)
try: try:
a = pool.submit(operator.add, 100, 1) a = pool.submit(operator.add, 100, 1)
b = pool.submit(operator.sub, 100, 1) b = pool.submit(operator.sub, 100, 1)
@ -33,7 +29,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.") @skipIfWindows(msg="pass_fds not supported on Windows.")
def test_exception(self): def test_exception(self):
pool = self.make_pool(2) pool = SubprocPool(2)
try: try:
a = pool.submit(raise_testexc) a = pool.submit(raise_testexc)
with self.assertRaisesRegex( with self.assertRaisesRegex(
@ -46,7 +42,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.") @skipIfWindows(msg="pass_fds not supported on Windows.")
def test_crash(self): def test_crash(self):
pool = self.make_pool(2) pool = SubprocPool(2)
try: try:
with self.assertRaises(Exception): with self.assertRaises(Exception):
a = pool.submit(os._exit, 1) a = pool.submit(os._exit, 1)
@ -62,7 +58,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.") @skipIfWindows(msg="pass_fds not supported on Windows.")
def test_quiesce(self): def test_quiesce(self):
pool = self.make_pool(2) pool = SubprocPool(2)
try: try:
a = pool.submit(operator.add, 100, 1) a = pool.submit(operator.add, 100, 1)
pool.quiesce() pool.quiesce()
@ -79,7 +75,7 @@ class TestCompileWorker(TestCase):
os.environ["ROLE_RANK"] = "0" os.environ["ROLE_RANK"] = "0"
with tempfile.NamedTemporaryFile(delete=True) as temp_log: with tempfile.NamedTemporaryFile(delete=True) as temp_log:
os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name
pool = self.make_pool(2) pool = SubprocPool(2)
try: try:
pool.submit(operator.add, 100, 1) pool.submit(operator.add, 100, 1)
self.assertEqual(os.path.exists(temp_log.name), True) self.assertEqual(os.path.exists(temp_log.name), True)
@ -87,12 +83,6 @@ class TestCompileWorker(TestCase):
pool.shutdown() pool.shutdown()
@config.patch("quiesce_async_compile_time", 0.1)
class TestCompileWorkerWithTimer(TestCompileWorker):
def make_pool(self, size):
return SubprocPool(size, quiesce=True)
class TestTimer(TestCase): class TestTimer(TestCase):
def test_basics(self): def test_basics(self):
done = Event() done = Event()

View File

@ -15280,7 +15280,7 @@ if RUN_GPU:
), ),
( (
fn3, fn3,
"triton_poi_fused_native_layer_norm_relu", "triton_poi_fused_addmm_native_layer_norm",
(torch.randn(4, 4, device=GPU_TYPE),), (torch.randn(4, 4, device=GPU_TYPE),),
), ),
] ]
@ -15293,7 +15293,7 @@ if RUN_GPU:
), ),
( (
fn3, fn3,
"triton_poi_fused_LayerNorm_ReLU", "triton_poi_fused_LayerNorm_Linear_ReLU",
(torch.randn(4, 4, device=GPU_TYPE),), (torch.randn(4, 4, device=GPU_TYPE),),
), ),
] ]

View File

@ -1282,6 +1282,7 @@ def _compile(
# in the case of normal and exception code paths # in the case of normal and exception code paths
convert_frame_box: Optional[ConvertFrameBox] = None, convert_frame_box: Optional[ConvertFrameBox] = None,
) -> ConvertFrameReturn: ) -> ConvertFrameReturn:
from torch._inductor.async_compile import async_compile_pool_manager
from torch.fx.experimental.validator import ( from torch.fx.experimental.validator import (
BisectValidationException, BisectValidationException,
ValidationException, ValidationException,
@ -1475,6 +1476,7 @@ def _compile(
with ( with (
_use_lazy_graph_module(config.use_lazy_graph_module), _use_lazy_graph_module(config.use_lazy_graph_module),
compile_context(CompileContext(compile_id)), compile_context(CompileContext(compile_id)),
async_compile_pool_manager(),
chromium_event_timed( chromium_event_timed(
"dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True "dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True
), ),

View File

@ -2365,6 +2365,8 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
@staticmethod @staticmethod
def _backward_impl(ctx, all_args): def _backward_impl(ctx, all_args):
from torch._inductor.async_compile import async_compile_pool_manager
# compiled autograd reimplements this function at proxy_call_aot_backward # compiled autograd reimplements this function at proxy_call_aot_backward
assert not backward_state_indices, ( assert not backward_state_indices, (
"BackwardState requires CompiledAutograd" "BackwardState requires CompiledAutograd"
@ -2444,6 +2446,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
with ( with (
tracing(saved_context), tracing(saved_context),
compile_context(saved_compile_context), compile_context(saved_compile_context),
async_compile_pool_manager(),
context(), context(),
track_graph_compiling(aot_config, "backward"), track_graph_compiling(aot_config, "backward"),
metrics_context, metrics_context,

View File

@ -228,6 +228,18 @@ class CompiledTritonKernels:
del CompiledTritonKernels._cache[key] del CompiledTritonKernels._cache[key]
@contextlib.contextmanager
def async_compile_pool_manager():
"""
Context manager to quiesce the subproc pool at the end of compilation, i.e.,
when dynamo is done.
"""
try:
yield
finally:
AsyncCompile.quiesce()
class AsyncCompile: class AsyncCompile:
""" """
Utilities to compile in thread pools or subprocess pools (in the case of Triton). Utilities to compile in thread pools or subprocess pools (in the case of Triton).
@ -263,9 +275,7 @@ class AsyncCompile:
pool: AnyPool pool: AnyPool
if config.worker_start_method == "subprocess": if config.worker_start_method == "subprocess":
# Wrapper around ProcessPoolExecutor forks in a new process we control # Wrapper around ProcessPoolExecutor forks in a new process we control
pool = SubprocPool( pool = SubprocPool(get_compile_threads())
get_compile_threads(), quiesce = config.quiesce_async_compile_pool
)
else: else:
if config.worker_start_method == "spawn": if config.worker_start_method == "spawn":
# Avoid creating pools in the spawned subprocs themselves: # Avoid creating pools in the spawned subprocs themselves:
@ -321,6 +331,20 @@ class AsyncCompile:
cls._ready_future = cls.process_pool().submit(cls._get_ready) cls._ready_future = cls.process_pool().submit(cls._get_ready)
return cls._ready_future.done() return cls._ready_future.done()
@classmethod
def quiesce(cls) -> None:
"""
If using a SubprocPool, signal the sidecar process to shut down its
ProcessPoolExecutor.
"""
# Don't inadvertently create a process pool if it doesn't already exist:
if not cls.process_pool.cache_info().currsize:
return
if config.quiesce_async_compile_pool:
pool = cls.process_pool()
if isinstance(pool, SubprocPool):
pool.quiesce()
@classmethod @classmethod
def wakeup(cls) -> None: def wakeup(cls) -> None:
""" """

View File

@ -23,7 +23,6 @@ from typing_extensions import Never, ParamSpec
import torch._thread_safe_fork # noqa: F401 import torch._thread_safe_fork # noqa: F401
from torch._inductor import config from torch._inductor import config
from torch._inductor.codecache import torch_key from torch._inductor.codecache import torch_key
from torch._inductor.compile_worker.timer import Timer
from torch._inductor.compile_worker.tracked_process_pool import ( from torch._inductor.compile_worker.tracked_process_pool import (
TrackedProcessPoolExecutor, TrackedProcessPoolExecutor,
) )
@ -132,7 +131,6 @@ class SubprocPool:
nprocs: int, nprocs: int,
pickler: Optional[SubprocPickler] = None, pickler: Optional[SubprocPickler] = None,
kind: SubprocKind = SubprocKind.FORK, kind: SubprocKind = SubprocKind.FORK,
quiesce: bool = False,
) -> None: ) -> None:
entry = os.path.join(os.path.dirname(__file__), "__main__.py") entry = os.path.join(os.path.dirname(__file__), "__main__.py")
self.pickler = pickler or SubprocPickler() self.pickler = pickler or SubprocPickler()
@ -217,13 +215,6 @@ class SubprocPool:
"pytorch.wait_counter.subproc_pool.first_job" "pytorch.wait_counter.subproc_pool.first_job"
).guard() ).guard()
if quiesce:
self.timer: Optional[Timer] = Timer(
config.quiesce_async_compile_time, self.quiesce
)
else:
self.timer = None
# Start thread last to ensure all member variables are initialized # Start thread last to ensure all member variables are initialized
# before any access. # before any access.
self.read_thread.start() self.read_thread.start()
@ -296,8 +287,6 @@ class SubprocPool:
with self.futures_lock: with self.futures_lock:
if not self.running: if not self.running:
return return
if self.timer:
self.timer.record_call()
if isinstance(result, _SubprocExceptionInfo): if isinstance(result, _SubprocExceptionInfo):
# An exception occurred in the submitted job # An exception occurred in the submitted job
self.pending_futures[job_id].set_exception( self.pending_futures[job_id].set_exception(
@ -332,8 +321,6 @@ class SubprocPool:
with self.write_lock: with self.write_lock:
if not self.running: if not self.running:
return return
if self.timer:
self.timer.quit()
self.running = False self.running = False
self.running_waitcounter.__exit__() self.running_waitcounter.__exit__()
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN) _send_msg(self.write_pipe, MsgHeader.SHUTDOWN)

View File

@ -17,7 +17,7 @@ class Timer:
self.background_thread: Optional[Thread] = None self.background_thread: Optional[Thread] = None
self.last_called: Optional[float] = None self.last_called: Optional[float] = None
self.duration = duration self.duration = duration
self.sleep_time = duration / 2 self.sleep_time = 60
self.call = call self.call = call
self.exit = False self.exit = False

View File

@ -960,11 +960,6 @@ quiesce_async_compile_pool: bool = Config(
default=False, default=False,
) )
# Time in seconds to wait before quiescing
quiesce_async_compile_time: int = Config(
default=60,
)
# Whether or not to enable statically launching CUDA kernels # Whether or not to enable statically launching CUDA kernels
# compiled by triton (instead of using triton's own launcher) # compiled by triton (instead of using triton's own launcher)
use_static_cuda_launcher: bool = static_cuda_launcher_default() use_static_cuda_launcher: bool = static_cuda_launcher_default()

View File

@ -51,8 +51,8 @@ from ..utils import (
decode_device, decode_device,
get_all_devices, get_all_devices,
get_gpu_type, get_gpu_type,
has_uses_tagged_as,
is_gpu, is_gpu,
is_pointwise_use,
OPTIMUS_EXCLUDE_POST_GRAD, OPTIMUS_EXCLUDE_POST_GRAD,
) )
from ..virtualized import V from ..virtualized import V
@ -1510,8 +1510,10 @@ def should_prefer_unfused_addmm(match):
if not is_gpu(inp.meta["val"].device.type): if not is_gpu(inp.meta["val"].device.type):
return False return False
output = match.output_node() return has_uses_tagged_as(
return all(is_pointwise_use(use) for use in output.users) match.output_node(),
(torch.Tag.pointwise, torch.Tag.reduction),
)
@register_graph_pattern( @register_graph_pattern(

View File

@ -553,6 +553,70 @@ def is_pointwise_use(
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target) return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
class LogicalConnective(enum.Enum):
OR = enum.auto()
AND = enum.auto()
def has_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Given a target, explore the uses of `target` by applying `use_selector_fn`
on them, and then aggregate these booleans with the `use_aggregate_type`
logical connective.
Uses in view ops will follow the views uses.
"""
def get_use_aggregate_fn(
use_aggregate_type: LogicalConnective,
) -> Callable[[Iterator[Any]], bool]:
match use_aggregate_type:
case LogicalConnective.AND:
return all
case LogicalConnective.OR:
return any
case _:
return any
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
def has_uses_impl(use: Node) -> bool:
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload)
or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
# Process getitem and view
if target is operator.getitem or is_view(target):
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
return use_selector_fn(target)
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
def has_uses_tagged_as(
target: Node,
use_tags: Collection[torch.Tag],
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Is there a use with given tags?
"""
return has_uses(
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
)
def gen_gm_and_inputs( def gen_gm_and_inputs(
target: Any, args: list[Any], kwargs: dict[str, Any] target: Any, args: list[Any], kwargs: dict[str, Any]
) -> tuple[GraphModule, list[torch.Tensor]]: ) -> tuple[GraphModule, list[torch.Tensor]]: