mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update base for Update on "[WIP] Support python slicing with tensor inputs."
allow things like ``` #!/usr/bin/env python import torch print("="*60) print("Testing tensor slicing with torch.compile") print("="*60) # Test 1: Simple eager mode print("\n1. Eager mode test:") x = torch.randn(10) idx = torch.tensor(4) result = x[:idx] print(f" x[:idx] where idx=4: result.shape = {result.shape}") assert result.shape[0] == 4 print(" ✓ Eager mode works!") # Test 2: With torch.compile print("\n2. Compiled mode test:") def slice_fn(x, idx): return x[:idx] try: compiled_fn = torch.compile(slice_fn) x = torch.randn(10) idx = torch.tensor(4) result = compiled_fn(x, idx) print(f" Compiled x[:idx] where idx=4: result.shape = {result.shape}") assert result.shape[0] == 4 print(" ✓ Compiled mode works!") except Exception as e: print(f" ✗ Compiled mode failed: {e}") import traceback traceback.print_exc() # Test 3: With dynamic slicing from sum print("\n3. Dynamic slicing with sum:") def dynamic_slice_fn(x, lengths): idx = lengths.sum() return x[:idx] try: compiled_fn = torch.compile(dynamic_slice_fn) x = torch.randn(10) lengths = torch.tensor([1, 1, 1, 1]) result = compiled_fn(x, lengths) print(f" Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}") assert result.shape[0] == 4 print(" ✓ Dynamic slicing works!") except Exception as e: print(f" ✗ Dynamic slicing failed: {e}") import traceback traceback.print_exc() print("\n" + "="*60) print("SUMMARY: Check results above") print("="*60) ``` [ghstack-poisoned]
This commit is contained in: