Update torch::stable::Tensor() default constructor (#159507)

Allows things like

```cpp
Tensor cu_seqlens_q;
if (...) {
   cu_seqlens_q = ...
}
...
```

Also adds `torch::stable::Tensor.defined()`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159507
Approved by: https://github.com/janeyx99
This commit is contained in:
Mikayla Gawarecki
2025-08-12 17:17:47 +00:00
committed by PyTorch MergeBot
parent f27232a213
commit 655137b678
6 changed files with 89 additions and 3 deletions

View File

@ -320,3 +320,38 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("my_zero_", &boxed_my_zero_);
}
bool test_default_constructor(bool defined) {
Tensor out;
if (defined) {
AtenTensorHandle defined_ath;
int64_t sizes[] = {2, 3};
int64_t strides[] = {3, 1};
aoti_torch_empty_strided(
2,
sizes,
strides,
aoti_torch_dtype_float32(),
aoti_torch_device_type_cpu(),
0,
&defined_ath);
out = Tensor(defined_ath);
}
return out.defined();
}
void boxed_test_default_constructor(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
bool res = test_default_constructor(to<bool>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("test_default_constructor(bool undefined) -> bool");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_default_constructor", &boxed_test_default_constructor);
}

View File

@ -164,3 +164,15 @@ def fill_infinity(t) -> Tensor:
Returns: The modified tensor (same as input)
"""
return torch.ops.libtorch_agnostic.fill_infinity.default(t)
def test_default_constructor(defined) -> bool:
"""
Tests the default constructor for torch::stable::Tensor.
Args:
defined: bool - if True, tests defined tensor; if False, tests undefined tensor
Returns: bool - result of calling .defined() on the tensor
"""
return torch.ops.libtorch_agnostic.test_default_constructor.default(defined)

View File

@ -218,6 +218,20 @@ if not IS_WINDOWS:
expected = torch.full_like(t, math.inf)
self.assertEqual(out, expected)
@onlyCPU
def test_default_constructor(self):
import libtorch_agnostic
defined_tensor_is_defined = libtorch_agnostic.ops.test_default_constructor(
True
)
self.assertTrue(defined_tensor_is_defined)
undefined_tensor_is_defined = (
libtorch_agnostic.ops.test_default_constructor(False)
)
self.assertFalse(undefined_tensor_is_defined)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -227,6 +227,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset(
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_is_contiguous(AtenTensorHandle tensor, bool* ret_is_contiguous);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_is_defined(AtenTensorHandle tensor, bool* ret_is_defined);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle(
AtenTensorHandle orig_handle,
AtenTensorHandle* new_handle);

View File

@ -402,6 +402,15 @@ AOTITorchError aoti_torch_is_contiguous(
});
}
AOTITorchError aoti_torch_is_defined(
AtenTensorHandle tensor,
bool* ret_is_defined) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
*ret_is_defined = t->defined();
});
}
AOTITorchError aoti_torch_new_tensor_handle(
AtenTensorHandle orig_handle,
AtenTensorHandle* new_handle) {
@ -1204,8 +1213,7 @@ void aoti_torch_print_tensor_handle(AtenTensorHandle self, const char* msg) {
if (msg) {
std::cout << " " << msg;
}
std::cout << " "
<< "]:" << '\n';
std::cout << " " << "]:" << '\n';
// Print exact tensor values for small size tensors
const int64_t numel = t->numel();

View File

@ -29,7 +29,15 @@ class Tensor {
std::shared_ptr<AtenTensorOpaque> ath_;
public:
Tensor() = delete;
// Construct a stable::Tensor with an uninitialized AtenTensorHandle (ATH)
// Steals ownership from the ATH
Tensor() {
AtenTensorHandle ret;
TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret));
ath_ = std::shared_ptr<AtenTensorOpaque>(ret, [](AtenTensorHandle ath) {
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
});
}
// Construct a stable::Tensor from an AtenTensorHandle (ATH)
// Steals ownership from the ATH
@ -115,6 +123,12 @@ class Tensor {
return size;
}
bool defined() const {
bool defined;
TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined));
return defined;
}
// =============================================================================
// END of C-shimified TensorBase APIs
// =============================================================================