Compare commits

...

6 Commits

Author SHA1 Message Date
d97f5b20ef https://github.com/pytorch/pytorch/pull/164205
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
2025-10-01 10:47:54 -07:00
5667052b85 https://github.com/pytorch/pytorch/pull/164174
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
2025-10-01 07:48:38 -07:00
dcb592a1cc https://github.com/pytorch/pytorch/pull/164315/
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
2025-10-01 07:48:38 -07:00
b5e8d5a976 silu hack
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
2025-10-01 07:48:38 -07:00
c7b66a86d8 Remove CompositeImplicitAutograd from _fused_rms_norm
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 83c7f47e6d9d0a00d975a5eed59a94f7ee07bbec
Pull-Request: https://github.com/pytorch/pytorch/pull/164289
2025-10-01 07:48:22 -07:00
f76eb15283 Add failing bitwise equivalence UT for aot_eager on rms_norm
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 94fcb543743c81f6277becf76e1a2082779006c7
Pull-Request: https://github.com/pytorch/pytorch/pull/164280
2025-10-01 07:48:21 -07:00
13 changed files with 183 additions and 22 deletions

View File

@ -362,7 +362,11 @@ Tensor rms_norm_symint(
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
#endif
return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
if (input.device().type() == DeviceType::CUDA) {
return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
} else {
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
}
} // namespace at::native

View File

@ -3332,7 +3332,7 @@
dispatch:
CUDA: _fused_rms_norm_cuda
MPS: _fused_rms_norm_mps
CompositeImplicitAutograd: rms_norm_composite
CompositeExplicitAutograd: rms_norm_composite
- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor)
dispatch:
@ -5327,7 +5327,7 @@
structured_delegate: silu_backward.grad_input
python_module: nn
dispatch:
CompositeImplicitAutograd: math_silu_backward
CompositeExplicitAutograd: math_silu_backward
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: silu_backward_nested
tags: pointwise

View File

@ -258,6 +258,16 @@ class TestDTensorDebugMode(TestCase):
# Verify that cond operations are captured in debug mode
self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string())
def test_compile(self):
@torch.compile
def f(x):
return x.sin().cos()
x = torch.randn(8)
with DebugMode() as debug_mode:
f(x)
self.assertEqual(len(debug_mode.debug_string()), 0)
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -910,7 +910,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
7|aten.view.default||l__self___fc1
6|aten.t.default||l__self___fc1
5|aten.view.default||l__self___fc1
4|aten.view.default||
4|aten.view.default||flatten
2|aten.detach.default||l__self___relu1
2|aten.detach.default||l__self___relu1
2|aten.threshold_backward.default||l__self___relu1

View File

@ -792,18 +792,10 @@ class inner_f(torch.nn.Module):
)
for node in joint_with_descriptors.graph_module.graph.nodes:
if (
node.target
in (
torch.ops.prims.transpose.default,
torch.ops.aten.mm.default,
torch.ops.prims.mul.default,
torch.ops.prims.broadcast_in_dim.default,
torch.ops.prims.add.default,
)
# TODO: add annotation to backward graph nodes
and node.meta.get("partitioner_tag") != "is_backward"
):
if node.target not in (
torch.ops.prims.sub.default,
torch.ops.aten.sub.default,
) and node.op not in ("placeholder", "output"):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.sub.default:
self.assertTrue(node.meta.get("custom", {}), {})

View File

@ -28,6 +28,7 @@ from common_utils import (
import torch
import torch._dynamo as torchdynamo
import torch.nn as nn
import torch.nn.functional as F
import torch.utils._pytree as pytree
from functorch import grad, jacrev, make_fx, vjp, vmap
from functorch.compile import (
@ -7199,6 +7200,26 @@ metadata incorrectly.
torch.compile(fn, backend="inductor", fullgraph=True)(x)
torch.compile(fn_, backend="inductor", fullgraph=True)(x)
def test_layer_norm(self):
def fn(x):
return F.layer_norm(x, normalized_shape=(8,))
x = torch.randn(2, 4, 8)
eager = fn(x)
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_rms_norm(self):
# Only CUDA rms norm fails to be decomposed
def fn(x):
return F.rms_norm(x, normalized_shape=(8,))
x = torch.randn(2, 4, 8, device="cuda")
eager = fn(x)
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
def test_subclass_parameters(self):
class _M(torch.nn.Module):
def __init__(self):

View File

@ -1747,6 +1747,58 @@ def native_layer_norm_backward_out(
return grad_input
@register_decomposition(aten._fused_rms_norm.default)
def _fused_rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Optional[Tensor],
eps: Optional[float],
) -> tuple[Tensor, Tensor]:
dims_to_reduce: list[int] = []
for i in range(len(normalized_shape)):
dims_to_reduce.append(input.dim() - i - 1)
# upcast is needed for fp16 and bf16
computation_dtype = utils.get_computation_dtype(input.dtype)
upcasted_input = input.to(computation_dtype)
# computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble]
if eps is None:
if computation_dtype in (torch.float32, torch.complex64):
eps_val = sys.float_info.epsilon
else:
eps_val = sys.float_info.epsilon
else:
eps_val = eps
rqrst_input = torch.rsqrt(
# NB: don't inplace here, will violate functional IR invariant
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val)
)
upcasted_result = upcasted_input.mul(rqrst_input)
if weight is not None:
upcasted_result = upcasted_result.mul(weight)
# NB: nested should be dead here, just here for fidelity
is_nested = input.is_nested or (weight is not None and weight.is_nested)
memory_format = utils.suggest_memory_format(input)
is_channels_last = memory_format in (
torch.channels_last,
torch.channels_last_3d,
)
if not is_nested and not is_channels_last:
upcasted_result = upcasted_result.contiguous()
rqrst_input = rqrst_input.contiguous()
# Cast normalized result back to original input type
result = upcasted_result.type_as(input)
return result, rqrst_input
@register_decomposition(aten._fused_rms_norm_backward.default)
def _fused_rms_norm_backward(
grad_out: Tensor,

View File

@ -420,14 +420,18 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
# the descendants of graph inputs corresponding to fwd inputs, didn't
# seem obvious at first glance on how to partition graph inputs into
# fwd vs bwd without relying on string names.
return "nn_module_stack" in node.meta and "seq_nr" in node.meta
return (
node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
)
def _is_backward_node_with_seq_nr(node):
# For now, assume that if nn_module_stack_metadata is not populated,
# this node is from the backward. Ignore nodes without `seq_nr`.
# TODO(future): there is likely a less brittle way to do this, same
# as with the forward.
return ("nn_module_stack" not in node.meta) and "seq_nr" in node.meta
return (
node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
)
fwd_seq_nr_to_node = {}
for node in fx_g.graph.nodes:
@ -447,8 +451,10 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
# fwd_node should always exist, but handle non-existence just in case
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
if fwd_node is not None:
node.meta["fwd_nn_module_stack"] = fwd_node.meta["nn_module_stack"]
node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack")
node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack")
# TODO: better to change to a specific field of custom?
node.meta["custom"] = fwd_node.meta.get("custom")
def register_buffer_assignment_hook(mod, assigned_buffers):

View File

@ -88,6 +88,7 @@ inductor_decompositions = get_decompositions(
aten.native_batch_norm,
aten.native_group_norm,
aten.native_layer_norm,
aten._fused_rms_norm,
aten.nll_loss2d_backward,
aten.permute_copy,
aten.rrelu_with_noise_backward,

View File

@ -1,6 +1,10 @@
import os
# Whether to disable showing progress on compilation passes
# Need to add a new config otherwise will get a circular import if dynamo config is imported here
disable_progress = True
# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy
verbose_progress = False
profiler_interpreter_stack_trace = os.environ.get("TORCH_PROFILE_INTERPRETER_STACK_TRACE", "0") == "1"

View File

@ -1,12 +1,15 @@
# mypy: allow-untyped-defs
import inspect
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
import torch.fx.traceback as fx_traceback
from torch._logging import trace_structured
from torch.hub import tqdm
from torch.profiler import profile, record_function, ProfilerActivity
import torch._C._profiler as _profiler
import json
from . import config
from ._compatibility import compatibility
@ -161,6 +164,16 @@ class Interpreter:
delay=0,
)
graph_id = id(self.graph)
if config.profiler_interpreter_stack_trace:
stack_traces = {}
for node in self.graph.nodes:
if node.stack_trace:
stack_traces[f"## {node.name}:{graph_id} interpreter ##"] = node.stack_trace.replace("\"", "'")
# add stack traces to profiler metadata
torch.autograd._add_metadata_json(f"node_stack_traces:{graph_id}", json.dumps(stack_traces))
for node in self.graph.nodes:
pbar.update(1)
if node in self.env:
@ -169,9 +182,13 @@ class Interpreter:
# where the caller has pre-populated `env` with
# values for a subset of the program.
continue
profiler_context = nullcontext()
if config.profiler_interpreter_stack_trace:
profiler_context = torch.profiler.record_function(f"## {node.name}:{graph_id} interpreter ##")
try:
self.env[node] = self.run_node(node)
with profiler_context:
self.env[node] = self.run_node(node)
except Exception as e:
if self.extra_traceback:
msg = f"While executing {node.format_node()}"

View File

@ -5,6 +5,7 @@ import traceback
from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional, Union
import json
from torch._utils_internal import signpost_event
@ -396,3 +397,49 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
},
)
return {}
def populate_stack_traces_to_kineto_trace(file_name: str, update_file = True):
"""
Process traces by attaching stack traces to user_annotation entries.
Args:
file_name (str): The filename of the exported kineto trace json.
update_file (bool): Whether to update the kineto trace json file with the stack traces.
Returns:
dict: Modified trace data with stack traces attached to matching entries
"""
trace_data = json.load(open(file_name, 'r'))
all_stack_traces = {}
# Get the trace events
for key in trace_data.keys():
if not key.startswith("node_stack_traces"):
continue
# Get the node stack traces mapping
node_stack_traces = trace_data.get(key, {})
all_stack_traces.update(node_stack_traces)
if len(all_stack_traces) == 0:
log.warning("No stack traces found in kineto trace data")
return trace_data
trace_events = trace_data.get("traceEvents", [])
# Process each trace event
for event in trace_events:
# Check if this is a user_annotation event
if event.get("cat") == "user_annotation":
event_name = event.get("name")
# If the event name matches a node in node_stack_traces, attach the stack trace
if event_name in all_stack_traces:
event["args"]["stack_trace"] = all_stack_traces[event_name]
if update_file:
json.dump(trace_data, open(file_name, 'w'))
return trace_data

View File

@ -101,6 +101,13 @@ class DebugMode(TorchDispatchMode):
self.operators = []
self.call_depth = 0
# Without this override, running torch.compile under DebugMode
# will force torch.compile to always use the “eager” backend
# With this, DebugMode will not take effect on torch.compile
@classmethod
def ignore_compile_internals(cls):
return True
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}