Compare commits

...

3 Commits

Author SHA1 Message Date
1af3a82933 Update
[ghstack-poisoned]
2025-11-13 11:46:16 -08:00
fae8ca233d Update
[ghstack-poisoned]
2025-11-13 11:24:17 -08:00
eb24dbfce8 Update (base update)
[ghstack-poisoned]
2025-11-13 11:24:17 -08:00
19 changed files with 1947 additions and 4 deletions

187
aaa.py Normal file
View File

@ -0,0 +1,187 @@
import torch
from torch.nn.attention import flex_attention
from torch.utils._debug_mode import DebugMode
from torch import nn
import functools
from torch.nn.attention.flex_attention import (
_create_empty_block_mask,
_DEFAULT_SPARSE_BLOCK_SIZE,
_identity,
_mask_mod_signature,
_score_mod_signature,
_WARNINGS_SHOWN,
and_masks,
AuxOutput,
AuxRequest,
BlockMask,
create_block_mask,
flex_attention,
flex_attention_hop,
noop_mask,
or_masks,
)
"""
@torch.compile()
def f(x, y):
return x + y
with DebugMode() as debug_mode:
print(f(torch.randn(2), torch.randn(2)))
print(debug_mode.debug_string())
"""
# NEXT
"""
from torch.utils.checkpoint import (
checkpoint,
CheckpointPolicy,
create_selective_checkpoint_contexts,
)
def _get_custom_policy(no_recompute_list=None, must_recompute_list=None):
def _custom_policy(ctx, func, *args, **kwargs):
if no_recompute_list is not None and func in no_recompute_list:
return CheckpointPolicy.MUST_SAVE
if must_recompute_list is not None and func in must_recompute_list:
return CheckpointPolicy.MUST_RECOMPUTE
else:
return CheckpointPolicy.PREFER_RECOMPUTE
return _custom_policy
def context_fn_must_recompute_mm():
must_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(
must_recompute_list=must_recompute_list,
),
)
@torch.compile(fullgraph=True)
def mm(x, y):
return torch.matmul(x, y)
def gn(x):
return torch.sigmoid(mm(x, x))
def fn(x):
return torch.utils.checkpoint.checkpoint(
gn,
x,
use_reentrant=False,
context_fn=context_fn_must_recompute_mm,
)
x = torch.randn(4, 4, requires_grad=True, device='cuda')
with DebugMode() as debug_mode:
print(fn(x))
print(debug_mode.debug_string())
"""
class FlexAttentionModule(nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
# In-projections (query, key, value)
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
# Out-projection
self.out_proj = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# Project queries, keys, and values
q = (
self.q_proj(x)
.view(batch_size, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
k = (
self.k_proj(x)
.view(batch_size, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
v = (
self.v_proj(x)
.view(batch_size, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# Apply flex attention
attn_output = torch.compile()(flex_attention)(
#attn_output = flex_attention(
q,
k,
v,
)
# Reshape output
attn_output = (
attn_output.transpose(1, 2)
.contiguous()
.view(batch_size, seq_len, self.hidden_size)
)
# Out projection
output = self.out_proj(attn_output)
return output
from torch.utils.checkpoint import (
checkpoint,
create_selective_checkpoint_contexts,
)
# TODO: do this
ops_to_save = [torch.ops.aten.mm.default]
context_fn = functools.partial(
create_selective_checkpoint_contexts, ops_to_save
)
# Define a model that uses FlexAttention with selective activation checkpointing
class SacModule(nn.Module):
def __init__(self, hidden_size, num_heads, context_fn):
super().__init__()
self.flex_attn = FlexAttentionModule(hidden_size, num_heads)
self.context_fn = context_fn
def forward(self, x):
def flex_attn_fn(x):
return self.flex_attn(x)
output = checkpoint(
flex_attn_fn,
x,
use_reentrant=False,
context_fn=self.context_fn,
)
return output
device='cuda'
flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to(
device, dtype=torch.bfloat16
)
x = torch.ones(8, 1024, 512, device=device, dtype=torch.bfloat16, requires_grad=True)
with DebugMode() as debug_mode:
output_module = flex_module(x)
grad_output = torch.ones_like(output_module)
grad_module = torch.autograd.grad(
outputs=output_module, inputs=x, grad_outputs=grad_output, retain_graph=True
)[0]
print(debug_mode.debug_string())

View File

@ -1019,6 +1019,28 @@ class DTensorMeshTest(DTensorTestBase):
except ValueError:
self.fail("Unexpected ValueError raised with run_check=False")
@with_comms
def test_as_strided_identity(self):
# Test calling as_strided with the same size/stride/offset as input tensor
# This should be a no-op but currently fails
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 4, device=self.device_type)
dtensor = DTensor.from_local(local_tensor, device_mesh, placements)
# Get the current size, stride, and storage_offset
size = dtensor.size()
stride = dtensor.stride()
storage_offset = dtensor.storage_offset()
# Call as_strided with the exact same parameters
result = dtensor.as_strided(size, stride, storage_offset)
# The result should be identical to the input
self.assertEqual(result.size(), dtensor.size())
self.assertEqual(result.stride(), dtensor.stride())
self.assertEqual(result.to_local(), dtensor.to_local())
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
DTensorMeshTest,

View File

@ -994,6 +994,96 @@ def forward(self, primals_1):
out_dt = torch.matmul(tmp_dt, y_dt)
out_dt.sum().backward()
def test_dtensor_broadcast_to_replicate(self):
"""Test broadcast_to operation with DTensor under compilation."""
mesh = init_device_mesh(self.device_type, (self.world_size, 1))
placement = [Replicate(), Replicate()]
def fn(input):
return torch.broadcast_to(input, (888, 12, 888, 12))
big_tensor = torch.randn(888, 12)
d_tensor = distribute_tensor(big_tensor, device_mesh=mesh, placements=placement)
# Test eager execution
result = fn(d_tensor)
self.assertEqual(result.shape, torch.Size([888, 12, 888, 12]))
# Test compiled execution
jitted_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
jitted_result = jitted_fn(d_tensor)
self.assertEqual(jitted_result.shape, torch.Size([888, 12, 888, 12]))
self.assertEqual(result.full_tensor(), jitted_result.full_tensor())
def test_dtensor_reshape_replicate(self):
"""Test reshape operation with DTensor under compilation."""
mesh = init_device_mesh(self.device_type, (self.world_size, 1))
placement = [Replicate(), Replicate()]
def fn(input):
return input.reshape((1, 888, 6, 2))
big_tensor = torch.randn(888, 12)
d_tensor = distribute_tensor(big_tensor, device_mesh=mesh, placements=placement)
# Test eager execution
result = fn(d_tensor)
self.assertEqual(result.shape, torch.Size([1, 888, 6, 2]))
# Test compiled execution
jitted_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
jitted_result = jitted_fn(d_tensor)
self.assertEqual(jitted_result.shape, torch.Size([1, 888, 6, 2]))
self.assertEqual(result.full_tensor(), jitted_result.full_tensor())
def test_dtensor_indexing_replicate(self):
"""Test tensor indexing with Ellipsis with DTensor under compilation."""
mesh = init_device_mesh(self.device_type, (self.world_size, 1))
placement = [Replicate(), Replicate()]
def fn(input):
return input[(Ellipsis, 1)]
big_tensor = torch.randn(1, 888, 12)
d_tensor = distribute_tensor(big_tensor, device_mesh=mesh, placements=placement)
# Test eager execution
result = fn(d_tensor)
self.assertEqual(result.shape, torch.Size([1, 888]))
# Test compiled execution
jitted_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
jitted_result = jitted_fn(d_tensor)
self.assertEqual(jitted_result.shape, torch.Size([1, 888]))
self.assertEqual(result.full_tensor(), jitted_result.full_tensor())
def test_dtensor_tensor_split_replicate(self):
"""Test tensor_split operation with DTensor under compilation."""
mesh = init_device_mesh(self.device_type, (self.world_size, 1))
placement = [Replicate(), Replicate()]
def fn(input):
return torch.tensor_split(input, 3, dim=-1)
big_tensor = torch.randn(1, 1, 9216)
d_tensor = distribute_tensor(big_tensor, device_mesh=mesh, placements=placement)
# Test eager execution
t1, t2, t3 = fn(d_tensor)
self.assertEqual(t1.shape, torch.Size([1, 1, 3072]))
self.assertEqual(t2.shape, torch.Size([1, 1, 3072]))
self.assertEqual(t3.shape, torch.Size([1, 1, 3072]))
# Test compiled execution
jitted_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
jitted_t1, jitted_t2, jitted_t3 = jitted_fn(d_tensor)
self.assertEqual(jitted_t1.shape, torch.Size([1, 1, 3072]))
self.assertEqual(jitted_t2.shape, torch.Size([1, 1, 3072]))
self.assertEqual(jitted_t3.shape, torch.Size([1, 1, 3072]))
self.assertEqual(t1.full_tensor(), jitted_t1.full_tensor())
self.assertEqual(t2.full_tensor(), jitted_t2.full_tensor())
self.assertEqual(t3.full_tensor(), jitted_t3.full_tensor())
@unittest.skipIf(
torch._inductor.config.triton.native_matmul, "Matmul is now generated"
)

View File

@ -664,6 +664,101 @@ class TestViewOps(DTensorTestBase):
)
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
@with_comms
def test_storage_offset_slice(self):
"""
Test that storage_offset is properly tracked on DTensor when slicing
a replicated tensor.
"""
mesh = init_device_mesh(self.device_type, (self.world_size,))
# Create a replicated DTensor
tensor = torch.randn(10, device=self.device_type)
dtensor = distribute_tensor(tensor, mesh, [Replicate()])
# Perform a slice operation [1:]
with CommDebugMode() as comm_mode:
sliced_dtensor = dtensor[1:]
# Slicing should not trigger any communication
self.assertEqual(comm_mode.get_total_counts(), 0)
# Verify that the DTensor's storage_offset matches the expected value
self.assertEqual(sliced_dtensor.storage_offset(), 1)
# Verify that the local tensor also has the correct storage_offset
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 1)
# Verify the shape is correct
self.assertEqual(sliced_dtensor.shape, torch.Size([9]))
# Verify the values are correct
expected = tensor[1:]
self.assertEqual(sliced_dtensor.full_tensor(), expected)
@with_comms
def test_storage_offset_shard_dim0_slice_dim1(self):
"""
Test that storage_offset is properly tracked when tensor is sharded on dim 0
and sliced on dim 1.
"""
mesh = init_device_mesh(self.device_type, (self.world_size,))
# Create a 2D tensor and shard on dim 0
tensor = torch.randn(12, 8, device=self.device_type)
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
# Perform a slice operation [:, 2:]
with CommDebugMode() as comm_mode:
sliced_dtensor = dtensor[:, 2:]
# Slicing should not trigger any communication
self.assertEqual(comm_mode.get_total_counts(), 0)
# The storage_offset should be 2 (skipping 2 elements in each row)
self.assertEqual(sliced_dtensor.storage_offset(), 2)
# Verify that the local tensor also has the correct storage_offset
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 2)
# Verify the shape is correct
expected_shape = torch.Size([12, 6])
self.assertEqual(sliced_dtensor.shape, expected_shape)
# Verify the values are correct
expected = tensor[:, 2:]
self.assertEqual(sliced_dtensor.full_tensor(), expected)
@with_comms
def test_storage_offset_shard_dim1_slice_dim0(self):
"""
Test that storage_offset is properly tracked when tensor is sharded on dim 1
and sliced on dim 0.
"""
mesh = init_device_mesh(self.device_type, (self.world_size,))
# Create a 2D tensor and shard on dim 1
tensor = torch.randn(10, 12, device=self.device_type)
dtensor = distribute_tensor(tensor, mesh, [Shard(1)])
# Perform a slice operation [2:, :]
with CommDebugMode() as comm_mode:
sliced_dtensor = dtensor[2:, :]
# Slicing should not trigger any communication
self.assertEqual(comm_mode.get_total_counts(), 0)
local_dim1_size = 12 // self.world_size
expected_offset = 2 * local_dim1_size
self.assertEqual(sliced_dtensor.storage_offset(), expected_offset)
self.assertEqual(sliced_dtensor.to_local().storage_offset(), expected_offset)
# Verify the shape is correct
expected_shape = torch.Size([8, 12])
self.assertEqual(sliced_dtensor.shape, expected_shape)
# Verify the values are correct
expected = tensor[2:, :]
self.assertEqual(sliced_dtensor.full_tensor(), expected)
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
TestViewOps,

890
test/test_as_strided.py Normal file
View File

@ -0,0 +1,890 @@
# Owner(s): ["oncall: pt2"]
from collections import deque
from typing import Optional
import torch
from torch._prims_common import check_significant_strides
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._as_strided import as_strided_via_views
def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Extract (sizes, strides) tuple from a tensor."""
return (tuple(t.size()), tuple(t.stride()))
def max_storage_offset(
base_numel: int, size: tuple[int, ...], stride: tuple[int, ...]
) -> int:
"""
Calculate the maximum storage offset for which a view with the given
size/stride would fit within a tensor of base_numel elements.
We conservatively require the full size*stride extent to fit inside,
which ensures there's enough working space for unflatten operations.
"""
if base_numel == 0:
return 0
if any(s == 0 for s in size):
# numel == 0, all offsets from 0 to base_numel-1 are technically valid
# but for simplicity we only test offset 0
return 0
# Require the full extent of the largest dimension to fit
# For each dimension, the extent is size[i] * stride[i]
max_extent = max(sz * max(st, 0) for sz, st in zip(size, stride))
max_offset = base_numel - max_extent
return max(0, max_offset)
def enumerate_reachable_states(
initial_size: int,
include_slice: bool = False,
) -> set[tuple[tuple[int, ...], tuple[int, ...]]]:
"""
Use BFS with DP to enumerate all reachable (size, stride) states from
a 1D contiguous tensor via valid view operations.
Args:
initial_size: Size of the initial 1D tensor
include_slice: If True, include slice operations with step>1
We only explore states with offset=0 (you can retroactively change the offset).
We reject states with size=0 or size=1 dimensions as they are degenerate.
"""
# Create initial 1D contiguous tensor
initial_tensor = torch.arange(initial_size)
initial_state = get_state(initial_tensor)
# Map from state to tensor for that state
state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = {
initial_state: initial_tensor
}
visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state}
queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state])
while queue:
state = queue.popleft()
t = state_to_tensor[state]
sizes, strides = state
ndim = len(sizes)
def add_state(new_t: torch.Tensor) -> None:
new_state = get_state(new_t)
sizes, strides = new_state
# Skip if has size-0 or size-1 dimensions
if any(s == 0 or s == 1 for s in sizes):
return
# Only accept states where strides are in descending order
if list(strides) != sorted(strides, reverse=True):
return
if new_state not in visited:
visited.add(new_state)
queue.append(new_state)
state_to_tensor[new_state] = new_t
# 1. Unflatten: try factoring each dimension
for dim in range(ndim):
size = sizes[dim]
assert size > 1
# Try all factorizations x * y = size where both x, y >= 2
# We only need to check x up to size // 2 since when x > size // 2,
# y = size // x < 2, which we reject
for x in range(2, size // 2 + 1):
if size % x == 0:
y = size // x
add_state(t.unflatten(dim, (x, y)))
# 2. Slice/Narrow
for dim in range(ndim):
size = sizes[dim]
max_step = size + 1 if include_slice else 2
for start in range(size):
for stop in range(start + 1, size + 1):
for step in range(1, max_step):
slices = [slice(None)] * ndim
slices[dim] = slice(start, stop, step)
add_state(t[tuple(slices)])
# 3. Flatten: merge adjacent dimensions
for dim in range(ndim - 1):
add_state(t.flatten(dim, dim + 1))
return visited
class TestAsStrided(TestCase):
def assertSameView(
self,
result: torch.Tensor | type[NotImplemented],
target: torch.Tensor,
base: torch.Tensor,
msg: str = "",
) -> None:
self.assertIsNot(result, NotImplemented, msg=msg or "Got NotImplemented")
self.assertEqual(result.size(), target.size(), msg=msg or "Size mismatch")
same_strides, idx = check_significant_strides(result, target, only_cuda=False)
if not same_strides:
fail_msg = f"Stride mismatch at dim {idx}: result={result.stride()}, target={target.stride()}"
if msg:
fail_msg = f"{msg}: {fail_msg}"
self.fail(fail_msg)
self.assertTrue(result._is_view(), msg=msg or "Result is not a view")
# Check that result is a view of the base tensor using object identity
result_base = result._base if result._base is not None else result
base_base = base._base if base._base is not None else base
self.assertIs(
result_base, base_base, msg=msg or "Result is not a view of the base tensor"
)
self.assertEqual(
result.storage_offset(),
target.storage_offset(),
msg=msg or "Storage offset mismatch",
)
def check_as_strided_via_views(
self,
base: torch.Tensor,
size: tuple[int, ...],
stride: tuple[int, ...],
storage_offset: int = 0,
) -> None:
"""Helper to test as_strided_via_views matches torch.as_strided."""
target = torch.as_strided(base, size, stride, storage_offset)
result = as_strided_via_views(base, size, stride, storage_offset)
self.assertSameView(
result, target, base, f"Failed for {size=}, {stride=}, {storage_offset=}"
)
def test_size_10_exhaustive_without_slice(self) -> None:
"""Test that size 10 produces exactly 26 states without slice (step>1)."""
expected_states = {
((2,), (1,)),
((2, 2), (2, 1)),
((2, 2), (3, 1)),
((2, 2), (4, 1)),
((2, 2), (5, 1)),
((2, 2, 2), (4, 2, 1)),
((2, 2, 2), (5, 2, 1)),
((2, 3), (3, 1)),
((2, 3), (4, 1)),
((2, 3), (5, 1)),
((2, 4), (4, 1)),
((2, 4), (5, 1)),
((2, 5), (5, 1)),
((3,), (1,)),
((3, 2), (2, 1)),
((3, 2), (3, 1)),
((3, 3), (3, 1)),
((4,), (1,)),
((4, 2), (2, 1)),
((5,), (1,)),
((5, 2), (2, 1)),
((6,), (1,)),
((7,), (1,)),
((8,), (1,)),
((9,), (1,)),
((10,), (1,)),
}
actual_states = enumerate_reachable_states(10, include_slice=False)
self.assertEqual(len(actual_states), 26)
self.assertEqual(actual_states, expected_states)
def test_size_10_exhaustive_with_slice(self) -> None:
"""Test that size 10 produces exactly 54 states with slice (step>1)."""
expected_states = {
((2,), (1,)),
((2,), (2,)),
((2,), (3,)),
((2,), (4,)),
((2,), (5,)),
((2,), (6,)),
((2,), (7,)),
((2,), (8,)),
((2,), (9,)),
((2, 2), (2, 1)),
((2, 2), (3, 1)),
((2, 2), (3, 2)),
((2, 2), (4, 1)),
((2, 2), (4, 2)),
((2, 2), (4, 3)),
((2, 2), (5, 1)),
((2, 2), (5, 2)),
((2, 2), (5, 3)),
((2, 2), (5, 4)),
((2, 2), (6, 1)),
((2, 2), (6, 2)),
((2, 2), (6, 3)),
((2, 2), (8, 1)),
((2, 2, 2), (4, 2, 1)),
((2, 2, 2), (5, 2, 1)),
((2, 3), (3, 1)),
((2, 3), (4, 1)),
((2, 3), (5, 1)),
((2, 3), (5, 2)),
((2, 3), (6, 1)),
((2, 4), (4, 1)),
((2, 4), (5, 1)),
((2, 5), (5, 1)),
((3,), (1,)),
((3,), (2,)),
((3,), (3,)),
((3,), (4,)),
((3, 2), (2, 1)),
((3, 2), (3, 1)),
((3, 2), (3, 2)),
((3, 2), (4, 1)),
((3, 3), (3, 1)),
((4,), (1,)),
((4,), (2,)),
((4,), (3,)),
((4, 2), (2, 1)),
((5,), (1,)),
((5,), (2,)),
((5, 2), (2, 1)),
((6,), (1,)),
((7,), (1,)),
((8,), (1,)),
((9,), (1,)),
((10,), (1,)),
}
actual_states = enumerate_reachable_states(10, include_slice=True)
self.assertEqual(len(actual_states), 54)
self.assertEqual(actual_states, expected_states)
def test_subset_property(self) -> None:
"""
Test that for sizes 2..10, each smaller tensor results in a strict
subset of possible states compared to the next one.
"""
prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None
for size in range(2, 11):
current_states = enumerate_reachable_states(size)
if prev_states is not None:
# Check that prev_states is a strict subset of current_states
self.assertTrue(
prev_states.issubset(current_states),
f"States from size {size - 1} are not a subset of size {size}",
)
# Check that it's a strict subset (not equal)
self.assertTrue(
len(prev_states) < len(current_states),
f"States from size {size - 1} should be strictly fewer than size {size}",
)
prev_states = current_states
def test_as_strided_via_views_exhaustive(self) -> None:
"""Exhaustively test on all reachable states from size 10, including all valid offsets."""
initial_tensor = torch.arange(10)
states = enumerate_reachable_states(10)
for size, stride in states:
# Test with offset=0
self.check_as_strided_via_views(
initial_tensor, size, stride, storage_offset=0
)
# Test with all valid non-zero offsets
max_offset = max_storage_offset(initial_tensor.numel(), size, stride)
for offset in range(1, max_offset + 1):
self.check_as_strided_via_views(
initial_tensor, size, stride, storage_offset=offset
)
# Tests for specific transformations
def test_permute_simple(self) -> None:
"""Test simple permute: (2, 5) with strides (5, 1) -> (5, 2) with strides (1, 5)"""
self.check_as_strided_via_views(torch.arange(10), (5, 2), (1, 5))
def test_unsqueeze_front(self) -> None:
"""Test unsqueeze at front: (2, 5) -> (1, 2, 5)"""
self.check_as_strided_via_views(torch.arange(10), (1, 2, 5), (10, 5, 1))
def test_unsqueeze_middle(self) -> None:
"""Test unsqueeze in middle: (2, 5) -> (2, 1, 5)"""
self.check_as_strided_via_views(torch.arange(10), (2, 1, 5), (5, 5, 1))
def test_unsqueeze_end(self) -> None:
"""Test unsqueeze at end: (2, 5) -> (2, 5, 1)"""
self.check_as_strided_via_views(torch.arange(10), (2, 5, 1), (5, 1, 1))
def test_narrow_simple(self) -> None:
"""Test narrow: (2, 5) -> (2, 3)"""
self.check_as_strided_via_views(torch.arange(10), (2, 3), (5, 1))
def test_permute_with_unsqueeze(self) -> None:
"""Test permute + unsqueeze: (2, 5) -> (1, 5, 2)"""
self.check_as_strided_via_views(torch.arange(10), (1, 5, 2), (10, 1, 5))
def test_multiple_unsqueeze_with_narrow(self) -> None:
"""Test multiple unsqueeze + narrow: (2, 5) -> (1, 1, 2, 3)"""
self.check_as_strided_via_views(torch.arange(10), (1, 1, 2, 3), (15, 10, 5, 1))
def test_permute_3d(self) -> None:
"""Test 3D permute: (2, 3, 4) -> (4, 2, 3)"""
self.check_as_strided_via_views(torch.arange(24), (4, 2, 3), (1, 12, 4))
def test_unsqueeze_multiple_size_one_dims(self) -> None:
"""Test multiple size-1 dims at various positions"""
self.check_as_strided_via_views(
torch.arange(6), (1, 2, 1, 3, 1), (6, 3, 3, 1, 1)
)
def test_as_strided_via_views_numel_zero(self) -> None:
"""Test numel==0 cases where all strides are insignificant."""
empty_tensor = torch.tensor([])
for size, stride in [
((0,), (1,)),
((0,), (999,)), # Arbitrary stride
((0, 5), (10, 1)),
((3, 0), (5, 1)),
((2, 0, 3), (0, 1, 0)), # Arbitrary strides
]:
self.check_as_strided_via_views(empty_tensor, size, stride)
def test_as_strided_via_views_numel_one(self) -> None:
"""Test numel==1 cases where all strides are insignificant."""
single_tensor = torch.tensor([42.0])
for size, stride in [
((1,), (1,)),
((1,), (999,)), # Arbitrary stride
((1, 1), (10, 1)),
((1, 1, 1), (100, 50, 25)), # All arbitrary strides
]:
self.check_as_strided_via_views(single_tensor, size, stride)
def test_numel_one_to_zero(self) -> None:
"""Test numel==0 output from numel==1 input."""
single_tensor = torch.tensor([42.0])
for size, stride in [
((0,), (1,)),
((0,), (999,)),
((0, 5), (10, 1)),
]:
self.check_as_strided_via_views(single_tensor, size, stride)
def test_scalar_to_zero_size(self) -> None:
"""Test 0D scalar → zero-size tensor (needs unsqueeze before narrow)."""
scalar_tensor = torch.tensor(42.0)
for size, stride in [
((0,), (1,)),
((0,), (999,)),
((0, 5), (10, 1)),
]:
self.check_as_strided_via_views(scalar_tensor, size, stride)
def test_as_strided_via_views_impossible_cases(self) -> None:
"""Test cases that should return NotImplemented."""
initial_tensor = torch.arange(10)
impossible_cases = [
((8, 3), (1, 1)), # Overlapping
((2, 2), (6, 3)), # Requires slice with step>1
]
for size, stride in impossible_cases:
result = as_strided_via_views(
initial_tensor, size, stride, storage_offset=0
)
self.assertIs(result, NotImplemented)
# Unit tests for internal helper functions
def test_squeeze_target_then_basic(self) -> None:
"""Test _squeeze_target_then with single size-1 dim."""
from torch.utils._as_strided import _squeeze_target_then
# Define a simple callback that just returns a tensor with the expected shape
def simple_cb(result, size, stride, storage_offset):
# Callback receives squeezed target: should be (2, 5)
self.assertEqual(size, (2, 5))
self.assertEqual(stride, (5, 1))
# Return a view with this shape
return result.view(2, 5)
base = torch.arange(10)
# Target has size-1 dim at position 0: (1, 2, 5)
result = _squeeze_target_then(
base, size=(1, 2, 5), stride=(10, 5, 1), storage_offset=0, cb=simple_cb
)
self.assertEqual(result.size(), (1, 2, 5))
self.assertEqual(result.stride(), (10, 5, 1))
def test_squeeze_target_then_multiple_size_one(self) -> None:
"""Test _squeeze_target_then with multiple size-1 dims."""
from torch.utils._as_strided import _squeeze_target_then
def simple_cb(result, size, stride, storage_offset):
# Callback receives squeezed target: should be (5,)
self.assertEqual(size, (5,))
self.assertEqual(stride, (1,))
return result.narrow(0, 0, 5)
base = torch.arange(10)
# Target has size-1 dims at positions 0, 2, 4: (1, 5, 1)
result = _squeeze_target_then(
base, size=(1, 5, 1), stride=(10, 1, 1), storage_offset=0, cb=simple_cb
)
self.assertEqual(result.size(), (1, 5, 1))
# Size-1 strides are insignificant, so we don't check them strictly
def test_squeeze_target_then_no_size_one(self) -> None:
"""Test _squeeze_target_then with no size-1 dims."""
from torch.utils._as_strided import _squeeze_target_then
def simple_cb(result, size, stride, storage_offset):
# Callback receives same target: (2, 5)
self.assertEqual(size, (2, 5))
self.assertEqual(stride, (5, 1))
return result.view(2, 5)
base = torch.arange(10)
result = _squeeze_target_then(
base, size=(2, 5), stride=(5, 1), storage_offset=0, cb=simple_cb
)
self.assertEqual(result.size(), (2, 5))
self.assertEqual(result.stride(), (5, 1))
def test_squeeze_target_then_returns_not_implemented(self) -> None:
"""Test _squeeze_target_then when callback returns NotImplemented."""
from torch.utils._as_strided import _squeeze_target_then
def failing_cb(result, size, stride, storage_offset):
return NotImplemented
base = torch.arange(10)
result = _squeeze_target_then(
base, size=(1, 5, 2), stride=(10, 1, 5), storage_offset=0, cb=failing_cb
)
self.assertIs(result, NotImplemented)
def test_permute_target_then_basic(self) -> None:
"""Test _permute_target_then with simple permutation."""
from torch.utils._as_strided import _permute_target_then
def simple_cb(result, size, stride, storage_offset):
# Callback receives sorted target: (2, 5) with strides (5, 1)
self.assertEqual(tuple(size), (2, 5))
self.assertEqual(tuple(stride), (5, 1))
return result.view(2, 5)
base = torch.arange(10)
# Target is (5, 2) with strides (1, 5) - needs permutation
result = _permute_target_then(
base, size=(5, 2), stride=(1, 5), storage_offset=0, cb=simple_cb
)
self.assertEqual(result.size(), (5, 2))
self.assertEqual(result.stride(), (1, 5))
def test_permute_target_then_already_sorted(self) -> None:
"""Test _permute_target_then when target is already sorted."""
from torch.utils._as_strided import _permute_target_then
def simple_cb(result, size, stride, storage_offset):
# Callback receives same target: (2, 5) with strides (5, 1)
self.assertEqual(tuple(size), (2, 5))
self.assertEqual(tuple(stride), (5, 1))
return result.view(2, 5)
base = torch.arange(10)
# Target is already sorted
result = _permute_target_then(
base, size=(2, 5), stride=(5, 1), storage_offset=0, cb=simple_cb
)
self.assertEqual(result.size(), (2, 5))
self.assertEqual(result.stride(), (5, 1))
def test_permute_target_then_3d(self) -> None:
"""Test _permute_target_then with 3D tensor."""
from torch.utils._as_strided import _permute_target_then
def simple_cb(result, size, stride, storage_offset):
# Callback receives sorted target: (2, 3, 4) with strides (12, 4, 1)
self.assertEqual(tuple(size), (2, 3, 4))
self.assertEqual(tuple(stride), (12, 4, 1))
return result.view(2, 3, 4)
base = torch.arange(24)
# Target is (4, 2, 3) with strides (1, 12, 4) - needs permutation
result = _permute_target_then(
base, size=(4, 2, 3), stride=(1, 12, 4), storage_offset=0, cb=simple_cb
)
self.assertEqual(result.size(), (4, 2, 3))
self.assertEqual(result.stride(), (1, 12, 4))
def test_permute_target_then_returns_not_implemented(self) -> None:
"""Test _permute_target_then when callback returns NotImplemented."""
from torch.utils._as_strided import _permute_target_then
def failing_cb(result, size, stride, storage_offset):
return NotImplemented
base = torch.arange(10)
result = _permute_target_then(
base, size=(5, 2), stride=(1, 5), storage_offset=0, cb=failing_cb
)
self.assertIs(result, NotImplemented)
# Tests for stride-0 dimensions (expand)
def test_expand_simple(self) -> None:
"""Test simple expand: (5,) -> (3, 5) with stride (0, 1)"""
self.check_as_strided_via_views(torch.arange(5), (3, 5), (0, 1))
def test_expand_2d_source(self) -> None:
"""Test expand from 2D: (2, 5) -> (3, 2, 5)"""
self.check_as_strided_via_views(
torch.arange(10).view(2, 5), (3, 2, 5), (0, 5, 1)
)
def test_expand_middle_dim(self) -> None:
"""Test expand in middle: (2, 5) -> (2, 3, 5)"""
self.check_as_strided_via_views(
torch.arange(10).view(2, 5), (2, 3, 5), (5, 0, 1)
)
def test_expand_last_dim(self) -> None:
"""Test expand at end: (2, 5) -> (2, 5, 3)"""
self.check_as_strided_via_views(
torch.arange(10).view(2, 5), (2, 5, 3), (5, 1, 0)
)
def test_expand_multiple_dims(self) -> None:
"""Test expand on multiple dimensions"""
self.check_as_strided_via_views(torch.arange(5), (3, 5, 4), (0, 1, 0))
def test_expand_all_dims(self) -> None:
"""Test expand on all dimensions from single element"""
self.check_as_strided_via_views(torch.tensor([42.0]), (3, 4, 5), (0, 0, 0))
def test_expand_with_size_one(self) -> None:
"""Test expand combined with size-1 dims"""
self.check_as_strided_via_views(torch.arange(5), (1, 3, 5), (15, 0, 1))
def test_source_with_expand(self) -> None:
"""Test source tensor with stride-0 dimension"""
base = torch.arange(5).expand(3, 5)
self.check_as_strided_via_views(base, (3, 3), (0, 1))
def test_source_and_target_with_expand(self) -> None:
"""Test both source and target with stride-0 dimensions"""
base = torch.arange(5).expand(3, 5)
self.check_as_strided_via_views(base, (3, 2, 5), (0, 0, 1))
def test_eliminate_last_dim(self) -> None:
"""Test eliminating the last dimension (edge case: last target dim needs unflatten+squeeze)."""
# This is the pattern from indexing: tensor[:, :, 0]
# Source: (1, 888, 12) -> Target: (1, 888)
base = torch.arange(1 * 888 * 12).view(1, 888, 12)
self.check_as_strided_via_views(base, (1, 888), (10656, 12))
# Another similar case: (10, 20, 30) -> (10, 20)
base2 = torch.arange(10 * 20 * 30).view(10, 20, 30)
self.check_as_strided_via_views(base2, (10, 20), (600, 30))
# And with offset
self.check_as_strided_via_views(base2, (5, 10), (600, 30), storage_offset=1800)
def test_unexpand_target_then_basic(self) -> None:
"""Test _unexpand_target_then with single stride-0 dim."""
from torch.utils._as_strided import _unexpand_target_then
def simple_cb(result, size, stride, storage_offset):
self.assertEqual(size, (5,))
self.assertEqual(stride, (1,))
return result.narrow(0, 0, 5)
base = torch.arange(10)
result = _unexpand_target_then(
base, size=(3, 5), stride=(0, 1), storage_offset=0, cb=simple_cb
)
self.assertEqual(result.size(), (3, 5))
self.assertEqual(result.stride(), (0, 1))
def test_unexpand_target_then_multiple_stride_zero(self) -> None:
"""Test _unexpand_target_then with multiple stride-0 dims."""
from torch.utils._as_strided import _unexpand_target_then
def simple_cb(result, size, stride, storage_offset):
self.assertEqual(size, (5,))
self.assertEqual(stride, (1,))
return result.narrow(0, 0, 5)
base = torch.arange(10)
result = _unexpand_target_then(
base, size=(3, 5, 4), stride=(0, 1, 0), storage_offset=0, cb=simple_cb
)
self.assertEqual(result.size(), (3, 5, 4))
self.assertEqual(result.stride(), (0, 1, 0))
def test_unexpand_target_then_no_stride_zero(self) -> None:
"""Test _unexpand_target_then with no stride-0 dims."""
from torch.utils._as_strided import _unexpand_target_then
def simple_cb(result, size, stride, storage_offset):
self.assertEqual(size, (2, 5))
self.assertEqual(stride, (5, 1))
return result.view(2, 5)
base = torch.arange(10)
result = _unexpand_target_then(
base, size=(2, 5), stride=(5, 1), storage_offset=0, cb=simple_cb
)
self.assertEqual(result.size(), (2, 5))
self.assertEqual(result.stride(), (5, 1))
def test_unexpand_target_then_returns_not_implemented(self) -> None:
"""Test _unexpand_target_then when callback returns NotImplemented."""
from torch.utils._as_strided import _unexpand_target_then
def failing_cb(result, size, stride, storage_offset):
return NotImplemented
base = torch.arange(10)
result = _unexpand_target_then(
base, size=(3, 5), stride=(0, 1), storage_offset=0, cb=failing_cb
)
self.assertIs(result, NotImplemented)
# Storage offset tests
def test_storage_offset_simple_1d(self) -> None:
"""Test simple 1D tensor with offset."""
self.check_as_strided_via_views(torch.arange(10), (5,), (1,), storage_offset=2)
def test_storage_offset_2d_simple(self) -> None:
"""Test 2D tensor with offset."""
self.check_as_strided_via_views(
torch.arange(20), (2, 3), (5, 1), storage_offset=7
)
def test_storage_offset_with_permute(self) -> None:
"""Test offset with permuted dimensions."""
self.check_as_strided_via_views(
torch.arange(20), (3, 2), (1, 5), storage_offset=4
)
def test_storage_offset_with_unsqueeze(self) -> None:
"""Test offset with size-1 dimensions."""
self.check_as_strided_via_views(
torch.arange(10), (1, 5, 1), (10, 1, 1), storage_offset=3
)
def test_storage_offset_max_offset(self) -> None:
"""Test with maximum possible offset."""
# For (2, 3) with stride (5, 1), max_extent = max(2*5, 3*1) = 10
# max_offset = 20 - 10 = 10
self.check_as_strided_via_views(
torch.arange(20), (2, 3), (5, 1), storage_offset=10
)
def test_storage_offset_with_expand(self) -> None:
"""Test offset with stride-0 dimension."""
self.check_as_strided_via_views(
torch.arange(10), (3, 5), (0, 1), storage_offset=2
)
def test_storage_offset_impossible(self) -> None:
"""Test that impossible offset returns NotImplemented."""
base = torch.arange(10)
# Offset too large: would access indices beyond the base tensor
result = as_strided_via_views(base, (5,), (1,), storage_offset=10)
self.assertIs(result, NotImplemented)
def test_storage_offset_negative(self) -> None:
"""Test that negative offset returns NotImplemented."""
base = torch.arange(10)
result = as_strided_via_views(base, (5,), (1,), storage_offset=-1)
self.assertIs(result, NotImplemented)
def test_storage_offset_within_flattened_dimension(self) -> None:
"""Test offset adjustment within a flattened dimension.
This tests the case where:
1. Input tensor gets flattened during canonicalization
2. Target has larger stride that's a multiple of canonical stride
3. Small offset needs to be consumed by unflattening and narrowing
Example: (1, 888, 12) -> (1, 888) with offset 1
This is equivalent to tensor[:, :, 1].
"""
# Create tensor with shape (1, 888, 12)
base = torch.arange(1 * 888 * 12).view(1, 888, 12)
# Target: select element at index 1 along last dimension, remove that dimension
# Result should be equivalent to base[:, :, 1]
self.check_as_strided_via_views(base, (1, 888), (10656, 12), storage_offset=1)
# Tests with multi-dimensional contiguous input tensors
def test_2d_contiguous_source_simple(self) -> None:
"""Test with 2D contiguous input tensor."""
base = torch.arange(24).view(4, 6) # Contiguous 2D tensor
# Simple permute
self.check_as_strided_via_views(base, (6, 4), (1, 6))
def test_2d_contiguous_source_narrow(self) -> None:
"""Test narrowing a 2D contiguous input tensor."""
base = torch.arange(24).view(4, 6)
# Narrow both dimensions
self.check_as_strided_via_views(base, (2, 3), (6, 1))
def test_2d_contiguous_source_unflatten(self) -> None:
"""Test unflattening a 2D contiguous input tensor."""
base = torch.arange(24).view(4, 6)
# Unflatten the second dimension
self.check_as_strided_via_views(base, (4, 2, 3), (6, 3, 1))
def test_2d_contiguous_source_flatten(self) -> None:
"""Test flattening a 2D contiguous input back to 1D."""
base = torch.arange(24).view(4, 6)
self.check_as_strided_via_views(base, (24,), (1,))
def test_3d_contiguous_source_permute(self) -> None:
"""Test with 3D contiguous input tensor and permute."""
base = torch.arange(60).view(3, 4, 5) # Contiguous 3D tensor
# Permute dimensions
self.check_as_strided_via_views(base, (5, 3, 4), (1, 20, 5))
def test_3d_contiguous_source_narrow_all(self) -> None:
"""Test narrowing all dimensions of 3D contiguous input."""
base = torch.arange(60).view(3, 4, 5)
# Narrow all three dimensions
self.check_as_strided_via_views(base, (2, 2, 3), (20, 5, 1))
# Tests with non-contiguous multi-dimensional input tensors
def test_2d_transposed_source(self) -> None:
"""Test with transposed (non-contiguous) 2D input tensor."""
base = (
torch.arange(24).view(4, 6).t()
) # Non-contiguous: (6, 4) with strides (1, 6)
# Simple narrow
self.check_as_strided_via_views(base, (3, 2), (1, 6))
def test_2d_transposed_source_permute_back(self) -> None:
"""Test permuting transposed tensor back to original layout."""
base = torch.arange(24).view(4, 6).t() # (6, 4) with strides (1, 6)
# Permute back to (4, 6) with strides (6, 1)
self.check_as_strided_via_views(base, (4, 6), (6, 1))
def test_2d_non_contiguous_with_gap(self) -> None:
"""Test with 2D tensor where dims have a gap (non-overlapping but not contiguous)."""
# Create tensor with stride gap: select every other row
base = torch.arange(48).view(8, 6)[::2, :] # (4, 6) with strides (12, 1)
# Narrow and reshape
self.check_as_strided_via_views(base, (2, 3), (12, 1))
def test_2d_sliced_both_dims(self) -> None:
"""Test with 2D tensor sliced on both dimensions."""
# Note: slicing creates storage offset in base, which implementation may not handle
# So we create a properly contiguous base tensor instead
base = torch.arange(20).view(5, 4) # (5, 4) with strides (4, 1), offset 0
# Test narrow
self.check_as_strided_via_views(base, (3, 2), (4, 1))
def test_3d_non_contiguous_permuted(self) -> None:
"""Test with 3D tensor that's been permuted (non-contiguous)."""
base = (
torch.arange(60).view(3, 4, 5).permute(2, 0, 1)
) # (5, 3, 4) with strides (1, 20, 5)
# Narrow and unflatten
self.check_as_strided_via_views(base, (3, 2, 2), (1, 20, 5))
def test_3d_select_creates_2d(self) -> None:
"""Test with 2D tensor with non-contiguous dimensions (gap between strides)."""
# Create a 2D tensor where dims are non-contiguous but non-overlapping
# We'll create this by using as_strided to avoid source offset issues
base = torch.as_strided(torch.arange(100), (3, 4), (20, 5))
# Both dims are non-contiguous with each other (gap of 15 between positions)
self.check_as_strided_via_views(base, (2, 2), (20, 5))
def test_2d_double_strided(self) -> None:
"""Test with 2D tensor where both dimensions have stride gaps."""
# Create tensor with gaps in both dimensions
base = torch.arange(100).view(10, 10)[::2, ::3] # (5, 4) with strides (20, 3)
# The two dimensions are discontiguous with each other
self.check_as_strided_via_views(base, (3, 2), (20, 3))
# Storage offset tests with multi-dimensional discontiguous sources
def test_storage_offset_2d_contiguous_source(self) -> None:
"""Test storage offset with 2D contiguous source."""
base = torch.arange(30).view(5, 6) # Contiguous (5, 6) with strides (6, 1)
# Apply offset that spans into second row
self.check_as_strided_via_views(base, (3, 4), (6, 1), storage_offset=8)
def test_storage_offset_2d_transposed_source(self) -> None:
"""Test storage offset with 2D transposed (non-contiguous) source."""
base = torch.arange(30).view(5, 6).t() # (6, 5) with strides (1, 6)
# Apply offset
self.check_as_strided_via_views(base, (4, 3), (1, 6), storage_offset=7)
def test_storage_offset_2d_discontiguous_both_dims(self) -> None:
"""Test storage offset where narrow needed on both discontiguous source dims."""
# Create tensor where both dimensions are discontiguous
base = torch.arange(100).view(10, 10)[::2, ::3] # (5, 4) with strides (20, 3)
# Offset = 23: This requires narrow on both dims to consume the offset
# 23 = 1*20 + 1*3, so we need to narrow first dim by 1 and second dim by 1
self.check_as_strided_via_views(base, (3, 2), (20, 3), storage_offset=23)
def test_storage_offset_2d_discontiguous_larger_offset(self) -> None:
"""Test larger storage offset requiring narrows on multiple discontiguous dims."""
base = torch.arange(100).view(10, 10)[::2, ::3] # (5, 4) with strides (20, 3)
# Offset = 46: 46 = 2*20 + 2*3, requires narrow on both dims
self.check_as_strided_via_views(base, (2, 2), (20, 3), storage_offset=46)
def test_storage_offset_3d_discontiguous_all_dims(self) -> None:
"""Test storage offset with 3D discontiguous tensor requiring narrows on multiple dims."""
# Create 3D tensor with gaps in all dimensions
base = torch.arange(1000).view(10, 10, 10)[
::2, ::3, ::5
] # (5, 4, 2) with strides (200, 30, 5)
# Offset = 30: just one step in the middle dimension
# Requires narrow on discontiguous dimensions
self.check_as_strided_via_views(
base, (4, 2, 2), (200, 30, 5), storage_offset=30
)
def test_storage_offset_3d_discontiguous_complex(self) -> None:
"""Test complex offset distribution across 3D discontiguous dims."""
base = torch.arange(1000).view(10, 10, 10)[
::3, ::2, ::4
] # (4, 5, 3) with strides (300, 20, 4)
# Offset = 24: 24 = 1*20 + 1*4
# Requires narrow on the smaller stride dimensions
self.check_as_strided_via_views(
base, (2, 3, 2), (300, 20, 4), storage_offset=24
)
def test_storage_offset_2d_permuted_discontiguous(self) -> None:
"""Test storage offset where source dims are in non-descending stride order."""
# Create non-contiguous tensor then permute so strides aren't descending
base_raw = torch.arange(100).view(10, 10)[
::3, ::2
] # (4, 5) with strides (30, 2)
# This already has strides in descending order, so let's transpose it
base = base_raw.t() # (5, 4) with strides (2, 30)
# Now strides are NOT in descending order (2 < 30)
# Offset = 32: 32 = 1*30 + 1*2
self.check_as_strided_via_views(base, (2, 2), (2, 30), storage_offset=32)
def test_storage_offset_barely_fits(self) -> None:
"""Test storage offset with discontiguous source requiring narrows on both dims."""
base = torch.arange(100).view(10, 10)[::3, ::4] # (4, 3) with strides (30, 4)
# Offset = 34 = 1*30 + 1*4, requires narrow on both dimensions
# For target (2, 2), extent = 34 + (2-1)*30 + (2-1)*4 = 34 + 34 = 68 < 98 (max base index)
self.check_as_strided_via_views(base, (2, 2), (30, 4), storage_offset=34)
if __name__ == "__main__":
run_tests()

View File

@ -405,6 +405,7 @@ isolate_fails_code_str = None
# pyrefly: ignore [missing-attribute]
kernel._fn_name
if isinstance(kernel, JITFunction)
# pyrefly: ignore # missing-attribute
else kernel.fn._fn_name
)
fn_name = fn_name.split(".")[-1]

View File

@ -264,6 +264,7 @@ def generate_ttir(
assert isinstance(kernel, JITFunction)
# pyrefly: ignore # missing-attribute
context = triton._C.libtriton.ir.context()
target = triton.runtime.driver.active.get_current_target()
backend = triton.compiler.compiler.make_backend(target)
@ -305,6 +306,7 @@ def generate_ttir(
base_tensor = torch.empty(
[elements_per_dim] * len(block_shape), dtype=a.dtype
)
# pyrefly: ignore # bad-argument-type
ordered_args[name] = TensorDescriptor.from_tensor(base_tensor, block_shape)
elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)):
with torch._C._DisableTorchDispatch():
@ -368,6 +370,7 @@ def generate_ttir(
target = triton.runtime.driver.active.get_current_target()
backend_ = triton.compiler.compiler.make_backend(target)
# pyrefly: ignore # missing-attribute
return backend_.get_attrs_descriptor(args, kernel.params)
else:
assert (
@ -384,6 +387,7 @@ def generate_ttir(
except TypeError: # Unknown arg `specialize_extra`
# Older versions of Triton take specialize_extra as an arg to specialize_impl
specialize_impl = functools.partial(
# pyrefly: ignore # missing-argument
triton.runtime.jit.create_specialize_impl(),
specialize_extra=backend.get_arg_specialization,
)
@ -468,6 +472,7 @@ def generate_ttir(
if i not in constexprs
}
# pyrefly: ignore # missing-attribute
triton._C.libtriton.ir.load_dialects(context)
backend.load_dialects(context)
@ -477,22 +482,29 @@ def generate_ttir(
# backward compatibility here.
make_ir_sig_params = len(inspect.signature(src.make_ir).parameters)
get_codegen_implementation_sig_params = len(
# pyrefly: ignore # missing-attribute
inspect.signature(backend.get_codegen_implementation).parameters
)
if make_ir_sig_params == 2:
# pyrefly: ignore # missing-argument
ttir_module = src.make_ir(options, context)
elif make_ir_sig_params == 3:
# pyrefly: ignore # missing-attribute
codegen_fns = backend.get_codegen_implementation()
# pyrefly: ignore # missing-argument
ttir_module = src.make_ir(options, codegen_fns, context)
elif make_ir_sig_params == 4:
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
# pyrefly: ignore # missing-attribute
codegen_fns = backend.get_codegen_implementation(*codegen_args)
module_map = backend.get_module_map()
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
else:
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
# pyrefly: ignore # missing-attribute
codegen_fns = backend.get_codegen_implementation(*codegen_args)
module_map = backend.get_module_map()
# pyrefly: ignore # bad-argument-count
ttir_module = src.make_ir(target, options, codegen_fns, module_map, context)
if not ttir_module.verify():
raise RuntimeError("Verification for TTIR module has failed")
@ -1102,6 +1114,7 @@ def triton_kernel_wrapper_mutation_dense(
from triton.tools.tensor_descriptor import TensorDescriptor
block_shape = stable_meta[0]
# pyrefly: ignore # bad-argument-type
kwargs[k] = TensorDescriptor.from_tensor(tensor, block_shape)
# move as many positional arguments from dicts to args as we
@ -1658,6 +1671,7 @@ class TritonHOPifier:
"Passing multiple @triton.autotune decorators is not supported. "
"Please use a single @triton.autotune decorator instead."
)
# pyrefly: ignore # missing-attribute
iter_kernel = iter_kernel.fn
# Process the @triton.heuristics decorator:
@ -1868,6 +1882,7 @@ class TritonHOPifier:
# Both for grid's meta as well as for the kernel, we need combined
# args and kwargs combined and normalized
# pyrefly: ignore # missing-attribute
combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs}
# precompute the grid for the kernel
@ -2061,6 +2076,7 @@ class TraceableTritonKernelWrapper:
kernel_idx: Optional[int],
grid: Optional["TritonGridType"],
) -> None:
# pyrefly: ignore # bad-assignment
self.kernel = None
self.grid = None
tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)

View File

@ -4,14 +4,17 @@ import itertools
import logging
from typing import Any, Optional
from torch._C import DispatchKey
import torch
import torch.utils._pytree as pytree
from torch._higher_order_ops.utils import reenter_make_fx
from torch._higher_order_ops.utils import reenter_make_fx, redirect_to_mode
from torch._logging import warning_once
from torch._ops import HigherOrderOperator
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.types import _dtype
from torch.utils._debug_mode import DebugMode
from torch.utils.checkpoint import _CachedTorchDispatchMode, _CachingTorchDispatchMode
log = logging.getLogger(__name__)
@ -41,6 +44,27 @@ class Wrap(HigherOrderOperator):
wrap = Wrap()
class InductorCompiledCode(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("inductor_compiled_code")
def __call__(self, func, *args, **kwargs):
return super().__call__(func, *args, **kwargs)
inductor_compiled_code = InductorCompiledCode()
inductor_compiled_code.fallthrough(DispatchKey.AutogradCPU)
inductor_compiled_code.fallthrough(DispatchKey.AutogradCUDA)
@inductor_compiled_code.py_impl(DispatchKey.CompositeExplicitAutograd)
def inductor_compiled_code_impl(func, inputs):
return func(inputs)
redirect_to_mode(inductor_compiled_code, DebugMode)
redirect_to_mode(inductor_compiled_code, _CachingTorchDispatchMode)
redirect_to_mode(inductor_compiled_code, _CachedTorchDispatchMode)
class WrapWithSetGradEnabled(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("wrap_with_set_grad_enabled")

View File

@ -289,11 +289,15 @@ def user_defined_triton_kernel_transitive_closure_source_code(kernel) -> str:
if isinstance(symbol, JITFunction):
compile_wrapper.newline()
compile_wrapper.writeline("@triton.jit")
# pyrefly: ignore # missing-attribute
compile_wrapper.splice(symbol.src, strip=True)
symbols_included.add(symbol_name)
traverse(symbol)
elif hasattr(triton, "constexpr_function") and isinstance(
symbol, triton.runtime.jit.ConstexprFunction
# pyrefly: ignore # missing-attribute
symbol,
# pyrefly: ignore # missing-attribute
triton.runtime.jit.ConstexprFunction,
):
compile_wrapper.newline()
compile_wrapper.writeline("@triton.constexpr_function")

View File

@ -949,7 +949,9 @@ class FxConverter:
from triton.runtime import driver
log.info("Autotuning Triton kernel %s at compile time.", kernel_name)
# pyrefly: ignore # missing-attribute
device = driver.active.get_current_device()
# pyrefly: ignore # missing-attribute
stream = driver.active.get_current_stream(device)
def node_to_tuning_arg(arg: Any) -> Any:

View File

@ -6983,6 +6983,7 @@ class UserDefinedTritonKernel(ExternKernel):
configs = kernel.configs
kernel = kernel.fn
# pyrefly: ignore # bad-return
return kernel, configs, restore_value_args, reset_to_zero_args
@override
@ -7138,7 +7139,10 @@ class UserDefinedTritonKernel(ExternKernel):
self.mutable_args = [
kernel_args[key]
for key in identify_mutated_tensors(
kernel, {**kernel_args, **autotuned_kwargs}, tma_descriptor_metadata
# pyrefly: ignore # bad-argument-type
kernel,
{**kernel_args, **autotuned_kwargs},
tma_descriptor_metadata,
)
]

View File

@ -52,6 +52,8 @@ from torch._inductor.utils import (
)
from torch.autograd.profiler import record_function
from torch.utils._ordered_set import OrderedSet
from torch.utils._python_dispatch import is_in_torch_dispatch_mode
from torch._higher_order_ops.wrap import inductor_compiled_code
from . import config
from .runtime.autotune_cache import AutotuneCacheBundler
@ -616,8 +618,15 @@ class CompiledFxGraph(OutputCode):
with record_function(
f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##"
):
if is_in_torch_dispatch_mode():
return inductor_compiled_code(self.current_callable, inputs)
return self.current_callable(inputs)
else:
# TODO: optimize this some more
# NB: Tensor dispatch is NOT supported
# TODO: deal with boxed calling convention
if is_in_torch_dispatch_mode():
return inductor_compiled_code(self.current_callable, inputs)
return self.current_callable(inputs)
finally:
get_runtime_metrics_context().finish()

View File

@ -2555,11 +2555,14 @@ def get_device_tflops(dtype: torch.dtype) -> float:
return get_max_simd_tflops(torch.float32, sm_clock)
else:
if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
# pyrefly: ignore # missing-argument
return get_max_tensorcore_tflops(dtype)
if torch.backends.cuda.matmul.allow_tf32:
# pyrefly: ignore # missing-argument
return get_max_tensorcore_tflops(torch.float32)
else:
# pyrefly: ignore # missing-argument
return get_max_simd_tflops(torch.float32)
@ -2573,6 +2576,7 @@ def get_gpu_dram_gbps() -> int:
def get_gpu_shared_memory() -> int:
from triton.runtime import driver
# pyrefly: ignore # missing-attribute
return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)

View File

@ -967,7 +967,7 @@ static PyObject* THPVariable_dtensor_new(
Tensor tensor = make_tensor_for_subclass_helper(
/*sym_sizes=*/tuple_to_symintlist(sizes.ptr()),
/*sym_strides=*/tuple_to_symintlist(stride.ptr()),
/*sym_storage_offset=*/std::nullopt,
/*sym_storage_offset=*/local_tensor.sym_storage_offset(),
options,
/*storage_size=*/std::nullopt,
extra_dispatch_keys);

View File

@ -9,6 +9,7 @@ import torch
import torch.distributed as dist
import torch.distributed.tensor._api as dtensor
import torch.distributed.tensor._random as random
from torch._library.utils import fill_defaults
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpInfo, OpSchema, OutputSpecType
@ -34,6 +35,37 @@ aten = torch.ops.aten
logger = logging.getLogger(__name__)
def as_strided_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: dict[str, object],
):
args, kwargs = fill_defaults(op_call._schema, args, kwargs)
assert not kwargs
tensor, size, stride, storage_offset = args
size = tuple(size)
stride = tuple(stride)
if storage_offset is None:
storage_offset = tensor.storage_offset()
if (
tensor.size() == size
and tensor.stride() == stride
and tensor.storage_offset() == storage_offset
):
return torch.ops.aten.alias.default(tensor)
from torch.utils._as_strided import as_strided_via_views
r = as_strided_via_views(tensor, size, stride, storage_offset)
if r is NotImplemented:
raise RuntimeError(
"as_strided not supported with DTensor for these inputs:"
f"{tuple(tensor.size())}:{tensor.stride()} offset {tensor.storage_offset()} to "
f"{size}:{stride} offset {storage_offset}"
)
return r
def is_same_size_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
@ -121,6 +153,7 @@ class OpDispatcher:
aten.convolution.default: convolution_handler,
aten.convolution_backward.default: convolution_backward_handler,
aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
aten.as_strided.default: as_strided_handler,
}
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)

View File

@ -84,6 +84,7 @@ register_op_strategy(
aten.clone.default,
aten.contiguous.default,
aten.detach.default,
aten.alias.default,
aten.fill_.Scalar,
aten.view.dtype,
aten.zero_.default,

View File

@ -1302,17 +1302,28 @@ def bsr_dense_addmm(
# pyrefly: ignore [unsupported-operation]
_bsr_strided_addmm_kernel[grid](
*ptr_stride_extractor(*sliced_tensors),
# pyrefly: ignore # bad-argument-count
beta,
alpha,
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
beta_is_one=beta == 1,
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
beta_is_nonzero=beta != 0,
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
alpha_is_one=alpha == 1,
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
left_alpha_is_one=left_alpha_is_one,
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
right_alpha_is_one=right_alpha_is_one,
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
BLOCKSIZE_ROW=BM,
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
BLOCKSIZE_INNER=BK,
# pyrefly: ignore # bad-keyword-argument
BLOCKSIZE_COL=BN,
# pyrefly: ignore # bad-keyword-argument
allow_tf32=dot_out_dtype == tl.float32,
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
acc_dtype=dot_out_dtype,
**meta,
)
@ -1633,12 +1644,17 @@ if has_triton():
beta,
is_beta_zero,
*blocksize,
# pyrefly: ignore # bad-argument-count
k,
tile_k,
*ptr_stride_extractor(*sliced_tensors),
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
acc_dtype=acc_dtype,
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
allow_tf32=allow_tf32,
# pyrefly: ignore # unexpected-keyword
num_stages=1,
# pyrefly: ignore # unexpected-keyword
num_warps=4,
)
@ -1923,6 +1939,7 @@ if has_triton():
def kernel(grid, *sliced_tensors):
_bsr_softmax_kernel[grid](
*ptr_stride_extractor(*sliced_tensors),
# pyrefly: ignore # bad-argument-count
row_block,
col_block,
max_row_nnz,
@ -2096,8 +2113,11 @@ if has_triton():
if "allow_tf32" not in meta:
meta.update(allow_tf32=dot_out_dtype == tl.float32)
_scatter_mm2_kernel[grid](
# pyrefly: ignore # bad-argument-type
M,
# pyrefly: ignore # bad-argument-type
K,
# pyrefly: ignore # bad-argument-type
N,
blocks,
blocks.stride(0),
@ -2116,7 +2136,9 @@ if has_triton():
pq_indices,
pq_indices.stride(0),
pq_indices.stride(1),
# pyrefly: ignore # bad-argument-type
dot_out_dtype=dot_out_dtype,
# pyrefly: ignore # bad-argument-type
**meta,
)
@ -2299,6 +2321,7 @@ if has_triton():
_scatter_mm6_kernel[grid](
B,
Ms,
# pyrefly: ignore # bad-argument-type
Ks,
N,
blocks,
@ -2317,6 +2340,7 @@ if has_triton():
r_offsets,
p_offsets,
q_offsets,
# pyrefly: ignore # bad-argument-type
dot_out_dtype=dot_out_dtype,
**meta,
)

529
torch/utils/_as_strided.py Normal file
View File

@ -0,0 +1,529 @@
"""Utility for reconstructing a sequence of view operations that is equivalent
to an as_strided call (change a source tensor's size/stride/offset to a
target size/stride/offset).
Given an arbitrary source tensor, imagine that a user has applied a
sequence of view operations (slice, permute and view) producing a target
tensor. With only the source and the target tensor, construct a sequence
of views that takes the source tensor to the target tensor.
This problem is trivial with as_strided, but often, it can be helpful to be
able to construct a sequence of "normal" view operations instead of brute
forcing. The reason is that as_strided is extremely flexible, making it
difficult for backends to implement: for example, using it you can generate
views with strange overlap patterns (e.g., rolling windows) or generate an
output view which is out of bounds for the original view. If you are
implementing a tensor subclass, it may be possible to implement simple views
but not as_strided (e.g., DTensor sharding propagation). These utilities help
you reconstruct view operations in some situations where it is possible.
We apply some simplifying assumptions for this algorithm:
1. The stride of a size-1 dimension doesn't matter. Actually, sometimes it
does (e.g., in the case of inferring memory layout), but our longer term
plan for addressing this is to add a special unsqueeze operation which will
let us specify exactly what stride an unsqueezed dimension should get.
2. We assume the input tensor is non-overlapping (except for stride-0 dimensions)
i.e., that for every physical location there is a unique coordinate that
maps to it. (This implies the destination tensor is non-overlapping except
for stride-0 dimensions, as we do not include unfold in the list of valid
view operations.)
3. We only account for these view operations:
- view; or equivalently, these four operations:
- squeeze
- unsqueeze
- unflatten
- flatten (output view only)
- permute (this subsumes transpose, movedim, etc.)
- narrow (slice with step=1 only; no support for nontrivial step)
- expand
# Definition of view operations
To start, it's helpful to review how these view operations affect the size/stride
of a tensor. In general, these operations (aside from permute and flatten) only affect a
single source dimension, so we will only discuss the action on a single dimension.
We will use (size):(stride) (CuTe notation) to compactly denote sizes and strides.
It can be helpful to study the rules when a = 1, as this variable is unconstrained
in all the formulas.
The important operations are these three:
- unflatten_2d(dim, x, y) # 2D case is sufficient to implement ND case
(x * y,):(a,) --> (x, y):(y * a, a)
- narrow(dim, start, length)
(x,):(a,) --> (length,):(a,)
where storage_offset += start * a
NB: slice(dim, start, stop, step) with nontrivial step is not currently
supported, though it could be added in the future.
- squeeze(dim)
(1,):(a,) --> ():()
These two are simply inverses:
- unsqueeze(dim):
(x,):(a,) --> (1, x):(x * a, a)
- flatten_2d(dim) # flatten dim and dim+1; sufficient for ND case
(x, y):(y * a, a) --> (x * y,):(a,)
And permute reorders the sizes-strides without otherwise making modifications.
# The algorithm
First, take the source tensor and put it into canonical form by:
1. Removing all size-1 and stride-0 dimensions: For each stride-0 dimension,
narrow it to size 1, then squeeze all size-1 dimensions at once.
2. Permuting the tensor so the strides are in strictly descending order.
This is always possible as all size-1 dimensions are removed and we
required the input to be non-overlapping.
3. Flattening all contiguous dims, so that every pair of adjacent
dims is non-contiguous with each other (cannot be flattened).
For any non size-1/stride-0 dimension in the target tensor, we can uniquely
assign it to the only canonical source dim that could have created it.
Specifically, for any dimension (x,):(a,), any target stride s such that
a <= s <= (x-1) * a could only have been generated from this source dim.
(x*a is excluded as it can only be generated by a size-1 dim! And it is
important to be precise on the bound here because it is not
guaranteed that on a canonical source tensor (x,y):(a,b) that y*b <= a).
So our plan is to process each canonical source dim one-by-one and translate
them into the target dimensions.
For each canonical dimension, we may have multiple target dimensions that map
to it. We process these target dimensions left to right (largest stride first).
For each target dimension (except the last one from this canonical dim):
1. We unflatten to create the target dimension and a "remainder" dimension.
The key insight: unflatten(x*y, (x, y)) gives strides (y*a, a) where a
is the canonical stride. To get the target stride s for the first dim,
we need y*a = s, so we choose y = s/a.
2. This creates dimension with the correct stride. We then move to the next
target dimension, which will be carved out of the "remainder" dimension.
3. If the needed size (target_size * y) exceeds available size, we return
NotImplemented.
For the last target dimension from this canonical dim:
1. Its stride must match the canonical stride, otherwise we return NotImplemented.
2. Narrow to the final target size.
Limitations:
- Slice operations with nontrivial step (e.g., [::2]) are not currently
supported. This means some valid view sequences cannot be reconstructed.
"""
import functools
import types
import torch
def _canonicalize_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""
Canonicalize a tensor by:
1. Narrowing stride-0 dimensions to size 1
2. Squeezing all size-1 dimensions
3. Permuting so strides are in descending order
4. Flattening contiguous dimensions
Returns a tensor with strictly descending strides where no two adjacent
dims are contiguous with each other, and every dim has size > 1.
"""
result = tensor
# Narrow stride-0 dimensions to size 1
for dim in reversed(range(result.ndim)):
if result.stride(dim) == 0:
result = result.narrow(dim, 0, 1)
# Squeeze all size-1 dimensions
# TODO: don't squeeze if unnecessary
result = result.squeeze()
# Permute so strides are in descending order
stride_order = sorted(
range(result.ndim), key=lambda i: result.stride(i), reverse=True
)
if stride_order != list(range(result.ndim)):
result = result.permute(stride_order)
# Flatten contiguous dimensions
# TODO: do it in one go rather than lots of flatten calls
i = 0
while i < result.ndim - 1:
if result.stride(i) == result.size(i + 1) * result.stride(i + 1):
result = result.flatten(i, i + 1)
else:
i += 1
return result
def _equal_significant_size_strides(lsize, lstride, rsize, rstride):
if lsize != rsize:
return False
if any(s == 0 for s in lsize):
return True
for s, l, r in zip(lsize, lstride, rstride):
if s == 1:
continue
if l != r:
return False
return True
def _unexpand_target_then(result: torch.Tensor, size, stride, storage_offset, cb):
"""
Remove stride-0 dimensions from target, call cb, then expand them back.
This handles expand operations which create stride-0 dimensions.
For example: (5,):(1,) -> expand -> (3, 5):(0, 1)
"""
# Identify stride-0 dimensions (from expand)
stride0_dims = tuple(i for i, s in enumerate(stride) if s == 0)
stride0_sizes = tuple(size[i] for i in stride0_dims)
# Remove stride-0 dimensions
new_size = tuple(s for i, s in enumerate(size) if i not in stride0_dims)
new_stride = tuple(s for i, s in enumerate(stride) if i not in stride0_dims)
result = cb(result, new_size, new_stride, storage_offset)
if result is NotImplemented:
return NotImplemented
assert (
_equal_significant_size_strides(
result.size(), result.stride(), new_size, new_stride
)
and result.storage_offset() == storage_offset
)
# Expand the stride-0 dimensions back
for i, orig_size in zip(stride0_dims, stride0_sizes):
# Because we expand left-to-right, result's dims get shifted so we
# need to unsqueeze at position i then expand
result = torch.unsqueeze(result, i)
# Expand the newly unsqueezed dimension to the original size
expand_size = list(result.size())
expand_size[i] = orig_size
result = result.expand(expand_size)
assert (
_equal_significant_size_strides(result.size(), result.stride(), size, stride)
and result.storage_offset() == storage_offset
)
return result
def _squeeze_target_then(result: torch.Tensor, size, stride, storage_offset, cb):
dims = tuple(i for i, s in enumerate(size) if s == 1)
new_size = tuple(s for i, s in enumerate(size) if i not in dims)
new_stride = tuple(s for i, s in enumerate(stride) if i not in dims)
result = cb(result, new_size, new_stride, storage_offset)
if result is NotImplemented:
return NotImplemented
assert (
_equal_significant_size_strides(
result.size(), result.stride(), new_size, new_stride
)
and result.storage_offset() == storage_offset
)
for i in dims:
# Because we unsqueeze left-to-right result's dims get shifted so we
# put dims in the right spot as we go
result = torch.unsqueeze(result, i)
assert (
_equal_significant_size_strides(result.size(), result.stride(), size, stride)
and result.storage_offset() == storage_offset
)
return result
def _permute_target_then(result: torch.Tensor, size, stride, storage_offset, cb):
perm, sorted_stride = zip(
*sorted(enumerate(stride), key=lambda x: x[1], reverse=True)
)
sorted_size = tuple(size[i] for i in perm)
result = cb(result, sorted_size, sorted_stride, storage_offset)
if result is NotImplemented:
return NotImplemented
assert (
_equal_significant_size_strides(
result.size(), result.stride(), sorted_size, sorted_stride
)
and result.storage_offset() == storage_offset
)
inv = [0] * len(perm)
for i, p in enumerate(perm):
inv[p] = i
result = result.permute(inv)
assert (
_equal_significant_size_strides(result.size(), result.stride(), size, stride)
and result.storage_offset() == storage_offset
)
return result
def _process_canonical_dims(
result: torch.Tensor,
size,
stride,
storage_offset,
) -> torch.Tensor | types.NotImplementedType:
"""
Process canonical source dimensions and map them to target dimensions.
For each canonical dimension (x,):(a,), any target stride s such that
a <= s <= (x-1) * a could only have been generated from this source dim.
Since the target stride is sorted in descending order (after _permute_target_then),
we can process each canonical dim and consume target dims as we go.
Offset is consumed opportunistically during narrow operations.
"""
# Read canonical sizes/strides from result (already canonicalized)
canonical_sizes = result.size()
canonical_strides = result.stride()
# Compute remaining offset to consume
remaining_offset = storage_offset - result.storage_offset()
# Target dims are already sorted by stride descending (via _permute_target_then)
# and size-1/stride-0 dims have been removed (via _squeeze_target_then)
target_dims = [(i, size[i], stride[i]) for i in range(len(size))]
# Pointer to current position in target_dims
target_idx = 0
# Current dimension in result that we're working on
current_dim = 0
for canonical_dim in range(len(canonical_sizes)):
canonical_size = canonical_sizes[canonical_dim]
canonical_stride = canonical_strides[canonical_dim]
# Collect all target dims that map to this canonical dim
# For dimension (x,):(a,), target stride s must satisfy: a <= s <= (x-1) * a
min_stride = canonical_stride
max_stride = (result.size(current_dim) - 1) * canonical_stride
target_dims_for_this_canonical = []
while target_idx < len(target_dims):
original_pos, target_size, target_stride = target_dims[target_idx]
if min_stride <= target_stride <= max_stride:
target_dims_for_this_canonical.append(
(original_pos, target_size, target_stride)
)
target_idx += 1
else:
# Since target strides are in descending order, once we find one that
# doesn't match, we won't find any more for this canonical dim
break
if not target_dims_for_this_canonical:
# No target dims map to this canonical dim, skip it
# (This dimension gets "lost" in the transformation)
continue
# Process target_dims left to right (largest stride first, already sorted)
# Offset is consumed opportunistically during each narrow operation
# Track which result dimension we're working on within this inner loop
work_dim = current_dim
for i, (original_pos, target_size, target_stride) in enumerate(
target_dims_for_this_canonical
):
is_last = i == len(target_dims_for_this_canonical) - 1
if not is_last:
# unflatten(a*b, (a, b)) gives strides (b*canonical_stride, canonical_stride)
# To get first dim stride = target_stride: b * canonical_stride = target_stride
# So b = target_stride / canonical_stride
next_dim_size = target_stride // canonical_stride
needed_size = target_size * next_dim_size
# Opportunistically consume offset when narrowing for unflatten
narrow_start = 0
if remaining_offset > 0:
offset_in_dim = remaining_offset // canonical_stride
if offset_in_dim < result.size(work_dim):
narrow_start = offset_in_dim
remaining_offset -= offset_in_dim * canonical_stride
# Check if we have enough size
if result.size(work_dim) < needed_size + narrow_start:
return NotImplemented
# Narrow and unflatten
result = result.narrow(work_dim, narrow_start, needed_size)
result = result.unflatten(work_dim, (target_size, next_dim_size))
# After unflatten, next target dim works on the remainder at work_dim + 1
work_dim += 1
else:
# Last target dim
if target_stride == canonical_stride:
# Simple case: stride matches, just narrow to size
# Opportunistically consume offset here
narrow_start = 0
if remaining_offset > 0:
offset_in_dim = remaining_offset // canonical_stride
if offset_in_dim < result.size(work_dim):
narrow_start = offset_in_dim
remaining_offset -= offset_in_dim * canonical_stride
# Check if we have enough size
if result.size(work_dim) < target_size + narrow_start:
return NotImplemented
# Narrow to target size from the offset position
result = result.narrow(work_dim, narrow_start, target_size)
elif (
target_stride > canonical_stride
and target_stride % canonical_stride == 0
):
# Edge case: need to unflatten and squeeze away trailing dimensions
# unflatten(a*b, (a, b)) gives strides (b*canonical_stride, canonical_stride)
# We want first dim stride = target_stride, so b = target_stride / canonical_stride
remainder_size = target_stride // canonical_stride
needed_size = target_size * remainder_size
# Check if we have enough size
if result.size(work_dim) < needed_size:
return NotImplemented
# Narrow if there's excess size
if result.size(work_dim) != needed_size:
result = result.narrow(work_dim, 0, needed_size)
# Unflatten to create target dimension and remainder
result = result.unflatten(work_dim, (target_size, remainder_size))
# The remainder dimension (at work_dim + 1) needs to be eliminated
# Opportunistically consume any remaining offset here
narrow_start = 0
if remaining_offset > 0 and canonical_stride > 0:
# Can we consume offset by selecting a different position in the remainder?
offset_in_remainder = remaining_offset // canonical_stride
if offset_in_remainder < remainder_size:
narrow_start = offset_in_remainder
remaining_offset -= offset_in_remainder * canonical_stride
result = result.narrow(work_dim + 1, narrow_start, 1)
result = result.squeeze(work_dim + 1)
else:
return NotImplemented
current_dim += 1
# Check if all offset was consumed
if remaining_offset != 0:
return NotImplemented
return result
def as_strided_via_views(
tensor: torch.Tensor,
size: tuple[int, ...],
stride: tuple[int, ...],
storage_offset: int = 0,
) -> torch.Tensor | types.NotImplementedType:
"""
Attempt to reconstruct a sequence of view operations that would produce
the target size/stride/offset from the source tensor.
Args:
tensor: Source tensor
size: Target size
stride: Target stride
storage_offset: Target storage offset (default: 0)
Returns:
A view of the tensor with the target size/stride/offset, or NotImplemented
if it cannot be achieved via simple view operations.
"""
import math
# NB: When strides are insignificant, we don't reproduce them exactly, as
# this would make the algorithm a lot more complicated for no reason
target_numel = math.prod(size)
# Easy case 1: source tensor has numel==0
if tensor.numel() == 0:
if target_numel != 0:
return NotImplemented
return tensor.view(size)
# Easy case 2: source tensor has numel==1
if tensor.numel() == 1:
if storage_offset != tensor.storage_offset():
return NotImplemented
if target_numel == 0:
# If tensor is 0D scalar, we need to unsqueeze first to get a dimension to narrow
if tensor.ndim == 0:
tensor = tensor.unsqueeze(0)
tensor = tensor.narrow(0, 0, 0)
return tensor.view(size)
if target_numel == 1:
# Simple case: numel 1 -> numel 1 (no expand needed)
return tensor.view(size)
# target_numel > 1: For numel==1 source, target is valid only if
# all dimensions with size > 1 have stride 0 (can be expanded)
requires_expand = any(sz > 1 and st != 0 for sz, st in zip(size, stride))
if requires_expand:
return NotImplemented
# Build target by unsqueezing and expanding
result = tensor.view(1) # Flatten to (1,)
# Unsqueeze to match target ndim
for _ in range(len(size) - 1):
result = result.unsqueeze(0)
# Now expand to target size (this creates stride-0 for expanded dims)
result = result.expand(size)
return result
# Step 1: Canonicalize source tensor
result = _canonicalize_tensor(tensor)
# Step 2: Calculate offset to consume
# We'll consume this opportunistically during narrow operations in _process_canonical_dims
offset_delta = storage_offset - result.storage_offset()
if offset_delta < 0:
return NotImplemented
# We're not going to take this straight to the target size/stride.
# Instead we're going to first go to a "canonicalized" target which is in
# sorted order, has no size-1 dims, no stride-0 dims (unlike source, we
# don't have to flatten contiguous dimensions; the algorithm below can
# handle it.) We need to compute this canonicalized target, and we should
# also record enough information so we can invert this transformation (so
# that we can apply the inverse on result to get to our final, desired
# result. So _unexpand_target_then/_squeeze_target_then/_permute_target_then
# take care of modifying the target and inverting it, and _process_canonical_dims
# is the main payload.
return _unexpand_target_then(
result,
size,
stride,
storage_offset,
cb=functools.partial(
_squeeze_target_then,
cb=functools.partial(_permute_target_then, cb=_process_canonical_dims),
),
)

View File

@ -1300,6 +1300,10 @@ SAC_IGNORED_OPS = {
class _CachingTorchDispatchMode(TorchDispatchMode):
@classmethod
def ignore_compile_internals(cls):
return True
# Used together with _CachedTorchDispatchMode to implement SAC.
def __init__(self, policy_fn, storage):
self.policy_fn = policy_fn
@ -1336,6 +1340,10 @@ class _CachingTorchDispatchMode(TorchDispatchMode):
return out
class _CachedTorchDispatchMode(TorchDispatchMode):
@classmethod
def ignore_compile_internals(cls):
return True
# Used together with _CachedTorchDispatchMode to implement SAC.
def __init__(self, policy_fn, storage, allow_cache_entry_mutation):
self.policy_fn = policy_fn