mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
``` // The torch::stable::Tensor class is a highlevel C++ header-only wrapper around // the C shim Tensor APIs. We've modeled this class after TensorBase, as custom // op kernels only really need to interact with Tensor metadata (think sizes, // strides, device, dtype). Other functions on Tensor (like empty_like) should // live like the ATen op that they are and exist outside of this struct. // // There are several goals of this class over AtenTensorHandle and // RAIIAtenTensorHandle: // 1. torch::stable::Tensor is a nicer UX much closer to torch::Tensor than the // C APIs with AtenTensorHandle. Under the hood we still call to these C shim // APIs to preserve stability. // 2. RAIIAtenTensorHandle boils down to a uniq_ptr that forces the user to pass // around ownership. This makes it difficult to pass one input into 2 // different functions, e.g., doing something like c = a(t) + b(t) for // stable::Tensor t. Thus, we use a shared_ptr here. ``` This PR: - exemplifies the above - adds test cases in libtorch_agnostic to make sure the file actually works - includes the results of a battle with template specialization Pull Request resolved: https://github.com/pytorch/pytorch/pull/155367 Approved by: https://github.com/albanD