[TorchGen] Use std::optional in generated code (#121454)

This PR changes TorchGen to generate std::optional.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121454
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-03-29 14:11:09 +00:00
committed by PyTorch MergeBot
parent 375a8041ed
commit fb90b4d4b2
21 changed files with 202 additions and 192 deletions

View File

@ -484,8 +484,8 @@ c10::optional<Tensor> to_functional_tensor(const c10::optional<Tensor>& tensor)
}
return c10::nullopt;
}
c10::List<c10::optional<Tensor>> to_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) {
c10::List<c10::optional<Tensor>> outputs;
c10::List<::std::optional<Tensor>> to_functional_tensor(const c10::List<::std::optional<Tensor>>& t_list) {
c10::List<::std::optional<Tensor>> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs.push_back(to_functional_tensor(t_list[i]));
@ -536,8 +536,8 @@ std::vector<Tensor> from_functional_tensor(ITensorListRef t_list) {
}
return outputs;
}
c10::List<c10::optional<Tensor>> from_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) {
c10::List<c10::optional<Tensor>> outputs;
c10::List<::std::optional<Tensor>> from_functional_tensor(const c10::List<::std::optional<Tensor>>& t_list) {
c10::List<::std::optional<Tensor>> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs.push_back(from_functional_tensor(t_list[i], /*assert_functional=*/false));
@ -572,7 +572,7 @@ void sync(ITensorListRef t_list) {
sync(t);
}
}
void sync(const c10::List<c10::optional<Tensor>>& t_list) {
void sync(const c10::List<::std::optional<Tensor>>& t_list) {
for (const auto i : c10::irange(t_list.size())) {
sync(t_list[i]);
}
@ -652,7 +652,7 @@ bool isFunctionalTensor(const c10::optional<Tensor>& t) {
}
}
bool isFunctionalTensor(const c10::List<c10::optional<Tensor>>& t_list) {
bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) {
if (t_list.empty()) return false;
auto functional_count = 0;
for (const auto i : c10::irange(t_list.size())) {

View File

@ -317,10 +317,10 @@ static inline void recordTensorIndex(
(*dim_ptr)++;
};
static inline c10::List<c10::optional<Tensor>> typeConvertIndices(
static inline c10::List<::std::optional<Tensor>> typeConvertIndices(
const Tensor& /*self*/,
std::vector<Tensor>&& indices) {
c10::List<c10::optional<Tensor>> converted_inds;
c10::List<::std::optional<Tensor>> converted_inds;
converted_inds.reserve(indices.size());
for (auto&& i : std::move(indices)) {
converted_inds.push_back(std::move(i));

View File

@ -1154,15 +1154,15 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
"(int[]? a) -> int[]?");
// Test list of optional (with empty list)
testArgTypes<c10::List<c10::optional<int64_t>>>::test(
c10::List<c10::optional<int64_t>>(c10::List<c10::optional<int64_t>>({})), [] (const c10::List<c10::optional<int64_t>>& v) {EXPECT_EQ(0, v.size());},
c10::List<c10::optional<int64_t>>(c10::List<c10::optional<int64_t>>({})), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<c10::optional<int64_t>>>().size());},
testArgTypes<c10::List<::std::optional<int64_t>>>::test(
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({})), [] (const c10::List<::std::optional<int64_t>>& v) {EXPECT_EQ(0, v.size());},
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({})), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<::std::optional<int64_t>>>().size());},
"(int?[] a) -> int?[]");
// Test list of optional (with values)
testArgTypes<c10::List<c10::optional<int64_t>>>::test(
c10::List<c10::optional<int64_t>>(c10::List<c10::optional<int64_t>>({3, c10::nullopt, 2})), [] (const c10::List<c10::optional<int64_t>>& v) {expectListEquals<c10::optional<int64_t>>({3, c10::nullopt, 2}, v);},
c10::List<c10::optional<int64_t>>(c10::List<c10::optional<int64_t>>({3, c10::nullopt, 2})), [] (const IValue& v) {expectListEquals<c10::optional<int64_t>>({3, c10::nullopt, 2}, v.to<c10::List<c10::optional<int64_t>>>());},
testArgTypes<c10::List<::std::optional<int64_t>>>::test(
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, c10::nullopt, 2})), [] (const c10::List<::std::optional<int64_t>>& v) {expectListEquals<c10::optional<int64_t>>({3, c10::nullopt, 2}, v);},
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, c10::nullopt, 2})), [] (const IValue& v) {expectListEquals<c10::optional<int64_t>>({3, c10::nullopt, 2}, v.to<c10::List<::std::optional<int64_t>>>());},
"(int?[] a) -> int?[]");
// dict types
@ -1234,15 +1234,15 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
"(Dict(int, Tensor) a) -> Dict(int, Tensor)");
// weird deeply nested type
using DeeplyNestedType = c10::List<c10::Dict<std::string, c10::List<c10::optional<c10::Dict<int64_t, std::string>>>>>;
using DeeplyNestedType = c10::List<c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>>>;
auto makeDeeplyNestedObject = [] () -> DeeplyNestedType {
c10::Dict<int64_t, std::string> inner3;
inner3.insert(1, "1");
c10::List<c10::optional<c10::Dict<int64_t, std::string>>> inner2;
c10::List<::std::optional<c10::Dict<int64_t, std::string>>> inner2;
inner2.push_back(std::move(inner3));
c10::Dict<std::string, c10::List<c10::optional<c10::Dict<int64_t, std::string>>>> inner1;
c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>> inner1;
inner1.insert("key", std::move(inner2));
c10::List<c10::Dict<std::string, c10::List<c10::optional<c10::Dict<int64_t, std::string>>>>> result;
c10::List<c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>>> result;
result.push_back(inner1);
return result;
};

View File

@ -85,8 +85,8 @@ inline c10::List<Tensor> to_meta(const c10::List<Tensor>& t_list) {
return outputs;
}
inline c10::List<c10::optional<Tensor>> to_meta(const c10::List<c10::optional<Tensor>>& t_list) {
c10::List<c10::optional<Tensor>> outputs;
inline c10::List<::std::optional<Tensor>> to_meta(const c10::List<::std::optional<Tensor>>& t_list) {
c10::List<::std::optional<Tensor>> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs.push_back(to_meta(t_list[i]));

View File

@ -6,10 +6,10 @@ namespace caffe2 {
namespace internal {
at::Tensor index_with_uint8_handling(
const at::Tensor& self,
const torch::List<c10::optional<at::Tensor>>& indices) {
const torch::List<std::optional<at::Tensor>>& indices) {
// Support BC only for the simplest case of mask indexing
if (indices.size() == 1) {
c10::optional<at::Tensor> first = indices[0];
std::optional<at::Tensor> first = indices[0];
if (first.has_value()
&& first->scalar_type() == at::kByte) {
TORCH_WARN(

View File

@ -22,7 +22,7 @@ using at::Half; // for AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ...)
namespace internal {
TORCH_API at::Tensor index_with_uint8_handling(
const at::Tensor& self,
const torch::List<c10::optional<at::Tensor>>& indices);
const torch::List<std::optional<at::Tensor>>& indices);
}
template <class Context>
@ -94,8 +94,8 @@ private:
return results;
}
torch::List<c10::optional<at::Tensor>> peekSliceOptionals(size_t i, size_t len, size_t N) {
torch::List<c10::optional<at::Tensor>> results;
torch::List<std::optional<at::Tensor>> peekSliceOptionals(size_t i, size_t len, size_t N) {
torch::List<std::optional<at::Tensor>> results;
results.reserve(len);
for (size_t ii = i; ii < i + len; ++ii) {
results.push_back(peek(ii, N));

View File

@ -73,7 +73,7 @@ def value_is_tensor_type(v):
TENSORLIST_TYPE = [
'at::TensorList',
'const at::ITensorListRef &',
'const c10::List<c10::optional<at::Tensor>> &',
'const c10::List<::std::optional<at::Tensor>> &',
]
# for each aten type, how do we handle a return value of that type?
@ -298,7 +298,7 @@ if __name__ == '__main__':
env['statements'].append(
'auto {} = peekSlice({}, InputSize() - {}, InputSize());'
.format(arg['name'], real_inputs, static_tensor_inputs))
elif arg['type'] == 'const c10::List<c10::optional<at::Tensor>> &':
elif arg['type'] == 'const c10::List<::std::optional<at::Tensor>> &':
# NOTE: do not advance real_inputs here. After this we will
# switch to indexing the "stack" from the end
env['statements'].append(

View File

@ -638,7 +638,7 @@ def generate_tensor_like_override_tests(cls):
return instance_gen()
elif arg_type == "TensorList" or arg_type == "ITensorListRef":
return [instance_gen(), instance_gen()]
elif arg_type == "c10::List<c10::optional<Tensor>>":
elif arg_type == "c10::List<::std::optional<Tensor>>":
return [instance_gen(), instance_gen()]
elif arg_type == "IntArrayRef" or arg_type == "SymIntArrayRef":
size = arg.get("size", 2)

View File

@ -177,21 +177,21 @@ std::vector<Shape> compute_shape_abs(const at::Tensor& self) {
std::vector<Shape> compute_shape_bernoulli(
const at::Tensor& self,
c10::optional<at::Generator> generator) {
::std::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape> compute_shape_bernoulli(
const at::Tensor& self,
double p,
c10::optional<at::Generator> generator) {
::std::optional<at::Generator> generator) {
return compute_shape_bernoulli(self, generator);
}
std::vector<Shape> compute_shape_binary_cross_entropy(
const at::Tensor& self,
const at::Tensor& target,
const c10::optional<at::Tensor>& weight,
const ::std::optional<at::Tensor>& weight,
int64_t reduction) {
if (reduction == at::Reduction::None) {
return {Shape(self.scalar_type(), self.sizes().vec())};
@ -203,7 +203,7 @@ std::vector<Shape> compute_shape_binary_cross_entropy_backward(
const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& target,
const c10::optional<at::Tensor>& weight,
const ::std::optional<at::Tensor>& weight,
int64_t reduction) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
@ -286,7 +286,7 @@ std::vector<Shape> compute_shape_convolution_backward(
std::vector<Shape> compute_shape_convolution(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias,
const ::std::optional<at::Tensor>& bias,
at::IntArrayRef stride,
at::IntArrayRef padding,
at::IntArrayRef dilation,
@ -390,19 +390,19 @@ std::vector<Shape> compute_shape_embedding(
}
std::vector<Shape> compute_shape_std(const at::Tensor& self, bool unbiased) {
return compute_shape_std(self, c10::nullopt, c10::nullopt, false);
return compute_shape_std(self, ::std::nullopt, ::std::nullopt, false);
}
std::vector<Shape> compute_shape_std(
const at::Tensor& self,
at::OptionalIntArrayRef dim,
bool unbiased,
bool keepdim) {
return compute_shape_std(self, dim, c10::nullopt, keepdim);
return compute_shape_std(self, dim, ::std::nullopt, keepdim);
}
std::vector<Shape> compute_shape_std(
const at::Tensor& self,
at::OptionalIntArrayRef dim,
const c10::optional<at::Scalar>& correction,
const ::std::optional<at::Scalar>& correction,
bool keepdim) {
if (dim.has_value()) {
auto shape = at::native::shape_from_dim_mask(
@ -530,10 +530,10 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_cholesky(
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var,
const ::std::optional<at::Tensor>& weight,
const ::std::optional<at::Tensor>& bias,
const ::std::optional<at::Tensor>& running_mean,
const ::std::optional<at::Tensor>& running_var,
bool training,
double momentum,
double eps) {
@ -570,11 +570,11 @@ std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
const at::Tensor& grad_out,
const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var,
const c10::optional<at::Tensor>& save_mean,
const c10::optional<at::Tensor>& save_invstd,
const ::std::optional<at::Tensor>& weight,
const ::std::optional<at::Tensor>& running_mean,
const ::std::optional<at::Tensor>& running_var,
const ::std::optional<at::Tensor>& save_mean,
const ::std::optional<at::Tensor>& save_invstd,
bool train,
double eps,
::std::array<bool, 3> output_mask) {
@ -602,8 +602,8 @@ std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
std::vector<Shape> compute_shape_native_layer_norm(
const at::Tensor& input,
at::IntArrayRef normalized_shape,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const ::std::optional<at::Tensor>& weight,
const ::std::optional<at::Tensor>& bias,
double eps) {
// Copied from aten/src/ATen/native/layer_norm.cpp::layer_norm_cpu_out.
auto input_shape = input.sizes().vec();
@ -631,8 +631,8 @@ std::vector<Shape> compute_shape_native_layer_norm_backward(
at::IntArrayRef normalized_shape,
const at::Tensor& mean,
const at::Tensor& rstd,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const ::std::optional<at::Tensor>& weight,
const ::std::optional<at::Tensor>& bias,
::std::array<bool, 3> output_mask) {
std::vector<Shape> shapes;
shapes.emplace_back(
@ -650,7 +650,7 @@ std::vector<Shape> compute_shape_native_layer_norm_backward(
std::vector<Shape> compute_shape_mean(
const at::Tensor& self,
c10::optional<at::ScalarType> dtype) {
::std::optional<at::ScalarType> dtype) {
if (dtype.has_value()) {
return {Shape(dtype.value(), {})};
}
@ -661,10 +661,10 @@ std::vector<Shape> compute_shape_new_empty_strided(
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
::std::optional<at::ScalarType> dtype,
::std::optional<at::Layout> layout,
::std::optional<at::Device> device,
::std::optional<bool> pin_memory) {
return {Shape(dtype.has_value() ? *dtype : self.scalar_type(), size.vec())};
}
@ -677,7 +677,7 @@ std::vector<Shape> compute_shape_mv(
std::vector<Shape> compute_shape_native_dropout(
const at::Tensor& input,
double p,
c10::optional<bool> train) {
::std::optional<bool> train) {
return {
Shape(input.scalar_type(), input.sizes().vec()),
Shape(c10::ScalarType::Bool, input.sizes().vec())};
@ -692,22 +692,22 @@ std::vector<Shape> compute_shape_native_dropout_backward(
std::vector<Shape> compute_shape_random(
const at::Tensor& self,
c10::optional<at::Generator> generator) {
::std::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape> compute_shape_random(
const at::Tensor& self,
int64_t to,
c10::optional<at::Generator> generator) {
::std::optional<at::Generator> generator) {
return compute_shape_random(self, generator);
}
std::vector<Shape> compute_shape_random(
const at::Tensor& self,
int64_t from,
c10::optional<int64_t> to,
c10::optional<at::Generator> generator) {
::std::optional<int64_t> to,
::std::optional<at::Generator> generator) {
return compute_shape_random(self, generator);
}
@ -717,7 +717,7 @@ std::vector<Shape> compute_shape_relu(const at::Tensor& self) {
std::vector<Shape> compute_shape_sum(
const at::Tensor& self,
c10::optional<at::ScalarType> dtype) {
::std::optional<at::ScalarType> dtype) {
if (dtype.has_value()) {
return {Shape(dtype.value(), {})};
}
@ -836,7 +836,7 @@ std::vector<Shape> compute_shape_log_sigmoid_backward(
std::vector<Shape> compute_shape_nll_loss2d_forward(
const at::Tensor& self,
const at::Tensor& target,
const c10::optional<at::Tensor>& weight,
const ::std::optional<at::Tensor>& weight,
int64_t reduction,
int64_t ignore_index) {
// Based on definition of
@ -851,7 +851,7 @@ std::vector<Shape> compute_shape_nll_loss2d_backward(
const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& target,
const c10::optional<at::Tensor>& weight,
const ::std::optional<at::Tensor>& weight,
int64_t reduction,
int64_t ignore_index,
const at::Tensor& total_weight) {
@ -1075,12 +1075,12 @@ std::vector<Shape> compute_shape_clamp_min(
std::vector<Shape> compute_shape__to_copy(
const at::Tensor& self,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
::std::optional<at::ScalarType> dtype,
::std::optional<at::Layout> layout,
::std::optional<at::Device> device,
::std::optional<bool> pin_memory,
bool non_blocking,
c10::optional<at::MemoryFormat> memory_format) {
::std::optional<at::MemoryFormat> memory_format) {
if (dtype) {
return {Shape(*dtype, self.sizes().vec())};
}
@ -1089,7 +1089,7 @@ std::vector<Shape> compute_shape__to_copy(
TORCH_API std::vector<Shape> compute_shape_clone(
const at::Tensor& self,
c10::optional<at::MemoryFormat> memory_format) {
::std::optional<at::MemoryFormat> memory_format) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
@ -1175,7 +1175,7 @@ std::vector<Shape> compute_shape_view(
std::vector<Shape> compute_shape_cast(
const Output& input,
const at::ScalarType& dtype,
const c10::optional<at::ScalarType>& stype) {
const ::std::optional<at::ScalarType>& stype) {
Shape shape = input.shape();
shape.set_scalar_type(dtype);
return {shape};
@ -1274,17 +1274,17 @@ std::vector<Shape> compute_shape_select_scatter(
auto self_meta = at::native::empty_strided_meta_symint(
self.sym_sizes(),
self.sym_strides(),
/*dtype=*/c10::make_optional(self.scalar_type()),
/*layout=*/c10::make_optional(self.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
/*dtype=*/::std::make_optional(self.scalar_type()),
/*layout=*/::std::make_optional(self.layout()),
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/::std::nullopt);
auto src_meta = at::native::empty_strided_meta_symint(
src.sym_sizes(),
src.sym_strides(),
/*dtype=*/c10::make_optional(src.scalar_type()),
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
/*dtype=*/::std::make_optional(src.scalar_type()),
/*layout=*/::std::make_optional(src.layout()),
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/::std::nullopt);
auto out_meta = at::compositeexplicitautogradnonfunctional::select_scatter(
self_meta, src_meta, dim, index);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
@ -1299,17 +1299,17 @@ std::vector<Shape> compute_shape_diagonal_scatter(
auto self_meta = at::native::empty_strided_meta_symint(
self.sym_sizes(),
self.sym_strides(),
/*dtype=*/c10::make_optional(self.scalar_type()),
/*layout=*/c10::make_optional(self.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
/*dtype=*/::std::make_optional(self.scalar_type()),
/*layout=*/::std::make_optional(self.layout()),
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/::std::nullopt);
auto src_meta = at::native::empty_strided_meta_symint(
src.sym_sizes(),
src.sym_strides(),
/*dtype=*/c10::make_optional(src.scalar_type()),
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
/*dtype=*/::std::make_optional(src.scalar_type()),
/*layout=*/::std::make_optional(src.layout()),
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/::std::nullopt);
auto out_meta = at::compositeexplicitautogradnonfunctional::diagonal_scatter(
self_meta, src_meta, offset, dim1, dim2);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
@ -1319,23 +1319,23 @@ std::vector<Shape> compute_shape_slice_scatter_symint(
const at::Tensor& self,
const at::Tensor& src,
int64_t dim,
c10::optional<c10::SymInt> start,
c10::optional<c10::SymInt> end,
::std::optional<c10::SymInt> start,
::std::optional<c10::SymInt> end,
c10::SymInt step) {
auto self_meta = at::native::empty_strided_meta_symint(
self.sym_sizes(),
self.sym_strides(),
/*dtype=*/c10::make_optional(self.scalar_type()),
/*layout=*/c10::make_optional(self.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
/*dtype=*/::std::make_optional(self.scalar_type()),
/*layout=*/::std::make_optional(self.layout()),
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/::std::nullopt);
auto src_meta = at::native::empty_strided_meta_symint(
src.sym_sizes(),
src.sym_strides(),
/*dtype=*/c10::make_optional(src.scalar_type()),
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
/*dtype=*/::std::make_optional(src.scalar_type()),
/*layout=*/::std::make_optional(src.layout()),
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/::std::nullopt);
auto out_meta =
at::compositeexplicitautogradnonfunctional::slice_scatter_symint(
self_meta, src_meta, dim, start, end, step);
@ -1347,21 +1347,21 @@ std::vector<Shape> compute_shape_as_strided_scatter_symint(
const at::Tensor& src,
at::SymIntArrayRef size,
at::SymIntArrayRef stride,
c10::optional<c10::SymInt> storage_offset) {
::std::optional<c10::SymInt> storage_offset) {
auto self_meta = at::native::empty_strided_meta_symint(
self.sym_sizes(),
self.sym_strides(),
/*dtype=*/c10::make_optional(self.scalar_type()),
/*layout=*/c10::make_optional(self.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
/*dtype=*/::std::make_optional(self.scalar_type()),
/*layout=*/::std::make_optional(self.layout()),
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/::std::nullopt);
auto src_meta = at::native::empty_strided_meta_symint(
src.sym_sizes(),
src.sym_strides(),
/*dtype=*/c10::make_optional(src.scalar_type()),
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
/*dtype=*/::std::make_optional(src.scalar_type()),
/*layout=*/::std::make_optional(src.layout()),
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/::std::nullopt);
auto out_meta =
at::compositeexplicitautogradnonfunctional::as_strided_scatter_symint(
self_meta, src_meta, size, stride, storage_offset);
@ -1372,7 +1372,7 @@ std::vector<Shape> compute_shape_normal_functional(
const at::Tensor& self,
double mean,
double std,
c10::optional<at::Generator> generator) {
::std::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
@ -1380,7 +1380,7 @@ std::vector<Shape> compute_shape_uniform(
const at::Tensor& self,
double from,
double to,
c10::optional<at::Generator> generator) {
::std::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

View File

@ -24,16 +24,16 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape__adaptive_avg_pool3d(con
TORCH_API std::vector<torch::lazy::Shape> compute_shape__adaptive_avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_abs(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, double p, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, ::std::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, double p, ::std::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const ::std::optional<at::Tensor> & weight, int64_t reduction);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional<at::Tensor> & weight, int64_t reduction);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_cat(at::TensorList tensors, int64_t dim);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_cholesky(const at::Tensor & self, bool upper);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_clamp_min(const at::Tensor & self, const at::Scalar & min);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_clone(const at::Tensor & self, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_clone(const at::Tensor & self, ::std::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const ::std::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array<bool,3> output_mask);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_embedding_dense_backward(const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq);
@ -57,23 +57,23 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_logical_xor(const at::Te
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_max(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mean(const at::Tensor & self, ::std::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_min(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mv(const at::Tensor & self, const at::Tensor & vec);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, const c10::optional<at::Tensor> & save_mean, const c10::optional<at::Tensor> & save_invstd, bool train, double eps, ::std::array<bool,3> output_mask);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout(const at::Tensor & input, double p, c10::optional<bool> train);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(const at::Tensor & input, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, const ::std::optional<at::Tensor> & running_mean, const ::std::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & running_mean, const ::std::optional<at::Tensor> & running_var, const ::std::optional<at::Tensor> & save_mean, const ::std::optional<at::Tensor> & save_invstd, bool train, double eps, ::std::array<bool,3> output_mask);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout(const at::Tensor & input, double p, ::std::optional<bool> train);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, ::std::array<bool,3> output_mask);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_new_empty_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, double eps);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, ::std::array<bool,3> output_mask);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_new_empty_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const ::std::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nonzero(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_normal_functional(const at::Tensor & self, double mean, double std, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_normal_functional(const at::Tensor & self, double mean, double std, ::std::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, ::std::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t to, ::std::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t from, ::std::optional<int64_t> to, ::std::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_relu(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_slogdet(const at::Tensor & self);
@ -82,9 +82,9 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_sort(const at::Tensor &
TORCH_API std::vector<torch::lazy::Shape> compute_shape_stack(at::TensorList tensors, int64_t dim);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_std(const at::Tensor & self, bool unbiased);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_std(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_std(const at::Tensor & self, at::OptionalIntArrayRef dim, const c10::optional<at::Scalar> & correction, bool keepdim);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sum(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape__to_copy(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, bool non_blocking, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_std(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional<at::Scalar> & correction, bool keepdim);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sum(const at::Tensor & self, ::std::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape__to_copy(const at::Tensor & self, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory, bool non_blocking, ::std::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_take(const at::Tensor & self, const at::Tensor & index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_trace(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero(const at::Tensor & self);
@ -92,13 +92,13 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy_symint(const
TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardswish(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardswish_backward(const at::Tensor & grad_output, const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_selu(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_uniform(const at::Tensor & self, double from, double to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_uniform(const at::Tensor & self, double from, double to, ::std::optional<at::Generator> generator);
// Non-Native ops
TORCH_API std::vector<Shape> compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type);
TORCH_API std::vector<Shape> compute_shape_expand(const Output& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand);
TORCH_API std::vector<Shape> compute_shape_view(const Output& input0, const std::vector<int64_t>& output_sizes);
TORCH_API std::vector<Shape> compute_shape_cast(const Output& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype);
TORCH_API std::vector<Shape> compute_shape_cast(const Output& input0, const at::ScalarType& dtype, const ::std::optional<at::ScalarType>& stype);
// View Ops
// (Now that functionalization pass is used, we should kill these in a later PR)
@ -117,8 +117,8 @@ TORCH_API std::vector<Shape> compute_shape_unsqueeze(const Output& input, const
TORCH_API std::vector<torch::lazy::Shape> compute_shape_select_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_diagonal_scatter(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_slice_scatter_symint(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_as_strided_scatter_symint(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_slice_scatter_symint(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional<c10::SymInt> start, ::std::optional<c10::SymInt> end, c10::SymInt step);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_as_strided_scatter_symint(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional<c10::SymInt> storage_offset);
// clang-format on
} // namespace lazy
} // namespace torch

View File

@ -30,7 +30,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
{"double", ParameterType::DOUBLE},
{"complex", ParameterType::COMPLEX},
{"TensorList", ParameterType::TENSOR_LIST},
{"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
{"c10::List<::std::optional<Tensor>>", ParameterType::TENSOR_LIST},
{"IntArrayRef", ParameterType::INT_LIST},
{"SymIntArrayRef", ParameterType::SYM_INT_LIST},
{"ArrayRef<double>", ParameterType::FLOAT_LIST},

View File

@ -312,7 +312,7 @@ def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequenc
JIT_TO_CPP_DEFAULT = {
"False": "false",
"True": "true",
"None": "c10::nullopt", # UGH this one is type directed
"None": "::std::nullopt", # UGH this one is type directed
"Mean": "at::Reduction::Mean",
"[]": "{}",
"contiguous_format": "MemoryFormat::Contiguous",
@ -347,7 +347,7 @@ def default_expr(d: str, t: Type, *, symint: bool) -> str:
if isinstance(t, OptionalType):
if d == "None":
return "c10::nullopt"
return "::std::nullopt"
return default_expr(d, t.elem, symint=symint)

View File

@ -62,9 +62,9 @@ from torchgen.model import (
# Note: the scattered TensorOptions fields are packed into 'options'.
#
# auto dispatch_empty =
# [](IntArrayRef size, c10::optional<DimnameList> names,
# [](IntArrayRef size, std::optional<DimnameList> names,
# const TensorOptions & options,
# c10::optional<MemoryFormat> memory_format) -> Tensor {
# std::optional<MemoryFormat> memory_format) -> Tensor {
# pybind11::gil_scoped_release no_gil;
# return torch::empty(size, names, options, memory_format);
# };
@ -93,9 +93,9 @@ from torchgen.model import (
# Where does 'names' come from? It involves special local init:
#
# auto __names = _r.toDimnameListOptional(1);
# c10::optional<DimnameList> names =
# __names ? c10::make_optional(DimnameList(__names.value()))
# : c10::nullopt;
# std::optional<DimnameList> names =
# __names ? std::make_optional(DimnameList(__names.value()))
# : std::nullopt;
#
# Where does 'options' come from? It involves special local init
# for TensorOptions. Note that Python side has the additional
@ -235,6 +235,8 @@ class PythonArgument:
default = {
"nullptr": "None",
"c10::nullopt": "None",
"::std::nullopt": "None",
"std::nullopt": "None",
"{}": "None",
}.get(self.default, self.default)
return f"{type_str} {name}={default}"
@ -280,6 +282,8 @@ class PythonArgument:
default = {
"nullptr": "None",
"c10::nullopt": "None",
"::std::nullopt": "None",
"std::nullopt": "None",
"{}": "None",
"MemoryFormat::Contiguous": "contiguous_format",
"QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
@ -697,9 +701,9 @@ def argument_type_str(
return f"ScalarList[{size}]" if size is not None else "ScalarList"
elif str(t.elem) == "Tensor?":
if simple_type:
return "c10::List<c10::optional<Tensor>>"
return "c10::List<::std::optional<Tensor>>"
else:
return "const c10::List<c10::optional<Tensor>> &"
return "const c10::List<::std::optional<Tensor>> &"
elif str(t.elem) == "Dimname":
return f"DimnameList[{size}]" if size is not None else "DimnameList"
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
@ -1308,7 +1312,13 @@ def arg_parser_unpack_method(
return "generator"
elif str(t.elem) == "Dimname[]":
return "toDimnameListOptional"
elif not has_default_init and default in (None, "None", "c10::nullopt"):
elif not has_default_init and default in (
None,
"None",
"c10::nullopt",
"::std::nullopt",
"std::nullopt",
):
# If default is None: append 'Optional' to elem's unpacking method
return (
arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
@ -1430,7 +1440,7 @@ def dispatch_lambda_exprs(
inits.extend(
[
f"auto __{name} = {arg_parser_expr};",
f"c10::optional<DimnameList> {name} = __{name} ? c10::make_optional(DimnameList(__{name}.value())) : c10::nullopt;", # noqa: B950
f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950
]
)
lambda_args_exprs[name] = name

View File

@ -323,7 +323,7 @@ Check this module for more information.
# If we're calling a factory op from its out= variant,
# We don't actually care about the value of pin_memory.
out_tensor = direct_solve(out_tensor_ctype)
return "c10::nullopt"
return "::std::nullopt"
# We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
elif goal.type == BaseCType(intArrayRefT):
@ -347,7 +347,7 @@ Check this module for more information.
argname = direct_solve(
NamedCType(goal.name, OptionalCType(BaseCType(longT)))
)
return f"{argname}.has_value() ? c10::make_optional(c10::SymInt(*{argname})) : c10::nullopt"
return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt"
elif goal.type == BaseCType(longT):
symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
return f"{symInt_type}.guard_int(__FILE__, __LINE__)"
@ -355,7 +355,7 @@ Check this module for more information.
argname = direct_solve(
NamedCType(goal.name, OptionalCType(BaseCType(SymIntT)))
)
return f"{argname}.has_value() ? c10::make_optional({argname}->guard_int(__FILE__, __LINE__)) : c10::nullopt"
return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt"
elif goal.type == BaseCType(optionalIntArrayRefT):
try:
return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
@ -363,14 +363,14 @@ Check this module for more information.
argname = direct_solve(
NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT))
)
return f"{argname}.has_value() ? c10::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : c10::nullopt"
return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt"
elif goal.type == BaseCType(optionalSymIntArrayRefT):
# TODO: You might also want to solve this from longSymVec_ctype or
# an optional version of it
argname = direct_solve(
NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
)
return f"{argname}.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*{argname})) : c10::nullopt"
return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt"
elif goal.type == BaseCType(optionalScalarRefT):
return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
elif goal.type == BaseCType(optionalTensorRefT):
@ -398,22 +398,22 @@ Check this module for more information.
goal.name, BaseCType(optionalIntArrayRefT)
)
argname = direct_solve(optionalIntArrayRef_ctype)
return f"{argname}.has_value() ? c10::make_optional({argname}->vec()) : c10::nullopt"
return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt"
elif goal.type == OptionalCType(BaseCType(scalarT)):
optionalScalarRef_ctype = NamedCType(
goal.name, BaseCType(optionalScalarRefT)
)
argname = direct_solve(optionalScalarRef_ctype)
return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt"
return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
elif goal.type == OptionalCType(BaseCType(scalarT)):
optionalTensorRef_ctype = NamedCType(
goal.name, BaseCType(optionalTensorRefT)
)
argname = direct_solve(optionalTensorRef_ctype)
return f"{argname}.has_value() ? c10::make_optional({argname}) : c10::nullopt"
return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
# Technically, we also need to handle cases of C++ containers holding reference types.
# But there currently aren't any ops that require lambda capture codegen
# With arguments like std::vector<IntArrayRef>.
# With arguments like ::std::vector<IntArrayRef>.
# If that changes, we'll have to add the translation here.
# We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor.

View File

@ -34,7 +34,7 @@ from .types_base import (
TENSOR_LIST_LIKE_CTYPES = [
"at::TensorList",
"const c10::List<c10::optional<at::Tensor>> &",
"const c10::List<::std::optional<at::Tensor>> &",
"const at::ITensorListRef &",
]
@ -133,10 +133,10 @@ class OptionalCType(CType):
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"c10::optional<{self.elem.cpp_type()}>"
return f"::std::optional<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"c10::optional<{self.elem.cpp_type_registration_declarations()}>"
return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType":
return OptionalCType(self.elem.remove_const_ref())

View File

@ -43,8 +43,8 @@ from torchgen.model import (
# ```
# - Dimname[]? names
# ```cpp
# c10::optional<c10::IValue> names_opt = (std::move(peek(stack, 1, 7))).toOptional<c10::IValue>();
# c10::optional<at::ArrayRef<at::Dimname>> names_opt_out;
# ::std::optional<c10::IValue> names_opt = (std::move(peek(stack, 1, 7))).toOptional<c10::IValue>();
# ::std::optional<at::ArrayRef<at::Dimname>> names_opt_out;
# if (names_opt.has_value()) {
# ~~~~~~~~~~~ <-- Unwrapping optional shell
# const c10::IValue names_opt_in = names_opt.value();
@ -58,23 +58,23 @@ from torchgen.model import (
# }
# at::ArrayRef<at::Dimname> names_list_out(names_vec);
#
# names_opt_out = c10::optional<at::ArrayRef<at::Dimname>>(names_list_out);
# names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>(names_list_out);
# } else {
# names_opt_out = c10::optional<at::ArrayRef<at::Dimname>>();
# names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>();
# }
# ```
# - ScalarType? dtype (similarly for the rest of the arguments)
# ```cpp
# c10::optional<c10::IValue> dtype_opt = (std::move(peek(stack, 2, 7))).toOptional<c10::IValue>();
# c10::optional<at::ScalarType> dtype_opt_out;
# ::std::optional<c10::IValue> dtype_opt = (std::move(peek(stack, 2, 7))).toOptional<c10::IValue>();
# ::std::optional<at::ScalarType> dtype_opt_out;
# if (dtype_opt.has_value()) {
# const c10::IValue dtype_opt_in = dtype_opt.value();
# at::ScalarType dtype_base = dtype_opt_in.to<at::ScalarType>();
# ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it
# directly using ".to<T>()" API.
# dtype_opt_out = c10::optional<at::ScalarType>(dtype_base);
# dtype_opt_out = ::std::optional<at::ScalarType>(dtype_base);
# } else {
# dtype_opt_out = c10::optional<at::ScalarType>();
# dtype_opt_out = ::std::optional<at::ScalarType>();
# }
# ```
#
@ -184,7 +184,7 @@ def _gen_code_optional_type(
res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
return (
f"""
c10::optional<c10::IValue> {arg_name}_opt = {arg_name}.toOptional<c10::IValue>();
auto {arg_name}_opt = {arg_name}.toOptional<c10::IValue>();
{ctype.cpp_type(strip_ref=True)} {out_name};
if ({arg_name}_opt.has_value()) {{
const c10::IValue {in_name} = {arg_name}_opt.value();
@ -216,7 +216,7 @@ def _gen_code_list_type(
"\n"
)
)
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<c10::optional<at::Tensor>>
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
elif isinstance(t.elem, OptionalType):
code.extend(
f"""

View File

@ -59,13 +59,13 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
if arg.is_symint_or_list:
# TODO: I don't understand when you should put lazy_ in the name
# or not
return f"{arg.name} ? c10::make_optional(GetSymIntValue(*{arg.name})) : c10::nullopt"
return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt"
elif arg.is_wrapped_scalar:
return f"node_{arg.name}"
return (
f"lazy_{arg.name} ? "
f"c10::make_optional(lazy_{arg.name}->GetIrValue()) : "
"c10::nullopt"
f"std::make_optional(lazy_{arg.name}->GetIrValue()) : "
"::std::nullopt"
)
else:
raise AssertionError(
@ -253,8 +253,8 @@ class GenLazyIR(ABC):
scalar_initializers = ",\n ".join(
[
# This code is just special casing the mapping from string_view -> strings
f"{a.name}({a.name}.has_value() ? c10::make_optional(std::string(*{a.name})) : c10::nullopt)"
if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)"
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
else f"{a.name}({a.name})"
for a in scalar_args
]
@ -265,8 +265,8 @@ class GenLazyIR(ABC):
[
f"std::string {a.name};"
if a.lazy_type.cpp_type() == "c10::string_view"
else f"c10::optional<std::string> {a.name};"
if a.lazy_type.cpp_type() == "c10::optional<c10::string_view>"
else f"::std::optional<std::string> {a.name};"
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
else f"{a.lazy_type.cpp_type()} {a.name};"
for a in scalar_args
]
@ -419,9 +419,9 @@ class GenLazyNativeFuncDefinition:
if isinstance(arg.lazy_type, OptionalCType):
lazy_tensor_decls.append(
f"""auto node_{arg.name} = {arg.name} ?
c10::make_optional(torch::lazy::LazyGraphExecutor::Get()->
std::make_optional(torch::lazy::LazyGraphExecutor::Get()->
GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
c10::nullopt;"""
::std::nullopt;"""
)
else:
lazy_tensor_decls.append(

View File

@ -127,11 +127,11 @@ def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]:
if empty_strided_impl is None
else [
f"""
c10::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
if (out.strides() != strides) {{
return {empty_strided_impl}(sizes, strides, options);
}}
return c10::nullopt;
return std::nullopt;
}}
"""
]
@ -260,7 +260,7 @@ class RegisterDispatchKey:
if type == DeviceCheckType.NoCheck:
return " // No device check\n"
device_check = "c10::optional<Device> common_device = nullopt;\n"
device_check = "std::optional<Device> common_device = std::nullopt;\n"
device_check += "(void)common_device; // Suppress unused variable warning\n"
for arg in args:
# Only tensor like arguments are eligible
@ -688,11 +688,11 @@ resize_out(out, sizes, strides, options);
elif k is SchemaKind.inplace:
output_type = "std::reference_wrapper<Tensor>"
output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
proxy_field = f"std::array<c10::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
elif k is SchemaKind.out:
output_type = "std::reference_wrapper<Tensor>"
output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
proxy_field = f"std::array<c10::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
if self.backend_index.dispatch_key == DispatchKey.CUDA:
if self.rocm:

View File

@ -171,7 +171,7 @@ class Unboxing:
)
)
# pytorch codegen:
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<c10::optional<at::Tensor>>
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
elif (
isinstance(t.elem, OptionalType)
and isinstance(t.elem.elem, BaseType)
@ -180,8 +180,8 @@ class Unboxing:
code.extend(
f"""
#ifdef USE_ATEN_LIB
at::ArrayRef<c10::optional<at::Tensor>> {in_name} = {arg_name}.toListOptionalTensor();
c10::List<c10::optional<at::Tensor>> {out_name};
auto {in_name} = {arg_name}.toListOptionalTensor();
c10::List<::std::optional<at::Tensor>> {out_name};
for (auto {elem_name}: {in_name}) {{
{out_name}.push_back({elem_name});
}}

View File

@ -108,7 +108,7 @@ def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str
c_types[j] = c_types[j] + "*"
if aten_type.startswith("c10::ArrayRef<"):
# ArrayRef is passed as pointer + size, but no need to add "*" to the size argument
new_aten_types.append(f"c10::optional<{aten_type}>")
new_aten_types.append(f"::std::optional<{aten_type}>")
base_type = aten_type[len("c10::ArrayRef<") : -1]
new_callsite_exprs.append(
f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})"
@ -116,13 +116,13 @@ def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str
j += 2
elif aten_type == "c10::Device":
# Device is passed as device_type + device_index
new_aten_types.append("c10::optional<c10::Device>")
new_aten_types.append("::std::optional<c10::Device>")
new_callsite_exprs.append(
f"pointer_to_optional_device({names[j]}, {names[j+1]})"
)
j += 2
else:
new_aten_types.append(f"c10::optional<{aten_type}>")
new_aten_types.append(f"::std::optional<{aten_type}>")
new_callsite_exprs.append(
f"pointer_to_optional<{aten_type}>({names[j]})"
)
@ -152,8 +152,8 @@ def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str
# construct std::array<bool, N> instead
assert typ.size is not None
callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
elif atype == "c10::optional<at::Tensor>":
# convert from std::vector<c10::optional<at::Tensor>> to c10::List<c10::optional<at::Tensor>>
elif atype == "::std::optional<at::Tensor>":
# convert from std::vector<::std::optional<at::Tensor>> to c10::List<::std::optional<at::Tensor>>
callsite_exprs.append(
f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))"
)

View File

@ -153,19 +153,19 @@ at::Tensor to_meta(const at::Tensor& tensor) {
// undefined tensors can't be converted to the meta device, since they don't have sizes/strides
if (!tensor.defined()) return tensor;
auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
/*dtype=*/c10::make_optional(tensor.scalar_type()), /*layout=*/c10::make_optional(tensor.layout()), \
/*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt);
/*dtype=*/std::make_optional(tensor.scalar_type()), /*layout=*/std::make_optional(tensor.layout()), \
/*device=*/std::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/std::nullopt);
// needs to handle wrapped numbers, so dtype promotion works properly.
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
out.unsafeGetTensorImpl()->set_wrapped_number(true);
}
return out;
}
c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor>& tensor) {
std::optional<at::Tensor> to_meta(const std::optional<at::Tensor>& tensor) {
if (tensor.has_value()) {
return to_meta(*tensor);
}
return c10::nullopt;
return std::nullopt;
}
std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {