mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Update torch::stable::Tensor() default constructor (#159507)
Allows things like
```cpp
Tensor cu_seqlens_q;
if (...) {
   cu_seqlens_q = ...
}
...
```
Also adds `torch::stable::Tensor.defined()`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159507
Approved by: https://github.com/janeyx99
			
			
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							f27232a213
						
					
				
				
					commit
					655137b678
				
			| @ -320,3 +320,38 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { | ||||
| STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { | ||||
|   m.impl("my_zero_", &boxed_my_zero_); | ||||
| } | ||||
|  | ||||
| bool test_default_constructor(bool defined) { | ||||
|   Tensor out; | ||||
|   if (defined) { | ||||
|     AtenTensorHandle defined_ath; | ||||
|     int64_t sizes[] = {2, 3}; | ||||
|     int64_t strides[] = {3, 1}; | ||||
|     aoti_torch_empty_strided( | ||||
|         2, | ||||
|         sizes, | ||||
|         strides, | ||||
|         aoti_torch_dtype_float32(), | ||||
|         aoti_torch_device_type_cpu(), | ||||
|         0, | ||||
|         &defined_ath); | ||||
|     out = Tensor(defined_ath); | ||||
|   } | ||||
|   return out.defined(); | ||||
| } | ||||
|  | ||||
| void boxed_test_default_constructor( | ||||
|     StableIValue* stack, | ||||
|     uint64_t num_args, | ||||
|     uint64_t num_outputs) { | ||||
|   bool res = test_default_constructor(to<bool>(stack[0])); | ||||
|   stack[0] = from(res); | ||||
| } | ||||
|  | ||||
| STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { | ||||
|   m.def("test_default_constructor(bool undefined) -> bool"); | ||||
| } | ||||
|  | ||||
| STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { | ||||
|   m.impl("test_default_constructor", &boxed_test_default_constructor); | ||||
| } | ||||
|  | ||||
| @ -164,3 +164,15 @@ def fill_infinity(t) -> Tensor: | ||||
|     Returns: The modified tensor (same as input) | ||||
|     """ | ||||
|     return torch.ops.libtorch_agnostic.fill_infinity.default(t) | ||||
|  | ||||
|  | ||||
| def test_default_constructor(defined) -> bool: | ||||
|     """ | ||||
|     Tests the default constructor for torch::stable::Tensor. | ||||
|  | ||||
|     Args: | ||||
|         defined: bool - if True, tests defined tensor; if False, tests undefined tensor | ||||
|  | ||||
|     Returns: bool - result of calling .defined() on the tensor | ||||
|     """ | ||||
|     return torch.ops.libtorch_agnostic.test_default_constructor.default(defined) | ||||
|  | ||||
| @ -218,6 +218,20 @@ if not IS_WINDOWS: | ||||
|             expected = torch.full_like(t, math.inf) | ||||
|             self.assertEqual(out, expected) | ||||
|  | ||||
|         @onlyCPU | ||||
|         def test_default_constructor(self): | ||||
|             import libtorch_agnostic | ||||
|  | ||||
|             defined_tensor_is_defined = libtorch_agnostic.ops.test_default_constructor( | ||||
|                 True | ||||
|             ) | ||||
|             self.assertTrue(defined_tensor_is_defined) | ||||
|  | ||||
|             undefined_tensor_is_defined = ( | ||||
|                 libtorch_agnostic.ops.test_default_constructor(False) | ||||
|             ) | ||||
|             self.assertFalse(undefined_tensor_is_defined) | ||||
|  | ||||
|     instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
| @ -227,6 +227,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset( | ||||
| AOTI_TORCH_EXPORT AOTITorchError | ||||
| aoti_torch_is_contiguous(AtenTensorHandle tensor, bool* ret_is_contiguous); | ||||
|  | ||||
| AOTI_TORCH_EXPORT AOTITorchError | ||||
| aoti_torch_is_defined(AtenTensorHandle tensor, bool* ret_is_defined); | ||||
|  | ||||
| AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle( | ||||
|     AtenTensorHandle orig_handle, | ||||
|     AtenTensorHandle* new_handle); | ||||
|  | ||||
| @ -402,6 +402,15 @@ AOTITorchError aoti_torch_is_contiguous( | ||||
|   }); | ||||
| } | ||||
|  | ||||
| AOTITorchError aoti_torch_is_defined( | ||||
|     AtenTensorHandle tensor, | ||||
|     bool* ret_is_defined) { | ||||
|   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ | ||||
|     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); | ||||
|     *ret_is_defined = t->defined(); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| AOTITorchError aoti_torch_new_tensor_handle( | ||||
|     AtenTensorHandle orig_handle, | ||||
|     AtenTensorHandle* new_handle) { | ||||
| @ -1204,8 +1213,7 @@ void aoti_torch_print_tensor_handle(AtenTensorHandle self, const char* msg) { | ||||
|   if (msg) { | ||||
|     std::cout << "  " << msg; | ||||
|   } | ||||
|   std::cout << "  " | ||||
|             << "]:" << '\n'; | ||||
|   std::cout << "  " << "]:" << '\n'; | ||||
|  | ||||
|   // Print exact tensor values for small size tensors | ||||
|   const int64_t numel = t->numel(); | ||||
|  | ||||
| @ -29,7 +29,15 @@ class Tensor { | ||||
|   std::shared_ptr<AtenTensorOpaque> ath_; | ||||
|  | ||||
|  public: | ||||
|   Tensor() = delete; | ||||
|   // Construct a stable::Tensor with an uninitialized AtenTensorHandle (ATH) | ||||
|   // Steals ownership from the ATH | ||||
|   Tensor() { | ||||
|     AtenTensorHandle ret; | ||||
|     TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret)); | ||||
|     ath_ = std::shared_ptr<AtenTensorOpaque>(ret, [](AtenTensorHandle ath) { | ||||
|       TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath)); | ||||
|     }); | ||||
|   } | ||||
|  | ||||
|   // Construct a stable::Tensor from an AtenTensorHandle (ATH) | ||||
|   // Steals ownership from the ATH | ||||
| @ -115,6 +123,12 @@ class Tensor { | ||||
|     return size; | ||||
|   } | ||||
|  | ||||
|   bool defined() const { | ||||
|     bool defined; | ||||
|     TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined)); | ||||
|     return defined; | ||||
|   } | ||||
|  | ||||
|   // ============================================================================= | ||||
|   // END of C-shimified TensorBase APIs | ||||
|   // ============================================================================= | ||||
|  | ||||
		Reference in New Issue
	
	Block a user