From b87d3ddd0d991eaf1a47b714aa66536834bc9c2c Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Sun, 19 Oct 2025 17:27:06 -0700 Subject: [PATCH] Update base for Update on "[WIP] Support python slicing with tensor inputs." MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit allow things like ``` #!/usr/bin/env python import torch print("="*60) print("Testing tensor slicing with torch.compile") print("="*60) # Test 1: Simple eager mode print("\n1. Eager mode test:") x = torch.randn(10) idx = torch.tensor(4) result = x[:idx] print(f" x[:idx] where idx=4: result.shape = {result.shape}") assert result.shape[0] == 4 print(" ✓ Eager mode works!") # Test 2: With torch.compile print("\n2. Compiled mode test:") def slice_fn(x, idx): return x[:idx] try: compiled_fn = torch.compile(slice_fn) x = torch.randn(10) idx = torch.tensor(4) result = compiled_fn(x, idx) print(f" Compiled x[:idx] where idx=4: result.shape = {result.shape}") assert result.shape[0] == 4 print(" ✓ Compiled mode works!") except Exception as e: print(f" ✗ Compiled mode failed: {e}") import traceback traceback.print_exc() # Test 3: With dynamic slicing from sum print("\n3. Dynamic slicing with sum:") def dynamic_slice_fn(x, lengths): idx = lengths.sum() return x[:idx] try: compiled_fn = torch.compile(dynamic_slice_fn) x = torch.randn(10) lengths = torch.tensor([1, 1, 1, 1]) result = compiled_fn(x, lengths) print(f" Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}") assert result.shape[0] == 4 print(" ✓ Dynamic slicing works!") except Exception as e: print(f" ✗ Dynamic slicing failed: {e}") import traceback traceback.print_exc() print("\n" + "="*60) print("SUMMARY: Check results above") print("="*60) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela [ghstack-poisoned]