Compare commits

...

7 Commits

Author SHA1 Message Date
9ebed6d17c [BE] Remove unnecessary semicolon in Module.cpp 2025-11-13 12:04:35 -08:00
fadb62f592 [PyTorch] fix profiler issue with empty exported trace file (#167601)
Summary:
The previous implementation incorrectly attempted to read from a `NamedTemporaryFile` file pointer after calling `profiler.export_chrome_trace(fp.name)`. The issue is that `export_chrome_trace()` writes to a file at the path `fp.name`, but doesn't write to the file pointer `fp` itself. This meant when the code tried to read from `fp`, it got empty content.

The fix explicitly closes the temporary file first, then calls `export_chrome_trace(fp.name)` which writes the JSON trace to a file at that path. We then open that file separately for reading and copy its contents to the gzipped output file. This ensures we're reading from the actual file that was written to, not an empty file pointer.

Changes made in both `fbcode/caffe2/torch/profiler/profiler.py` and `xplat/caffe2/torch/profiler/profiler.py`:
- `export_chrome_trace()`: Fixed file reading for gzipped chrome trace exports by opening the written file separately
- `export_memory_timeline()`: Fixed file reading for gzipped memory timeline exports by opening the written file separately

Test Plan:
* run benchmark
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_train_pipeline -- \
    --yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml
```
* upload trace
```
DIFF=D86737513 fbcode/torchrec/fb/scripts/trace_to_manifold.sh
```
======== markdown ============

[manifold folder](https://www.internalfb.com/manifold/explorer/torchrec_benchmark_traces/tree/permanent_traces/DIFF/D86737513)
[trace-sparse_data_dist_base-rank0.json.gz](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/permanent_traces/DIFF/D86737513/trace-sparse_data_dist_base-rank0.json.gz&bucket=torchrec_benchmark_traces)

Differential Revision: D86737513

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167601
Approved by: https://github.com/angelayi
2025-11-13 19:40:09 +00:00
e5eb89e111 remove allocation of new unbacked symbols during mod eval (#167123)
When executing code like torch._check(numel % newsize == 0, ...), we previously allocated a new unbacked symbol due to #113165. However, this allocation is no longer necessary and can cause issues due to inconsistent behavior when tracing torch._check multiple times.

In particular, the allocation can lead to a memo disaster where the previously allocated symbol is returned instead of a new one, causing unexpected behavior.

This PR removes the unnecessary allocation, ensuring consistent behavior and avoiding potential issues. The change is validated by the following code, which now compiles without issues:
```
import torch

def fn(x):
    i0 = x.nonzero().size(0)
    y = torch.zeros((i0, 192))
    return y.view([12, -1, 192])
with torch._dynamo.config.patch({"capture_dynamic_output_shape_ops": True}):
    torch.compile(fn, fullgraph=True)(torch.ones((12,)))
```

By removing this unnecessary allocation, we simplify the code and avoid potential issues."

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167123
Approved by: https://github.com/Lucaskabela
2025-11-13 18:52:41 +00:00
b5e0e6932a Correctly populate storage offset in DTensor constructor (#167597)
The storage offset always matches the local offset because you never have rank dependent offset (your shard may be different, but your view into it will always be the same across all ranks!)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167597
Approved by: https://github.com/malfet
ghstack dependencies: #166868, #166867, #167076
2025-11-13 18:26:11 +00:00
6ea779188c [DebugMode] torch.hash_tensor option (#167486)
Adds `torch.hash_tensor` (#154149) as tensor hashing variant; allows tuple of hashes in log annotations for more info (e.g. `with DebugMode.log_tensor_hashes(hash_fn=["norm", "hash_tensor"]): ...`)

also fixes some corner cases around norm hashing (preserves NaNs/infs, avoids erroring on smaller dtypes)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167486
Approved by: https://github.com/xmfan
2025-11-13 17:46:09 +00:00
460c7e196c Handle only a Tensor for IntList parsing (#167606)
Fixes https://github.com/pytorch/pytorch/issues/167562

Authored with Claude Code

Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167606
Approved by: https://github.com/colesbury
2025-11-13 17:39:38 +00:00
7aac506cdc Revert "[precompile] Integrate AOTI as a backend. (#167338)"
This reverts commit 273babeec3c6211f30b806797f35a6e9c47c737f.

Reverted https://github.com/pytorch/pytorch/pull/167338 on behalf of https://github.com/jeanschmidt due to seems to be breaking internal tests and builds, see D86919103 ([comment](https://github.com/pytorch/pytorch/pull/167338#issuecomment-3528950888))
2025-11-13 17:39:03 +00:00
18 changed files with 267 additions and 279 deletions

View File

@ -31,6 +31,8 @@ from torch.utils._debug_mode import (
_RedistributeCall,
_TritonKernelCall,
DebugMode,
hash_tensor_fn,
norm_hash_fn,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._triton import has_triton_package
@ -115,6 +117,28 @@ class TestDTensorDebugMode(TestCase):
"aten::sum(t: f32[1, 32]) # {'hash': " in debug_mode.debug_string()
)
# check tuple hash functions
with (
DebugMode() as debug_mode,
DebugMode.log_tensor_hashes(hash_fn=["norm", "hash_tensor"]),
):
mm(x_dtensor, y_dtensor)
output_hash = debug_mode.operators[-1].log["hash"]
norm_ = lambda x: norm_hash_fn(x, use_scalar=True) # noqa: E731
hash_ = lambda x: hash_tensor_fn(x, use_scalar=True) # noqa: E731
self.assertEqual(output_hash[0], norm_(eager_out))
self.assertEqual(output_hash[1], hash_(eager_out))
# some edge cases
self.assertEqual(norm_(torch.tensor(torch.nan)), torch.nan)
self.assertEqual(norm_(torch.tensor(torch.inf)), torch.inf)
self.assertEqual(norm_(torch.complex(torch.ones(4), torch.zeros(4))), 4)
self.assertEqual(hash_(torch.ones(4, dtype=torch.float8_e5m2)), 0)
self.assertEqual(hash_(torch.ones(4, dtype=torch.int8)), 0)
self.assertEqual(hash_(torch.ones(5, dtype=torch.int8)), 1)
def test_debug_string_inside_context(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

View File

@ -664,6 +664,101 @@ class TestViewOps(DTensorTestBase):
)
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
@with_comms
def test_storage_offset_slice(self):
"""
Test that storage_offset is properly tracked on DTensor when slicing
a replicated tensor.
"""
mesh = init_device_mesh(self.device_type, (self.world_size,))
# Create a replicated DTensor
tensor = torch.randn(10, device=self.device_type)
dtensor = distribute_tensor(tensor, mesh, [Replicate()])
# Perform a slice operation [1:]
with CommDebugMode() as comm_mode:
sliced_dtensor = dtensor[1:]
# Slicing should not trigger any communication
self.assertEqual(comm_mode.get_total_counts(), 0)
# Verify that the DTensor's storage_offset matches the expected value
self.assertEqual(sliced_dtensor.storage_offset(), 1)
# Verify that the local tensor also has the correct storage_offset
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 1)
# Verify the shape is correct
self.assertEqual(sliced_dtensor.shape, torch.Size([9]))
# Verify the values are correct
expected = tensor[1:]
self.assertEqual(sliced_dtensor.full_tensor(), expected)
@with_comms
def test_storage_offset_shard_dim0_slice_dim1(self):
"""
Test that storage_offset is properly tracked when tensor is sharded on dim 0
and sliced on dim 1.
"""
mesh = init_device_mesh(self.device_type, (self.world_size,))
# Create a 2D tensor and shard on dim 0
tensor = torch.randn(12, 8, device=self.device_type)
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
# Perform a slice operation [:, 2:]
with CommDebugMode() as comm_mode:
sliced_dtensor = dtensor[:, 2:]
# Slicing should not trigger any communication
self.assertEqual(comm_mode.get_total_counts(), 0)
# The storage_offset should be 2 (skipping 2 elements in each row)
self.assertEqual(sliced_dtensor.storage_offset(), 2)
# Verify that the local tensor also has the correct storage_offset
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 2)
# Verify the shape is correct
expected_shape = torch.Size([12, 6])
self.assertEqual(sliced_dtensor.shape, expected_shape)
# Verify the values are correct
expected = tensor[:, 2:]
self.assertEqual(sliced_dtensor.full_tensor(), expected)
@with_comms
def test_storage_offset_shard_dim1_slice_dim0(self):
"""
Test that storage_offset is properly tracked when tensor is sharded on dim 1
and sliced on dim 0.
"""
mesh = init_device_mesh(self.device_type, (self.world_size,))
# Create a 2D tensor and shard on dim 1
tensor = torch.randn(10, 12, device=self.device_type)
dtensor = distribute_tensor(tensor, mesh, [Shard(1)])
# Perform a slice operation [2:, :]
with CommDebugMode() as comm_mode:
sliced_dtensor = dtensor[2:, :]
# Slicing should not trigger any communication
self.assertEqual(comm_mode.get_total_counts(), 0)
local_dim1_size = 12 // self.world_size
expected_offset = 2 * local_dim1_size
self.assertEqual(sliced_dtensor.storage_offset(), expected_offset)
self.assertEqual(sliced_dtensor.to_local().storage_offset(), expected_offset)
# Verify the shape is correct
expected_shape = torch.Size([8, 12])
self.assertEqual(sliced_dtensor.shape, expected_shape)
# Verify the values are correct
expected = tensor[2:, :]
self.assertEqual(sliced_dtensor.full_tensor(), expected)
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
TestViewOps,

View File

@ -1,11 +1,9 @@
# Owner(s): ["module: dynamo"]
import copy
import functools
import inspect
import os
import pickle
import unittest
from contextlib import contextmanager
from unittest.mock import patch
@ -15,16 +13,13 @@ import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
from torch._dynamo.exc import PackageError, Unsupported
from torch._dynamo.package import DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.fx._graph_pickler import GraphPickler
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
TEST_CUDA,
)
from torch.testing._internal.common_utils import instantiate_parametrized_tests
MY_LAMBDA = lambda x: x + 1 # noqa: E731
@ -604,92 +599,6 @@ from user code:
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
fn,
(make_inputs(), {}),
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_module(self):
with torch.device("cuda"):
from torch._dynamo.hooks import Hooks
mod = SimpleLinearModule()
def make_inputs():
return (torch.randn(4, 3),)
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
mod,
[ModelInput(make_inputs(), {}, [])],
Hooks(),
torch._TorchCompileAOTInductorWrapper(None, None, None),
)
def get_grads(m: torch.nn.Module):
return {name: p.grad for name, p in m.named_parameters()}
original_mod = copy.deepcopy(mod)
test_inputs = make_inputs()
expected = mod(*test_inputs)
expected.sum().backward()
expected_grads = get_grads(mod)
actual = compiled_mod(*test_inputs)
self.assertEqual(expected, actual)
serialized = compiled_mod.serialize()
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
actual = compiled_fn(*test_inputs)
actual.sum().backward()
self.assertEqual(get_grads(original_mod), expected_grads)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_aot_compile_with_aoti_torch_compile(self):
with torch.device("cuda"):
def fn(x, y):
return x + y
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch.compile(
fn, fullgraph=True, options={"use_aoti": True}
).aot_compile((make_inputs(), {}))
test_inputs = make_inputs()
expected = fn(*test_inputs)
actual = compiled_fn(*test_inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*test_inputs)
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -4524,6 +4524,17 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
run(torch.rand(2, 10), torch.rand(2, 10))
self.assertEqual(cnt.frame_count, 2)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
def test_unbacked_view_extra(self):
def fn(x):
i0 = x.nonzero().size(0)
y = torch.zeros((i0, 192))
return y.view([12, -1, 192])
res1 = torch.compile(fn, fullgraph=True)(torch.ones((12,)))
res2 = fn(torch.ones((12,)))
self.assertEqual(res1, res2)
instantiate_parametrized_tests(TestUnbacked)

View File

@ -3755,6 +3755,44 @@ as the input tensor excluding its innermost dimension'):
with ctx:
self.assertEqual(torch.mean(t), expected)
def test_scalar_tensor_as_dim_argument(self):
"""Tests that scalar tensors work correctly as dimension arguments.
This tests the fix for the PythonArgParser bug where scalar Tensors
passed to IntList/SymIntList parameters would be incorrectly handled.
"""
x = torch.ones(1, 2, 3, 4, 5)
# Scalar tensors should work correctly (same as passing an int)
result_tensor = x.sum(dim=torch.tensor(3))
result_int = x.sum(dim=3)
self.assertEqual(result_tensor.shape, result_int.shape)
self.assertEqual(result_tensor.shape, torch.Size([1, 2, 3, 5]))
# Test with different integer dtypes
for dtype in [torch.int32, torch.int64, torch.int16, torch.int8]:
dim_tensor = torch.tensor(1, dtype=dtype)
result = x.sum(dim=dim_tensor)
expected = x.sum(dim=1)
self.assertEqual(result.shape, expected.shape)
@skipIfTorchDynamo("Test uses random.randint which creates FakeTensors")
def test_scalar_tensor_dim_compiled_mode(self):
"""Tests that scalar FakeTensors from random.randint work correctly in compiled mode."""
def foo():
x = torch.ones(2, 2, 2)
return x.sum(dim=random.randint(0, 0))
@torch.compile
def foo_compile():
x = torch.ones(2, 2, 2)
return x.sum(dim=random.randint(0, 0))
result_eager = foo()
result_compiled = foo_compile()
self.assertEqual(result_eager.shape, result_compiled.shape)
self.assertEqual(result_eager.shape, torch.Size([2, 2]))
instantiate_device_type_tests(TestReductions, globals())
if __name__ == '__main__':

View File

@ -2439,35 +2439,6 @@ class _TorchCompileInductorWrapper:
reset_cudagraph_trees()
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
compiler_name = "aotinductor"
def __init__(self, mode, options, dynamic):
super().__init__(mode, options, dynamic)
self.apply_options({"cpp_wrapper": True})
self.apply_options({"aot_inductor.package": True})
def __call__(self, model_, inputs_):
from contextlib import nullcontext
from unittest import mock
from torch._guards import detect_fake_mode
from torch._inductor.virtualized import V
fake_mode = detect_fake_mode(inputs_)
ctx = (
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
if fake_mode
else nullcontext()
)
with (
V.set_aot_compilation(True),
ctx,
torch._inductor.config.patch("enable_autograd_for_aot", True),
):
return super().__call__(model_, inputs_)
class _TorchCompileWrapper:
def __init__(self, backend, mode, options, dynamic):
from torch._dynamo.backends.registry import lookup_backend
@ -2701,10 +2672,8 @@ def compile(
backend = bisect_backend
guard_filter_fn = None
use_aoti = False
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
use_aoti = options.pop("use_aoti", False)
if torch.compiler.is_exporting():
warnings.warn(
@ -2731,10 +2700,7 @@ def compile(
return export_wrapped_fn
if backend == "inductor":
if use_aoti:
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileWrapper(backend, mode, options, dynamic)

View File

@ -53,7 +53,6 @@ class CompileArtifacts:
argdefs: Optional[tuple[Any, ...]]
source_info: "SourceInfo"
device_type: str
backend_name: str
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
def check_compatibility(self) -> None:
@ -274,7 +273,6 @@ def aot_compile_fullgraph(
argdefs=fn.__defaults__,
source_info=source_info,
device_type=device_type,
backend_name=getattr(backend, "compiler_name", "unknown"),
)
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)

View File

@ -511,7 +511,6 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
).post_compile(
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
)
compiled_fw_func._boxed_call = True
disable_amp = torch._C._is_any_autocast_enabled()
if needs_autograd:

View File

@ -1640,9 +1640,7 @@ class _InProcessFxCompile(FxCompile):
# pyrefly: ignore [unbound-name]
(str, list, torch.fx.GraphModule),
), type(compiled_fn)
return CompiledAOTI(
filename=compiled_fn, device_type=graph.device_type
)
return CompiledAOTI(compiled_fn)
# TODO: Hoist this above V.aot_compilation
# pyrefly: ignore [unbound-name]
@ -2715,7 +2713,7 @@ def _compile_fx_main(
or torch._guards.TracingContext(fake_mode)
)
if V.aot_compilation and not config.enable_autograd_for_aot:
if V.aot_compilation:
from .utils import is_valid_aoti_model_name
is_valid_aoti_model_name()

View File

@ -1193,8 +1193,6 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
enable_autograd_for_aot: bool = False
def get_worker_log_path() -> Optional[str]:
log_loc = None

View File

@ -773,83 +773,9 @@ class CompiledAOTI(OutputCode):
"""
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
device_type: str
current_callable: Optional[Callable[..., Any]] = None
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
def __post_init__(self):
if not config.aot_inductor.link_libtorch:
return
if (
torch._inductor.cpp_builder._IS_MACOS
or torch._inductor.cpp_builder._IS_WINDOWS
):
return
if config.aot_inductor.cross_target_platform == "windows":
return
if config.aot_inductor.package_cpp_only:
return
if isinstance(self.filename, list):
current_callable = next(
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
)
else:
current_callable = self.filename
if isinstance(current_callable, torch.fx.GraphModule):
self.current_callable = current_callable
return
if self.device_type.startswith("cuda"):
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
current_callable,
1,
self.device_type,
"",
True,
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
elif self.device_type == "cpu":
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
current_callable, 1
).run # type: ignore[attr-defined]
) # type: ignore[attr-defined]
else:
raise RuntimeError(f"unsupported device type {self.device_type}")
self.current_callable = current_callable
self._boxed_call = True
for file in self._cached_files:
if not os.path.exists(file):
with open(file, "wb") as f:
f.write(self._cached_files[file])
def __call__(self, inputs: Sequence[Any]) -> Any:
if self.current_callable is None:
raise RuntimeError("AOTInductor compiled so is not loaded")
return self.current_callable(inputs)
def prepare_for_serialization(self) -> None:
self.current_callable = None
self._cached_files = {}
filenames: list[str] = []
if isinstance(self.filename, list):
filenames = self.filename # type: ignore[assignment]
elif isinstance(self.filename, str):
filenames = [self.filename]
for name in filenames:
with open(name, "rb") as f:
self._cached_files[name] = f.read()
def __getstate__(self):
state = self.__dict__.copy()
state["current_callable"] = None
return state
raise NotImplementedError("NYI")
def post_compile(
self,
@ -857,8 +783,10 @@ class CompiledAOTI(OutputCode):
constants: CompiledFxGraphConstants,
graph_kwargs: _CompileFxKwargs,
) -> None:
if self.current_callable is None:
self.__post_init__()
pass
def prepare_for_serialization(self) -> None:
pass
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass

View File

@ -2918,7 +2918,6 @@ static void pytorch_duplicate_guard() {
abort();
}
initialized = 1;
;
}
struct call_duplicate_guard {

View File

@ -1751,7 +1751,7 @@ static PyObject* THPVariable_dtensor_new(
Tensor tensor = make_tensor_for_subclass_helper(
/*sym_sizes=*/tuple_to_symintlist(sizes.ptr()),
/*sym_strides=*/tuple_to_symintlist(stride.ptr()),
/*sym_storage_offset=*/std::nullopt,
/*sym_storage_offset=*/local_tensor.sym_storage_offset(),
options,
/*storage_size=*/std::nullopt,
extra_dispatch_keys);

View File

@ -66,12 +66,6 @@ void initAOTIRunnerBindings(PyObject* module) {
int,
const std::string&,
const std::string&>())
.def(py::init<
const std::string&,
int,
const std::string&,
const std::string&,
const bool>())
.def(
"run",
&AOTIModelContainerRunnerCuda::run,

View File

@ -565,8 +565,16 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
return std::vector<c10::SymInt>(size1, si);
}
if (size1 > 0 && THPVariable_Check(args[i])) {
return std::vector<c10::SymInt>(
size1, THPVariable_Unpack(args[i]).item().toSymInt());
}
PyObject* arg = args[i];
auto tuple = PyTuple_Check(arg);
if (!tuple) {
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
}
// NOLINTNEXTLINE(bugprone-branch-clone)
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
std::vector<c10::SymInt> res;
@ -645,7 +653,13 @@ inline std::vector<int64_t> PythonArgs::intlistWithDefault(
if (size1 > 0 && torch::is_dynint(py::handle(arg))) {
return std::vector<int64_t>(size1, py::handle(arg).cast<int>());
}
if (size1 > 0 && THPVariable_Check(arg)) {
return std::vector<int64_t>(size1, THPVariable_Unpack(arg).item<int64_t>());
}
auto tuple = PyTuple_Check(arg);
if (!tuple) {
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
}
// NOLINTNEXTLINE(bugprone-branch-clone)
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
std::vector<int64_t> res(size2);
@ -716,6 +730,9 @@ inline c10::OptionalArray<c10::SymInt> PythonArgs::symintlistOptional(int i) {
inline std::vector<double> PythonArgs::getDoublelist(int i) {
PyObject* arg = args[i];
auto tuple = PyTuple_Check(arg);
if (!tuple) {
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
}
// NOLINTNEXTLINE(bugprone-branch-clone)
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
std::vector<double> res(size);
@ -889,6 +906,9 @@ inline at::Dimname PythonArgs::dimname(int i) {
inline std::vector<at::Dimname> parseDimnameList(PyObject* arg) {
auto tuple = PyTuple_Check(arg);
if (!tuple) {
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
}
// NOLINTNEXTLINE(bugprone-branch-clone)
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
std::vector<at::Dimname> res;

View File

@ -7037,52 +7037,16 @@ class ShapeEnv:
ok = len(free_unbacked_symbols(new_var)) == 0
if ok:
self._set_replacement(free[0], new_var, "solve")
except NotImplementedError:
pass
if expr.has(Mod):
else:
# expression has mod.
mod_expr = next(iter(expr.atoms(Mod)))
try:
r = try_solve(expr, mod_expr, floordiv_inequality=False)
if r is not None and r[1] == 0:
self._add_divisible(mod_expr)
# This is a little bit of extra logic to make things like
# torch.empty(i0, q).view(c, -1, q) work out
p, q = mod_expr.args
if (
isinstance(q, sympy.Number)
and isinstance(p, sympy.Mul)
and len(p.args) == 2
):
c, i0 = p.args
# Given Mod(c * i0, q) == 0
if (
isinstance(c, sympy.Number)
and isinstance(i0, sympy.Symbol)
and self.is_unbacked_symint(i0)
):
# We have Mod(i0, q / c) == 0, which means we can
# rewrite i0 as (q / gcd(q, c)) * i1
d = q / sympy.gcd(q, c) # TODO: CleanDiv?
i1 = self.create_unbacked_symint().node.expr
# Propagate the value ranges. It doesn't really
# matter if we use truediv or floordiv, because we
# have established divisibility.
self._update_var_to_range(
i1,
SymPyValueRangeAnalysis.floordiv(
self.var_to_range[i0], ValueRanges.wrap(d)
),
)
# Propagate hints (real tensor tracing)
if i0 in self.unbacked_var_to_val:
self.set_unbacked_var_to_val(
i1, self.unbacked_var_to_val[i0] // d
)
# Propagate size-like-ness
if i0 in self.size_like:
self.size_like.add(i1)
self._set_replacement(i0, d * i1, "divisibility")
except NotImplementedError:
pass
return

View File

@ -273,9 +273,8 @@ class _KinetoProfile:
if path.endswith(".gz"):
with tempfile.NamedTemporaryFile("w+b", suffix=".json") as fp:
retvalue = self.profiler.export_chrome_trace(fp.name)
fp.seek(0)
with gzip.open(path, "wb") as fout:
fout.writelines(fp)
with open(fp.name, "rb") as fin, gzip.open(path, "wb") as fout:
fout.writelines(fin)
return retvalue
else:
return self.profiler.export_chrome_trace(path)
@ -447,7 +446,6 @@ class _KinetoProfile:
self.mem_tl.export_memory_timeline_html(path, device)
elif path.endswith(".gz"):
with tempfile.NamedTemporaryFile("w+t", suffix=".json") as fp:
fp.close()
if path.endswith("raw.json.gz"):
self.mem_tl.export_memory_timeline_raw(fp.name, device)
else:

View File

@ -39,7 +39,7 @@ import os
import traceback
import weakref
from collections.abc import Callable
from typing import Any, Optional, TYPE_CHECKING # noqa: F401
from typing import Any, Optional, TYPE_CHECKING, Union # noqa: F401
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
@ -157,21 +157,25 @@ def _arg_to_str(arg, attributes, tensor_memo=None) -> str:
return str(arg)
def default_hash_fn(t: torch.Tensor, use_scalar: bool = False) -> torch.Tensor:
def norm_hash_fn(
t: torch.Tensor, use_scalar: bool = False
) -> Union[torch.Tensor, float]:
"""
from Observer. Computes a hash for a tensor by converting it to float (if needed), making it contiguous,
replacing NaN/inf values with fixed numbers, and then computing the L1 norm in float64 or complex128.
This is used to generate a deterministic summary value for tensor comparison.
"""
with torch._C._DisablePythonDispatcher(), torch._C._DisableTorchDispatch():
with torch._C._DisablePythonDispatcher():
if not (t.is_floating_point() or t.is_complex()):
t = t.float()
t = t.contiguous()
# Clean the tensor to handle NaN/inf values, then compute norm
t_clean = torch.nan_to_num(t, nan=0.0, posinf=1.0, neginf=-1.0)
dtype = torch.complex128 if t.is_complex() else torch.float64
out = t_clean.norm(p=1, dtype=dtype)
if t.is_complex():
t_float = t.to(dtype=torch.complex128)
else:
t_float = t.to(dtype=torch.float64)
out = t_float.norm(p=1)
if use_scalar:
return out.item()
return out
@ -184,6 +188,28 @@ def _compute_rel_diff(hash1, hash2):
return numerator / denominator
def hash_tensor_fn(
t: torch.Tensor, use_scalar: bool = False
) -> Union[torch.Tensor, int]:
"""
wrapper over torch.hash_tensor
"""
if isinstance(t, torch.distributed.tensor.DTensor):
t = t.to_local()
if t.is_floating_point():
t_clean = t.to(dtype=torch.float64)
elif t.is_complex():
t_clean = t.to(dtype=torch.complex128).view(torch.float64)
else:
t_clean = t.to(dtype=torch.int64)
out = torch.hash_tensor(t_clean)
if use_scalar:
return out.item() # type: ignore[attribute]
return out
def _get_stack_trace() -> str:
from torch.fx.experimental.symbolic_shapes import uninteresting_files
@ -897,20 +923,43 @@ class DebugMode(TorchDispatchMode):
@staticmethod
@contextlib.contextmanager
def log_tensor_hashes(hash_fn: Callable | None = None, hash_inputs: bool = False):
def log_tensor_hashes(
hash_fn: Union[Callable, str, list[str]] = "norm", hash_inputs: bool = False
):
"""
Installs hook for tensor hash logging.
hash_fn: optional function for custom hashing
hash_fn: One of:
- Custom-defined hash function
- String: one of ("norm", "hash_tensor")
- "norm": uses norm_hash_fn; basically tensor's L1 norm
- "hash_tensor": uses torch.hash_tensor (XOR sum reduction)
- List of strings: returns tuple of hashes from above options
hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash".
NOTE: this is currently a post-hook, so e.g. inplace ops will log the "output" hashes.
"""
if hash_fn is None:
hash_fn = functools.partial(default_hash_fn, use_scalar=True)
def hash_fn_option(hash_type):
assert isinstance(hash_type, str) and hash_type in ["norm", "hash_tensor"]
return functools.partial(
norm_hash_fn if hash_type == "norm" else hash_tensor_fn, use_scalar=True
)
if callable(hash_fn):
fn = hash_fn
elif isinstance(hash_fn, str):
fn = hash_fn_option(hash_fn)
elif isinstance(hash_fn, list):
fns = [hash_fn_option(fn) for fn in hash_fn]
fn = lambda x: tuple(fn(x) for fn in fns) # noqa: E731
else:
raise NotImplementedError(
f"log_tensor_hashes() expected hash_fn to be callable, str, or list[str], but found {type(hash_fn)}"
)
def _tree_hash(obj):
return tree_map(
lambda x: hash_fn(x) if isinstance(x, torch.Tensor) else None, obj
lambda x: fn(x) if isinstance(x, torch.Tensor) else None, obj
)
def _dispatch_hash_hook(func, types, args, kwargs, result):
@ -930,9 +979,9 @@ class DebugMode(TorchDispatchMode):
try:
if hash_inputs:
_old_input_hfn = _TRITON_INPUT_HASH_FN
_TRITON_INPUT_HASH_FN = hash_fn
_TRITON_INPUT_HASH_FN = fn
_old_output_hfn = _TRITON_OUTPUT_HASH_FN
_TRITON_OUTPUT_HASH_FN = hash_fn
_TRITON_OUTPUT_HASH_FN = fn
with DebugMode.dispatch_hooks(log_hook=_dispatch_hash_hook):
yield
finally: