mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Device mismatches in tracing can most often be ignored. These are only logical mismatches not physical. Take any intermediate computation, and that computation will not actually materialize in a compiled binary execution. So a device mismatch in the middle of the program is not real. The runtime will never materialize those tensors on CPU device during the execution, as they are temporary allocations. If a user knows his tensors at graph input are all on the correct device, then he can ignore all tracing errors. Users who know what they are doing should have an escape hatch to ignore any device mismatch in tracing. Users can set ``` torch._functorch.config.fake_tensor_prefer_device_type = 'mtia' ``` to forcefully override any mismatch and prefer the non cpu device. This unblocks vLLM graph mode for MTIA. Test Plan: Added two unit tests. Rollback Plan: Differential Revision: D79698438 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159931 Approved by: https://github.com/jansel