mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-17 16:46:31 +08:00
Compare commits
7 Commits
ciflow/tru
...
malfet-pat
| Author | SHA1 | Date | |
|---|---|---|---|
| 9ebed6d17c | |||
| fadb62f592 | |||
| e5eb89e111 | |||
| b5e0e6932a | |||
| 6ea779188c | |||
| 460c7e196c | |||
| 7aac506cdc |
@ -31,6 +31,8 @@ from torch.utils._debug_mode import (
|
|||||||
_RedistributeCall,
|
_RedistributeCall,
|
||||||
_TritonKernelCall,
|
_TritonKernelCall,
|
||||||
DebugMode,
|
DebugMode,
|
||||||
|
hash_tensor_fn,
|
||||||
|
norm_hash_fn,
|
||||||
)
|
)
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
from torch.utils._triton import has_triton_package
|
from torch.utils._triton import has_triton_package
|
||||||
@ -115,6 +117,28 @@ class TestDTensorDebugMode(TestCase):
|
|||||||
"aten::sum(t: f32[1, 32]) # {'hash': " in debug_mode.debug_string()
|
"aten::sum(t: f32[1, 32]) # {'hash': " in debug_mode.debug_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# check tuple hash functions
|
||||||
|
with (
|
||||||
|
DebugMode() as debug_mode,
|
||||||
|
DebugMode.log_tensor_hashes(hash_fn=["norm", "hash_tensor"]),
|
||||||
|
):
|
||||||
|
mm(x_dtensor, y_dtensor)
|
||||||
|
|
||||||
|
output_hash = debug_mode.operators[-1].log["hash"]
|
||||||
|
norm_ = lambda x: norm_hash_fn(x, use_scalar=True) # noqa: E731
|
||||||
|
hash_ = lambda x: hash_tensor_fn(x, use_scalar=True) # noqa: E731
|
||||||
|
|
||||||
|
self.assertEqual(output_hash[0], norm_(eager_out))
|
||||||
|
self.assertEqual(output_hash[1], hash_(eager_out))
|
||||||
|
|
||||||
|
# some edge cases
|
||||||
|
self.assertEqual(norm_(torch.tensor(torch.nan)), torch.nan)
|
||||||
|
self.assertEqual(norm_(torch.tensor(torch.inf)), torch.inf)
|
||||||
|
self.assertEqual(norm_(torch.complex(torch.ones(4), torch.zeros(4))), 4)
|
||||||
|
self.assertEqual(hash_(torch.ones(4, dtype=torch.float8_e5m2)), 0)
|
||||||
|
self.assertEqual(hash_(torch.ones(4, dtype=torch.int8)), 0)
|
||||||
|
self.assertEqual(hash_(torch.ones(5, dtype=torch.int8)), 1)
|
||||||
|
|
||||||
def test_debug_string_inside_context(self):
|
def test_debug_string_inside_context(self):
|
||||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||||
|
|
||||||
|
|||||||
@ -664,6 +664,101 @@ class TestViewOps(DTensorTestBase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
|
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
|
||||||
|
|
||||||
|
@with_comms
|
||||||
|
def test_storage_offset_slice(self):
|
||||||
|
"""
|
||||||
|
Test that storage_offset is properly tracked on DTensor when slicing
|
||||||
|
a replicated tensor.
|
||||||
|
"""
|
||||||
|
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||||
|
|
||||||
|
# Create a replicated DTensor
|
||||||
|
tensor = torch.randn(10, device=self.device_type)
|
||||||
|
dtensor = distribute_tensor(tensor, mesh, [Replicate()])
|
||||||
|
|
||||||
|
# Perform a slice operation [1:]
|
||||||
|
with CommDebugMode() as comm_mode:
|
||||||
|
sliced_dtensor = dtensor[1:]
|
||||||
|
# Slicing should not trigger any communication
|
||||||
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||||
|
|
||||||
|
# Verify that the DTensor's storage_offset matches the expected value
|
||||||
|
self.assertEqual(sliced_dtensor.storage_offset(), 1)
|
||||||
|
|
||||||
|
# Verify that the local tensor also has the correct storage_offset
|
||||||
|
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 1)
|
||||||
|
|
||||||
|
# Verify the shape is correct
|
||||||
|
self.assertEqual(sliced_dtensor.shape, torch.Size([9]))
|
||||||
|
|
||||||
|
# Verify the values are correct
|
||||||
|
expected = tensor[1:]
|
||||||
|
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||||
|
|
||||||
|
@with_comms
|
||||||
|
def test_storage_offset_shard_dim0_slice_dim1(self):
|
||||||
|
"""
|
||||||
|
Test that storage_offset is properly tracked when tensor is sharded on dim 0
|
||||||
|
and sliced on dim 1.
|
||||||
|
"""
|
||||||
|
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||||
|
|
||||||
|
# Create a 2D tensor and shard on dim 0
|
||||||
|
tensor = torch.randn(12, 8, device=self.device_type)
|
||||||
|
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
|
||||||
|
|
||||||
|
# Perform a slice operation [:, 2:]
|
||||||
|
with CommDebugMode() as comm_mode:
|
||||||
|
sliced_dtensor = dtensor[:, 2:]
|
||||||
|
# Slicing should not trigger any communication
|
||||||
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||||
|
|
||||||
|
# The storage_offset should be 2 (skipping 2 elements in each row)
|
||||||
|
self.assertEqual(sliced_dtensor.storage_offset(), 2)
|
||||||
|
|
||||||
|
# Verify that the local tensor also has the correct storage_offset
|
||||||
|
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 2)
|
||||||
|
|
||||||
|
# Verify the shape is correct
|
||||||
|
expected_shape = torch.Size([12, 6])
|
||||||
|
self.assertEqual(sliced_dtensor.shape, expected_shape)
|
||||||
|
|
||||||
|
# Verify the values are correct
|
||||||
|
expected = tensor[:, 2:]
|
||||||
|
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||||
|
|
||||||
|
@with_comms
|
||||||
|
def test_storage_offset_shard_dim1_slice_dim0(self):
|
||||||
|
"""
|
||||||
|
Test that storage_offset is properly tracked when tensor is sharded on dim 1
|
||||||
|
and sliced on dim 0.
|
||||||
|
"""
|
||||||
|
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||||
|
|
||||||
|
# Create a 2D tensor and shard on dim 1
|
||||||
|
tensor = torch.randn(10, 12, device=self.device_type)
|
||||||
|
dtensor = distribute_tensor(tensor, mesh, [Shard(1)])
|
||||||
|
|
||||||
|
# Perform a slice operation [2:, :]
|
||||||
|
with CommDebugMode() as comm_mode:
|
||||||
|
sliced_dtensor = dtensor[2:, :]
|
||||||
|
# Slicing should not trigger any communication
|
||||||
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||||
|
|
||||||
|
local_dim1_size = 12 // self.world_size
|
||||||
|
expected_offset = 2 * local_dim1_size
|
||||||
|
self.assertEqual(sliced_dtensor.storage_offset(), expected_offset)
|
||||||
|
|
||||||
|
self.assertEqual(sliced_dtensor.to_local().storage_offset(), expected_offset)
|
||||||
|
|
||||||
|
# Verify the shape is correct
|
||||||
|
expected_shape = torch.Size([8, 12])
|
||||||
|
self.assertEqual(sliced_dtensor.shape, expected_shape)
|
||||||
|
|
||||||
|
# Verify the values are correct
|
||||||
|
expected = tensor[2:, :]
|
||||||
|
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||||
|
|
||||||
|
|
||||||
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
|
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
|
||||||
TestViewOps,
|
TestViewOps,
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
import copy
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import unittest
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -15,16 +13,13 @@ import torch._inductor.config
|
|||||||
import torch._inductor.test_case
|
import torch._inductor.test_case
|
||||||
import torch.onnx.operators
|
import torch.onnx.operators
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
|
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
|
||||||
from torch._dynamo.exc import PackageError, Unsupported
|
from torch._dynamo.exc import PackageError, Unsupported
|
||||||
from torch._dynamo.package import DynamoCache
|
from torch._dynamo.package import DynamoCache
|
||||||
from torch._dynamo.precompile_context import PrecompileContext
|
from torch._dynamo.precompile_context import PrecompileContext
|
||||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||||
from torch.fx._graph_pickler import GraphPickler
|
from torch.fx._graph_pickler import GraphPickler
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import instantiate_parametrized_tests
|
||||||
instantiate_parametrized_tests,
|
|
||||||
TEST_CUDA,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
MY_LAMBDA = lambda x: x + 1 # noqa: E731
|
MY_LAMBDA = lambda x: x + 1 # noqa: E731
|
||||||
@ -604,92 +599,6 @@ from user code:
|
|||||||
actual = compiled_fn(*inputs)
|
actual = compiled_fn(*inputs)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
|
||||||
def test_aot_compile_with_aoti(self):
|
|
||||||
with torch.device("cuda"):
|
|
||||||
from torch._dynamo.hooks import Hooks
|
|
||||||
|
|
||||||
def fn(x, y):
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
def make_inputs():
|
|
||||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
|
||||||
|
|
||||||
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
|
|
||||||
fn,
|
|
||||||
(make_inputs(), {}),
|
|
||||||
Hooks(),
|
|
||||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
|
||||||
)
|
|
||||||
|
|
||||||
test_inputs = make_inputs()
|
|
||||||
expected = fn(*test_inputs)
|
|
||||||
actual = compiled_fn(*test_inputs)
|
|
||||||
self.assertEqual(expected, actual)
|
|
||||||
compiled_fn.save_compiled_function(self.path())
|
|
||||||
with open(self.path(), "rb") as f:
|
|
||||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
|
||||||
actual = compiled_fn(*test_inputs)
|
|
||||||
self.assertEqual(expected, actual)
|
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
|
||||||
def test_aot_compile_with_aoti_module(self):
|
|
||||||
with torch.device("cuda"):
|
|
||||||
from torch._dynamo.hooks import Hooks
|
|
||||||
|
|
||||||
mod = SimpleLinearModule()
|
|
||||||
|
|
||||||
def make_inputs():
|
|
||||||
return (torch.randn(4, 3),)
|
|
||||||
|
|
||||||
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
|
|
||||||
mod,
|
|
||||||
[ModelInput(make_inputs(), {}, [])],
|
|
||||||
Hooks(),
|
|
||||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_grads(m: torch.nn.Module):
|
|
||||||
return {name: p.grad for name, p in m.named_parameters()}
|
|
||||||
|
|
||||||
original_mod = copy.deepcopy(mod)
|
|
||||||
test_inputs = make_inputs()
|
|
||||||
expected = mod(*test_inputs)
|
|
||||||
expected.sum().backward()
|
|
||||||
expected_grads = get_grads(mod)
|
|
||||||
|
|
||||||
actual = compiled_mod(*test_inputs)
|
|
||||||
self.assertEqual(expected, actual)
|
|
||||||
serialized = compiled_mod.serialize()
|
|
||||||
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
|
|
||||||
actual = compiled_fn(*test_inputs)
|
|
||||||
actual.sum().backward()
|
|
||||||
self.assertEqual(get_grads(original_mod), expected_grads)
|
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
|
||||||
def test_aot_compile_with_aoti_torch_compile(self):
|
|
||||||
with torch.device("cuda"):
|
|
||||||
|
|
||||||
def fn(x, y):
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
def make_inputs():
|
|
||||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
|
||||||
|
|
||||||
compiled_fn = torch.compile(
|
|
||||||
fn, fullgraph=True, options={"use_aoti": True}
|
|
||||||
).aot_compile((make_inputs(), {}))
|
|
||||||
test_inputs = make_inputs()
|
|
||||||
expected = fn(*test_inputs)
|
|
||||||
actual = compiled_fn(*test_inputs)
|
|
||||||
self.assertEqual(expected, actual)
|
|
||||||
compiled_fn.save_compiled_function(self.path())
|
|
||||||
with open(self.path(), "rb") as f:
|
|
||||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
|
||||||
actual = compiled_fn(*test_inputs)
|
|
||||||
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
|
|
||||||
self.assertEqual(expected, actual)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|||||||
@ -4524,6 +4524,17 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
|||||||
run(torch.rand(2, 10), torch.rand(2, 10))
|
run(torch.rand(2, 10), torch.rand(2, 10))
|
||||||
self.assertEqual(cnt.frame_count, 2)
|
self.assertEqual(cnt.frame_count, 2)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
|
||||||
|
def test_unbacked_view_extra(self):
|
||||||
|
def fn(x):
|
||||||
|
i0 = x.nonzero().size(0)
|
||||||
|
y = torch.zeros((i0, 192))
|
||||||
|
return y.view([12, -1, 192])
|
||||||
|
|
||||||
|
res1 = torch.compile(fn, fullgraph=True)(torch.ones((12,)))
|
||||||
|
res2 = fn(torch.ones((12,)))
|
||||||
|
self.assertEqual(res1, res2)
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestUnbacked)
|
instantiate_parametrized_tests(TestUnbacked)
|
||||||
|
|
||||||
|
|||||||
@ -3755,6 +3755,44 @@ as the input tensor excluding its innermost dimension'):
|
|||||||
with ctx:
|
with ctx:
|
||||||
self.assertEqual(torch.mean(t), expected)
|
self.assertEqual(torch.mean(t), expected)
|
||||||
|
|
||||||
|
def test_scalar_tensor_as_dim_argument(self):
|
||||||
|
"""Tests that scalar tensors work correctly as dimension arguments.
|
||||||
|
|
||||||
|
This tests the fix for the PythonArgParser bug where scalar Tensors
|
||||||
|
passed to IntList/SymIntList parameters would be incorrectly handled.
|
||||||
|
"""
|
||||||
|
x = torch.ones(1, 2, 3, 4, 5)
|
||||||
|
|
||||||
|
# Scalar tensors should work correctly (same as passing an int)
|
||||||
|
result_tensor = x.sum(dim=torch.tensor(3))
|
||||||
|
result_int = x.sum(dim=3)
|
||||||
|
self.assertEqual(result_tensor.shape, result_int.shape)
|
||||||
|
self.assertEqual(result_tensor.shape, torch.Size([1, 2, 3, 5]))
|
||||||
|
|
||||||
|
# Test with different integer dtypes
|
||||||
|
for dtype in [torch.int32, torch.int64, torch.int16, torch.int8]:
|
||||||
|
dim_tensor = torch.tensor(1, dtype=dtype)
|
||||||
|
result = x.sum(dim=dim_tensor)
|
||||||
|
expected = x.sum(dim=1)
|
||||||
|
self.assertEqual(result.shape, expected.shape)
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Test uses random.randint which creates FakeTensors")
|
||||||
|
def test_scalar_tensor_dim_compiled_mode(self):
|
||||||
|
"""Tests that scalar FakeTensors from random.randint work correctly in compiled mode."""
|
||||||
|
def foo():
|
||||||
|
x = torch.ones(2, 2, 2)
|
||||||
|
return x.sum(dim=random.randint(0, 0))
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def foo_compile():
|
||||||
|
x = torch.ones(2, 2, 2)
|
||||||
|
return x.sum(dim=random.randint(0, 0))
|
||||||
|
|
||||||
|
result_eager = foo()
|
||||||
|
result_compiled = foo_compile()
|
||||||
|
self.assertEqual(result_eager.shape, result_compiled.shape)
|
||||||
|
self.assertEqual(result_eager.shape, torch.Size([2, 2]))
|
||||||
|
|
||||||
instantiate_device_type_tests(TestReductions, globals())
|
instantiate_device_type_tests(TestReductions, globals())
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -2439,35 +2439,6 @@ class _TorchCompileInductorWrapper:
|
|||||||
reset_cudagraph_trees()
|
reset_cudagraph_trees()
|
||||||
|
|
||||||
|
|
||||||
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
|
|
||||||
compiler_name = "aotinductor"
|
|
||||||
|
|
||||||
def __init__(self, mode, options, dynamic):
|
|
||||||
super().__init__(mode, options, dynamic)
|
|
||||||
self.apply_options({"cpp_wrapper": True})
|
|
||||||
self.apply_options({"aot_inductor.package": True})
|
|
||||||
|
|
||||||
def __call__(self, model_, inputs_):
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
from torch._guards import detect_fake_mode
|
|
||||||
from torch._inductor.virtualized import V
|
|
||||||
|
|
||||||
fake_mode = detect_fake_mode(inputs_)
|
|
||||||
ctx = (
|
|
||||||
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
|
|
||||||
if fake_mode
|
|
||||||
else nullcontext()
|
|
||||||
)
|
|
||||||
with (
|
|
||||||
V.set_aot_compilation(True),
|
|
||||||
ctx,
|
|
||||||
torch._inductor.config.patch("enable_autograd_for_aot", True),
|
|
||||||
):
|
|
||||||
return super().__call__(model_, inputs_)
|
|
||||||
|
|
||||||
|
|
||||||
class _TorchCompileWrapper:
|
class _TorchCompileWrapper:
|
||||||
def __init__(self, backend, mode, options, dynamic):
|
def __init__(self, backend, mode, options, dynamic):
|
||||||
from torch._dynamo.backends.registry import lookup_backend
|
from torch._dynamo.backends.registry import lookup_backend
|
||||||
@ -2701,10 +2672,8 @@ def compile(
|
|||||||
backend = bisect_backend
|
backend = bisect_backend
|
||||||
|
|
||||||
guard_filter_fn = None
|
guard_filter_fn = None
|
||||||
use_aoti = False
|
|
||||||
if options and isinstance(options, dict):
|
if options and isinstance(options, dict):
|
||||||
guard_filter_fn = options.pop("guard_filter_fn", None)
|
guard_filter_fn = options.pop("guard_filter_fn", None)
|
||||||
use_aoti = options.pop("use_aoti", False)
|
|
||||||
|
|
||||||
if torch.compiler.is_exporting():
|
if torch.compiler.is_exporting():
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -2731,10 +2700,7 @@ def compile(
|
|||||||
return export_wrapped_fn
|
return export_wrapped_fn
|
||||||
|
|
||||||
if backend == "inductor":
|
if backend == "inductor":
|
||||||
if use_aoti:
|
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||||
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
|
|
||||||
else:
|
|
||||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
|
||||||
else:
|
else:
|
||||||
backend = _TorchCompileWrapper(backend, mode, options, dynamic)
|
backend = _TorchCompileWrapper(backend, mode, options, dynamic)
|
||||||
|
|
||||||
|
|||||||
@ -53,7 +53,6 @@ class CompileArtifacts:
|
|||||||
argdefs: Optional[tuple[Any, ...]]
|
argdefs: Optional[tuple[Any, ...]]
|
||||||
source_info: "SourceInfo"
|
source_info: "SourceInfo"
|
||||||
device_type: str
|
device_type: str
|
||||||
backend_name: str
|
|
||||||
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
||||||
|
|
||||||
def check_compatibility(self) -> None:
|
def check_compatibility(self) -> None:
|
||||||
@ -274,7 +273,6 @@ def aot_compile_fullgraph(
|
|||||||
argdefs=fn.__defaults__,
|
argdefs=fn.__defaults__,
|
||||||
source_info=source_info,
|
source_info=source_info,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
backend_name=getattr(backend, "compiler_name", "unknown"),
|
|
||||||
)
|
)
|
||||||
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
|
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
|
||||||
|
|
||||||
|
|||||||
@ -511,7 +511,6 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
|
|||||||
).post_compile(
|
).post_compile(
|
||||||
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
|
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
|
||||||
)
|
)
|
||||||
compiled_fw_func._boxed_call = True
|
|
||||||
disable_amp = torch._C._is_any_autocast_enabled()
|
disable_amp = torch._C._is_any_autocast_enabled()
|
||||||
|
|
||||||
if needs_autograd:
|
if needs_autograd:
|
||||||
|
|||||||
@ -1640,9 +1640,7 @@ class _InProcessFxCompile(FxCompile):
|
|||||||
# pyrefly: ignore [unbound-name]
|
# pyrefly: ignore [unbound-name]
|
||||||
(str, list, torch.fx.GraphModule),
|
(str, list, torch.fx.GraphModule),
|
||||||
), type(compiled_fn)
|
), type(compiled_fn)
|
||||||
return CompiledAOTI(
|
return CompiledAOTI(compiled_fn)
|
||||||
filename=compiled_fn, device_type=graph.device_type
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Hoist this above V.aot_compilation
|
# TODO: Hoist this above V.aot_compilation
|
||||||
# pyrefly: ignore [unbound-name]
|
# pyrefly: ignore [unbound-name]
|
||||||
@ -2715,7 +2713,7 @@ def _compile_fx_main(
|
|||||||
or torch._guards.TracingContext(fake_mode)
|
or torch._guards.TracingContext(fake_mode)
|
||||||
)
|
)
|
||||||
|
|
||||||
if V.aot_compilation and not config.enable_autograd_for_aot:
|
if V.aot_compilation:
|
||||||
from .utils import is_valid_aoti_model_name
|
from .utils import is_valid_aoti_model_name
|
||||||
|
|
||||||
is_valid_aoti_model_name()
|
is_valid_aoti_model_name()
|
||||||
|
|||||||
@ -1193,8 +1193,6 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
|
|||||||
|
|
||||||
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
|
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
|
||||||
|
|
||||||
enable_autograd_for_aot: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_worker_log_path() -> Optional[str]:
|
def get_worker_log_path() -> Optional[str]:
|
||||||
log_loc = None
|
log_loc = None
|
||||||
|
|||||||
@ -773,83 +773,9 @@ class CompiledAOTI(OutputCode):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
|
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
|
||||||
device_type: str
|
|
||||||
current_callable: Optional[Callable[..., Any]] = None
|
|
||||||
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if not config.aot_inductor.link_libtorch:
|
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
|
||||||
torch._inductor.cpp_builder._IS_MACOS
|
|
||||||
or torch._inductor.cpp_builder._IS_WINDOWS
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
if config.aot_inductor.cross_target_platform == "windows":
|
|
||||||
return
|
|
||||||
|
|
||||||
if config.aot_inductor.package_cpp_only:
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(self.filename, list):
|
|
||||||
current_callable = next(
|
|
||||||
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
current_callable = self.filename
|
|
||||||
|
|
||||||
if isinstance(current_callable, torch.fx.GraphModule):
|
|
||||||
self.current_callable = current_callable
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.device_type.startswith("cuda"):
|
|
||||||
current_callable = (
|
|
||||||
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
|
|
||||||
current_callable,
|
|
||||||
1,
|
|
||||||
self.device_type,
|
|
||||||
"",
|
|
||||||
True,
|
|
||||||
).run # type: ignore[attr-defined]
|
|
||||||
) # type: ignore[attr-defined]
|
|
||||||
elif self.device_type == "cpu":
|
|
||||||
current_callable = (
|
|
||||||
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
|
|
||||||
current_callable, 1
|
|
||||||
).run # type: ignore[attr-defined]
|
|
||||||
) # type: ignore[attr-defined]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"unsupported device type {self.device_type}")
|
|
||||||
self.current_callable = current_callable
|
|
||||||
self._boxed_call = True
|
|
||||||
for file in self._cached_files:
|
|
||||||
if not os.path.exists(file):
|
|
||||||
with open(file, "wb") as f:
|
|
||||||
f.write(self._cached_files[file])
|
|
||||||
|
|
||||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||||
if self.current_callable is None:
|
raise NotImplementedError("NYI")
|
||||||
raise RuntimeError("AOTInductor compiled so is not loaded")
|
|
||||||
return self.current_callable(inputs)
|
|
||||||
|
|
||||||
def prepare_for_serialization(self) -> None:
|
|
||||||
self.current_callable = None
|
|
||||||
self._cached_files = {}
|
|
||||||
filenames: list[str] = []
|
|
||||||
if isinstance(self.filename, list):
|
|
||||||
filenames = self.filename # type: ignore[assignment]
|
|
||||||
elif isinstance(self.filename, str):
|
|
||||||
filenames = [self.filename]
|
|
||||||
for name in filenames:
|
|
||||||
with open(name, "rb") as f:
|
|
||||||
self._cached_files[name] = f.read()
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
state = self.__dict__.copy()
|
|
||||||
state["current_callable"] = None
|
|
||||||
return state
|
|
||||||
|
|
||||||
def post_compile(
|
def post_compile(
|
||||||
self,
|
self,
|
||||||
@ -857,8 +783,10 @@ class CompiledAOTI(OutputCode):
|
|||||||
constants: CompiledFxGraphConstants,
|
constants: CompiledFxGraphConstants,
|
||||||
graph_kwargs: _CompileFxKwargs,
|
graph_kwargs: _CompileFxKwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.current_callable is None:
|
pass
|
||||||
self.__post_init__()
|
|
||||||
|
def prepare_for_serialization(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -2918,7 +2918,6 @@ static void pytorch_duplicate_guard() {
|
|||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
initialized = 1;
|
initialized = 1;
|
||||||
;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct call_duplicate_guard {
|
struct call_duplicate_guard {
|
||||||
|
|||||||
@ -1751,7 +1751,7 @@ static PyObject* THPVariable_dtensor_new(
|
|||||||
Tensor tensor = make_tensor_for_subclass_helper(
|
Tensor tensor = make_tensor_for_subclass_helper(
|
||||||
/*sym_sizes=*/tuple_to_symintlist(sizes.ptr()),
|
/*sym_sizes=*/tuple_to_symintlist(sizes.ptr()),
|
||||||
/*sym_strides=*/tuple_to_symintlist(stride.ptr()),
|
/*sym_strides=*/tuple_to_symintlist(stride.ptr()),
|
||||||
/*sym_storage_offset=*/std::nullopt,
|
/*sym_storage_offset=*/local_tensor.sym_storage_offset(),
|
||||||
options,
|
options,
|
||||||
/*storage_size=*/std::nullopt,
|
/*storage_size=*/std::nullopt,
|
||||||
extra_dispatch_keys);
|
extra_dispatch_keys);
|
||||||
|
|||||||
@ -66,12 +66,6 @@ void initAOTIRunnerBindings(PyObject* module) {
|
|||||||
int,
|
int,
|
||||||
const std::string&,
|
const std::string&,
|
||||||
const std::string&>())
|
const std::string&>())
|
||||||
.def(py::init<
|
|
||||||
const std::string&,
|
|
||||||
int,
|
|
||||||
const std::string&,
|
|
||||||
const std::string&,
|
|
||||||
const bool>())
|
|
||||||
.def(
|
.def(
|
||||||
"run",
|
"run",
|
||||||
&AOTIModelContainerRunnerCuda::run,
|
&AOTIModelContainerRunnerCuda::run,
|
||||||
|
|||||||
@ -565,8 +565,16 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
|
|||||||
return std::vector<c10::SymInt>(size1, si);
|
return std::vector<c10::SymInt>(size1, si);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (size1 > 0 && THPVariable_Check(args[i])) {
|
||||||
|
return std::vector<c10::SymInt>(
|
||||||
|
size1, THPVariable_Unpack(args[i]).item().toSymInt());
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* arg = args[i];
|
PyObject* arg = args[i];
|
||||||
auto tuple = PyTuple_Check(arg);
|
auto tuple = PyTuple_Check(arg);
|
||||||
|
if (!tuple) {
|
||||||
|
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
|
||||||
|
}
|
||||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||||
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
||||||
std::vector<c10::SymInt> res;
|
std::vector<c10::SymInt> res;
|
||||||
@ -645,7 +653,13 @@ inline std::vector<int64_t> PythonArgs::intlistWithDefault(
|
|||||||
if (size1 > 0 && torch::is_dynint(py::handle(arg))) {
|
if (size1 > 0 && torch::is_dynint(py::handle(arg))) {
|
||||||
return std::vector<int64_t>(size1, py::handle(arg).cast<int>());
|
return std::vector<int64_t>(size1, py::handle(arg).cast<int>());
|
||||||
}
|
}
|
||||||
|
if (size1 > 0 && THPVariable_Check(arg)) {
|
||||||
|
return std::vector<int64_t>(size1, THPVariable_Unpack(arg).item<int64_t>());
|
||||||
|
}
|
||||||
auto tuple = PyTuple_Check(arg);
|
auto tuple = PyTuple_Check(arg);
|
||||||
|
if (!tuple) {
|
||||||
|
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
|
||||||
|
}
|
||||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||||
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
||||||
std::vector<int64_t> res(size2);
|
std::vector<int64_t> res(size2);
|
||||||
@ -716,6 +730,9 @@ inline c10::OptionalArray<c10::SymInt> PythonArgs::symintlistOptional(int i) {
|
|||||||
inline std::vector<double> PythonArgs::getDoublelist(int i) {
|
inline std::vector<double> PythonArgs::getDoublelist(int i) {
|
||||||
PyObject* arg = args[i];
|
PyObject* arg = args[i];
|
||||||
auto tuple = PyTuple_Check(arg);
|
auto tuple = PyTuple_Check(arg);
|
||||||
|
if (!tuple) {
|
||||||
|
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
|
||||||
|
}
|
||||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||||
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
||||||
std::vector<double> res(size);
|
std::vector<double> res(size);
|
||||||
@ -889,6 +906,9 @@ inline at::Dimname PythonArgs::dimname(int i) {
|
|||||||
|
|
||||||
inline std::vector<at::Dimname> parseDimnameList(PyObject* arg) {
|
inline std::vector<at::Dimname> parseDimnameList(PyObject* arg) {
|
||||||
auto tuple = PyTuple_Check(arg);
|
auto tuple = PyTuple_Check(arg);
|
||||||
|
if (!tuple) {
|
||||||
|
TORCH_INTERNAL_ASSERT(PyList_Check(arg), "expected tuple or list");
|
||||||
|
}
|
||||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||||
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
||||||
std::vector<at::Dimname> res;
|
std::vector<at::Dimname> res;
|
||||||
|
|||||||
@ -7037,52 +7037,16 @@ class ShapeEnv:
|
|||||||
ok = len(free_unbacked_symbols(new_var)) == 0
|
ok = len(free_unbacked_symbols(new_var)) == 0
|
||||||
if ok:
|
if ok:
|
||||||
self._set_replacement(free[0], new_var, "solve")
|
self._set_replacement(free[0], new_var, "solve")
|
||||||
|
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
pass
|
pass
|
||||||
if expr.has(Mod):
|
else:
|
||||||
|
# expression has mod.
|
||||||
mod_expr = next(iter(expr.atoms(Mod)))
|
mod_expr = next(iter(expr.atoms(Mod)))
|
||||||
try:
|
try:
|
||||||
r = try_solve(expr, mod_expr, floordiv_inequality=False)
|
r = try_solve(expr, mod_expr, floordiv_inequality=False)
|
||||||
if r is not None and r[1] == 0:
|
if r is not None and r[1] == 0:
|
||||||
self._add_divisible(mod_expr)
|
self._add_divisible(mod_expr)
|
||||||
# This is a little bit of extra logic to make things like
|
|
||||||
# torch.empty(i0, q).view(c, -1, q) work out
|
|
||||||
p, q = mod_expr.args
|
|
||||||
if (
|
|
||||||
isinstance(q, sympy.Number)
|
|
||||||
and isinstance(p, sympy.Mul)
|
|
||||||
and len(p.args) == 2
|
|
||||||
):
|
|
||||||
c, i0 = p.args
|
|
||||||
# Given Mod(c * i0, q) == 0
|
|
||||||
if (
|
|
||||||
isinstance(c, sympy.Number)
|
|
||||||
and isinstance(i0, sympy.Symbol)
|
|
||||||
and self.is_unbacked_symint(i0)
|
|
||||||
):
|
|
||||||
# We have Mod(i0, q / c) == 0, which means we can
|
|
||||||
# rewrite i0 as (q / gcd(q, c)) * i1
|
|
||||||
d = q / sympy.gcd(q, c) # TODO: CleanDiv?
|
|
||||||
i1 = self.create_unbacked_symint().node.expr
|
|
||||||
# Propagate the value ranges. It doesn't really
|
|
||||||
# matter if we use truediv or floordiv, because we
|
|
||||||
# have established divisibility.
|
|
||||||
self._update_var_to_range(
|
|
||||||
i1,
|
|
||||||
SymPyValueRangeAnalysis.floordiv(
|
|
||||||
self.var_to_range[i0], ValueRanges.wrap(d)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# Propagate hints (real tensor tracing)
|
|
||||||
if i0 in self.unbacked_var_to_val:
|
|
||||||
self.set_unbacked_var_to_val(
|
|
||||||
i1, self.unbacked_var_to_val[i0] // d
|
|
||||||
)
|
|
||||||
# Propagate size-like-ness
|
|
||||||
if i0 in self.size_like:
|
|
||||||
self.size_like.add(i1)
|
|
||||||
self._set_replacement(i0, d * i1, "divisibility")
|
|
||||||
|
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
pass
|
pass
|
||||||
return
|
return
|
||||||
|
|||||||
@ -273,9 +273,8 @@ class _KinetoProfile:
|
|||||||
if path.endswith(".gz"):
|
if path.endswith(".gz"):
|
||||||
with tempfile.NamedTemporaryFile("w+b", suffix=".json") as fp:
|
with tempfile.NamedTemporaryFile("w+b", suffix=".json") as fp:
|
||||||
retvalue = self.profiler.export_chrome_trace(fp.name)
|
retvalue = self.profiler.export_chrome_trace(fp.name)
|
||||||
fp.seek(0)
|
with open(fp.name, "rb") as fin, gzip.open(path, "wb") as fout:
|
||||||
with gzip.open(path, "wb") as fout:
|
fout.writelines(fin)
|
||||||
fout.writelines(fp)
|
|
||||||
return retvalue
|
return retvalue
|
||||||
else:
|
else:
|
||||||
return self.profiler.export_chrome_trace(path)
|
return self.profiler.export_chrome_trace(path)
|
||||||
@ -447,7 +446,6 @@ class _KinetoProfile:
|
|||||||
self.mem_tl.export_memory_timeline_html(path, device)
|
self.mem_tl.export_memory_timeline_html(path, device)
|
||||||
elif path.endswith(".gz"):
|
elif path.endswith(".gz"):
|
||||||
with tempfile.NamedTemporaryFile("w+t", suffix=".json") as fp:
|
with tempfile.NamedTemporaryFile("w+t", suffix=".json") as fp:
|
||||||
fp.close()
|
|
||||||
if path.endswith("raw.json.gz"):
|
if path.endswith("raw.json.gz"):
|
||||||
self.mem_tl.export_memory_timeline_raw(fp.name, device)
|
self.mem_tl.export_memory_timeline_raw(fp.name, device)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -39,7 +39,7 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
import weakref
|
import weakref
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, Optional, TYPE_CHECKING # noqa: F401
|
from typing import Any, Optional, TYPE_CHECKING, Union # noqa: F401
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||||
@ -157,21 +157,25 @@ def _arg_to_str(arg, attributes, tensor_memo=None) -> str:
|
|||||||
return str(arg)
|
return str(arg)
|
||||||
|
|
||||||
|
|
||||||
def default_hash_fn(t: torch.Tensor, use_scalar: bool = False) -> torch.Tensor:
|
def norm_hash_fn(
|
||||||
|
t: torch.Tensor, use_scalar: bool = False
|
||||||
|
) -> Union[torch.Tensor, float]:
|
||||||
"""
|
"""
|
||||||
from Observer. Computes a hash for a tensor by converting it to float (if needed), making it contiguous,
|
from Observer. Computes a hash for a tensor by converting it to float (if needed), making it contiguous,
|
||||||
replacing NaN/inf values with fixed numbers, and then computing the L1 norm in float64 or complex128.
|
replacing NaN/inf values with fixed numbers, and then computing the L1 norm in float64 or complex128.
|
||||||
This is used to generate a deterministic summary value for tensor comparison.
|
This is used to generate a deterministic summary value for tensor comparison.
|
||||||
"""
|
"""
|
||||||
with torch._C._DisablePythonDispatcher(), torch._C._DisableTorchDispatch():
|
with torch._C._DisablePythonDispatcher():
|
||||||
if not (t.is_floating_point() or t.is_complex()):
|
if not (t.is_floating_point() or t.is_complex()):
|
||||||
t = t.float()
|
t = t.float()
|
||||||
t = t.contiguous()
|
t = t.contiguous()
|
||||||
# Clean the tensor to handle NaN/inf values, then compute norm
|
|
||||||
t_clean = torch.nan_to_num(t, nan=0.0, posinf=1.0, neginf=-1.0)
|
|
||||||
|
|
||||||
dtype = torch.complex128 if t.is_complex() else torch.float64
|
if t.is_complex():
|
||||||
out = t_clean.norm(p=1, dtype=dtype)
|
t_float = t.to(dtype=torch.complex128)
|
||||||
|
else:
|
||||||
|
t_float = t.to(dtype=torch.float64)
|
||||||
|
|
||||||
|
out = t_float.norm(p=1)
|
||||||
if use_scalar:
|
if use_scalar:
|
||||||
return out.item()
|
return out.item()
|
||||||
return out
|
return out
|
||||||
@ -184,6 +188,28 @@ def _compute_rel_diff(hash1, hash2):
|
|||||||
return numerator / denominator
|
return numerator / denominator
|
||||||
|
|
||||||
|
|
||||||
|
def hash_tensor_fn(
|
||||||
|
t: torch.Tensor, use_scalar: bool = False
|
||||||
|
) -> Union[torch.Tensor, int]:
|
||||||
|
"""
|
||||||
|
wrapper over torch.hash_tensor
|
||||||
|
"""
|
||||||
|
if isinstance(t, torch.distributed.tensor.DTensor):
|
||||||
|
t = t.to_local()
|
||||||
|
|
||||||
|
if t.is_floating_point():
|
||||||
|
t_clean = t.to(dtype=torch.float64)
|
||||||
|
elif t.is_complex():
|
||||||
|
t_clean = t.to(dtype=torch.complex128).view(torch.float64)
|
||||||
|
else:
|
||||||
|
t_clean = t.to(dtype=torch.int64)
|
||||||
|
|
||||||
|
out = torch.hash_tensor(t_clean)
|
||||||
|
if use_scalar:
|
||||||
|
return out.item() # type: ignore[attribute]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _get_stack_trace() -> str:
|
def _get_stack_trace() -> str:
|
||||||
from torch.fx.experimental.symbolic_shapes import uninteresting_files
|
from torch.fx.experimental.symbolic_shapes import uninteresting_files
|
||||||
|
|
||||||
@ -897,20 +923,43 @@ class DebugMode(TorchDispatchMode):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def log_tensor_hashes(hash_fn: Callable | None = None, hash_inputs: bool = False):
|
def log_tensor_hashes(
|
||||||
|
hash_fn: Union[Callable, str, list[str]] = "norm", hash_inputs: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Installs hook for tensor hash logging.
|
Installs hook for tensor hash logging.
|
||||||
|
|
||||||
hash_fn: optional function for custom hashing
|
hash_fn: One of:
|
||||||
|
- Custom-defined hash function
|
||||||
|
- String: one of ("norm", "hash_tensor")
|
||||||
|
- "norm": uses norm_hash_fn; basically tensor's L1 norm
|
||||||
|
- "hash_tensor": uses torch.hash_tensor (XOR sum reduction)
|
||||||
|
- List of strings: returns tuple of hashes from above options
|
||||||
hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash".
|
hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash".
|
||||||
NOTE: this is currently a post-hook, so e.g. inplace ops will log the "output" hashes.
|
NOTE: this is currently a post-hook, so e.g. inplace ops will log the "output" hashes.
|
||||||
"""
|
"""
|
||||||
if hash_fn is None:
|
|
||||||
hash_fn = functools.partial(default_hash_fn, use_scalar=True)
|
def hash_fn_option(hash_type):
|
||||||
|
assert isinstance(hash_type, str) and hash_type in ["norm", "hash_tensor"]
|
||||||
|
return functools.partial(
|
||||||
|
norm_hash_fn if hash_type == "norm" else hash_tensor_fn, use_scalar=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if callable(hash_fn):
|
||||||
|
fn = hash_fn
|
||||||
|
elif isinstance(hash_fn, str):
|
||||||
|
fn = hash_fn_option(hash_fn)
|
||||||
|
elif isinstance(hash_fn, list):
|
||||||
|
fns = [hash_fn_option(fn) for fn in hash_fn]
|
||||||
|
fn = lambda x: tuple(fn(x) for fn in fns) # noqa: E731
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"log_tensor_hashes() expected hash_fn to be callable, str, or list[str], but found {type(hash_fn)}"
|
||||||
|
)
|
||||||
|
|
||||||
def _tree_hash(obj):
|
def _tree_hash(obj):
|
||||||
return tree_map(
|
return tree_map(
|
||||||
lambda x: hash_fn(x) if isinstance(x, torch.Tensor) else None, obj
|
lambda x: fn(x) if isinstance(x, torch.Tensor) else None, obj
|
||||||
)
|
)
|
||||||
|
|
||||||
def _dispatch_hash_hook(func, types, args, kwargs, result):
|
def _dispatch_hash_hook(func, types, args, kwargs, result):
|
||||||
@ -930,9 +979,9 @@ class DebugMode(TorchDispatchMode):
|
|||||||
try:
|
try:
|
||||||
if hash_inputs:
|
if hash_inputs:
|
||||||
_old_input_hfn = _TRITON_INPUT_HASH_FN
|
_old_input_hfn = _TRITON_INPUT_HASH_FN
|
||||||
_TRITON_INPUT_HASH_FN = hash_fn
|
_TRITON_INPUT_HASH_FN = fn
|
||||||
_old_output_hfn = _TRITON_OUTPUT_HASH_FN
|
_old_output_hfn = _TRITON_OUTPUT_HASH_FN
|
||||||
_TRITON_OUTPUT_HASH_FN = hash_fn
|
_TRITON_OUTPUT_HASH_FN = fn
|
||||||
with DebugMode.dispatch_hooks(log_hook=_dispatch_hash_hook):
|
with DebugMode.dispatch_hooks(log_hook=_dispatch_hash_hook):
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
Reference in New Issue
Block a user