mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Update torch::stable::Tensor() default constructor (#159507)
Allows things like ```cpp Tensor cu_seqlens_q; if (...) { cu_seqlens_q = ... } ... ``` Also adds `torch::stable::Tensor.defined()` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159507 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
f27232a213
commit
655137b678
@ -320,3 +320,38 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
m.impl("my_zero_", &boxed_my_zero_);
|
||||
}
|
||||
|
||||
bool test_default_constructor(bool defined) {
|
||||
Tensor out;
|
||||
if (defined) {
|
||||
AtenTensorHandle defined_ath;
|
||||
int64_t sizes[] = {2, 3};
|
||||
int64_t strides[] = {3, 1};
|
||||
aoti_torch_empty_strided(
|
||||
2,
|
||||
sizes,
|
||||
strides,
|
||||
aoti_torch_dtype_float32(),
|
||||
aoti_torch_device_type_cpu(),
|
||||
0,
|
||||
&defined_ath);
|
||||
out = Tensor(defined_ath);
|
||||
}
|
||||
return out.defined();
|
||||
}
|
||||
|
||||
void boxed_test_default_constructor(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
bool res = test_default_constructor(to<bool>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("test_default_constructor(bool undefined) -> bool");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_default_constructor", &boxed_test_default_constructor);
|
||||
}
|
||||
|
@ -164,3 +164,15 @@ def fill_infinity(t) -> Tensor:
|
||||
Returns: The modified tensor (same as input)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.fill_infinity.default(t)
|
||||
|
||||
|
||||
def test_default_constructor(defined) -> bool:
|
||||
"""
|
||||
Tests the default constructor for torch::stable::Tensor.
|
||||
|
||||
Args:
|
||||
defined: bool - if True, tests defined tensor; if False, tests undefined tensor
|
||||
|
||||
Returns: bool - result of calling .defined() on the tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_default_constructor.default(defined)
|
||||
|
@ -218,6 +218,20 @@ if not IS_WINDOWS:
|
||||
expected = torch.full_like(t, math.inf)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_default_constructor(self):
|
||||
import libtorch_agnostic
|
||||
|
||||
defined_tensor_is_defined = libtorch_agnostic.ops.test_default_constructor(
|
||||
True
|
||||
)
|
||||
self.assertTrue(defined_tensor_is_defined)
|
||||
|
||||
undefined_tensor_is_defined = (
|
||||
libtorch_agnostic.ops.test_default_constructor(False)
|
||||
)
|
||||
self.assertFalse(undefined_tensor_is_defined)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -227,6 +227,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset(
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_is_contiguous(AtenTensorHandle tensor, bool* ret_is_contiguous);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_is_defined(AtenTensorHandle tensor, bool* ret_is_defined);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle(
|
||||
AtenTensorHandle orig_handle,
|
||||
AtenTensorHandle* new_handle);
|
||||
|
@ -402,6 +402,15 @@ AOTITorchError aoti_torch_is_contiguous(
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_is_defined(
|
||||
AtenTensorHandle tensor,
|
||||
bool* ret_is_defined) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
|
||||
*ret_is_defined = t->defined();
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_new_tensor_handle(
|
||||
AtenTensorHandle orig_handle,
|
||||
AtenTensorHandle* new_handle) {
|
||||
@ -1204,8 +1213,7 @@ void aoti_torch_print_tensor_handle(AtenTensorHandle self, const char* msg) {
|
||||
if (msg) {
|
||||
std::cout << " " << msg;
|
||||
}
|
||||
std::cout << " "
|
||||
<< "]:" << '\n';
|
||||
std::cout << " " << "]:" << '\n';
|
||||
|
||||
// Print exact tensor values for small size tensors
|
||||
const int64_t numel = t->numel();
|
||||
|
@ -29,7 +29,15 @@ class Tensor {
|
||||
std::shared_ptr<AtenTensorOpaque> ath_;
|
||||
|
||||
public:
|
||||
Tensor() = delete;
|
||||
// Construct a stable::Tensor with an uninitialized AtenTensorHandle (ATH)
|
||||
// Steals ownership from the ATH
|
||||
Tensor() {
|
||||
AtenTensorHandle ret;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret));
|
||||
ath_ = std::shared_ptr<AtenTensorOpaque>(ret, [](AtenTensorHandle ath) {
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
|
||||
});
|
||||
}
|
||||
|
||||
// Construct a stable::Tensor from an AtenTensorHandle (ATH)
|
||||
// Steals ownership from the ATH
|
||||
@ -115,6 +123,12 @@ class Tensor {
|
||||
return size;
|
||||
}
|
||||
|
||||
bool defined() const {
|
||||
bool defined;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined));
|
||||
return defined;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// END of C-shimified TensorBase APIs
|
||||
// =============================================================================
|
||||
|
Reference in New Issue
Block a user