[14/N] Use std::optional (#133417)

Follows #132527
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133417
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-08-16 00:48:34 +00:00
committed by PyTorch MergeBot
parent d9576c9440
commit 8f7cf796ea
5 changed files with 11 additions and 11 deletions

View File

@ -71,7 +71,7 @@ std::pair<SafePyObject, SafePyObject> SavedTensorDefaultHooks::pop_hooks() {
return hooks;
}
c10::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks() {
std::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks() {
// For tls.is_tracing, see NOTE: [Deferring tensor pack/unpack hooks until runtime]
if (!is_initialized || tls.stack.empty() || tls.is_tracing) {
return c10::nullopt;

View File

@ -128,7 +128,7 @@ void quantize_tensor_per_tensor_affine_privateuse1(
}
int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale, bool enable_gqa){
const std::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
auto backend = sdp::SDPBackend::overrideable;
return static_cast<int64_t>(backend);
}

View File

@ -7467,7 +7467,7 @@ for shape in [(1,), ()]:
out = a[:, indices]
self.assertEqual(
out.grad_fn._saved_indices, (None, indices)
) # c10::List<c10::optional<Tensor>> -> Tuple[Tensor?]
) # c10::List<std::optional<Tensor>> -> Tuple[Tensor?]
self.assertIsInstance(out.grad_fn._saved_indices[1], torch.Tensor)
self.assertIsInstance(
out.grad_fn._raw_saved_indices[1], torch._C._autograd.SavedTensor
@ -7497,24 +7497,24 @@ for shape in [(1,), ()]:
out = torch.nn.functional.interpolate(a, 4, mode="linear")
self.assertEqual(
out.grad_fn._saved_output_size, (4,)
) # c10::optional<IntArrayRef> -> int[]?
) # std::optional<IntArrayRef> -> int[]?
self.assertIsInstance(out.grad_fn._saved_output_size[0], int)
self.assertEqual(out.grad_fn._saved_align_corners, False) # bool -> bool
self.assertIsInstance(out.grad_fn._saved_align_corners, bool)
if hasattr(out.grad_fn, "_saved_scale_factors"):
self.assertIsNone(
out.grad_fn._saved_scale_factors
) # c10::optional<ArrayRef<double>> -> float[]?
) # std::optional<ArrayRef<double>> -> float[]?
else:
self.assertIsNone(
out.grad_fn._saved_scales
) # c10::optional<ArrayRef<double>> -> float[]?
) # std::optional<ArrayRef<double>> -> float[]?
a = torch.ones(1, 1, 3, 3, requires_grad=True)
out = nn.Conv2d(1, 1, 3)(a)
self.assertEqual(
out.grad_fn._saved_bias_sym_sizes_opt, (1,)
) # c10::optional<SymIntArrayRef> -> SymInt[]?
) # std::optional<SymIntArrayRef> -> SymInt[]?
out = nn.Conv2d(1, 1, 3, bias=False)(a)
# TODO: This is BAD! we converted a std::nullopt into a (0,)
self.assertEqual(out.grad_fn._saved_bias_sym_sizes_opt, (0,))
@ -7554,11 +7554,11 @@ for shape in [(1,), ()]:
out = torch.div(a, 2.0, rounding_mode="trunc")
self.assertEqual(
out.grad_fn._saved_rounding_mode, "trunc"
) # c10::optional<std::string> -> str?
) # std::optional<std::string> -> str?
out = torch.div(a, 2.0, rounding_mode=None)
self.assertIsNone(
out.grad_fn._saved_rounding_mode
) # c10::optional<std::string> -> str?
) # std::optional<std::string> -> str?
x = torch.zeros(5, requires_grad=True)
out = torch.threshold(x, threshold=(1 + 0j), value=(1 + 0j))

View File

@ -2840,7 +2840,7 @@
dummy: non_differentiable
- name: _nested_get_values(Tensor(a) self) -> Tensor(a)
self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? c10::optional<Tensor>(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? c10::optional<Tensor>(at::_nested_get_max_seqlen(self)) : ::std::nullopt)"
self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? std::optional<Tensor>(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? std::optional<Tensor>(at::_nested_get_max_seqlen(self)) : ::std::nullopt)"
# Transformer
- name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor

View File

@ -75,7 +75,7 @@ std::vector<at::Tensor> unpack_tensors(
} else if (
*ivalue_arg.real_type() ==
*c10::getTypePtr<std::optional<at::Tensor>>()) {
// ivalue is c10::optional<at::Tensor>
// ivalue is std::optional<at::Tensor>
unpack_optional_tensor_ivalue(ivalue, device, inputs);
}
}