mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 01:54:54 +08:00
Compare commits
3 Commits
main
...
ciflow/ind
| Author | SHA1 | Date | |
|---|---|---|---|
| 1af3a82933 | |||
| fae8ca233d | |||
| eb24dbfce8 |
187
aaa.py
Normal file
187
aaa.py
Normal file
@ -0,0 +1,187 @@
|
||||
import torch
|
||||
from torch.nn.attention import flex_attention
|
||||
from torch.utils._debug_mode import DebugMode
|
||||
from torch import nn
|
||||
import functools
|
||||
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_create_empty_block_mask,
|
||||
_DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
_identity,
|
||||
_mask_mod_signature,
|
||||
_score_mod_signature,
|
||||
_WARNINGS_SHOWN,
|
||||
and_masks,
|
||||
AuxOutput,
|
||||
AuxRequest,
|
||||
BlockMask,
|
||||
create_block_mask,
|
||||
flex_attention,
|
||||
flex_attention_hop,
|
||||
noop_mask,
|
||||
or_masks,
|
||||
)
|
||||
|
||||
"""
|
||||
@torch.compile()
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
with DebugMode() as debug_mode:
|
||||
print(f(torch.randn(2), torch.randn(2)))
|
||||
|
||||
print(debug_mode.debug_string())
|
||||
"""
|
||||
|
||||
# NEXT
|
||||
|
||||
"""
|
||||
from torch.utils.checkpoint import (
|
||||
checkpoint,
|
||||
CheckpointPolicy,
|
||||
create_selective_checkpoint_contexts,
|
||||
)
|
||||
|
||||
def _get_custom_policy(no_recompute_list=None, must_recompute_list=None):
|
||||
def _custom_policy(ctx, func, *args, **kwargs):
|
||||
if no_recompute_list is not None and func in no_recompute_list:
|
||||
return CheckpointPolicy.MUST_SAVE
|
||||
if must_recompute_list is not None and func in must_recompute_list:
|
||||
return CheckpointPolicy.MUST_RECOMPUTE
|
||||
else:
|
||||
return CheckpointPolicy.PREFER_RECOMPUTE
|
||||
|
||||
return _custom_policy
|
||||
|
||||
def context_fn_must_recompute_mm():
|
||||
must_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
]
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(
|
||||
must_recompute_list=must_recompute_list,
|
||||
),
|
||||
)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def mm(x, y):
|
||||
return torch.matmul(x, y)
|
||||
|
||||
def gn(x):
|
||||
return torch.sigmoid(mm(x, x))
|
||||
|
||||
def fn(x):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
gn,
|
||||
x,
|
||||
use_reentrant=False,
|
||||
context_fn=context_fn_must_recompute_mm,
|
||||
)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True, device='cuda')
|
||||
with DebugMode() as debug_mode:
|
||||
print(fn(x))
|
||||
|
||||
print(debug_mode.debug_string())
|
||||
"""
|
||||
|
||||
class FlexAttentionModule(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
|
||||
# In-projections (query, key, value)
|
||||
self.q_proj = nn.Linear(hidden_size, hidden_size)
|
||||
self.k_proj = nn.Linear(hidden_size, hidden_size)
|
||||
self.v_proj = nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
# Out-projection
|
||||
self.out_proj = nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
# Project queries, keys, and values
|
||||
q = (
|
||||
self.q_proj(x)
|
||||
.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
k = (
|
||||
self.k_proj(x)
|
||||
.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
v = (
|
||||
self.v_proj(x)
|
||||
.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
# Apply flex attention
|
||||
attn_output = torch.compile()(flex_attention)(
|
||||
#attn_output = flex_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
|
||||
# Reshape output
|
||||
attn_output = (
|
||||
attn_output.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, seq_len, self.hidden_size)
|
||||
)
|
||||
|
||||
# Out projection
|
||||
output = self.out_proj(attn_output)
|
||||
|
||||
return output
|
||||
|
||||
from torch.utils.checkpoint import (
|
||||
checkpoint,
|
||||
create_selective_checkpoint_contexts,
|
||||
)
|
||||
|
||||
# TODO: do this
|
||||
ops_to_save = [torch.ops.aten.mm.default]
|
||||
context_fn = functools.partial(
|
||||
create_selective_checkpoint_contexts, ops_to_save
|
||||
)
|
||||
|
||||
# Define a model that uses FlexAttention with selective activation checkpointing
|
||||
class SacModule(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, context_fn):
|
||||
super().__init__()
|
||||
self.flex_attn = FlexAttentionModule(hidden_size, num_heads)
|
||||
self.context_fn = context_fn
|
||||
|
||||
def forward(self, x):
|
||||
def flex_attn_fn(x):
|
||||
return self.flex_attn(x)
|
||||
|
||||
output = checkpoint(
|
||||
flex_attn_fn,
|
||||
x,
|
||||
use_reentrant=False,
|
||||
context_fn=self.context_fn,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
device='cuda'
|
||||
flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to(
|
||||
device, dtype=torch.bfloat16
|
||||
)
|
||||
x = torch.ones(8, 1024, 512, device=device, dtype=torch.bfloat16, requires_grad=True)
|
||||
|
||||
with DebugMode() as debug_mode:
|
||||
output_module = flex_module(x)
|
||||
grad_output = torch.ones_like(output_module)
|
||||
grad_module = torch.autograd.grad(
|
||||
outputs=output_module, inputs=x, grad_outputs=grad_output, retain_graph=True
|
||||
)[0]
|
||||
|
||||
print(debug_mode.debug_string())
|
||||
@ -1019,6 +1019,28 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
except ValueError:
|
||||
self.fail("Unexpected ValueError raised with run_check=False")
|
||||
|
||||
@with_comms
|
||||
def test_as_strided_identity(self):
|
||||
# Test calling as_strided with the same size/stride/offset as input tensor
|
||||
# This should be a no-op but currently fails
|
||||
device_mesh = self.build_device_mesh()
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 4, device=self.device_type)
|
||||
dtensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
|
||||
# Get the current size, stride, and storage_offset
|
||||
size = dtensor.size()
|
||||
stride = dtensor.stride()
|
||||
storage_offset = dtensor.storage_offset()
|
||||
|
||||
# Call as_strided with the exact same parameters
|
||||
result = dtensor.as_strided(size, stride, storage_offset)
|
||||
|
||||
# The result should be identical to the input
|
||||
self.assertEqual(result.size(), dtensor.size())
|
||||
self.assertEqual(result.stride(), dtensor.stride())
|
||||
self.assertEqual(result.to_local(), dtensor.to_local())
|
||||
|
||||
|
||||
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DTensorMeshTest,
|
||||
|
||||
@ -994,6 +994,96 @@ def forward(self, primals_1):
|
||||
out_dt = torch.matmul(tmp_dt, y_dt)
|
||||
out_dt.sum().backward()
|
||||
|
||||
def test_dtensor_broadcast_to_replicate(self):
|
||||
"""Test broadcast_to operation with DTensor under compilation."""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size, 1))
|
||||
placement = [Replicate(), Replicate()]
|
||||
|
||||
def fn(input):
|
||||
return torch.broadcast_to(input, (888, 12, 888, 12))
|
||||
|
||||
big_tensor = torch.randn(888, 12)
|
||||
d_tensor = distribute_tensor(big_tensor, device_mesh=mesh, placements=placement)
|
||||
|
||||
# Test eager execution
|
||||
result = fn(d_tensor)
|
||||
self.assertEqual(result.shape, torch.Size([888, 12, 888, 12]))
|
||||
|
||||
# Test compiled execution
|
||||
jitted_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
||||
jitted_result = jitted_fn(d_tensor)
|
||||
self.assertEqual(jitted_result.shape, torch.Size([888, 12, 888, 12]))
|
||||
self.assertEqual(result.full_tensor(), jitted_result.full_tensor())
|
||||
|
||||
def test_dtensor_reshape_replicate(self):
|
||||
"""Test reshape operation with DTensor under compilation."""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size, 1))
|
||||
placement = [Replicate(), Replicate()]
|
||||
|
||||
def fn(input):
|
||||
return input.reshape((1, 888, 6, 2))
|
||||
|
||||
big_tensor = torch.randn(888, 12)
|
||||
d_tensor = distribute_tensor(big_tensor, device_mesh=mesh, placements=placement)
|
||||
|
||||
# Test eager execution
|
||||
result = fn(d_tensor)
|
||||
self.assertEqual(result.shape, torch.Size([1, 888, 6, 2]))
|
||||
|
||||
# Test compiled execution
|
||||
jitted_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
||||
jitted_result = jitted_fn(d_tensor)
|
||||
self.assertEqual(jitted_result.shape, torch.Size([1, 888, 6, 2]))
|
||||
self.assertEqual(result.full_tensor(), jitted_result.full_tensor())
|
||||
|
||||
def test_dtensor_indexing_replicate(self):
|
||||
"""Test tensor indexing with Ellipsis with DTensor under compilation."""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size, 1))
|
||||
placement = [Replicate(), Replicate()]
|
||||
|
||||
def fn(input):
|
||||
return input[(Ellipsis, 1)]
|
||||
|
||||
big_tensor = torch.randn(1, 888, 12)
|
||||
d_tensor = distribute_tensor(big_tensor, device_mesh=mesh, placements=placement)
|
||||
|
||||
# Test eager execution
|
||||
result = fn(d_tensor)
|
||||
self.assertEqual(result.shape, torch.Size([1, 888]))
|
||||
|
||||
# Test compiled execution
|
||||
jitted_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
||||
jitted_result = jitted_fn(d_tensor)
|
||||
self.assertEqual(jitted_result.shape, torch.Size([1, 888]))
|
||||
self.assertEqual(result.full_tensor(), jitted_result.full_tensor())
|
||||
|
||||
def test_dtensor_tensor_split_replicate(self):
|
||||
"""Test tensor_split operation with DTensor under compilation."""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size, 1))
|
||||
placement = [Replicate(), Replicate()]
|
||||
|
||||
def fn(input):
|
||||
return torch.tensor_split(input, 3, dim=-1)
|
||||
|
||||
big_tensor = torch.randn(1, 1, 9216)
|
||||
d_tensor = distribute_tensor(big_tensor, device_mesh=mesh, placements=placement)
|
||||
|
||||
# Test eager execution
|
||||
t1, t2, t3 = fn(d_tensor)
|
||||
self.assertEqual(t1.shape, torch.Size([1, 1, 3072]))
|
||||
self.assertEqual(t2.shape, torch.Size([1, 1, 3072]))
|
||||
self.assertEqual(t3.shape, torch.Size([1, 1, 3072]))
|
||||
|
||||
# Test compiled execution
|
||||
jitted_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
||||
jitted_t1, jitted_t2, jitted_t3 = jitted_fn(d_tensor)
|
||||
self.assertEqual(jitted_t1.shape, torch.Size([1, 1, 3072]))
|
||||
self.assertEqual(jitted_t2.shape, torch.Size([1, 1, 3072]))
|
||||
self.assertEqual(jitted_t3.shape, torch.Size([1, 1, 3072]))
|
||||
self.assertEqual(t1.full_tensor(), jitted_t1.full_tensor())
|
||||
self.assertEqual(t2.full_tensor(), jitted_t2.full_tensor())
|
||||
self.assertEqual(t3.full_tensor(), jitted_t3.full_tensor())
|
||||
|
||||
@unittest.skipIf(
|
||||
torch._inductor.config.triton.native_matmul, "Matmul is now generated"
|
||||
)
|
||||
|
||||
@ -664,6 +664,101 @@ class TestViewOps(DTensorTestBase):
|
||||
)
|
||||
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_slice(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked on DTensor when slicing
|
||||
a replicated tensor.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a replicated DTensor
|
||||
tensor = torch.randn(10, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Replicate()])
|
||||
|
||||
# Perform a slice operation [1:]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[1:]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
# Verify that the DTensor's storage_offset matches the expected value
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), 1)
|
||||
|
||||
# Verify that the local tensor also has the correct storage_offset
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 1)
|
||||
|
||||
# Verify the shape is correct
|
||||
self.assertEqual(sliced_dtensor.shape, torch.Size([9]))
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[1:]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_shard_dim0_slice_dim1(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked when tensor is sharded on dim 0
|
||||
and sliced on dim 1.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a 2D tensor and shard on dim 0
|
||||
tensor = torch.randn(12, 8, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
|
||||
|
||||
# Perform a slice operation [:, 2:]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[:, 2:]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
# The storage_offset should be 2 (skipping 2 elements in each row)
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), 2)
|
||||
|
||||
# Verify that the local tensor also has the correct storage_offset
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), 2)
|
||||
|
||||
# Verify the shape is correct
|
||||
expected_shape = torch.Size([12, 6])
|
||||
self.assertEqual(sliced_dtensor.shape, expected_shape)
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[:, 2:]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
@with_comms
|
||||
def test_storage_offset_shard_dim1_slice_dim0(self):
|
||||
"""
|
||||
Test that storage_offset is properly tracked when tensor is sharded on dim 1
|
||||
and sliced on dim 0.
|
||||
"""
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
# Create a 2D tensor and shard on dim 1
|
||||
tensor = torch.randn(10, 12, device=self.device_type)
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(1)])
|
||||
|
||||
# Perform a slice operation [2:, :]
|
||||
with CommDebugMode() as comm_mode:
|
||||
sliced_dtensor = dtensor[2:, :]
|
||||
# Slicing should not trigger any communication
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
local_dim1_size = 12 // self.world_size
|
||||
expected_offset = 2 * local_dim1_size
|
||||
self.assertEqual(sliced_dtensor.storage_offset(), expected_offset)
|
||||
|
||||
self.assertEqual(sliced_dtensor.to_local().storage_offset(), expected_offset)
|
||||
|
||||
# Verify the shape is correct
|
||||
expected_shape = torch.Size([8, 12])
|
||||
self.assertEqual(sliced_dtensor.shape, expected_shape)
|
||||
|
||||
# Verify the values are correct
|
||||
expected = tensor[2:, :]
|
||||
self.assertEqual(sliced_dtensor.full_tensor(), expected)
|
||||
|
||||
|
||||
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
|
||||
TestViewOps,
|
||||
|
||||
890
test/test_as_strided.py
Normal file
890
test/test_as_strided.py
Normal file
@ -0,0 +1,890 @@
|
||||
# Owner(s): ["oncall: pt2"]
|
||||
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._prims_common import check_significant_strides
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.utils._as_strided import as_strided_via_views
|
||||
|
||||
|
||||
def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
"""Extract (sizes, strides) tuple from a tensor."""
|
||||
return (tuple(t.size()), tuple(t.stride()))
|
||||
|
||||
|
||||
def max_storage_offset(
|
||||
base_numel: int, size: tuple[int, ...], stride: tuple[int, ...]
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the maximum storage offset for which a view with the given
|
||||
size/stride would fit within a tensor of base_numel elements.
|
||||
|
||||
We conservatively require the full size*stride extent to fit inside,
|
||||
which ensures there's enough working space for unflatten operations.
|
||||
"""
|
||||
if base_numel == 0:
|
||||
return 0
|
||||
|
||||
if any(s == 0 for s in size):
|
||||
# numel == 0, all offsets from 0 to base_numel-1 are technically valid
|
||||
# but for simplicity we only test offset 0
|
||||
return 0
|
||||
|
||||
# Require the full extent of the largest dimension to fit
|
||||
# For each dimension, the extent is size[i] * stride[i]
|
||||
max_extent = max(sz * max(st, 0) for sz, st in zip(size, stride))
|
||||
max_offset = base_numel - max_extent
|
||||
return max(0, max_offset)
|
||||
|
||||
|
||||
def enumerate_reachable_states(
|
||||
initial_size: int,
|
||||
include_slice: bool = False,
|
||||
) -> set[tuple[tuple[int, ...], tuple[int, ...]]]:
|
||||
"""
|
||||
Use BFS with DP to enumerate all reachable (size, stride) states from
|
||||
a 1D contiguous tensor via valid view operations.
|
||||
|
||||
Args:
|
||||
initial_size: Size of the initial 1D tensor
|
||||
include_slice: If True, include slice operations with step>1
|
||||
|
||||
We only explore states with offset=0 (you can retroactively change the offset).
|
||||
We reject states with size=0 or size=1 dimensions as they are degenerate.
|
||||
"""
|
||||
# Create initial 1D contiguous tensor
|
||||
initial_tensor = torch.arange(initial_size)
|
||||
|
||||
initial_state = get_state(initial_tensor)
|
||||
|
||||
# Map from state to tensor for that state
|
||||
state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = {
|
||||
initial_state: initial_tensor
|
||||
}
|
||||
visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state}
|
||||
queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state])
|
||||
|
||||
while queue:
|
||||
state = queue.popleft()
|
||||
t = state_to_tensor[state]
|
||||
sizes, strides = state
|
||||
ndim = len(sizes)
|
||||
|
||||
def add_state(new_t: torch.Tensor) -> None:
|
||||
new_state = get_state(new_t)
|
||||
sizes, strides = new_state
|
||||
# Skip if has size-0 or size-1 dimensions
|
||||
if any(s == 0 or s == 1 for s in sizes):
|
||||
return
|
||||
# Only accept states where strides are in descending order
|
||||
if list(strides) != sorted(strides, reverse=True):
|
||||
return
|
||||
if new_state not in visited:
|
||||
visited.add(new_state)
|
||||
queue.append(new_state)
|
||||
state_to_tensor[new_state] = new_t
|
||||
|
||||
# 1. Unflatten: try factoring each dimension
|
||||
for dim in range(ndim):
|
||||
size = sizes[dim]
|
||||
assert size > 1
|
||||
# Try all factorizations x * y = size where both x, y >= 2
|
||||
# We only need to check x up to size // 2 since when x > size // 2,
|
||||
# y = size // x < 2, which we reject
|
||||
for x in range(2, size // 2 + 1):
|
||||
if size % x == 0:
|
||||
y = size // x
|
||||
add_state(t.unflatten(dim, (x, y)))
|
||||
|
||||
# 2. Slice/Narrow
|
||||
for dim in range(ndim):
|
||||
size = sizes[dim]
|
||||
max_step = size + 1 if include_slice else 2
|
||||
for start in range(size):
|
||||
for stop in range(start + 1, size + 1):
|
||||
for step in range(1, max_step):
|
||||
slices = [slice(None)] * ndim
|
||||
slices[dim] = slice(start, stop, step)
|
||||
add_state(t[tuple(slices)])
|
||||
|
||||
# 3. Flatten: merge adjacent dimensions
|
||||
for dim in range(ndim - 1):
|
||||
add_state(t.flatten(dim, dim + 1))
|
||||
|
||||
return visited
|
||||
|
||||
|
||||
class TestAsStrided(TestCase):
|
||||
def assertSameView(
|
||||
self,
|
||||
result: torch.Tensor | type[NotImplemented],
|
||||
target: torch.Tensor,
|
||||
base: torch.Tensor,
|
||||
msg: str = "",
|
||||
) -> None:
|
||||
self.assertIsNot(result, NotImplemented, msg=msg or "Got NotImplemented")
|
||||
self.assertEqual(result.size(), target.size(), msg=msg or "Size mismatch")
|
||||
same_strides, idx = check_significant_strides(result, target, only_cuda=False)
|
||||
if not same_strides:
|
||||
fail_msg = f"Stride mismatch at dim {idx}: result={result.stride()}, target={target.stride()}"
|
||||
if msg:
|
||||
fail_msg = f"{msg}: {fail_msg}"
|
||||
self.fail(fail_msg)
|
||||
self.assertTrue(result._is_view(), msg=msg or "Result is not a view")
|
||||
# Check that result is a view of the base tensor using object identity
|
||||
result_base = result._base if result._base is not None else result
|
||||
base_base = base._base if base._base is not None else base
|
||||
self.assertIs(
|
||||
result_base, base_base, msg=msg or "Result is not a view of the base tensor"
|
||||
)
|
||||
self.assertEqual(
|
||||
result.storage_offset(),
|
||||
target.storage_offset(),
|
||||
msg=msg or "Storage offset mismatch",
|
||||
)
|
||||
|
||||
def check_as_strided_via_views(
|
||||
self,
|
||||
base: torch.Tensor,
|
||||
size: tuple[int, ...],
|
||||
stride: tuple[int, ...],
|
||||
storage_offset: int = 0,
|
||||
) -> None:
|
||||
"""Helper to test as_strided_via_views matches torch.as_strided."""
|
||||
target = torch.as_strided(base, size, stride, storage_offset)
|
||||
result = as_strided_via_views(base, size, stride, storage_offset)
|
||||
self.assertSameView(
|
||||
result, target, base, f"Failed for {size=}, {stride=}, {storage_offset=}"
|
||||
)
|
||||
|
||||
def test_size_10_exhaustive_without_slice(self) -> None:
|
||||
"""Test that size 10 produces exactly 26 states without slice (step>1)."""
|
||||
expected_states = {
|
||||
((2,), (1,)),
|
||||
((2, 2), (2, 1)),
|
||||
((2, 2), (3, 1)),
|
||||
((2, 2), (4, 1)),
|
||||
((2, 2), (5, 1)),
|
||||
((2, 2, 2), (4, 2, 1)),
|
||||
((2, 2, 2), (5, 2, 1)),
|
||||
((2, 3), (3, 1)),
|
||||
((2, 3), (4, 1)),
|
||||
((2, 3), (5, 1)),
|
||||
((2, 4), (4, 1)),
|
||||
((2, 4), (5, 1)),
|
||||
((2, 5), (5, 1)),
|
||||
((3,), (1,)),
|
||||
((3, 2), (2, 1)),
|
||||
((3, 2), (3, 1)),
|
||||
((3, 3), (3, 1)),
|
||||
((4,), (1,)),
|
||||
((4, 2), (2, 1)),
|
||||
((5,), (1,)),
|
||||
((5, 2), (2, 1)),
|
||||
((6,), (1,)),
|
||||
((7,), (1,)),
|
||||
((8,), (1,)),
|
||||
((9,), (1,)),
|
||||
((10,), (1,)),
|
||||
}
|
||||
|
||||
actual_states = enumerate_reachable_states(10, include_slice=False)
|
||||
|
||||
self.assertEqual(len(actual_states), 26)
|
||||
self.assertEqual(actual_states, expected_states)
|
||||
|
||||
def test_size_10_exhaustive_with_slice(self) -> None:
|
||||
"""Test that size 10 produces exactly 54 states with slice (step>1)."""
|
||||
expected_states = {
|
||||
((2,), (1,)),
|
||||
((2,), (2,)),
|
||||
((2,), (3,)),
|
||||
((2,), (4,)),
|
||||
((2,), (5,)),
|
||||
((2,), (6,)),
|
||||
((2,), (7,)),
|
||||
((2,), (8,)),
|
||||
((2,), (9,)),
|
||||
((2, 2), (2, 1)),
|
||||
((2, 2), (3, 1)),
|
||||
((2, 2), (3, 2)),
|
||||
((2, 2), (4, 1)),
|
||||
((2, 2), (4, 2)),
|
||||
((2, 2), (4, 3)),
|
||||
((2, 2), (5, 1)),
|
||||
((2, 2), (5, 2)),
|
||||
((2, 2), (5, 3)),
|
||||
((2, 2), (5, 4)),
|
||||
((2, 2), (6, 1)),
|
||||
((2, 2), (6, 2)),
|
||||
((2, 2), (6, 3)),
|
||||
((2, 2), (8, 1)),
|
||||
((2, 2, 2), (4, 2, 1)),
|
||||
((2, 2, 2), (5, 2, 1)),
|
||||
((2, 3), (3, 1)),
|
||||
((2, 3), (4, 1)),
|
||||
((2, 3), (5, 1)),
|
||||
((2, 3), (5, 2)),
|
||||
((2, 3), (6, 1)),
|
||||
((2, 4), (4, 1)),
|
||||
((2, 4), (5, 1)),
|
||||
((2, 5), (5, 1)),
|
||||
((3,), (1,)),
|
||||
((3,), (2,)),
|
||||
((3,), (3,)),
|
||||
((3,), (4,)),
|
||||
((3, 2), (2, 1)),
|
||||
((3, 2), (3, 1)),
|
||||
((3, 2), (3, 2)),
|
||||
((3, 2), (4, 1)),
|
||||
((3, 3), (3, 1)),
|
||||
((4,), (1,)),
|
||||
((4,), (2,)),
|
||||
((4,), (3,)),
|
||||
((4, 2), (2, 1)),
|
||||
((5,), (1,)),
|
||||
((5,), (2,)),
|
||||
((5, 2), (2, 1)),
|
||||
((6,), (1,)),
|
||||
((7,), (1,)),
|
||||
((8,), (1,)),
|
||||
((9,), (1,)),
|
||||
((10,), (1,)),
|
||||
}
|
||||
|
||||
actual_states = enumerate_reachable_states(10, include_slice=True)
|
||||
|
||||
self.assertEqual(len(actual_states), 54)
|
||||
self.assertEqual(actual_states, expected_states)
|
||||
|
||||
def test_subset_property(self) -> None:
|
||||
"""
|
||||
Test that for sizes 2..10, each smaller tensor results in a strict
|
||||
subset of possible states compared to the next one.
|
||||
"""
|
||||
prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None
|
||||
for size in range(2, 11):
|
||||
current_states = enumerate_reachable_states(size)
|
||||
|
||||
if prev_states is not None:
|
||||
# Check that prev_states is a strict subset of current_states
|
||||
self.assertTrue(
|
||||
prev_states.issubset(current_states),
|
||||
f"States from size {size - 1} are not a subset of size {size}",
|
||||
)
|
||||
# Check that it's a strict subset (not equal)
|
||||
self.assertTrue(
|
||||
len(prev_states) < len(current_states),
|
||||
f"States from size {size - 1} should be strictly fewer than size {size}",
|
||||
)
|
||||
|
||||
prev_states = current_states
|
||||
|
||||
def test_as_strided_via_views_exhaustive(self) -> None:
|
||||
"""Exhaustively test on all reachable states from size 10, including all valid offsets."""
|
||||
initial_tensor = torch.arange(10)
|
||||
states = enumerate_reachable_states(10)
|
||||
for size, stride in states:
|
||||
# Test with offset=0
|
||||
self.check_as_strided_via_views(
|
||||
initial_tensor, size, stride, storage_offset=0
|
||||
)
|
||||
|
||||
# Test with all valid non-zero offsets
|
||||
max_offset = max_storage_offset(initial_tensor.numel(), size, stride)
|
||||
for offset in range(1, max_offset + 1):
|
||||
self.check_as_strided_via_views(
|
||||
initial_tensor, size, stride, storage_offset=offset
|
||||
)
|
||||
|
||||
# Tests for specific transformations
|
||||
def test_permute_simple(self) -> None:
|
||||
"""Test simple permute: (2, 5) with strides (5, 1) -> (5, 2) with strides (1, 5)"""
|
||||
self.check_as_strided_via_views(torch.arange(10), (5, 2), (1, 5))
|
||||
|
||||
def test_unsqueeze_front(self) -> None:
|
||||
"""Test unsqueeze at front: (2, 5) -> (1, 2, 5)"""
|
||||
self.check_as_strided_via_views(torch.arange(10), (1, 2, 5), (10, 5, 1))
|
||||
|
||||
def test_unsqueeze_middle(self) -> None:
|
||||
"""Test unsqueeze in middle: (2, 5) -> (2, 1, 5)"""
|
||||
self.check_as_strided_via_views(torch.arange(10), (2, 1, 5), (5, 5, 1))
|
||||
|
||||
def test_unsqueeze_end(self) -> None:
|
||||
"""Test unsqueeze at end: (2, 5) -> (2, 5, 1)"""
|
||||
self.check_as_strided_via_views(torch.arange(10), (2, 5, 1), (5, 1, 1))
|
||||
|
||||
def test_narrow_simple(self) -> None:
|
||||
"""Test narrow: (2, 5) -> (2, 3)"""
|
||||
self.check_as_strided_via_views(torch.arange(10), (2, 3), (5, 1))
|
||||
|
||||
def test_permute_with_unsqueeze(self) -> None:
|
||||
"""Test permute + unsqueeze: (2, 5) -> (1, 5, 2)"""
|
||||
self.check_as_strided_via_views(torch.arange(10), (1, 5, 2), (10, 1, 5))
|
||||
|
||||
def test_multiple_unsqueeze_with_narrow(self) -> None:
|
||||
"""Test multiple unsqueeze + narrow: (2, 5) -> (1, 1, 2, 3)"""
|
||||
self.check_as_strided_via_views(torch.arange(10), (1, 1, 2, 3), (15, 10, 5, 1))
|
||||
|
||||
def test_permute_3d(self) -> None:
|
||||
"""Test 3D permute: (2, 3, 4) -> (4, 2, 3)"""
|
||||
self.check_as_strided_via_views(torch.arange(24), (4, 2, 3), (1, 12, 4))
|
||||
|
||||
def test_unsqueeze_multiple_size_one_dims(self) -> None:
|
||||
"""Test multiple size-1 dims at various positions"""
|
||||
self.check_as_strided_via_views(
|
||||
torch.arange(6), (1, 2, 1, 3, 1), (6, 3, 3, 1, 1)
|
||||
)
|
||||
|
||||
def test_as_strided_via_views_numel_zero(self) -> None:
|
||||
"""Test numel==0 cases where all strides are insignificant."""
|
||||
empty_tensor = torch.tensor([])
|
||||
for size, stride in [
|
||||
((0,), (1,)),
|
||||
((0,), (999,)), # Arbitrary stride
|
||||
((0, 5), (10, 1)),
|
||||
((3, 0), (5, 1)),
|
||||
((2, 0, 3), (0, 1, 0)), # Arbitrary strides
|
||||
]:
|
||||
self.check_as_strided_via_views(empty_tensor, size, stride)
|
||||
|
||||
def test_as_strided_via_views_numel_one(self) -> None:
|
||||
"""Test numel==1 cases where all strides are insignificant."""
|
||||
single_tensor = torch.tensor([42.0])
|
||||
for size, stride in [
|
||||
((1,), (1,)),
|
||||
((1,), (999,)), # Arbitrary stride
|
||||
((1, 1), (10, 1)),
|
||||
((1, 1, 1), (100, 50, 25)), # All arbitrary strides
|
||||
]:
|
||||
self.check_as_strided_via_views(single_tensor, size, stride)
|
||||
|
||||
def test_numel_one_to_zero(self) -> None:
|
||||
"""Test numel==0 output from numel==1 input."""
|
||||
single_tensor = torch.tensor([42.0])
|
||||
for size, stride in [
|
||||
((0,), (1,)),
|
||||
((0,), (999,)),
|
||||
((0, 5), (10, 1)),
|
||||
]:
|
||||
self.check_as_strided_via_views(single_tensor, size, stride)
|
||||
|
||||
def test_scalar_to_zero_size(self) -> None:
|
||||
"""Test 0D scalar → zero-size tensor (needs unsqueeze before narrow)."""
|
||||
scalar_tensor = torch.tensor(42.0)
|
||||
for size, stride in [
|
||||
((0,), (1,)),
|
||||
((0,), (999,)),
|
||||
((0, 5), (10, 1)),
|
||||
]:
|
||||
self.check_as_strided_via_views(scalar_tensor, size, stride)
|
||||
|
||||
def test_as_strided_via_views_impossible_cases(self) -> None:
|
||||
"""Test cases that should return NotImplemented."""
|
||||
initial_tensor = torch.arange(10)
|
||||
|
||||
impossible_cases = [
|
||||
((8, 3), (1, 1)), # Overlapping
|
||||
((2, 2), (6, 3)), # Requires slice with step>1
|
||||
]
|
||||
|
||||
for size, stride in impossible_cases:
|
||||
result = as_strided_via_views(
|
||||
initial_tensor, size, stride, storage_offset=0
|
||||
)
|
||||
self.assertIs(result, NotImplemented)
|
||||
|
||||
# Unit tests for internal helper functions
|
||||
def test_squeeze_target_then_basic(self) -> None:
|
||||
"""Test _squeeze_target_then with single size-1 dim."""
|
||||
from torch.utils._as_strided import _squeeze_target_then
|
||||
|
||||
# Define a simple callback that just returns a tensor with the expected shape
|
||||
def simple_cb(result, size, stride, storage_offset):
|
||||
# Callback receives squeezed target: should be (2, 5)
|
||||
self.assertEqual(size, (2, 5))
|
||||
self.assertEqual(stride, (5, 1))
|
||||
# Return a view with this shape
|
||||
return result.view(2, 5)
|
||||
|
||||
base = torch.arange(10)
|
||||
# Target has size-1 dim at position 0: (1, 2, 5)
|
||||
result = _squeeze_target_then(
|
||||
base, size=(1, 2, 5), stride=(10, 5, 1), storage_offset=0, cb=simple_cb
|
||||
)
|
||||
|
||||
self.assertEqual(result.size(), (1, 2, 5))
|
||||
self.assertEqual(result.stride(), (10, 5, 1))
|
||||
|
||||
def test_squeeze_target_then_multiple_size_one(self) -> None:
|
||||
"""Test _squeeze_target_then with multiple size-1 dims."""
|
||||
from torch.utils._as_strided import _squeeze_target_then
|
||||
|
||||
def simple_cb(result, size, stride, storage_offset):
|
||||
# Callback receives squeezed target: should be (5,)
|
||||
self.assertEqual(size, (5,))
|
||||
self.assertEqual(stride, (1,))
|
||||
return result.narrow(0, 0, 5)
|
||||
|
||||
base = torch.arange(10)
|
||||
# Target has size-1 dims at positions 0, 2, 4: (1, 5, 1)
|
||||
result = _squeeze_target_then(
|
||||
base, size=(1, 5, 1), stride=(10, 1, 1), storage_offset=0, cb=simple_cb
|
||||
)
|
||||
|
||||
self.assertEqual(result.size(), (1, 5, 1))
|
||||
# Size-1 strides are insignificant, so we don't check them strictly
|
||||
|
||||
def test_squeeze_target_then_no_size_one(self) -> None:
|
||||
"""Test _squeeze_target_then with no size-1 dims."""
|
||||
from torch.utils._as_strided import _squeeze_target_then
|
||||
|
||||
def simple_cb(result, size, stride, storage_offset):
|
||||
# Callback receives same target: (2, 5)
|
||||
self.assertEqual(size, (2, 5))
|
||||
self.assertEqual(stride, (5, 1))
|
||||
return result.view(2, 5)
|
||||
|
||||
base = torch.arange(10)
|
||||
result = _squeeze_target_then(
|
||||
base, size=(2, 5), stride=(5, 1), storage_offset=0, cb=simple_cb
|
||||
)
|
||||
|
||||
self.assertEqual(result.size(), (2, 5))
|
||||
self.assertEqual(result.stride(), (5, 1))
|
||||
|
||||
def test_squeeze_target_then_returns_not_implemented(self) -> None:
|
||||
"""Test _squeeze_target_then when callback returns NotImplemented."""
|
||||
from torch.utils._as_strided import _squeeze_target_then
|
||||
|
||||
def failing_cb(result, size, stride, storage_offset):
|
||||
return NotImplemented
|
||||
|
||||
base = torch.arange(10)
|
||||
result = _squeeze_target_then(
|
||||
base, size=(1, 5, 2), stride=(10, 1, 5), storage_offset=0, cb=failing_cb
|
||||
)
|
||||
|
||||
self.assertIs(result, NotImplemented)
|
||||
|
||||
def test_permute_target_then_basic(self) -> None:
|
||||
"""Test _permute_target_then with simple permutation."""
|
||||
from torch.utils._as_strided import _permute_target_then
|
||||
|
||||
def simple_cb(result, size, stride, storage_offset):
|
||||
# Callback receives sorted target: (2, 5) with strides (5, 1)
|
||||
self.assertEqual(tuple(size), (2, 5))
|
||||
self.assertEqual(tuple(stride), (5, 1))
|
||||
return result.view(2, 5)
|
||||
|
||||
base = torch.arange(10)
|
||||
# Target is (5, 2) with strides (1, 5) - needs permutation
|
||||
result = _permute_target_then(
|
||||
base, size=(5, 2), stride=(1, 5), storage_offset=0, cb=simple_cb
|
||||
)
|
||||
|
||||
self.assertEqual(result.size(), (5, 2))
|
||||
self.assertEqual(result.stride(), (1, 5))
|
||||
|
||||
def test_permute_target_then_already_sorted(self) -> None:
|
||||
"""Test _permute_target_then when target is already sorted."""
|
||||
from torch.utils._as_strided import _permute_target_then
|
||||
|
||||
def simple_cb(result, size, stride, storage_offset):
|
||||
# Callback receives same target: (2, 5) with strides (5, 1)
|
||||
self.assertEqual(tuple(size), (2, 5))
|
||||
self.assertEqual(tuple(stride), (5, 1))
|
||||
return result.view(2, 5)
|
||||
|
||||
base = torch.arange(10)
|
||||
# Target is already sorted
|
||||
result = _permute_target_then(
|
||||
base, size=(2, 5), stride=(5, 1), storage_offset=0, cb=simple_cb
|
||||
)
|
||||
|
||||
self.assertEqual(result.size(), (2, 5))
|
||||
self.assertEqual(result.stride(), (5, 1))
|
||||
|
||||
def test_permute_target_then_3d(self) -> None:
|
||||
"""Test _permute_target_then with 3D tensor."""
|
||||
from torch.utils._as_strided import _permute_target_then
|
||||
|
||||
def simple_cb(result, size, stride, storage_offset):
|
||||
# Callback receives sorted target: (2, 3, 4) with strides (12, 4, 1)
|
||||
self.assertEqual(tuple(size), (2, 3, 4))
|
||||
self.assertEqual(tuple(stride), (12, 4, 1))
|
||||
return result.view(2, 3, 4)
|
||||
|
||||
base = torch.arange(24)
|
||||
# Target is (4, 2, 3) with strides (1, 12, 4) - needs permutation
|
||||
result = _permute_target_then(
|
||||
base, size=(4, 2, 3), stride=(1, 12, 4), storage_offset=0, cb=simple_cb
|
||||
)
|
||||
|
||||
self.assertEqual(result.size(), (4, 2, 3))
|
||||
self.assertEqual(result.stride(), (1, 12, 4))
|
||||
|
||||
def test_permute_target_then_returns_not_implemented(self) -> None:
|
||||
"""Test _permute_target_then when callback returns NotImplemented."""
|
||||
from torch.utils._as_strided import _permute_target_then
|
||||
|
||||
def failing_cb(result, size, stride, storage_offset):
|
||||
return NotImplemented
|
||||
|
||||
base = torch.arange(10)
|
||||
result = _permute_target_then(
|
||||
base, size=(5, 2), stride=(1, 5), storage_offset=0, cb=failing_cb
|
||||
)
|
||||
|
||||
self.assertIs(result, NotImplemented)
|
||||
|
||||
# Tests for stride-0 dimensions (expand)
|
||||
def test_expand_simple(self) -> None:
|
||||
"""Test simple expand: (5,) -> (3, 5) with stride (0, 1)"""
|
||||
self.check_as_strided_via_views(torch.arange(5), (3, 5), (0, 1))
|
||||
|
||||
def test_expand_2d_source(self) -> None:
|
||||
"""Test expand from 2D: (2, 5) -> (3, 2, 5)"""
|
||||
self.check_as_strided_via_views(
|
||||
torch.arange(10).view(2, 5), (3, 2, 5), (0, 5, 1)
|
||||
)
|
||||
|
||||
def test_expand_middle_dim(self) -> None:
|
||||
"""Test expand in middle: (2, 5) -> (2, 3, 5)"""
|
||||
self.check_as_strided_via_views(
|
||||
torch.arange(10).view(2, 5), (2, 3, 5), (5, 0, 1)
|
||||
)
|
||||
|
||||
def test_expand_last_dim(self) -> None:
|
||||
"""Test expand at end: (2, 5) -> (2, 5, 3)"""
|
||||
self.check_as_strided_via_views(
|
||||
torch.arange(10).view(2, 5), (2, 5, 3), (5, 1, 0)
|
||||
)
|
||||
|
||||
def test_expand_multiple_dims(self) -> None:
|
||||
"""Test expand on multiple dimensions"""
|
||||
self.check_as_strided_via_views(torch.arange(5), (3, 5, 4), (0, 1, 0))
|
||||
|
||||
def test_expand_all_dims(self) -> None:
|
||||
"""Test expand on all dimensions from single element"""
|
||||
self.check_as_strided_via_views(torch.tensor([42.0]), (3, 4, 5), (0, 0, 0))
|
||||
|
||||
def test_expand_with_size_one(self) -> None:
|
||||
"""Test expand combined with size-1 dims"""
|
||||
self.check_as_strided_via_views(torch.arange(5), (1, 3, 5), (15, 0, 1))
|
||||
|
||||
def test_source_with_expand(self) -> None:
|
||||
"""Test source tensor with stride-0 dimension"""
|
||||
base = torch.arange(5).expand(3, 5)
|
||||
self.check_as_strided_via_views(base, (3, 3), (0, 1))
|
||||
|
||||
def test_source_and_target_with_expand(self) -> None:
|
||||
"""Test both source and target with stride-0 dimensions"""
|
||||
base = torch.arange(5).expand(3, 5)
|
||||
self.check_as_strided_via_views(base, (3, 2, 5), (0, 0, 1))
|
||||
|
||||
def test_eliminate_last_dim(self) -> None:
|
||||
"""Test eliminating the last dimension (edge case: last target dim needs unflatten+squeeze)."""
|
||||
# This is the pattern from indexing: tensor[:, :, 0]
|
||||
# Source: (1, 888, 12) -> Target: (1, 888)
|
||||
base = torch.arange(1 * 888 * 12).view(1, 888, 12)
|
||||
self.check_as_strided_via_views(base, (1, 888), (10656, 12))
|
||||
|
||||
# Another similar case: (10, 20, 30) -> (10, 20)
|
||||
base2 = torch.arange(10 * 20 * 30).view(10, 20, 30)
|
||||
self.check_as_strided_via_views(base2, (10, 20), (600, 30))
|
||||
|
||||
# And with offset
|
||||
self.check_as_strided_via_views(base2, (5, 10), (600, 30), storage_offset=1800)
|
||||
|
||||
def test_unexpand_target_then_basic(self) -> None:
|
||||
"""Test _unexpand_target_then with single stride-0 dim."""
|
||||
from torch.utils._as_strided import _unexpand_target_then
|
||||
|
||||
def simple_cb(result, size, stride, storage_offset):
|
||||
self.assertEqual(size, (5,))
|
||||
self.assertEqual(stride, (1,))
|
||||
return result.narrow(0, 0, 5)
|
||||
|
||||
base = torch.arange(10)
|
||||
result = _unexpand_target_then(
|
||||
base, size=(3, 5), stride=(0, 1), storage_offset=0, cb=simple_cb
|
||||
)
|
||||
self.assertEqual(result.size(), (3, 5))
|
||||
self.assertEqual(result.stride(), (0, 1))
|
||||
|
||||
def test_unexpand_target_then_multiple_stride_zero(self) -> None:
|
||||
"""Test _unexpand_target_then with multiple stride-0 dims."""
|
||||
from torch.utils._as_strided import _unexpand_target_then
|
||||
|
||||
def simple_cb(result, size, stride, storage_offset):
|
||||
self.assertEqual(size, (5,))
|
||||
self.assertEqual(stride, (1,))
|
||||
return result.narrow(0, 0, 5)
|
||||
|
||||
base = torch.arange(10)
|
||||
result = _unexpand_target_then(
|
||||
base, size=(3, 5, 4), stride=(0, 1, 0), storage_offset=0, cb=simple_cb
|
||||
)
|
||||
self.assertEqual(result.size(), (3, 5, 4))
|
||||
self.assertEqual(result.stride(), (0, 1, 0))
|
||||
|
||||
def test_unexpand_target_then_no_stride_zero(self) -> None:
|
||||
"""Test _unexpand_target_then with no stride-0 dims."""
|
||||
from torch.utils._as_strided import _unexpand_target_then
|
||||
|
||||
def simple_cb(result, size, stride, storage_offset):
|
||||
self.assertEqual(size, (2, 5))
|
||||
self.assertEqual(stride, (5, 1))
|
||||
return result.view(2, 5)
|
||||
|
||||
base = torch.arange(10)
|
||||
result = _unexpand_target_then(
|
||||
base, size=(2, 5), stride=(5, 1), storage_offset=0, cb=simple_cb
|
||||
)
|
||||
self.assertEqual(result.size(), (2, 5))
|
||||
self.assertEqual(result.stride(), (5, 1))
|
||||
|
||||
def test_unexpand_target_then_returns_not_implemented(self) -> None:
|
||||
"""Test _unexpand_target_then when callback returns NotImplemented."""
|
||||
from torch.utils._as_strided import _unexpand_target_then
|
||||
|
||||
def failing_cb(result, size, stride, storage_offset):
|
||||
return NotImplemented
|
||||
|
||||
base = torch.arange(10)
|
||||
result = _unexpand_target_then(
|
||||
base, size=(3, 5), stride=(0, 1), storage_offset=0, cb=failing_cb
|
||||
)
|
||||
self.assertIs(result, NotImplemented)
|
||||
|
||||
# Storage offset tests
|
||||
def test_storage_offset_simple_1d(self) -> None:
|
||||
"""Test simple 1D tensor with offset."""
|
||||
self.check_as_strided_via_views(torch.arange(10), (5,), (1,), storage_offset=2)
|
||||
|
||||
def test_storage_offset_2d_simple(self) -> None:
|
||||
"""Test 2D tensor with offset."""
|
||||
self.check_as_strided_via_views(
|
||||
torch.arange(20), (2, 3), (5, 1), storage_offset=7
|
||||
)
|
||||
|
||||
def test_storage_offset_with_permute(self) -> None:
|
||||
"""Test offset with permuted dimensions."""
|
||||
self.check_as_strided_via_views(
|
||||
torch.arange(20), (3, 2), (1, 5), storage_offset=4
|
||||
)
|
||||
|
||||
def test_storage_offset_with_unsqueeze(self) -> None:
|
||||
"""Test offset with size-1 dimensions."""
|
||||
self.check_as_strided_via_views(
|
||||
torch.arange(10), (1, 5, 1), (10, 1, 1), storage_offset=3
|
||||
)
|
||||
|
||||
def test_storage_offset_max_offset(self) -> None:
|
||||
"""Test with maximum possible offset."""
|
||||
# For (2, 3) with stride (5, 1), max_extent = max(2*5, 3*1) = 10
|
||||
# max_offset = 20 - 10 = 10
|
||||
self.check_as_strided_via_views(
|
||||
torch.arange(20), (2, 3), (5, 1), storage_offset=10
|
||||
)
|
||||
|
||||
def test_storage_offset_with_expand(self) -> None:
|
||||
"""Test offset with stride-0 dimension."""
|
||||
self.check_as_strided_via_views(
|
||||
torch.arange(10), (3, 5), (0, 1), storage_offset=2
|
||||
)
|
||||
|
||||
def test_storage_offset_impossible(self) -> None:
|
||||
"""Test that impossible offset returns NotImplemented."""
|
||||
base = torch.arange(10)
|
||||
# Offset too large: would access indices beyond the base tensor
|
||||
result = as_strided_via_views(base, (5,), (1,), storage_offset=10)
|
||||
self.assertIs(result, NotImplemented)
|
||||
|
||||
def test_storage_offset_negative(self) -> None:
|
||||
"""Test that negative offset returns NotImplemented."""
|
||||
base = torch.arange(10)
|
||||
result = as_strided_via_views(base, (5,), (1,), storage_offset=-1)
|
||||
self.assertIs(result, NotImplemented)
|
||||
|
||||
def test_storage_offset_within_flattened_dimension(self) -> None:
|
||||
"""Test offset adjustment within a flattened dimension.
|
||||
|
||||
This tests the case where:
|
||||
1. Input tensor gets flattened during canonicalization
|
||||
2. Target has larger stride that's a multiple of canonical stride
|
||||
3. Small offset needs to be consumed by unflattening and narrowing
|
||||
|
||||
Example: (1, 888, 12) -> (1, 888) with offset 1
|
||||
This is equivalent to tensor[:, :, 1].
|
||||
"""
|
||||
# Create tensor with shape (1, 888, 12)
|
||||
base = torch.arange(1 * 888 * 12).view(1, 888, 12)
|
||||
# Target: select element at index 1 along last dimension, remove that dimension
|
||||
# Result should be equivalent to base[:, :, 1]
|
||||
self.check_as_strided_via_views(base, (1, 888), (10656, 12), storage_offset=1)
|
||||
|
||||
# Tests with multi-dimensional contiguous input tensors
|
||||
def test_2d_contiguous_source_simple(self) -> None:
|
||||
"""Test with 2D contiguous input tensor."""
|
||||
base = torch.arange(24).view(4, 6) # Contiguous 2D tensor
|
||||
# Simple permute
|
||||
self.check_as_strided_via_views(base, (6, 4), (1, 6))
|
||||
|
||||
def test_2d_contiguous_source_narrow(self) -> None:
|
||||
"""Test narrowing a 2D contiguous input tensor."""
|
||||
base = torch.arange(24).view(4, 6)
|
||||
# Narrow both dimensions
|
||||
self.check_as_strided_via_views(base, (2, 3), (6, 1))
|
||||
|
||||
def test_2d_contiguous_source_unflatten(self) -> None:
|
||||
"""Test unflattening a 2D contiguous input tensor."""
|
||||
base = torch.arange(24).view(4, 6)
|
||||
# Unflatten the second dimension
|
||||
self.check_as_strided_via_views(base, (4, 2, 3), (6, 3, 1))
|
||||
|
||||
def test_2d_contiguous_source_flatten(self) -> None:
|
||||
"""Test flattening a 2D contiguous input back to 1D."""
|
||||
base = torch.arange(24).view(4, 6)
|
||||
self.check_as_strided_via_views(base, (24,), (1,))
|
||||
|
||||
def test_3d_contiguous_source_permute(self) -> None:
|
||||
"""Test with 3D contiguous input tensor and permute."""
|
||||
base = torch.arange(60).view(3, 4, 5) # Contiguous 3D tensor
|
||||
# Permute dimensions
|
||||
self.check_as_strided_via_views(base, (5, 3, 4), (1, 20, 5))
|
||||
|
||||
def test_3d_contiguous_source_narrow_all(self) -> None:
|
||||
"""Test narrowing all dimensions of 3D contiguous input."""
|
||||
base = torch.arange(60).view(3, 4, 5)
|
||||
# Narrow all three dimensions
|
||||
self.check_as_strided_via_views(base, (2, 2, 3), (20, 5, 1))
|
||||
|
||||
# Tests with non-contiguous multi-dimensional input tensors
|
||||
def test_2d_transposed_source(self) -> None:
|
||||
"""Test with transposed (non-contiguous) 2D input tensor."""
|
||||
base = (
|
||||
torch.arange(24).view(4, 6).t()
|
||||
) # Non-contiguous: (6, 4) with strides (1, 6)
|
||||
# Simple narrow
|
||||
self.check_as_strided_via_views(base, (3, 2), (1, 6))
|
||||
|
||||
def test_2d_transposed_source_permute_back(self) -> None:
|
||||
"""Test permuting transposed tensor back to original layout."""
|
||||
base = torch.arange(24).view(4, 6).t() # (6, 4) with strides (1, 6)
|
||||
# Permute back to (4, 6) with strides (6, 1)
|
||||
self.check_as_strided_via_views(base, (4, 6), (6, 1))
|
||||
|
||||
def test_2d_non_contiguous_with_gap(self) -> None:
|
||||
"""Test with 2D tensor where dims have a gap (non-overlapping but not contiguous)."""
|
||||
# Create tensor with stride gap: select every other row
|
||||
base = torch.arange(48).view(8, 6)[::2, :] # (4, 6) with strides (12, 1)
|
||||
# Narrow and reshape
|
||||
self.check_as_strided_via_views(base, (2, 3), (12, 1))
|
||||
|
||||
def test_2d_sliced_both_dims(self) -> None:
|
||||
"""Test with 2D tensor sliced on both dimensions."""
|
||||
# Note: slicing creates storage offset in base, which implementation may not handle
|
||||
# So we create a properly contiguous base tensor instead
|
||||
base = torch.arange(20).view(5, 4) # (5, 4) with strides (4, 1), offset 0
|
||||
# Test narrow
|
||||
self.check_as_strided_via_views(base, (3, 2), (4, 1))
|
||||
|
||||
def test_3d_non_contiguous_permuted(self) -> None:
|
||||
"""Test with 3D tensor that's been permuted (non-contiguous)."""
|
||||
base = (
|
||||
torch.arange(60).view(3, 4, 5).permute(2, 0, 1)
|
||||
) # (5, 3, 4) with strides (1, 20, 5)
|
||||
# Narrow and unflatten
|
||||
self.check_as_strided_via_views(base, (3, 2, 2), (1, 20, 5))
|
||||
|
||||
def test_3d_select_creates_2d(self) -> None:
|
||||
"""Test with 2D tensor with non-contiguous dimensions (gap between strides)."""
|
||||
# Create a 2D tensor where dims are non-contiguous but non-overlapping
|
||||
# We'll create this by using as_strided to avoid source offset issues
|
||||
base = torch.as_strided(torch.arange(100), (3, 4), (20, 5))
|
||||
# Both dims are non-contiguous with each other (gap of 15 between positions)
|
||||
self.check_as_strided_via_views(base, (2, 2), (20, 5))
|
||||
|
||||
def test_2d_double_strided(self) -> None:
|
||||
"""Test with 2D tensor where both dimensions have stride gaps."""
|
||||
# Create tensor with gaps in both dimensions
|
||||
base = torch.arange(100).view(10, 10)[::2, ::3] # (5, 4) with strides (20, 3)
|
||||
# The two dimensions are discontiguous with each other
|
||||
self.check_as_strided_via_views(base, (3, 2), (20, 3))
|
||||
|
||||
# Storage offset tests with multi-dimensional discontiguous sources
|
||||
def test_storage_offset_2d_contiguous_source(self) -> None:
|
||||
"""Test storage offset with 2D contiguous source."""
|
||||
base = torch.arange(30).view(5, 6) # Contiguous (5, 6) with strides (6, 1)
|
||||
# Apply offset that spans into second row
|
||||
self.check_as_strided_via_views(base, (3, 4), (6, 1), storage_offset=8)
|
||||
|
||||
def test_storage_offset_2d_transposed_source(self) -> None:
|
||||
"""Test storage offset with 2D transposed (non-contiguous) source."""
|
||||
base = torch.arange(30).view(5, 6).t() # (6, 5) with strides (1, 6)
|
||||
# Apply offset
|
||||
self.check_as_strided_via_views(base, (4, 3), (1, 6), storage_offset=7)
|
||||
|
||||
def test_storage_offset_2d_discontiguous_both_dims(self) -> None:
|
||||
"""Test storage offset where narrow needed on both discontiguous source dims."""
|
||||
# Create tensor where both dimensions are discontiguous
|
||||
base = torch.arange(100).view(10, 10)[::2, ::3] # (5, 4) with strides (20, 3)
|
||||
# Offset = 23: This requires narrow on both dims to consume the offset
|
||||
# 23 = 1*20 + 1*3, so we need to narrow first dim by 1 and second dim by 1
|
||||
self.check_as_strided_via_views(base, (3, 2), (20, 3), storage_offset=23)
|
||||
|
||||
def test_storage_offset_2d_discontiguous_larger_offset(self) -> None:
|
||||
"""Test larger storage offset requiring narrows on multiple discontiguous dims."""
|
||||
base = torch.arange(100).view(10, 10)[::2, ::3] # (5, 4) with strides (20, 3)
|
||||
# Offset = 46: 46 = 2*20 + 2*3, requires narrow on both dims
|
||||
self.check_as_strided_via_views(base, (2, 2), (20, 3), storage_offset=46)
|
||||
|
||||
def test_storage_offset_3d_discontiguous_all_dims(self) -> None:
|
||||
"""Test storage offset with 3D discontiguous tensor requiring narrows on multiple dims."""
|
||||
# Create 3D tensor with gaps in all dimensions
|
||||
base = torch.arange(1000).view(10, 10, 10)[
|
||||
::2, ::3, ::5
|
||||
] # (5, 4, 2) with strides (200, 30, 5)
|
||||
# Offset = 30: just one step in the middle dimension
|
||||
# Requires narrow on discontiguous dimensions
|
||||
self.check_as_strided_via_views(
|
||||
base, (4, 2, 2), (200, 30, 5), storage_offset=30
|
||||
)
|
||||
|
||||
def test_storage_offset_3d_discontiguous_complex(self) -> None:
|
||||
"""Test complex offset distribution across 3D discontiguous dims."""
|
||||
base = torch.arange(1000).view(10, 10, 10)[
|
||||
::3, ::2, ::4
|
||||
] # (4, 5, 3) with strides (300, 20, 4)
|
||||
# Offset = 24: 24 = 1*20 + 1*4
|
||||
# Requires narrow on the smaller stride dimensions
|
||||
self.check_as_strided_via_views(
|
||||
base, (2, 3, 2), (300, 20, 4), storage_offset=24
|
||||
)
|
||||
|
||||
def test_storage_offset_2d_permuted_discontiguous(self) -> None:
|
||||
"""Test storage offset where source dims are in non-descending stride order."""
|
||||
# Create non-contiguous tensor then permute so strides aren't descending
|
||||
base_raw = torch.arange(100).view(10, 10)[
|
||||
::3, ::2
|
||||
] # (4, 5) with strides (30, 2)
|
||||
# This already has strides in descending order, so let's transpose it
|
||||
base = base_raw.t() # (5, 4) with strides (2, 30)
|
||||
# Now strides are NOT in descending order (2 < 30)
|
||||
# Offset = 32: 32 = 1*30 + 1*2
|
||||
self.check_as_strided_via_views(base, (2, 2), (2, 30), storage_offset=32)
|
||||
|
||||
def test_storage_offset_barely_fits(self) -> None:
|
||||
"""Test storage offset with discontiguous source requiring narrows on both dims."""
|
||||
base = torch.arange(100).view(10, 10)[::3, ::4] # (4, 3) with strides (30, 4)
|
||||
# Offset = 34 = 1*30 + 1*4, requires narrow on both dimensions
|
||||
# For target (2, 2), extent = 34 + (2-1)*30 + (2-1)*4 = 34 + 34 = 68 < 98 (max base index)
|
||||
self.check_as_strided_via_views(base, (2, 2), (30, 4), storage_offset=34)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -405,6 +405,7 @@ isolate_fails_code_str = None
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
kernel._fn_name
|
||||
if isinstance(kernel, JITFunction)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
else kernel.fn._fn_name
|
||||
)
|
||||
fn_name = fn_name.split(".")[-1]
|
||||
|
||||
@ -264,6 +264,7 @@ def generate_ttir(
|
||||
|
||||
assert isinstance(kernel, JITFunction)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
context = triton._C.libtriton.ir.context()
|
||||
target = triton.runtime.driver.active.get_current_target()
|
||||
backend = triton.compiler.compiler.make_backend(target)
|
||||
@ -305,6 +306,7 @@ def generate_ttir(
|
||||
base_tensor = torch.empty(
|
||||
[elements_per_dim] * len(block_shape), dtype=a.dtype
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
ordered_args[name] = TensorDescriptor.from_tensor(base_tensor, block_shape)
|
||||
elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)):
|
||||
with torch._C._DisableTorchDispatch():
|
||||
@ -368,6 +370,7 @@ def generate_ttir(
|
||||
|
||||
target = triton.runtime.driver.active.get_current_target()
|
||||
backend_ = triton.compiler.compiler.make_backend(target)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return backend_.get_attrs_descriptor(args, kernel.params)
|
||||
else:
|
||||
assert (
|
||||
@ -384,6 +387,7 @@ def generate_ttir(
|
||||
except TypeError: # Unknown arg `specialize_extra`
|
||||
# Older versions of Triton take specialize_extra as an arg to specialize_impl
|
||||
specialize_impl = functools.partial(
|
||||
# pyrefly: ignore # missing-argument
|
||||
triton.runtime.jit.create_specialize_impl(),
|
||||
specialize_extra=backend.get_arg_specialization,
|
||||
)
|
||||
@ -468,6 +472,7 @@ def generate_ttir(
|
||||
if i not in constexprs
|
||||
}
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
triton._C.libtriton.ir.load_dialects(context)
|
||||
backend.load_dialects(context)
|
||||
|
||||
@ -477,22 +482,29 @@ def generate_ttir(
|
||||
# backward compatibility here.
|
||||
make_ir_sig_params = len(inspect.signature(src.make_ir).parameters)
|
||||
get_codegen_implementation_sig_params = len(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
inspect.signature(backend.get_codegen_implementation).parameters
|
||||
)
|
||||
if make_ir_sig_params == 2:
|
||||
# pyrefly: ignore # missing-argument
|
||||
ttir_module = src.make_ir(options, context)
|
||||
elif make_ir_sig_params == 3:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
codegen_fns = backend.get_codegen_implementation()
|
||||
# pyrefly: ignore # missing-argument
|
||||
ttir_module = src.make_ir(options, codegen_fns, context)
|
||||
elif make_ir_sig_params == 4:
|
||||
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
|
||||
# pyrefly: ignore # missing-attribute
|
||||
codegen_fns = backend.get_codegen_implementation(*codegen_args)
|
||||
module_map = backend.get_module_map()
|
||||
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
|
||||
else:
|
||||
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
|
||||
# pyrefly: ignore # missing-attribute
|
||||
codegen_fns = backend.get_codegen_implementation(*codegen_args)
|
||||
module_map = backend.get_module_map()
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
ttir_module = src.make_ir(target, options, codegen_fns, module_map, context)
|
||||
if not ttir_module.verify():
|
||||
raise RuntimeError("Verification for TTIR module has failed")
|
||||
@ -1102,6 +1114,7 @@ def triton_kernel_wrapper_mutation_dense(
|
||||
from triton.tools.tensor_descriptor import TensorDescriptor
|
||||
|
||||
block_shape = stable_meta[0]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
kwargs[k] = TensorDescriptor.from_tensor(tensor, block_shape)
|
||||
|
||||
# move as many positional arguments from dicts to args as we
|
||||
@ -1658,6 +1671,7 @@ class TritonHOPifier:
|
||||
"Passing multiple @triton.autotune decorators is not supported. "
|
||||
"Please use a single @triton.autotune decorator instead."
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
iter_kernel = iter_kernel.fn
|
||||
|
||||
# Process the @triton.heuristics decorator:
|
||||
@ -1868,6 +1882,7 @@ class TritonHOPifier:
|
||||
|
||||
# Both for grid's meta as well as for the kernel, we need combined
|
||||
# args and kwargs combined and normalized
|
||||
# pyrefly: ignore # missing-attribute
|
||||
combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs}
|
||||
|
||||
# precompute the grid for the kernel
|
||||
@ -2061,6 +2076,7 @@ class TraceableTritonKernelWrapper:
|
||||
kernel_idx: Optional[int],
|
||||
grid: Optional["TritonGridType"],
|
||||
) -> None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.kernel = None
|
||||
self.grid = None
|
||||
tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
|
||||
|
||||
@ -4,14 +4,17 @@ import itertools
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch._C import DispatchKey
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._higher_order_ops.utils import reenter_make_fx
|
||||
from torch._higher_order_ops.utils import reenter_make_fx, redirect_to_mode
|
||||
from torch._logging import warning_once
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
||||
from torch.types import _dtype
|
||||
from torch.utils._debug_mode import DebugMode
|
||||
from torch.utils.checkpoint import _CachedTorchDispatchMode, _CachingTorchDispatchMode
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -41,6 +44,27 @@ class Wrap(HigherOrderOperator):
|
||||
wrap = Wrap()
|
||||
|
||||
|
||||
class InductorCompiledCode(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("inductor_compiled_code")
|
||||
|
||||
def __call__(self, func, *args, **kwargs):
|
||||
return super().__call__(func, *args, **kwargs)
|
||||
|
||||
|
||||
inductor_compiled_code = InductorCompiledCode()
|
||||
inductor_compiled_code.fallthrough(DispatchKey.AutogradCPU)
|
||||
inductor_compiled_code.fallthrough(DispatchKey.AutogradCUDA)
|
||||
|
||||
@inductor_compiled_code.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def inductor_compiled_code_impl(func, inputs):
|
||||
return func(inputs)
|
||||
|
||||
redirect_to_mode(inductor_compiled_code, DebugMode)
|
||||
redirect_to_mode(inductor_compiled_code, _CachingTorchDispatchMode)
|
||||
redirect_to_mode(inductor_compiled_code, _CachedTorchDispatchMode)
|
||||
|
||||
|
||||
class WrapWithSetGradEnabled(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("wrap_with_set_grad_enabled")
|
||||
|
||||
@ -289,11 +289,15 @@ def user_defined_triton_kernel_transitive_closure_source_code(kernel) -> str:
|
||||
if isinstance(symbol, JITFunction):
|
||||
compile_wrapper.newline()
|
||||
compile_wrapper.writeline("@triton.jit")
|
||||
# pyrefly: ignore # missing-attribute
|
||||
compile_wrapper.splice(symbol.src, strip=True)
|
||||
symbols_included.add(symbol_name)
|
||||
traverse(symbol)
|
||||
elif hasattr(triton, "constexpr_function") and isinstance(
|
||||
symbol, triton.runtime.jit.ConstexprFunction
|
||||
# pyrefly: ignore # missing-attribute
|
||||
symbol,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
triton.runtime.jit.ConstexprFunction,
|
||||
):
|
||||
compile_wrapper.newline()
|
||||
compile_wrapper.writeline("@triton.constexpr_function")
|
||||
|
||||
@ -949,7 +949,9 @@ class FxConverter:
|
||||
from triton.runtime import driver
|
||||
|
||||
log.info("Autotuning Triton kernel %s at compile time.", kernel_name)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
device = driver.active.get_current_device()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
stream = driver.active.get_current_stream(device)
|
||||
|
||||
def node_to_tuning_arg(arg: Any) -> Any:
|
||||
|
||||
@ -6983,6 +6983,7 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||
|
||||
configs = kernel.configs
|
||||
kernel = kernel.fn
|
||||
# pyrefly: ignore # bad-return
|
||||
return kernel, configs, restore_value_args, reset_to_zero_args
|
||||
|
||||
@override
|
||||
@ -7138,7 +7139,10 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||
self.mutable_args = [
|
||||
kernel_args[key]
|
||||
for key in identify_mutated_tensors(
|
||||
kernel, {**kernel_args, **autotuned_kwargs}, tma_descriptor_metadata
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
kernel,
|
||||
{**kernel_args, **autotuned_kwargs},
|
||||
tma_descriptor_metadata,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@ -52,6 +52,8 @@ from torch._inductor.utils import (
|
||||
)
|
||||
from torch.autograd.profiler import record_function
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._python_dispatch import is_in_torch_dispatch_mode
|
||||
from torch._higher_order_ops.wrap import inductor_compiled_code
|
||||
|
||||
from . import config
|
||||
from .runtime.autotune_cache import AutotuneCacheBundler
|
||||
@ -616,8 +618,15 @@ class CompiledFxGraph(OutputCode):
|
||||
with record_function(
|
||||
f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##"
|
||||
):
|
||||
if is_in_torch_dispatch_mode():
|
||||
return inductor_compiled_code(self.current_callable, inputs)
|
||||
return self.current_callable(inputs)
|
||||
else:
|
||||
# TODO: optimize this some more
|
||||
# NB: Tensor dispatch is NOT supported
|
||||
# TODO: deal with boxed calling convention
|
||||
if is_in_torch_dispatch_mode():
|
||||
return inductor_compiled_code(self.current_callable, inputs)
|
||||
return self.current_callable(inputs)
|
||||
finally:
|
||||
get_runtime_metrics_context().finish()
|
||||
|
||||
@ -2555,11 +2555,14 @@ def get_device_tflops(dtype: torch.dtype) -> float:
|
||||
return get_max_simd_tflops(torch.float32, sm_clock)
|
||||
else:
|
||||
if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
|
||||
# pyrefly: ignore # missing-argument
|
||||
return get_max_tensorcore_tflops(dtype)
|
||||
|
||||
if torch.backends.cuda.matmul.allow_tf32:
|
||||
# pyrefly: ignore # missing-argument
|
||||
return get_max_tensorcore_tflops(torch.float32)
|
||||
else:
|
||||
# pyrefly: ignore # missing-argument
|
||||
return get_max_simd_tflops(torch.float32)
|
||||
|
||||
|
||||
@ -2573,6 +2576,7 @@ def get_gpu_dram_gbps() -> int:
|
||||
def get_gpu_shared_memory() -> int:
|
||||
from triton.runtime import driver
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)
|
||||
|
||||
|
||||
|
||||
@ -967,7 +967,7 @@ static PyObject* THPVariable_dtensor_new(
|
||||
Tensor tensor = make_tensor_for_subclass_helper(
|
||||
/*sym_sizes=*/tuple_to_symintlist(sizes.ptr()),
|
||||
/*sym_strides=*/tuple_to_symintlist(stride.ptr()),
|
||||
/*sym_storage_offset=*/std::nullopt,
|
||||
/*sym_storage_offset=*/local_tensor.sym_storage_offset(),
|
||||
options,
|
||||
/*storage_size=*/std::nullopt,
|
||||
extra_dispatch_keys);
|
||||
|
||||
@ -9,6 +9,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.tensor._api as dtensor
|
||||
import torch.distributed.tensor._random as random
|
||||
from torch._library.utils import fill_defaults
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor._op_schema import OpInfo, OpSchema, OutputSpecType
|
||||
@ -34,6 +35,37 @@ aten = torch.ops.aten
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def as_strided_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object],
|
||||
):
|
||||
args, kwargs = fill_defaults(op_call._schema, args, kwargs)
|
||||
assert not kwargs
|
||||
tensor, size, stride, storage_offset = args
|
||||
size = tuple(size)
|
||||
stride = tuple(stride)
|
||||
if storage_offset is None:
|
||||
storage_offset = tensor.storage_offset()
|
||||
if (
|
||||
tensor.size() == size
|
||||
and tensor.stride() == stride
|
||||
and tensor.storage_offset() == storage_offset
|
||||
):
|
||||
return torch.ops.aten.alias.default(tensor)
|
||||
|
||||
from torch.utils._as_strided import as_strided_via_views
|
||||
|
||||
r = as_strided_via_views(tensor, size, stride, storage_offset)
|
||||
if r is NotImplemented:
|
||||
raise RuntimeError(
|
||||
"as_strided not supported with DTensor for these inputs:"
|
||||
f"{tuple(tensor.size())}:{tensor.stride()} offset {tensor.storage_offset()} to "
|
||||
f"{size}:{stride} offset {storage_offset}"
|
||||
)
|
||||
return r
|
||||
|
||||
|
||||
def is_same_size_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
@ -121,6 +153,7 @@ class OpDispatcher:
|
||||
aten.convolution.default: convolution_handler,
|
||||
aten.convolution_backward.default: convolution_backward_handler,
|
||||
aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
|
||||
aten.as_strided.default: as_strided_handler,
|
||||
}
|
||||
|
||||
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
|
||||
|
||||
@ -84,6 +84,7 @@ register_op_strategy(
|
||||
aten.clone.default,
|
||||
aten.contiguous.default,
|
||||
aten.detach.default,
|
||||
aten.alias.default,
|
||||
aten.fill_.Scalar,
|
||||
aten.view.dtype,
|
||||
aten.zero_.default,
|
||||
|
||||
@ -1302,17 +1302,28 @@ def bsr_dense_addmm(
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
_bsr_strided_addmm_kernel[grid](
|
||||
*ptr_stride_extractor(*sliced_tensors),
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
beta,
|
||||
alpha,
|
||||
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
|
||||
beta_is_one=beta == 1,
|
||||
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
|
||||
beta_is_nonzero=beta != 0,
|
||||
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
|
||||
alpha_is_one=alpha == 1,
|
||||
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
|
||||
left_alpha_is_one=left_alpha_is_one,
|
||||
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
|
||||
right_alpha_is_one=right_alpha_is_one,
|
||||
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
|
||||
BLOCKSIZE_ROW=BM,
|
||||
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
|
||||
BLOCKSIZE_INNER=BK,
|
||||
# pyrefly: ignore # bad-keyword-argument
|
||||
BLOCKSIZE_COL=BN,
|
||||
# pyrefly: ignore # bad-keyword-argument
|
||||
allow_tf32=dot_out_dtype == tl.float32,
|
||||
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
|
||||
acc_dtype=dot_out_dtype,
|
||||
**meta,
|
||||
)
|
||||
@ -1633,12 +1644,17 @@ if has_triton():
|
||||
beta,
|
||||
is_beta_zero,
|
||||
*blocksize,
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
k,
|
||||
tile_k,
|
||||
*ptr_stride_extractor(*sliced_tensors),
|
||||
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
|
||||
acc_dtype=acc_dtype,
|
||||
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
|
||||
allow_tf32=allow_tf32,
|
||||
# pyrefly: ignore # unexpected-keyword
|
||||
num_stages=1,
|
||||
# pyrefly: ignore # unexpected-keyword
|
||||
num_warps=4,
|
||||
)
|
||||
|
||||
@ -1923,6 +1939,7 @@ if has_triton():
|
||||
def kernel(grid, *sliced_tensors):
|
||||
_bsr_softmax_kernel[grid](
|
||||
*ptr_stride_extractor(*sliced_tensors),
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
row_block,
|
||||
col_block,
|
||||
max_row_nnz,
|
||||
@ -2096,8 +2113,11 @@ if has_triton():
|
||||
if "allow_tf32" not in meta:
|
||||
meta.update(allow_tf32=dot_out_dtype == tl.float32)
|
||||
_scatter_mm2_kernel[grid](
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
M,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
K,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
N,
|
||||
blocks,
|
||||
blocks.stride(0),
|
||||
@ -2116,7 +2136,9 @@ if has_triton():
|
||||
pq_indices,
|
||||
pq_indices.stride(0),
|
||||
pq_indices.stride(1),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
**meta,
|
||||
)
|
||||
|
||||
@ -2299,6 +2321,7 @@ if has_triton():
|
||||
_scatter_mm6_kernel[grid](
|
||||
B,
|
||||
Ms,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
Ks,
|
||||
N,
|
||||
blocks,
|
||||
@ -2317,6 +2340,7 @@ if has_triton():
|
||||
r_offsets,
|
||||
p_offsets,
|
||||
q_offsets,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
**meta,
|
||||
)
|
||||
|
||||
529
torch/utils/_as_strided.py
Normal file
529
torch/utils/_as_strided.py
Normal file
@ -0,0 +1,529 @@
|
||||
"""Utility for reconstructing a sequence of view operations that is equivalent
|
||||
to an as_strided call (change a source tensor's size/stride/offset to a
|
||||
target size/stride/offset).
|
||||
|
||||
Given an arbitrary source tensor, imagine that a user has applied a
|
||||
sequence of view operations (slice, permute and view) producing a target
|
||||
tensor. With only the source and the target tensor, construct a sequence
|
||||
of views that takes the source tensor to the target tensor.
|
||||
|
||||
This problem is trivial with as_strided, but often, it can be helpful to be
|
||||
able to construct a sequence of "normal" view operations instead of brute
|
||||
forcing. The reason is that as_strided is extremely flexible, making it
|
||||
difficult for backends to implement: for example, using it you can generate
|
||||
views with strange overlap patterns (e.g., rolling windows) or generate an
|
||||
output view which is out of bounds for the original view. If you are
|
||||
implementing a tensor subclass, it may be possible to implement simple views
|
||||
but not as_strided (e.g., DTensor sharding propagation). These utilities help
|
||||
you reconstruct view operations in some situations where it is possible.
|
||||
|
||||
We apply some simplifying assumptions for this algorithm:
|
||||
|
||||
1. The stride of a size-1 dimension doesn't matter. Actually, sometimes it
|
||||
does (e.g., in the case of inferring memory layout), but our longer term
|
||||
plan for addressing this is to add a special unsqueeze operation which will
|
||||
let us specify exactly what stride an unsqueezed dimension should get.
|
||||
|
||||
2. We assume the input tensor is non-overlapping (except for stride-0 dimensions)
|
||||
i.e., that for every physical location there is a unique coordinate that
|
||||
maps to it. (This implies the destination tensor is non-overlapping except
|
||||
for stride-0 dimensions, as we do not include unfold in the list of valid
|
||||
view operations.)
|
||||
|
||||
3. We only account for these view operations:
|
||||
|
||||
- view; or equivalently, these four operations:
|
||||
- squeeze
|
||||
- unsqueeze
|
||||
- unflatten
|
||||
- flatten (output view only)
|
||||
- permute (this subsumes transpose, movedim, etc.)
|
||||
- narrow (slice with step=1 only; no support for nontrivial step)
|
||||
- expand
|
||||
|
||||
# Definition of view operations
|
||||
|
||||
To start, it's helpful to review how these view operations affect the size/stride
|
||||
of a tensor. In general, these operations (aside from permute and flatten) only affect a
|
||||
single source dimension, so we will only discuss the action on a single dimension.
|
||||
We will use (size):(stride) (CuTe notation) to compactly denote sizes and strides.
|
||||
It can be helpful to study the rules when a = 1, as this variable is unconstrained
|
||||
in all the formulas.
|
||||
|
||||
The important operations are these three:
|
||||
|
||||
- unflatten_2d(dim, x, y) # 2D case is sufficient to implement ND case
|
||||
|
||||
(x * y,):(a,) --> (x, y):(y * a, a)
|
||||
|
||||
- narrow(dim, start, length)
|
||||
|
||||
(x,):(a,) --> (length,):(a,)
|
||||
|
||||
where storage_offset += start * a
|
||||
|
||||
NB: slice(dim, start, stop, step) with nontrivial step is not currently
|
||||
supported, though it could be added in the future.
|
||||
|
||||
- squeeze(dim)
|
||||
|
||||
(1,):(a,) --> ():()
|
||||
|
||||
These two are simply inverses:
|
||||
|
||||
- unsqueeze(dim):
|
||||
|
||||
(x,):(a,) --> (1, x):(x * a, a)
|
||||
|
||||
- flatten_2d(dim) # flatten dim and dim+1; sufficient for ND case
|
||||
|
||||
(x, y):(y * a, a) --> (x * y,):(a,)
|
||||
|
||||
And permute reorders the sizes-strides without otherwise making modifications.
|
||||
|
||||
# The algorithm
|
||||
|
||||
First, take the source tensor and put it into canonical form by:
|
||||
|
||||
1. Removing all size-1 and stride-0 dimensions: For each stride-0 dimension,
|
||||
narrow it to size 1, then squeeze all size-1 dimensions at once.
|
||||
2. Permuting the tensor so the strides are in strictly descending order.
|
||||
This is always possible as all size-1 dimensions are removed and we
|
||||
required the input to be non-overlapping.
|
||||
3. Flattening all contiguous dims, so that every pair of adjacent
|
||||
dims is non-contiguous with each other (cannot be flattened).
|
||||
|
||||
For any non size-1/stride-0 dimension in the target tensor, we can uniquely
|
||||
assign it to the only canonical source dim that could have created it.
|
||||
Specifically, for any dimension (x,):(a,), any target stride s such that
|
||||
a <= s <= (x-1) * a could only have been generated from this source dim.
|
||||
(x*a is excluded as it can only be generated by a size-1 dim! And it is
|
||||
important to be precise on the bound here because it is not
|
||||
guaranteed that on a canonical source tensor (x,y):(a,b) that y*b <= a).
|
||||
So our plan is to process each canonical source dim one-by-one and translate
|
||||
them into the target dimensions.
|
||||
|
||||
For each canonical dimension, we may have multiple target dimensions that map
|
||||
to it. We process these target dimensions left to right (largest stride first).
|
||||
|
||||
For each target dimension (except the last one from this canonical dim):
|
||||
1. We unflatten to create the target dimension and a "remainder" dimension.
|
||||
The key insight: unflatten(x*y, (x, y)) gives strides (y*a, a) where a
|
||||
is the canonical stride. To get the target stride s for the first dim,
|
||||
we need y*a = s, so we choose y = s/a.
|
||||
2. This creates dimension with the correct stride. We then move to the next
|
||||
target dimension, which will be carved out of the "remainder" dimension.
|
||||
3. If the needed size (target_size * y) exceeds available size, we return
|
||||
NotImplemented.
|
||||
|
||||
For the last target dimension from this canonical dim:
|
||||
1. Its stride must match the canonical stride, otherwise we return NotImplemented.
|
||||
2. Narrow to the final target size.
|
||||
|
||||
Limitations:
|
||||
- Slice operations with nontrivial step (e.g., [::2]) are not currently
|
||||
supported. This means some valid view sequences cannot be reconstructed.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import types
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _canonicalize_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Canonicalize a tensor by:
|
||||
1. Narrowing stride-0 dimensions to size 1
|
||||
2. Squeezing all size-1 dimensions
|
||||
3. Permuting so strides are in descending order
|
||||
4. Flattening contiguous dimensions
|
||||
|
||||
Returns a tensor with strictly descending strides where no two adjacent
|
||||
dims are contiguous with each other, and every dim has size > 1.
|
||||
"""
|
||||
result = tensor
|
||||
|
||||
# Narrow stride-0 dimensions to size 1
|
||||
for dim in reversed(range(result.ndim)):
|
||||
if result.stride(dim) == 0:
|
||||
result = result.narrow(dim, 0, 1)
|
||||
|
||||
# Squeeze all size-1 dimensions
|
||||
# TODO: don't squeeze if unnecessary
|
||||
result = result.squeeze()
|
||||
|
||||
# Permute so strides are in descending order
|
||||
stride_order = sorted(
|
||||
range(result.ndim), key=lambda i: result.stride(i), reverse=True
|
||||
)
|
||||
if stride_order != list(range(result.ndim)):
|
||||
result = result.permute(stride_order)
|
||||
|
||||
# Flatten contiguous dimensions
|
||||
# TODO: do it in one go rather than lots of flatten calls
|
||||
i = 0
|
||||
while i < result.ndim - 1:
|
||||
if result.stride(i) == result.size(i + 1) * result.stride(i + 1):
|
||||
result = result.flatten(i, i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _equal_significant_size_strides(lsize, lstride, rsize, rstride):
|
||||
if lsize != rsize:
|
||||
return False
|
||||
if any(s == 0 for s in lsize):
|
||||
return True
|
||||
for s, l, r in zip(lsize, lstride, rstride):
|
||||
if s == 1:
|
||||
continue
|
||||
if l != r:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _unexpand_target_then(result: torch.Tensor, size, stride, storage_offset, cb):
|
||||
"""
|
||||
Remove stride-0 dimensions from target, call cb, then expand them back.
|
||||
|
||||
This handles expand operations which create stride-0 dimensions.
|
||||
For example: (5,):(1,) -> expand -> (3, 5):(0, 1)
|
||||
"""
|
||||
# Identify stride-0 dimensions (from expand)
|
||||
stride0_dims = tuple(i for i, s in enumerate(stride) if s == 0)
|
||||
stride0_sizes = tuple(size[i] for i in stride0_dims)
|
||||
|
||||
# Remove stride-0 dimensions
|
||||
new_size = tuple(s for i, s in enumerate(size) if i not in stride0_dims)
|
||||
new_stride = tuple(s for i, s in enumerate(stride) if i not in stride0_dims)
|
||||
|
||||
result = cb(result, new_size, new_stride, storage_offset)
|
||||
if result is NotImplemented:
|
||||
return NotImplemented
|
||||
assert (
|
||||
_equal_significant_size_strides(
|
||||
result.size(), result.stride(), new_size, new_stride
|
||||
)
|
||||
and result.storage_offset() == storage_offset
|
||||
)
|
||||
|
||||
# Expand the stride-0 dimensions back
|
||||
for i, orig_size in zip(stride0_dims, stride0_sizes):
|
||||
# Because we expand left-to-right, result's dims get shifted so we
|
||||
# need to unsqueeze at position i then expand
|
||||
result = torch.unsqueeze(result, i)
|
||||
# Expand the newly unsqueezed dimension to the original size
|
||||
expand_size = list(result.size())
|
||||
expand_size[i] = orig_size
|
||||
result = result.expand(expand_size)
|
||||
|
||||
assert (
|
||||
_equal_significant_size_strides(result.size(), result.stride(), size, stride)
|
||||
and result.storage_offset() == storage_offset
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _squeeze_target_then(result: torch.Tensor, size, stride, storage_offset, cb):
|
||||
dims = tuple(i for i, s in enumerate(size) if s == 1)
|
||||
new_size = tuple(s for i, s in enumerate(size) if i not in dims)
|
||||
new_stride = tuple(s for i, s in enumerate(stride) if i not in dims)
|
||||
result = cb(result, new_size, new_stride, storage_offset)
|
||||
if result is NotImplemented:
|
||||
return NotImplemented
|
||||
assert (
|
||||
_equal_significant_size_strides(
|
||||
result.size(), result.stride(), new_size, new_stride
|
||||
)
|
||||
and result.storage_offset() == storage_offset
|
||||
)
|
||||
for i in dims:
|
||||
# Because we unsqueeze left-to-right result's dims get shifted so we
|
||||
# put dims in the right spot as we go
|
||||
result = torch.unsqueeze(result, i)
|
||||
assert (
|
||||
_equal_significant_size_strides(result.size(), result.stride(), size, stride)
|
||||
and result.storage_offset() == storage_offset
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _permute_target_then(result: torch.Tensor, size, stride, storage_offset, cb):
|
||||
perm, sorted_stride = zip(
|
||||
*sorted(enumerate(stride), key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
sorted_size = tuple(size[i] for i in perm)
|
||||
|
||||
result = cb(result, sorted_size, sorted_stride, storage_offset)
|
||||
if result is NotImplemented:
|
||||
return NotImplemented
|
||||
assert (
|
||||
_equal_significant_size_strides(
|
||||
result.size(), result.stride(), sorted_size, sorted_stride
|
||||
)
|
||||
and result.storage_offset() == storage_offset
|
||||
)
|
||||
|
||||
inv = [0] * len(perm)
|
||||
for i, p in enumerate(perm):
|
||||
inv[p] = i
|
||||
result = result.permute(inv)
|
||||
assert (
|
||||
_equal_significant_size_strides(result.size(), result.stride(), size, stride)
|
||||
and result.storage_offset() == storage_offset
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _process_canonical_dims(
|
||||
result: torch.Tensor,
|
||||
size,
|
||||
stride,
|
||||
storage_offset,
|
||||
) -> torch.Tensor | types.NotImplementedType:
|
||||
"""
|
||||
Process canonical source dimensions and map them to target dimensions.
|
||||
|
||||
For each canonical dimension (x,):(a,), any target stride s such that
|
||||
a <= s <= (x-1) * a could only have been generated from this source dim.
|
||||
|
||||
Since the target stride is sorted in descending order (after _permute_target_then),
|
||||
we can process each canonical dim and consume target dims as we go.
|
||||
|
||||
Offset is consumed opportunistically during narrow operations.
|
||||
"""
|
||||
# Read canonical sizes/strides from result (already canonicalized)
|
||||
canonical_sizes = result.size()
|
||||
canonical_strides = result.stride()
|
||||
|
||||
# Compute remaining offset to consume
|
||||
remaining_offset = storage_offset - result.storage_offset()
|
||||
|
||||
# Target dims are already sorted by stride descending (via _permute_target_then)
|
||||
# and size-1/stride-0 dims have been removed (via _squeeze_target_then)
|
||||
target_dims = [(i, size[i], stride[i]) for i in range(len(size))]
|
||||
|
||||
# Pointer to current position in target_dims
|
||||
target_idx = 0
|
||||
# Current dimension in result that we're working on
|
||||
current_dim = 0
|
||||
|
||||
for canonical_dim in range(len(canonical_sizes)):
|
||||
canonical_size = canonical_sizes[canonical_dim]
|
||||
canonical_stride = canonical_strides[canonical_dim]
|
||||
|
||||
# Collect all target dims that map to this canonical dim
|
||||
# For dimension (x,):(a,), target stride s must satisfy: a <= s <= (x-1) * a
|
||||
min_stride = canonical_stride
|
||||
max_stride = (result.size(current_dim) - 1) * canonical_stride
|
||||
|
||||
target_dims_for_this_canonical = []
|
||||
while target_idx < len(target_dims):
|
||||
original_pos, target_size, target_stride = target_dims[target_idx]
|
||||
if min_stride <= target_stride <= max_stride:
|
||||
target_dims_for_this_canonical.append(
|
||||
(original_pos, target_size, target_stride)
|
||||
)
|
||||
target_idx += 1
|
||||
else:
|
||||
# Since target strides are in descending order, once we find one that
|
||||
# doesn't match, we won't find any more for this canonical dim
|
||||
break
|
||||
|
||||
if not target_dims_for_this_canonical:
|
||||
# No target dims map to this canonical dim, skip it
|
||||
# (This dimension gets "lost" in the transformation)
|
||||
continue
|
||||
|
||||
# Process target_dims left to right (largest stride first, already sorted)
|
||||
# Offset is consumed opportunistically during each narrow operation
|
||||
# Track which result dimension we're working on within this inner loop
|
||||
work_dim = current_dim
|
||||
for i, (original_pos, target_size, target_stride) in enumerate(
|
||||
target_dims_for_this_canonical
|
||||
):
|
||||
is_last = i == len(target_dims_for_this_canonical) - 1
|
||||
|
||||
if not is_last:
|
||||
# unflatten(a*b, (a, b)) gives strides (b*canonical_stride, canonical_stride)
|
||||
# To get first dim stride = target_stride: b * canonical_stride = target_stride
|
||||
# So b = target_stride / canonical_stride
|
||||
next_dim_size = target_stride // canonical_stride
|
||||
|
||||
needed_size = target_size * next_dim_size
|
||||
|
||||
# Opportunistically consume offset when narrowing for unflatten
|
||||
narrow_start = 0
|
||||
if remaining_offset > 0:
|
||||
offset_in_dim = remaining_offset // canonical_stride
|
||||
if offset_in_dim < result.size(work_dim):
|
||||
narrow_start = offset_in_dim
|
||||
remaining_offset -= offset_in_dim * canonical_stride
|
||||
|
||||
# Check if we have enough size
|
||||
if result.size(work_dim) < needed_size + narrow_start:
|
||||
return NotImplemented
|
||||
|
||||
# Narrow and unflatten
|
||||
result = result.narrow(work_dim, narrow_start, needed_size)
|
||||
|
||||
result = result.unflatten(work_dim, (target_size, next_dim_size))
|
||||
# After unflatten, next target dim works on the remainder at work_dim + 1
|
||||
work_dim += 1
|
||||
|
||||
else:
|
||||
# Last target dim
|
||||
if target_stride == canonical_stride:
|
||||
# Simple case: stride matches, just narrow to size
|
||||
# Opportunistically consume offset here
|
||||
narrow_start = 0
|
||||
if remaining_offset > 0:
|
||||
offset_in_dim = remaining_offset // canonical_stride
|
||||
if offset_in_dim < result.size(work_dim):
|
||||
narrow_start = offset_in_dim
|
||||
remaining_offset -= offset_in_dim * canonical_stride
|
||||
|
||||
# Check if we have enough size
|
||||
if result.size(work_dim) < target_size + narrow_start:
|
||||
return NotImplemented
|
||||
|
||||
# Narrow to target size from the offset position
|
||||
result = result.narrow(work_dim, narrow_start, target_size)
|
||||
elif (
|
||||
target_stride > canonical_stride
|
||||
and target_stride % canonical_stride == 0
|
||||
):
|
||||
# Edge case: need to unflatten and squeeze away trailing dimensions
|
||||
# unflatten(a*b, (a, b)) gives strides (b*canonical_stride, canonical_stride)
|
||||
# We want first dim stride = target_stride, so b = target_stride / canonical_stride
|
||||
remainder_size = target_stride // canonical_stride
|
||||
needed_size = target_size * remainder_size
|
||||
|
||||
# Check if we have enough size
|
||||
if result.size(work_dim) < needed_size:
|
||||
return NotImplemented
|
||||
|
||||
# Narrow if there's excess size
|
||||
if result.size(work_dim) != needed_size:
|
||||
result = result.narrow(work_dim, 0, needed_size)
|
||||
|
||||
# Unflatten to create target dimension and remainder
|
||||
result = result.unflatten(work_dim, (target_size, remainder_size))
|
||||
|
||||
# The remainder dimension (at work_dim + 1) needs to be eliminated
|
||||
# Opportunistically consume any remaining offset here
|
||||
narrow_start = 0
|
||||
if remaining_offset > 0 and canonical_stride > 0:
|
||||
# Can we consume offset by selecting a different position in the remainder?
|
||||
offset_in_remainder = remaining_offset // canonical_stride
|
||||
if offset_in_remainder < remainder_size:
|
||||
narrow_start = offset_in_remainder
|
||||
remaining_offset -= offset_in_remainder * canonical_stride
|
||||
|
||||
result = result.narrow(work_dim + 1, narrow_start, 1)
|
||||
result = result.squeeze(work_dim + 1)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
current_dim += 1
|
||||
|
||||
# Check if all offset was consumed
|
||||
if remaining_offset != 0:
|
||||
return NotImplemented
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def as_strided_via_views(
|
||||
tensor: torch.Tensor,
|
||||
size: tuple[int, ...],
|
||||
stride: tuple[int, ...],
|
||||
storage_offset: int = 0,
|
||||
) -> torch.Tensor | types.NotImplementedType:
|
||||
"""
|
||||
Attempt to reconstruct a sequence of view operations that would produce
|
||||
the target size/stride/offset from the source tensor.
|
||||
|
||||
Args:
|
||||
tensor: Source tensor
|
||||
size: Target size
|
||||
stride: Target stride
|
||||
storage_offset: Target storage offset (default: 0)
|
||||
|
||||
Returns:
|
||||
A view of the tensor with the target size/stride/offset, or NotImplemented
|
||||
if it cannot be achieved via simple view operations.
|
||||
"""
|
||||
import math
|
||||
|
||||
# NB: When strides are insignificant, we don't reproduce them exactly, as
|
||||
# this would make the algorithm a lot more complicated for no reason
|
||||
|
||||
target_numel = math.prod(size)
|
||||
|
||||
# Easy case 1: source tensor has numel==0
|
||||
if tensor.numel() == 0:
|
||||
if target_numel != 0:
|
||||
return NotImplemented
|
||||
return tensor.view(size)
|
||||
|
||||
# Easy case 2: source tensor has numel==1
|
||||
if tensor.numel() == 1:
|
||||
if storage_offset != tensor.storage_offset():
|
||||
return NotImplemented
|
||||
if target_numel == 0:
|
||||
# If tensor is 0D scalar, we need to unsqueeze first to get a dimension to narrow
|
||||
if tensor.ndim == 0:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
tensor = tensor.narrow(0, 0, 0)
|
||||
return tensor.view(size)
|
||||
if target_numel == 1:
|
||||
# Simple case: numel 1 -> numel 1 (no expand needed)
|
||||
return tensor.view(size)
|
||||
# target_numel > 1: For numel==1 source, target is valid only if
|
||||
# all dimensions with size > 1 have stride 0 (can be expanded)
|
||||
requires_expand = any(sz > 1 and st != 0 for sz, st in zip(size, stride))
|
||||
if requires_expand:
|
||||
return NotImplemented
|
||||
# Build target by unsqueezing and expanding
|
||||
result = tensor.view(1) # Flatten to (1,)
|
||||
# Unsqueeze to match target ndim
|
||||
for _ in range(len(size) - 1):
|
||||
result = result.unsqueeze(0)
|
||||
# Now expand to target size (this creates stride-0 for expanded dims)
|
||||
result = result.expand(size)
|
||||
return result
|
||||
|
||||
# Step 1: Canonicalize source tensor
|
||||
result = _canonicalize_tensor(tensor)
|
||||
|
||||
# Step 2: Calculate offset to consume
|
||||
# We'll consume this opportunistically during narrow operations in _process_canonical_dims
|
||||
offset_delta = storage_offset - result.storage_offset()
|
||||
if offset_delta < 0:
|
||||
return NotImplemented
|
||||
|
||||
# We're not going to take this straight to the target size/stride.
|
||||
# Instead we're going to first go to a "canonicalized" target which is in
|
||||
# sorted order, has no size-1 dims, no stride-0 dims (unlike source, we
|
||||
# don't have to flatten contiguous dimensions; the algorithm below can
|
||||
# handle it.) We need to compute this canonicalized target, and we should
|
||||
# also record enough information so we can invert this transformation (so
|
||||
# that we can apply the inverse on result to get to our final, desired
|
||||
# result. So _unexpand_target_then/_squeeze_target_then/_permute_target_then
|
||||
# take care of modifying the target and inverting it, and _process_canonical_dims
|
||||
# is the main payload.
|
||||
|
||||
return _unexpand_target_then(
|
||||
result,
|
||||
size,
|
||||
stride,
|
||||
storage_offset,
|
||||
cb=functools.partial(
|
||||
_squeeze_target_then,
|
||||
cb=functools.partial(_permute_target_then, cb=_process_canonical_dims),
|
||||
),
|
||||
)
|
||||
@ -1300,6 +1300,10 @@ SAC_IGNORED_OPS = {
|
||||
|
||||
|
||||
class _CachingTorchDispatchMode(TorchDispatchMode):
|
||||
@classmethod
|
||||
def ignore_compile_internals(cls):
|
||||
return True
|
||||
|
||||
# Used together with _CachedTorchDispatchMode to implement SAC.
|
||||
def __init__(self, policy_fn, storage):
|
||||
self.policy_fn = policy_fn
|
||||
@ -1336,6 +1340,10 @@ class _CachingTorchDispatchMode(TorchDispatchMode):
|
||||
return out
|
||||
|
||||
class _CachedTorchDispatchMode(TorchDispatchMode):
|
||||
@classmethod
|
||||
def ignore_compile_internals(cls):
|
||||
return True
|
||||
|
||||
# Used together with _CachedTorchDispatchMode to implement SAC.
|
||||
def __init__(self, policy_fn, storage, allow_cache_entry_mutation):
|
||||
self.policy_fn = policy_fn
|
||||
|
||||
Reference in New Issue
Block a user