mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Handle DDE in infer_size_impl (#163822)
hit this while running VLLM with unbacked for model Qwen/Qwen2-1.5B-Instruct Pull Request resolved: https://github.com/pytorch/pytorch/pull/163822 Approved by: https://github.com/bobrenjc93, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
1cc9263f52
commit
dfcab0e7e1
@ -3889,6 +3889,20 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||
x = torch.rand(10)
|
||||
f(x, 4, 4096, 3920)
|
||||
|
||||
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_reshape3(self):
|
||||
def func(x):
|
||||
x = x.as_strided([x.size()[0], 1536], [2048, 1])
|
||||
result1 = x.view(x.size()[0], -1, 128)
|
||||
return result1 * 10
|
||||
|
||||
compiled = torch.compile(fullgraph=True, backend="inductor")(func)
|
||||
x = torch.randn(10, 2048)
|
||||
|
||||
torch._dynamo.decorators.mark_unbacked(x, 0)
|
||||
self.assertEqual(func(x), compiled(x))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user