Compare commits

...

1 Commits

Author SHA1 Message Date
66bcd2236c Add cudagraph static inputs logging
ghstack-source-id: 307d4c6c74395a3f2da711fd413f09936ca5d4a0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132726
2024-08-06 01:42:12 -07:00
8 changed files with 63 additions and 0 deletions

View File

@ -621,6 +621,18 @@ print("arf")
record_str,
)
@make_logging_test(cudagraph_static_inputs=True)
def test_cudagraph_static_inputs(self, records):
@torch.compile(mode="reduce-overhead")
def fn(x):
return x + 1
x = torch.ones(2, 2)
torch._dynamo.mark_static_address(x)
fn(x)
self.assertGreater(len(records), 0)
self.assertLess(len(records), 4)
@skipIfTorchDynamo("too slow")
@make_logging_test(**torch._logging.DEFAULT_LOGGING)
def test_default_logging(self, records):
@ -699,6 +711,7 @@ exclusions = {
"sym_node",
"export",
"trace_shape_events",
"cudagraph_static_inputs",
}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:

View File

@ -199,6 +199,9 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
static_inputs_log = torch._logging.getArtifactLogger(
__name__, "cudagraph_static_inputs"
)
DimList = List
@ -1195,6 +1198,9 @@ class VariableBuilder:
def mark_static_input(self, value: torch.Tensor, guard: bool):
from ..decorators import mark_static_address
static_inputs_log.debug(
"Marking static input %s, id: %s)", self.source.name(), id(value)
)
mark_static_address(value, guard=guard)
# Check if we've seen this tensor before and update graph metadata if needed

View File

@ -17,6 +17,7 @@ import torch
import torch.utils._pytree as pytree
from torch import Tensor
from torch._guards import detect_fake_mode
from torch._logging import getArtifactLogger
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch._subclasses.meta_utils import safe_is_leaf
from torch.fx.experimental.symbolic_shapes import is_concrete_int
@ -51,6 +52,7 @@ from .utils import _get_autocast_states, KNOWN_TYPES, strict_zip
zip = strict_zip
log = logging.getLogger(__name__)
static_input_logger = getArtifactLogger("torch._dynamo", "cudagraph_static_inputs")
# Note [Tangents must be contiguous]
@ -679,6 +681,10 @@ from a multi-output view call"
if (isinstance(arg, torch.nn.Parameter) or i in passed_indices)
]
static_input_logger.debug(
"static input indices metadata analysis: %s", static_input_indices
)
f_mutated_inputs = [
inp
for inp, info in zip(flat_f_args, input_info)

View File

@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
import torch
import torch._dynamo.logging
import torch.nn as nn
import torch.utils._pytree as pytree
import torch.utils.dlpack
@ -21,6 +22,11 @@ from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
static_inputs_log = torch._logging.getArtifactLogger(
__name__, "cudagraph_static_inputs"
)
from . import config
from ._aot_autograd.autograd_cache import ( # noqa: F401
AOTAutogradCache,
@ -964,11 +970,19 @@ def aot_module_simplified(
assert source not in seen_sources, source
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)
source_name = source.name() if source else str(source)
if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
"_dynamo_static_input_type", None
):
static_inputs_log.debug(
"Adding static input pos %s for source %s", pos, source_name
)
static_input_indices.append(pos)
else:
static_inputs_log.debug(
"Non-static input pos %s for source %s", pos, source_name
)
if aot_autograd_arg_pos_to_source is not None:
assert len(full_args) == len(aot_autograd_arg_pos_to_source)

View File

@ -94,6 +94,9 @@ else:
log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs")
static_inputs_log = torch._logging.getArtifactLogger(
__name__, "cudagraph_static_inputs"
)
# copy_ fails when trying to write to tensors with memory overlap,
@ -489,6 +492,8 @@ def compile_fx_inner(
if static_input_idxs is None:
static_input_idxs = []
static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs)
assert isinstance(
next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"

View File

@ -11,6 +11,9 @@ from torch._inductor.utils import InputType
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
static_inputs_log = torch._logging.getArtifactLogger(
__name__, "cudagraph_static_inputs"
)
OutputType = List[Optional[Union[int, torch.Tensor]]]
@ -136,6 +139,11 @@ def check_for_mutation(
else:
mutation_indices = func.mutated_input_idxs
static_inputs_log.debug(
"check mutation static input indices: %s", func.static_input_idxs
)
static_inputs_log.debug("check mutation mutation indices: %s", mutation_indices)
return (
get_mutation_stack_trace(func.placeholders, mutation_indices)
if mutation_indices

View File

@ -231,6 +231,7 @@ def set_logs(
cudagraphs: bool = False,
sym_node: bool = False,
compiled_autograd_verbose: bool = False,
cudagraph_static_inputs: bool = False,
):
"""
Sets the log level for individual components and toggles individual log
@ -404,6 +405,9 @@ def set_logs(
needs to be set. This can be done by providing the fully-qualified module
name as the key, with the log level as the value. Default: ``None``
cudagraph_static_inputs (:class:`bool`):
Whether to emit debug info for cudagraph static input detection. Default: ``False``
Example::
@ -499,6 +503,7 @@ def set_logs(
export=export,
cudagraphs=cudagraphs,
compiled_autograd_verbose=compiled_autograd_verbose,
cudagraph_static_inputs=cudagraph_static_inputs,
)

View File

@ -161,4 +161,10 @@ register_artifact(
off_by_default=True,
)
register_artifact(
"cudagraph_static_inputs",
"Logs static inputs handling in dynamo, AOT, and cudagraphs",
off_by_default=True,
)
register_artifact("custom_format_test_artifact", "Testing only", log_format="")