mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[12/N] Use std::optional (#132361)
Follows #132396 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132361 Approved by: https://github.com/eqy
This commit is contained in:
@ -127,7 +127,7 @@ void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<DimnameList> get_opt_names(const TensorImpl* impl) {
|
std::optional<DimnameList> get_opt_names(const TensorImpl* impl) {
|
||||||
const auto* meta = get_named_tensor_meta(impl);
|
const auto* meta = get_named_tensor_meta(impl);
|
||||||
if (meta == nullptr) {
|
if (meta == nullptr) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
@ -392,7 +392,7 @@ namespace impl {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
template<class T, bool AllowDeprecatedTypes>
|
template<class T, bool AllowDeprecatedTypes>
|
||||||
struct ivalue_to_arg<optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
|
struct ivalue_to_arg<std::optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
|
||||||
// If an argument is std::optional<ArrayRef<T>>, convert the IValue to an std::optional<std::vector<T>> and pass that
|
// If an argument is std::optional<ArrayRef<T>>, convert the IValue to an std::optional<std::vector<T>> and pass that
|
||||||
// to the operator. OptionalArray<T> is basically a std::optional<std::vector<T>> but implicitly convertible
|
// to the operator. OptionalArray<T> is basically a std::optional<std::vector<T>> but implicitly convertible
|
||||||
// to std::optional<ArrayRef<T>>.
|
// to std::optional<ArrayRef<T>>.
|
||||||
|
@ -45,7 +45,7 @@ namespace impl {
|
|||||||
|
|
||||||
TORCH_API void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName);
|
TORCH_API void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName);
|
||||||
|
|
||||||
inline void check_and_update_common_device(optional<Device>& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
inline void check_and_update_common_device(std::optional<Device>& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
||||||
// TODO: Remove this once the following issue is addressed:
|
// TODO: Remove this once the following issue is addressed:
|
||||||
// https://github.com/pytorch/pytorch/issues/57380
|
// https://github.com/pytorch/pytorch/issues/57380
|
||||||
if (!tensor.defined()) {
|
if (!tensor.defined()) {
|
||||||
@ -62,19 +62,19 @@ inline void check_and_update_common_device(optional<Device>& common_device, cons
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void check_and_update_common_device(optional<Device>& common_device, const std::optional<at::Tensor>& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
inline void check_and_update_common_device(std::optional<Device>& common_device, const std::optional<at::Tensor>& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
||||||
if (tensor.has_value()) {
|
if (tensor.has_value()) {
|
||||||
check_and_update_common_device(common_device, tensor.value(), methodName, argName);
|
check_and_update_common_device(common_device, tensor.value(), methodName, argName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void check_and_update_common_device(optional<Device>& common_device, at::ITensorListRef tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
inline void check_and_update_common_device(std::optional<Device>& common_device, at::ITensorListRef tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
||||||
for (const auto& tensor : tensors) {
|
for (const auto& tensor : tensors) {
|
||||||
check_and_update_common_device(common_device, tensor, methodName, argName);
|
check_and_update_common_device(common_device, tensor, methodName, argName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void check_and_update_common_device(optional<Device>& common_device, const List<optional<at::Tensor>>& tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
inline void check_and_update_common_device(std::optional<Device>& common_device, const List<std::optional<at::Tensor>>& tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
|
||||||
for (const auto& tensor : tensors) {
|
for (const auto& tensor : tensors) {
|
||||||
check_and_update_common_device(common_device, tensor, methodName, argName);
|
check_and_update_common_device(common_device, tensor, methodName, argName);
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
// NB: most activation functions fit pointwise unary or binary rules.
|
// NB: most activation functions fit pointwise unary or binary rules.
|
||||||
// These are only the ones that have special batch rules to help with organization
|
// These are only the ones that have special batch rules to help with organization
|
||||||
namespace at::functorch {
|
namespace at::functorch {
|
||||||
static std::tuple<Tensor,optional<int64_t>>
|
static std::tuple<Tensor, std::optional<int64_t>>
|
||||||
glu_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim, int64_t dim) {
|
glu_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim, int64_t dim) {
|
||||||
// repeated error message from glu because 0D -> 1D when batched
|
// repeated error message from glu because 0D -> 1D when batched
|
||||||
// this can't pass anyway because a 0-dimensional tensor has "size" 1, which
|
// this can't pass anyway because a 0-dimensional tensor has "size" 1, which
|
||||||
@ -27,7 +27,7 @@ glu_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim, int64_t dim
|
|||||||
return std::make_tuple(res, 0);
|
return std::make_tuple(res, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> glu_backward_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> glu_backward_batch_rule(
|
||||||
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim, int64_t dim) {
|
const Tensor& self, std::optional<int64_t> self_bdim, int64_t dim) {
|
||||||
if (self_bdim) {
|
if (self_bdim) {
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
namespace at::functorch {
|
namespace at::functorch {
|
||||||
|
|
||||||
template <typename F, F Func, typename... ExtraArgs>
|
template <typename F, F Func, typename... ExtraArgs>
|
||||||
std::tuple<Tensor,optional<int64_t>> _binary_pointwise_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> _binary_pointwise_batch_rule(
|
||||||
const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
|
const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
|
||||||
const Tensor& other, std::optional<int64_t> other_batch_dim,
|
const Tensor& other, std::optional<int64_t> other_batch_dim,
|
||||||
ExtraArgs... extra_args) {
|
ExtraArgs... extra_args) {
|
||||||
@ -33,7 +33,7 @@ struct BinaryPointwiseBatchRuleHelper;
|
|||||||
|
|
||||||
template <typename F, F Func, typename T1, typename T2, typename... T>
|
template <typename F, F Func, typename T1, typename T2, typename... T>
|
||||||
struct BinaryPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
|
struct BinaryPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
|
||||||
static std::tuple<Tensor,optional<int64_t>> apply(
|
static std::tuple<Tensor, std::optional<int64_t>> apply(
|
||||||
const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
|
const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
|
||||||
const Tensor& other, std::optional<int64_t> other_batch_dim,
|
const Tensor& other, std::optional<int64_t> other_batch_dim,
|
||||||
T... extra_args) {
|
T... extra_args) {
|
||||||
@ -120,7 +120,7 @@ void binary_pointwise_inplace_batch_rule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename F, F Func>
|
template <typename F, F Func>
|
||||||
std::tuple<Tensor,optional<int64_t>> comparison_pointwise_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> comparison_pointwise_batch_rule(
|
||||||
const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
|
const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
|
||||||
const Tensor& other, std::optional<int64_t> other_batch_dim) {
|
const Tensor& other, std::optional<int64_t> other_batch_dim) {
|
||||||
// compute max logical rank
|
// compute max logical rank
|
||||||
@ -142,7 +142,7 @@ std::tuple<Tensor,optional<int64_t>> comparison_pointwise_batch_rule(
|
|||||||
return std::make_tuple( std::move(result), 0 );
|
return std::make_tuple( std::move(result), 0 );
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> where_self_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> where_self_batch_rule(
|
||||||
const Tensor& condition, std::optional<int64_t> condition_bdim,
|
const Tensor& condition, std::optional<int64_t> condition_bdim,
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim, const Tensor& other, std::optional<int64_t> other_bdim) {
|
const Tensor& self, std::optional<int64_t> self_bdim, const Tensor& other, std::optional<int64_t> other_bdim) {
|
||||||
auto condition_logical_rank = rankWithoutBatchDim(condition, condition_bdim);
|
auto condition_logical_rank = rankWithoutBatchDim(condition, condition_bdim);
|
||||||
@ -177,7 +177,7 @@ static std::tuple<Tensor, std::optional<int64_t>> gelu_backward_batch_rule(
|
|||||||
return std::make_tuple(at::gelu_backward(grad_out_, input_, approximate), 0);
|
return std::make_tuple(at::gelu_backward(grad_out_, input_, approximate), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> masked_select_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> masked_select_batch_rule(
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
const Tensor& mask, std::optional<int64_t> mask_bdim) {
|
const Tensor& mask, std::optional<int64_t> mask_bdim) {
|
||||||
TORCH_CHECK(!mask_bdim.has_value(),
|
TORCH_CHECK(!mask_bdim.has_value(),
|
||||||
@ -196,7 +196,7 @@ static std::tuple<Tensor,optional<int64_t>> masked_select_batch_rule(
|
|||||||
return std::make_tuple(result, 0);
|
return std::make_tuple(result, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> masked_select_backward_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> masked_select_backward_batch_rule(
|
||||||
const Tensor& grad, std::optional<int64_t> grad_bdim,
|
const Tensor& grad, std::optional<int64_t> grad_bdim,
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
const Tensor& mask, std::optional<int64_t> mask_bdim) {
|
const Tensor& mask, std::optional<int64_t> mask_bdim) {
|
||||||
@ -221,7 +221,7 @@ static std::tuple<Tensor,optional<int64_t>> masked_select_backward_batch_rule(
|
|||||||
return std::make_tuple(result, 0);
|
return std::make_tuple(result, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> cdist_backward_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> cdist_backward_batch_rule(
|
||||||
const Tensor& grad, std::optional<int64_t> grad_bdim,
|
const Tensor& grad, std::optional<int64_t> grad_bdim,
|
||||||
const Tensor& x1, std::optional<int64_t> x1_bdim,
|
const Tensor& x1, std::optional<int64_t> x1_bdim,
|
||||||
const Tensor& x2, std::optional<int64_t> x2_bdim,
|
const Tensor& x2, std::optional<int64_t> x2_bdim,
|
||||||
|
@ -16,7 +16,7 @@ namespace at::functorch {
|
|||||||
// PyTorch's convolution is different from JAX's conv_general_dilated:
|
// PyTorch's convolution is different from JAX's conv_general_dilated:
|
||||||
// we do not support batch_group_count (which is needed for convolution backwards).
|
// we do not support batch_group_count (which is needed for convolution backwards).
|
||||||
// Instead, there's a convolution_backward op that needs a batching rule.
|
// Instead, there's a convolution_backward op that needs a batching rule.
|
||||||
static std::tuple<Tensor,optional<int64_t>>
|
static std::tuple<Tensor, std::optional<int64_t>>
|
||||||
convolution_batch_rule(const Tensor& lhs, std::optional<int64_t> lhs_bdim, const Tensor& rhs, std::optional<int64_t> rhs_bdim, const std::optional<Tensor>& bias, std::optional<int64_t> bias_bdim, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) {
|
convolution_batch_rule(const Tensor& lhs, std::optional<int64_t> lhs_bdim, const Tensor& rhs, std::optional<int64_t> rhs_bdim, const std::optional<Tensor>& bias, std::optional<int64_t> bias_bdim, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) {
|
||||||
DimVector lhs_spec(stride.size() + 2);
|
DimVector lhs_spec(stride.size() + 2);
|
||||||
std::iota(lhs_spec.begin(), lhs_spec.end(), 0);
|
std::iota(lhs_spec.begin(), lhs_spec.end(), 0);
|
||||||
@ -239,7 +239,7 @@ static Tensor make_dummy(
|
|||||||
return tensor_.new_empty({}).expand(expand_shape);
|
return tensor_.new_empty({}).expand(expand_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>>
|
static std::tuple<Tensor, std::optional<int64_t>>
|
||||||
convolution_backward_input_batch_rule(
|
convolution_backward_input_batch_rule(
|
||||||
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
||||||
const Tensor& input, std::optional<int64_t> input_bdim,
|
const Tensor& input, std::optional<int64_t> input_bdim,
|
||||||
@ -320,7 +320,7 @@ convolution_backward_input_batch_rule(
|
|||||||
return std::make_tuple(std::get<0>(result), std::nullopt);
|
return std::make_tuple(std::get<0>(result), std::nullopt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
static std::tuple<Tensor,optional<int64_t>>
|
static std::tuple<Tensor, std::optional<int64_t>>
|
||||||
convolution_backward_weight_batch_rule(
|
convolution_backward_weight_batch_rule(
|
||||||
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
||||||
const Tensor& input, std::optional<int64_t> input_bdim,
|
const Tensor& input, std::optional<int64_t> input_bdim,
|
||||||
|
@ -14,7 +14,7 @@ struct NewBlahBatchRuleHelperSymInt;
|
|||||||
|
|
||||||
template <typename F, F Func, typename A, typename B, typename... T>
|
template <typename F, F Func, typename A, typename B, typename... T>
|
||||||
struct NewBlahBatchRuleHelperSymInt<F, Func, typelist<A, B, T...>> {
|
struct NewBlahBatchRuleHelperSymInt<F, Func, typelist<A, B, T...>> {
|
||||||
static std::tuple<Tensor,optional<int64_t>> apply(
|
static std::tuple<Tensor, std::optional<int64_t>> apply(
|
||||||
const Tensor& tensor,
|
const Tensor& tensor,
|
||||||
std::optional<int64_t> batch_dim,
|
std::optional<int64_t> batch_dim,
|
||||||
SymIntArrayRef shape,
|
SymIntArrayRef shape,
|
||||||
@ -33,7 +33,7 @@ struct NewBlahBatchRuleHelper;
|
|||||||
|
|
||||||
template <typename F, F Func, typename A, typename B, typename... T>
|
template <typename F, F Func, typename A, typename B, typename... T>
|
||||||
struct NewBlahBatchRuleHelper<F, Func, typelist<A, B, T...>> {
|
struct NewBlahBatchRuleHelper<F, Func, typelist<A, B, T...>> {
|
||||||
static std::tuple<Tensor,optional<int64_t>> apply(
|
static std::tuple<Tensor, std::optional<int64_t>> apply(
|
||||||
const Tensor& tensor,
|
const Tensor& tensor,
|
||||||
std::optional<int64_t> batch_dim,
|
std::optional<int64_t> batch_dim,
|
||||||
IntArrayRef shape,
|
IntArrayRef shape,
|
||||||
@ -62,7 +62,7 @@ struct NewBlahBatchRuleHelper<F, Func, typelist<A, B, T...>> {
|
|||||||
&fn,\
|
&fn,\
|
||||||
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
|
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> _new_zeros_with_same_feature_meta_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> _new_zeros_with_same_feature_meta_batch_rule(
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
const Tensor& other, std::optional<int64_t> other_bdim,
|
const Tensor& other, std::optional<int64_t> other_bdim,
|
||||||
int64_t self_num_batch_dims) {
|
int64_t self_num_batch_dims) {
|
||||||
@ -103,7 +103,7 @@ static std::tuple<Tensor,optional<int64_t>> _new_zeros_with_same_feature_meta_ba
|
|||||||
return std::make_tuple(result, 0);
|
return std::make_tuple(result, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> linspace_logspace_batch_rule_helper(
|
static std::tuple<Tensor, std::optional<int64_t>> linspace_logspace_batch_rule_helper(
|
||||||
const at::Tensor& start, std::optional<int64_t> start_bdim,
|
const at::Tensor& start, std::optional<int64_t> start_bdim,
|
||||||
const at::Tensor& end, std::optional<int64_t> end_bdim,
|
const at::Tensor& end, std::optional<int64_t> end_bdim,
|
||||||
int64_t steps,
|
int64_t steps,
|
||||||
@ -141,7 +141,7 @@ static std::tuple<Tensor,optional<int64_t>> linspace_logspace_batch_rule_helper(
|
|||||||
return std::make_tuple(result, 0);
|
return std::make_tuple(result, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> linspace_Tensor_Tensor_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> linspace_Tensor_Tensor_batch_rule(
|
||||||
const at::Tensor& start, std::optional<int64_t> start_bdim,
|
const at::Tensor& start, std::optional<int64_t> start_bdim,
|
||||||
const at::Tensor& end, std::optional<int64_t> end_bdim,
|
const at::Tensor& end, std::optional<int64_t> end_bdim,
|
||||||
int64_t steps,
|
int64_t steps,
|
||||||
@ -152,7 +152,7 @@ static std::tuple<Tensor,optional<int64_t>> linspace_Tensor_Tensor_batch_rule(
|
|||||||
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, std::nullopt, dtype, layout, device, pin_memory);
|
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, std::nullopt, dtype, layout, device, pin_memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> linspace_Tensor_Scalar_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> linspace_Tensor_Scalar_batch_rule(
|
||||||
const at::Tensor& start, std::optional<int64_t> start_bdim,
|
const at::Tensor& start, std::optional<int64_t> start_bdim,
|
||||||
const at::Scalar& end,
|
const at::Scalar& end,
|
||||||
int64_t steps,
|
int64_t steps,
|
||||||
@ -165,7 +165,7 @@ static std::tuple<Tensor,optional<int64_t>> linspace_Tensor_Scalar_batch_rule(
|
|||||||
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, std::nullopt, steps, std::nullopt, dtype, layout, device, pin_memory);
|
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, std::nullopt, steps, std::nullopt, dtype, layout, device, pin_memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> linspace_Scalar_Tensor_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> linspace_Scalar_Tensor_batch_rule(
|
||||||
const at::Scalar& start,
|
const at::Scalar& start,
|
||||||
const at::Tensor& end, std::optional<int64_t> end_bdim,
|
const at::Tensor& end, std::optional<int64_t> end_bdim,
|
||||||
int64_t steps,
|
int64_t steps,
|
||||||
@ -178,7 +178,7 @@ static std::tuple<Tensor,optional<int64_t>> linspace_Scalar_Tensor_batch_rule(
|
|||||||
return linspace_logspace_batch_rule_helper(start_t, std::nullopt, end, end_bdim, steps, std::nullopt, dtype, layout, device, pin_memory);
|
return linspace_logspace_batch_rule_helper(start_t, std::nullopt, end, end_bdim, steps, std::nullopt, dtype, layout, device, pin_memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> logspace_Tensor_Tensor_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> logspace_Tensor_Tensor_batch_rule(
|
||||||
const at::Tensor& start, std::optional<int64_t> start_bdim,
|
const at::Tensor& start, std::optional<int64_t> start_bdim,
|
||||||
const at::Tensor& end, std::optional<int64_t> end_bdim,
|
const at::Tensor& end, std::optional<int64_t> end_bdim,
|
||||||
int64_t steps,
|
int64_t steps,
|
||||||
@ -190,7 +190,7 @@ static std::tuple<Tensor,optional<int64_t>> logspace_Tensor_Tensor_batch_rule(
|
|||||||
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, std::make_optional(base), dtype, layout, device, pin_memory);
|
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, std::make_optional(base), dtype, layout, device, pin_memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> logspace_Tensor_Scalar_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> logspace_Tensor_Scalar_batch_rule(
|
||||||
const at::Tensor& start, std::optional<int64_t> start_bdim,
|
const at::Tensor& start, std::optional<int64_t> start_bdim,
|
||||||
const at::Scalar& end,
|
const at::Scalar& end,
|
||||||
int64_t steps,
|
int64_t steps,
|
||||||
@ -204,7 +204,7 @@ static std::tuple<Tensor,optional<int64_t>> logspace_Tensor_Scalar_batch_rule(
|
|||||||
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, std::nullopt, steps, std::make_optional(base), dtype, layout, device, pin_memory);
|
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, std::nullopt, steps, std::make_optional(base), dtype, layout, device, pin_memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> logspace_Scalar_Tensor_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> logspace_Scalar_Tensor_batch_rule(
|
||||||
const at::Scalar& start,
|
const at::Scalar& start,
|
||||||
const at::Tensor& end, std::optional<int64_t> end_bdim,
|
const at::Tensor& end, std::optional<int64_t> end_bdim,
|
||||||
int64_t steps,
|
int64_t steps,
|
||||||
|
@ -33,7 +33,7 @@ TORCH_API Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1,
|
|||||||
Tensor moveBatchDimToFront(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
|
Tensor moveBatchDimToFront(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
|
||||||
int64_t rankWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
|
int64_t rankWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
|
||||||
int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
|
int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
|
||||||
optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val);
|
std::optional<int64_t> valIfNonempty(std::optional<int64_t> maybe_empty, int64_t new_val);
|
||||||
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
|
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
|
||||||
VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
|
VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ struct BasicUnaryBatchRuleHelper;
|
|||||||
|
|
||||||
template <typename F, F Func, typename A, typename... T>
|
template <typename F, F Func, typename A, typename... T>
|
||||||
struct BasicUnaryBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
struct BasicUnaryBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
||||||
static std::tuple<Tensor,optional<int64_t>> apply(
|
static std::tuple<Tensor, std::optional<int64_t>> apply(
|
||||||
const Tensor& tensor,
|
const Tensor& tensor,
|
||||||
std::optional<int64_t> batch_dim,
|
std::optional<int64_t> batch_dim,
|
||||||
T... extra_args) {
|
T... extra_args) {
|
||||||
@ -96,7 +96,7 @@ struct VariadicBdimsBatchRuleHelper;
|
|||||||
|
|
||||||
template <typename F, F Func, typename A, typename... T>
|
template <typename F, F Func, typename A, typename... T>
|
||||||
struct VariadicBdimsBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
struct VariadicBdimsBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
||||||
static std::tuple<Tensor,optional<int64_t>> apply(
|
static std::tuple<Tensor, std::optional<int64_t>> apply(
|
||||||
const Tensor& tensor,
|
const Tensor& tensor,
|
||||||
std::optional<int64_t> batch_dim,
|
std::optional<int64_t> batch_dim,
|
||||||
T... extra_args) {
|
T... extra_args) {
|
||||||
@ -201,7 +201,7 @@ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, std::optional<in
|
|||||||
#define VARIADIC_BDIMS_BOXED(op) \
|
#define VARIADIC_BDIMS_BOXED(op) \
|
||||||
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
|
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
|
||||||
|
|
||||||
using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t>>;
|
using UnpackedBatchedTensor = std::tuple<Tensor, std::optional<int64_t>>;
|
||||||
|
|
||||||
inline void find_and_unpack_tensors(
|
inline void find_and_unpack_tensors(
|
||||||
const torch::jit::Stack* stack,
|
const torch::jit::Stack* stack,
|
||||||
@ -384,7 +384,7 @@ struct ExistingBdimBatchRuleHelper;
|
|||||||
|
|
||||||
template <typename F, F Func, typename A, typename... T>
|
template <typename F, F Func, typename A, typename... T>
|
||||||
struct ExistingBdimBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
struct ExistingBdimBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
||||||
static std::tuple<Tensor,optional<int64_t>> apply(
|
static std::tuple<Tensor, std::optional<int64_t>> apply(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
T... extra_args) {
|
T... extra_args) {
|
||||||
|
@ -27,7 +27,7 @@ static at::Tensor flatten_logical(const Tensor& tensor, std::optional<int64_t> b
|
|||||||
|
|
||||||
// Useful for many loss functions
|
// Useful for many loss functions
|
||||||
template <typename Func>
|
template <typename Func>
|
||||||
static std::tuple<at::Tensor,optional<int64_t>>
|
static std::tuple<at::Tensor, std::optional<int64_t>>
|
||||||
loss_batch_rule_helper(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
|
loss_batch_rule_helper(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
|
||||||
std::optional<int64_t> target_bdim, int64_t reduction,
|
std::optional<int64_t> target_bdim, int64_t reduction,
|
||||||
Func loss_fn) {
|
Func loss_fn) {
|
||||||
@ -49,7 +49,7 @@ loss_batch_rule_helper(const at::Tensor& self, std::optional<int64_t> self_bdim,
|
|||||||
TORCH_INTERNAL_ASSERT(false);
|
TORCH_INTERNAL_ASSERT(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::tuple<at::Tensor,optional<int64_t>>
|
static std::tuple<at::Tensor, std::optional<int64_t>>
|
||||||
mse_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
|
mse_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
|
||||||
std::optional<int64_t> target_bdim, int64_t reduction) {
|
std::optional<int64_t> target_bdim, int64_t reduction) {
|
||||||
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
|
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
|
||||||
@ -58,7 +58,7 @@ mse_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, co
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::tuple<at::Tensor,optional<int64_t>>
|
static std::tuple<at::Tensor, std::optional<int64_t>>
|
||||||
huber_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
|
huber_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
|
||||||
std::optional<int64_t> target_bdim, int64_t reduction, double delta) {
|
std::optional<int64_t> target_bdim, int64_t reduction, double delta) {
|
||||||
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
|
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
|
||||||
@ -67,7 +67,7 @@ huber_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim,
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::tuple<at::Tensor,optional<int64_t>>
|
static std::tuple<at::Tensor, std::optional<int64_t>>
|
||||||
smooth_l1_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
|
smooth_l1_loss_batch_rule(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
|
||||||
std::optional<int64_t> target_bdim, int64_t reduction, double beta) {
|
std::optional<int64_t> target_bdim, int64_t reduction, double beta) {
|
||||||
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
|
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
|
||||||
|
@ -20,7 +20,7 @@ static Tensor getStepTensor(const Tensor& indices, const c10::SymInt& bdim_size,
|
|||||||
return range.view_symint(view_shape);
|
return range.view_symint(view_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> embedding_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> embedding_batch_rule(
|
||||||
const Tensor& weight, std::optional<int64_t> weight_bdim,
|
const Tensor& weight, std::optional<int64_t> weight_bdim,
|
||||||
const Tensor& indices, std::optional<int64_t> indices_bdim,
|
const Tensor& indices, std::optional<int64_t> indices_bdim,
|
||||||
c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) {
|
c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) {
|
||||||
@ -50,7 +50,7 @@ static std::tuple<Tensor,optional<int64_t>> embedding_batch_rule(
|
|||||||
return std::make_tuple(std::move(result), 0);
|
return std::make_tuple(std::move(result), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>>
|
static std::tuple<Tensor, std::optional<int64_t>>
|
||||||
embedding_dense_backward_batch_rule(
|
embedding_dense_backward_batch_rule(
|
||||||
const Tensor& grad_, std::optional<int64_t> grad_bdim,
|
const Tensor& grad_, std::optional<int64_t> grad_bdim,
|
||||||
const Tensor& indices_, std::optional<int64_t> indices_bdim,
|
const Tensor& indices_, std::optional<int64_t> indices_bdim,
|
||||||
@ -109,7 +109,7 @@ embedding_dense_backward_batch_rule(
|
|||||||
* output: (BN)CD_{out}H_{out}W_{out}
|
* output: (BN)CD_{out}H_{out}W_{out}
|
||||||
*/
|
*/
|
||||||
template<typename F, F Func, typename... ExtraArgs>
|
template<typename F, F Func, typename... ExtraArgs>
|
||||||
std::tuple<Tensor,optional<int64_t>>
|
std::tuple<Tensor, std::optional<int64_t>>
|
||||||
grid_sample_batch_rule(const Tensor& input, std::optional<int64_t> input_bdim, const Tensor& grid, std::optional<int64_t> grid_bdim, ExtraArgs... extra_args) {
|
grid_sample_batch_rule(const Tensor& input, std::optional<int64_t> input_bdim, const Tensor& grid, std::optional<int64_t> grid_bdim, ExtraArgs... extra_args) {
|
||||||
std::tuple<Tensor, std::optional<int64_t>> result;
|
std::tuple<Tensor, std::optional<int64_t>> result;
|
||||||
if (input_bdim && !grid_bdim) {
|
if (input_bdim && !grid_bdim) {
|
||||||
@ -256,7 +256,7 @@ struct UpsampleBackwardBatchRuleHelper;
|
|||||||
|
|
||||||
template <typename F, F Func, typename A, typename B, typename C, typename... T>
|
template <typename F, F Func, typename A, typename B, typename C, typename... T>
|
||||||
struct UpsampleBackwardBatchRuleHelper<F, Func, typelist<A, B, C, T...>> {
|
struct UpsampleBackwardBatchRuleHelper<F, Func, typelist<A, B, C, T...>> {
|
||||||
static std::tuple<Tensor,optional<int64_t>> apply(
|
static std::tuple<Tensor, std::optional<int64_t>> apply(
|
||||||
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
||||||
c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size,
|
c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size,
|
||||||
T... extra_args) {
|
T... extra_args) {
|
||||||
@ -282,7 +282,7 @@ struct GridSampleBatchRuleHelper;
|
|||||||
|
|
||||||
template <typename F, F Func, typename T1, typename T2, typename... T>
|
template <typename F, F Func, typename T1, typename T2, typename... T>
|
||||||
struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
|
struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
|
||||||
static std::tuple<Tensor,optional<int64_t>> apply(
|
static std::tuple<Tensor, std::optional<int64_t>> apply(
|
||||||
const Tensor& input, std::optional<int64_t> input_batch_dim,
|
const Tensor& input, std::optional<int64_t> input_batch_dim,
|
||||||
const Tensor& grid, std::optional<int64_t> grid_batch_dim,
|
const Tensor& grid, std::optional<int64_t> grid_batch_dim,
|
||||||
T... extra_args) {
|
T... extra_args) {
|
||||||
|
@ -42,7 +42,7 @@ static Tensor padRight(const Tensor& tensor, std::optional<int64_t> has_bdim, in
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename F, F Func>
|
template<typename F, F Func>
|
||||||
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>,Tensor,optional<int64_t>>
|
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
|
||||||
batch_norm_batch_rule(
|
batch_norm_batch_rule(
|
||||||
const Tensor& input, std::optional<int64_t> input_bdim,
|
const Tensor& input, std::optional<int64_t> input_bdim,
|
||||||
const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
|
const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
|
||||||
@ -124,7 +124,7 @@ batch_norm_batch_rule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename F, F Func>
|
template<typename F, F Func>
|
||||||
std::tuple<at::Tensor,optional<int64_t>> batch_norm_backward_no_weight_bias_batch_rule(
|
std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bias_batch_rule(
|
||||||
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
|
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
|
||||||
const at::Tensor & input, std::optional<int64_t> input_bdim,
|
const at::Tensor & input, std::optional<int64_t> input_bdim,
|
||||||
const std::optional<at::Tensor> & running_mean_opt, std::optional<int64_t> running_mean_bdim,
|
const std::optional<at::Tensor> & running_mean_opt, std::optional<int64_t> running_mean_bdim,
|
||||||
@ -337,7 +337,7 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
|
|||||||
return std::make_tuple(result0, mean, rstd);
|
return std::make_tuple(result0, mean, rstd);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<at::Tensor,optional<int64_t>> group_norm_backward_no_weight_bias_batch_rule(
|
static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_weight_bias_batch_rule(
|
||||||
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
|
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
|
||||||
const at::Tensor & input, std::optional<int64_t> input_bdim,
|
const at::Tensor & input, std::optional<int64_t> input_bdim,
|
||||||
const at::Tensor & mean, std::optional<int64_t> mean_bdim,
|
const at::Tensor & mean, std::optional<int64_t> mean_bdim,
|
||||||
@ -484,7 +484,7 @@ C10_ALWAYS_INLINE void _check_layer_norm_inputs(
|
|||||||
check_same_shape(bias, bias_bdim, normalized_shape, "weight");
|
check_same_shape(bias, bias_bdim, normalized_shape, "weight");
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>,Tensor,optional<int64_t>>
|
static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
|
||||||
native_layer_norm_batch_rule(
|
native_layer_norm_batch_rule(
|
||||||
const Tensor& input, std::optional<int64_t> input_bdim,
|
const Tensor& input, std::optional<int64_t> input_bdim,
|
||||||
c10::SymIntArrayRef normalized_shape,
|
c10::SymIntArrayRef normalized_shape,
|
||||||
@ -530,7 +530,7 @@ native_layer_norm_batch_rule(
|
|||||||
return std::make_tuple(result0, 0, mean, stats_bdim, rstd, stats_bdim);
|
return std::make_tuple(result0, 0, mean, stats_bdim, rstd, stats_bdim);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<at::Tensor,optional<int64_t>> native_layer_norm_backward_no_weight_bias_batch_rule(
|
static std::tuple<at::Tensor, std::optional<int64_t>> native_layer_norm_backward_no_weight_bias_batch_rule(
|
||||||
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
|
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
|
||||||
const at::Tensor & input, std::optional<int64_t> input_bdim,
|
const at::Tensor & input, std::optional<int64_t> input_bdim,
|
||||||
at::IntArrayRef normalized_shape,
|
at::IntArrayRef normalized_shape,
|
||||||
@ -651,7 +651,7 @@ static std::tuple<at::Tensor,at::Tensor,at::Tensor> native_layer_norm_backward_p
|
|||||||
|
|
||||||
template <typename F, F Func>
|
template <typename F, F Func>
|
||||||
struct NativeBatchNormBatchRuleHelper {
|
struct NativeBatchNormBatchRuleHelper {
|
||||||
static std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>,Tensor,optional<int64_t>> apply(
|
static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>> apply(
|
||||||
const Tensor& input, std::optional<int64_t> input_bdim,
|
const Tensor& input, std::optional<int64_t> input_bdim,
|
||||||
const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
|
const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
|
||||||
const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
|
const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
|
||||||
@ -666,7 +666,7 @@ struct NativeBatchNormBatchRuleHelper {
|
|||||||
|
|
||||||
template <typename F, F Func>
|
template <typename F, F Func>
|
||||||
struct CudnnBatchNormBatchRuleHelper {
|
struct CudnnBatchNormBatchRuleHelper {
|
||||||
static std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>,Tensor,optional<int64_t>,Tensor,optional<int64_t>> apply(
|
static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>> apply(
|
||||||
const Tensor& input, std::optional<int64_t> input_bdim,
|
const Tensor& input, std::optional<int64_t> input_bdim,
|
||||||
const Tensor& weight_opt, std::optional<int64_t> weight_bdim,
|
const Tensor& weight_opt, std::optional<int64_t> weight_bdim,
|
||||||
const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
|
const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
|
||||||
@ -683,7 +683,7 @@ struct CudnnBatchNormBatchRuleHelper {
|
|||||||
|
|
||||||
template <typename F, F Func>
|
template <typename F, F Func>
|
||||||
struct MiopenBatchNormBatchRuleHelper {
|
struct MiopenBatchNormBatchRuleHelper {
|
||||||
static std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>,Tensor,optional<int64_t>> apply(
|
static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>> apply(
|
||||||
const Tensor& input, std::optional<int64_t> input_bdim,
|
const Tensor& input, std::optional<int64_t> input_bdim,
|
||||||
const Tensor& weight_opt, std::optional<int64_t> weight_bdim,
|
const Tensor& weight_opt, std::optional<int64_t> weight_bdim,
|
||||||
const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
|
const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
namespace at::functorch {
|
namespace at::functorch {
|
||||||
|
|
||||||
template <typename Func>
|
template <typename Func>
|
||||||
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
|
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
|
||||||
max_pool_with_indices_batch_rule_helper(
|
max_pool_with_indices_batch_rule_helper(
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
IntArrayRef kernel_size, IntArrayRef stride,
|
IntArrayRef kernel_size, IntArrayRef stride,
|
||||||
@ -37,7 +37,7 @@ max_pool_with_indices_batch_rule_helper(
|
|||||||
reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0);
|
reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
|
static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
|
||||||
max_pool3d_with_indices_batch_rule(
|
max_pool3d_with_indices_batch_rule(
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
IntArrayRef kernel_size, IntArrayRef stride,
|
IntArrayRef kernel_size, IntArrayRef stride,
|
||||||
@ -45,7 +45,7 @@ max_pool3d_with_indices_batch_rule(
|
|||||||
return max_pool_with_indices_batch_rule_helper(self, self_bdim, kernel_size, stride, padding, dilation, ceil_mode, 3, at::max_pool3d_with_indices);
|
return max_pool_with_indices_batch_rule_helper(self, self_bdim, kernel_size, stride, padding, dilation, ceil_mode, 3, at::max_pool3d_with_indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
|
static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
|
||||||
max_pool2d_with_indices_batch_rule(
|
max_pool2d_with_indices_batch_rule(
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
IntArrayRef kernel_size, IntArrayRef stride,
|
IntArrayRef kernel_size, IntArrayRef stride,
|
||||||
|
@ -256,7 +256,7 @@ static std::tuple<Tensor, Tensor> expand_bdims(
|
|||||||
b_has_bdim ? b : b.expand_as(flagpole));
|
b_has_bdim ? b : b.expand_as(flagpole));
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> _softmax_backward_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> _softmax_backward_batch_rule(
|
||||||
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
||||||
const Tensor& output, std::optional<int64_t> output_bdim,
|
const Tensor& output, std::optional<int64_t> output_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
@ -286,7 +286,7 @@ static std::tuple<Tensor,optional<int64_t>> _softmax_backward_batch_rule(
|
|||||||
return std::make_tuple(at::_softmax_backward_data(grad_output_, output_.contiguous(), dim, input_dtype), 0);
|
return std::make_tuple(at::_softmax_backward_data(grad_output_, output_.contiguous(), dim, input_dtype), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> _log_softmax_backward_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> _log_softmax_backward_batch_rule(
|
||||||
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
||||||
const Tensor& output, std::optional<int64_t> output_bdim,
|
const Tensor& output, std::optional<int64_t> output_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
@ -314,7 +314,7 @@ static std::tuple<Tensor,optional<int64_t>> _log_softmax_backward_batch_rule(
|
|||||||
return std::make_tuple(at::_log_softmax_backward_data(grad_output_, output_, dim, input_dtype), 0);
|
return std::make_tuple(at::_log_softmax_backward_data(grad_output_, output_, dim, input_dtype), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,optional<int64_t>> searchsorted_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> searchsorted_batch_rule(
|
||||||
const Tensor& sorted_sequence,
|
const Tensor& sorted_sequence,
|
||||||
std::optional<int64_t> sorted_sequence_bdim,
|
std::optional<int64_t> sorted_sequence_bdim,
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
namespace at::functorch {
|
namespace at::functorch {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
static bool any_has_value(ArrayRef<optional<int64_t>> bdims) {
|
static bool any_has_value(ArrayRef<std::optional<int64_t>> bdims) {
|
||||||
for (const auto& bdim : bdims) {
|
for (const auto& bdim : bdims) {
|
||||||
if (bdim.has_value()) {
|
if (bdim.has_value()) {
|
||||||
return true;
|
return true;
|
||||||
@ -26,7 +26,7 @@ static bool any_has_value(ArrayRef<optional<int64_t>> bdims) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int64_t get_num_leading_nones(ArrayRef<optional<Tensor>> indices) {
|
static int64_t get_num_leading_nones(ArrayRef<std::optional<Tensor>> indices) {
|
||||||
int64_t result = 0;
|
int64_t result = 0;
|
||||||
for (const auto& idx : indices) {
|
for (const auto& idx : indices) {
|
||||||
if (!idx.has_value() || !idx->defined()) {
|
if (!idx.has_value() || !idx->defined()) {
|
||||||
@ -39,8 +39,8 @@ static int64_t get_num_leading_nones(ArrayRef<optional<Tensor>> indices) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static int64_t get_max_index_logical_dim(
|
static int64_t get_max_index_logical_dim(
|
||||||
ArrayRef<optional<Tensor>> indices,
|
ArrayRef<std::optional<Tensor>> indices,
|
||||||
ArrayRef<optional<int64_t>> indices_bdims) {
|
ArrayRef<std::optional<int64_t>> indices_bdims) {
|
||||||
int64_t max_logical_dim = -1;
|
int64_t max_logical_dim = -1;
|
||||||
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
|
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
|
||||||
TORCH_INTERNAL_ASSERT(!indices.empty());
|
TORCH_INTERNAL_ASSERT(!indices.empty());
|
||||||
@ -55,9 +55,9 @@ static int64_t get_max_index_logical_dim(
|
|||||||
return max_logical_dim;
|
return max_logical_dim;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<optional<Tensor>> batchIndices(
|
static std::vector<std::optional<Tensor>> batchIndices(
|
||||||
ArrayRef<optional<Tensor>> indices,
|
ArrayRef<std::optional<Tensor>> indices,
|
||||||
ArrayRef<optional<int64_t>> indices_bdims,
|
ArrayRef<std::optional<int64_t>> indices_bdims,
|
||||||
int64_t batch_size,
|
int64_t batch_size,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
std::optional<int64_t> values_bdim = std::nullopt) {
|
std::optional<int64_t> values_bdim = std::nullopt) {
|
||||||
@ -82,7 +82,7 @@ static std::vector<optional<Tensor>> batchIndices(
|
|||||||
// There is one more case worth mentioning - boolean tensor indices. If we
|
// There is one more case worth mentioning - boolean tensor indices. If we
|
||||||
// have "batched" boolean tensor indices, that is unrepresentable, as each
|
// have "batched" boolean tensor indices, that is unrepresentable, as each
|
||||||
// batch would result in a tensor with different values.
|
// batch would result in a tensor with different values.
|
||||||
std::vector<optional<Tensor>> indices_;
|
std::vector<std::optional<Tensor>> indices_;
|
||||||
|
|
||||||
int64_t maxLogicalRank = get_max_index_logical_dim(indices, indices_bdims);
|
int64_t maxLogicalRank = get_max_index_logical_dim(indices, indices_bdims);
|
||||||
bool indices_batched = any_has_value(indices_bdims);
|
bool indices_batched = any_has_value(indices_bdims);
|
||||||
@ -133,7 +133,7 @@ static bool is_advanced_index(const std::optional<Tensor>& idx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// See NOTE: [advanced indices adjacent] for definition
|
// See NOTE: [advanced indices adjacent] for definition
|
||||||
static bool are_advanced_indices_adjacent(ArrayRef<optional<Tensor>> indices) {
|
static bool are_advanced_indices_adjacent(ArrayRef<std::optional<Tensor>> indices) {
|
||||||
int64_t num_advanced_indices_regions = 0;
|
int64_t num_advanced_indices_regions = 0;
|
||||||
bool in_advanced_indices_region = false;
|
bool in_advanced_indices_region = false;
|
||||||
for (const auto& idx : indices) {
|
for (const auto& idx : indices) {
|
||||||
@ -171,11 +171,11 @@ static Tensor swap_regions(const Tensor& tensor, int64_t first_region_size, int6
|
|||||||
return tensor.permute(permutation);
|
return tensor.permute(permutation);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> index_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> index_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
ArrayRef<optional<Tensor>> indices,
|
ArrayRef<std::optional<Tensor>> indices,
|
||||||
ArrayRef<optional<int64_t>> indices_bdims) {
|
ArrayRef<std::optional<int64_t>> indices_bdims) {
|
||||||
|
|
||||||
// NOTE: [advanced indexing (index.Tensor) batch rule]
|
// NOTE: [advanced indexing (index.Tensor) batch rule]
|
||||||
//
|
//
|
||||||
@ -240,7 +240,7 @@ std::tuple<Tensor,optional<int64_t>> index_batch_rule(
|
|||||||
auto max_index_dim = get_max_index_logical_dim(indices, indices_bdims);
|
auto max_index_dim = get_max_index_logical_dim(indices, indices_bdims);
|
||||||
|
|
||||||
// Step 2
|
// Step 2
|
||||||
auto res = at::index(self_, List<optional<Tensor>>(batched_indices));
|
auto res = at::index(self_, List<std::optional<Tensor>>(batched_indices));
|
||||||
|
|
||||||
// Step 3: There are three cases (these match the cases outlined in batchIndices)
|
// Step 3: There are three cases (these match the cases outlined in batchIndices)
|
||||||
bool self_batched = self_bdim.has_value();
|
bool self_batched = self_bdim.has_value();
|
||||||
@ -315,8 +315,8 @@ std::tuple<Tensor,optional<int64_t>> index_batch_rule(
|
|||||||
return std::make_tuple(swap_regions(res, max_index_dim, num_leading_nones), 0);
|
return std::make_tuple(swap_regions(res, max_index_dim, num_leading_nones), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// plumbing done since we don't support List<optional<Tensor>> in codegen
|
// plumbing done since we don't support List<std::optional<Tensor>> in codegen
|
||||||
Tensor index_plumbing(const Tensor & self, const List<optional<Tensor>> & indices
|
Tensor index_plumbing(const Tensor & self, const List<std::optional<Tensor>> & indices
|
||||||
) {
|
) {
|
||||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||||
auto maybe_layer = maybeCurrentDynamicLayer();
|
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||||
@ -326,8 +326,8 @@ Tensor index_plumbing(const Tensor & self, const List<optional<Tensor>> & indice
|
|||||||
return at::index(self, indices);
|
return at::index(self, indices);
|
||||||
}
|
}
|
||||||
auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
|
auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
|
||||||
std::vector<optional<Tensor>> indices_value;
|
std::vector<std::optional<Tensor>> indices_value;
|
||||||
std::vector<optional<int64_t>> indices_bdims;
|
std::vector<std::optional<int64_t>> indices_bdims;
|
||||||
for (const auto&& indRef : indices) {
|
for (const auto&& indRef : indices) {
|
||||||
std::optional<Tensor> ind = indRef;
|
std::optional<Tensor> ind = indRef;
|
||||||
std::optional<Tensor> index;
|
std::optional<Tensor> index;
|
||||||
@ -399,11 +399,11 @@ namespace {
|
|||||||
return compute_indexed_shape(self, indices);
|
return compute_indexed_shape(self, indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, std::vector<optional<Tensor>>, Tensor>
|
std::tuple<Tensor, std::vector<std::optional<Tensor>>, Tensor>
|
||||||
index_put_batch_rule_helper(const Tensor &self,
|
index_put_batch_rule_helper(const Tensor &self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
ArrayRef<optional<Tensor>> indices,
|
ArrayRef<std::optional<Tensor>> indices,
|
||||||
ArrayRef<optional<int64_t>> indices_bdims,
|
ArrayRef<std::optional<int64_t>> indices_bdims,
|
||||||
const Tensor &values,
|
const Tensor &values,
|
||||||
std::optional<int64_t> values_bdim,
|
std::optional<int64_t> values_bdim,
|
||||||
std::optional<int64_t> opt_batch_size = {}) {
|
std::optional<int64_t> opt_batch_size = {}) {
|
||||||
@ -420,7 +420,7 @@ namespace {
|
|||||||
// we've already made sure that self has bdim at 0.
|
// we've already made sure that self has bdim at 0.
|
||||||
const auto indices_ = batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim);
|
const auto indices_ = batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim);
|
||||||
|
|
||||||
auto indexed_shape = get_indexed_shape(self_, List<optional<Tensor>>(indices_));
|
auto indexed_shape = get_indexed_shape(self_, List<std::optional<Tensor>>(indices_));
|
||||||
|
|
||||||
// handle broadcasting support for values
|
// handle broadcasting support for values
|
||||||
// Eg. Given `indexed_shape.size()` is 5 and
|
// Eg. Given `indexed_shape.size()` is 5 and
|
||||||
@ -452,12 +452,12 @@ namespace {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto unpackSelfAndIndicesAndValuesAtCurrentLevel(const Tensor &self,
|
auto unpackSelfAndIndicesAndValuesAtCurrentLevel(const Tensor &self,
|
||||||
const List<optional<Tensor>> &indices,
|
const List<std::optional<Tensor>> &indices,
|
||||||
const Tensor &values, int64_t cur_level)
|
const Tensor &values, int64_t cur_level)
|
||||||
{
|
{
|
||||||
auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
|
auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
|
||||||
std::vector<optional<Tensor>> indices_value;
|
std::vector<std::optional<Tensor>> indices_value;
|
||||||
std::vector<optional<int64_t>> indices_bdims;
|
std::vector<std::optional<int64_t>> indices_bdims;
|
||||||
for (const auto &&indRef : indices)
|
for (const auto &&indRef : indices)
|
||||||
{
|
{
|
||||||
std::optional<Tensor> ind = indRef;
|
std::optional<Tensor> ind = indRef;
|
||||||
@ -478,8 +478,8 @@ namespace {
|
|||||||
void index_put__batch_rule(
|
void index_put__batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
ArrayRef<optional<Tensor>> indices,
|
ArrayRef<std::optional<Tensor>> indices,
|
||||||
ArrayRef<optional<int64_t>> indices_bdims,
|
ArrayRef<std::optional<int64_t>> indices_bdims,
|
||||||
const Tensor& values,
|
const Tensor& values,
|
||||||
std::optional<int64_t> values_bdim,
|
std::optional<int64_t> values_bdim,
|
||||||
bool accumulate) {
|
bool accumulate) {
|
||||||
@ -488,11 +488,11 @@ void index_put__batch_rule(
|
|||||||
}
|
}
|
||||||
auto [self_, indices_, values_] = index_put_batch_rule_helper(
|
auto [self_, indices_, values_] = index_put_batch_rule_helper(
|
||||||
self, self_bdim, indices, indices_bdims, values, values_bdim);
|
self, self_bdim, indices, indices_bdims, values, values_bdim);
|
||||||
at::index_put_(self_, List<optional<Tensor>>(indices_), values_, accumulate);
|
at::index_put_(self_, List<std::optional<Tensor>>(indices_), values_, accumulate);
|
||||||
}
|
}
|
||||||
|
|
||||||
// plumbing done since we don't support List<optional<Tensor>> in codegen
|
// plumbing done since we don't support List<std::optional<Tensor>> in codegen
|
||||||
Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indices
|
Tensor& index_put__plumbing(Tensor & self, const List<std::optional<Tensor>> & indices
|
||||||
, const Tensor & values, bool accumulate) {
|
, const Tensor & values, bool accumulate) {
|
||||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||||
auto maybe_layer = maybeCurrentDynamicLayer();
|
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||||
@ -517,8 +517,8 @@ Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indice
|
|||||||
void _index_put_impl__batch_rule(
|
void _index_put_impl__batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
ArrayRef<optional<Tensor>> indices,
|
ArrayRef<std::optional<Tensor>> indices,
|
||||||
ArrayRef<optional<int64_t>> indices_bdims,
|
ArrayRef<std::optional<int64_t>> indices_bdims,
|
||||||
const Tensor& values,
|
const Tensor& values,
|
||||||
std::optional<int64_t> values_bdim,
|
std::optional<int64_t> values_bdim,
|
||||||
bool accumulate,
|
bool accumulate,
|
||||||
@ -528,11 +528,11 @@ void _index_put_impl__batch_rule(
|
|||||||
}
|
}
|
||||||
auto [self_, indices_, values_] = index_put_batch_rule_helper(
|
auto [self_, indices_, values_] = index_put_batch_rule_helper(
|
||||||
self, self_bdim, indices, indices_bdims, values, values_bdim);
|
self, self_bdim, indices, indices_bdims, values, values_bdim);
|
||||||
at::_index_put_impl_(self_, List<optional<Tensor>>(indices_), values_, accumulate, unsafe);
|
at::_index_put_impl_(self_, List<std::optional<Tensor>>(indices_), values_, accumulate, unsafe);
|
||||||
}
|
}
|
||||||
|
|
||||||
// plumbing done since we don't support List<optional<Tensor>> in codegen
|
// plumbing done since we don't support List<std::optional<Tensor>> in codegen
|
||||||
Tensor &_index_put_impl__plumbing(Tensor &self, const List<optional<Tensor>> &indices,
|
Tensor &_index_put_impl__plumbing(Tensor &self, const List<std::optional<Tensor>> &indices,
|
||||||
const Tensor &values, bool accumulate, bool unsafe) {
|
const Tensor &values, bool accumulate, bool unsafe) {
|
||||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||||
auto maybe_layer = maybeCurrentDynamicLayer();
|
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||||
@ -549,8 +549,8 @@ Tensor &_index_put_impl__plumbing(Tensor &self, const List<optional<Tensor>> &in
|
|||||||
|
|
||||||
static Tensor maybe_permute_values(
|
static Tensor maybe_permute_values(
|
||||||
const Tensor& values,
|
const Tensor& values,
|
||||||
ArrayRef<optional<Tensor>> orig_indices,
|
ArrayRef<std::optional<Tensor>> orig_indices,
|
||||||
ArrayRef<optional<int64_t>> orig_indices_bdims) {
|
ArrayRef<std::optional<int64_t>> orig_indices_bdims) {
|
||||||
bool indices_batched = any_has_value(orig_indices_bdims);
|
bool indices_batched = any_has_value(orig_indices_bdims);
|
||||||
bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(orig_indices);
|
bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(orig_indices);
|
||||||
auto num_leading_nones = get_num_leading_nones(orig_indices);
|
auto num_leading_nones = get_num_leading_nones(orig_indices);
|
||||||
@ -602,11 +602,11 @@ static Tensor maybe_permute_values(
|
|||||||
return swap_regions(values, num_leading_nones, max_index_dim);
|
return swap_regions(values, num_leading_nones, max_index_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> index_put_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> index_put_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
ArrayRef<optional<Tensor>> indices,
|
ArrayRef<std::optional<Tensor>> indices,
|
||||||
ArrayRef<optional<int64_t>> indices_bdims,
|
ArrayRef<std::optional<int64_t>> indices_bdims,
|
||||||
const Tensor& values,
|
const Tensor& values,
|
||||||
std::optional<int64_t> values_bdim,
|
std::optional<int64_t> values_bdim,
|
||||||
bool accumulate) {
|
bool accumulate) {
|
||||||
@ -641,12 +641,12 @@ std::tuple<Tensor,optional<int64_t>> index_put_batch_rule(
|
|||||||
// and the batched `indices_` might change the "have adjacent advanced indices" property
|
// and the batched `indices_` might change the "have adjacent advanced indices" property
|
||||||
values_ = maybe_permute_values(values_, indices, indices_bdims);
|
values_ = maybe_permute_values(values_, indices, indices_bdims);
|
||||||
|
|
||||||
auto result = at::index_put(self_, List<optional<Tensor>>(indices_), values_, accumulate);
|
auto result = at::index_put(self_, List<std::optional<Tensor>>(indices_), values_, accumulate);
|
||||||
return std::make_tuple(result, 0);
|
return std::make_tuple(result, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// plumbing done since we don't support List<optional<Tensor>> in codegen
|
// plumbing done since we don't support List<std::optional<Tensor>> in codegen
|
||||||
Tensor index_put_plumbing(const Tensor & self, const List<optional<Tensor>> & indices,
|
Tensor index_put_plumbing(const Tensor & self, const List<std::optional<Tensor>> & indices,
|
||||||
const Tensor & values, bool accumulate) {
|
const Tensor & values, bool accumulate) {
|
||||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||||
auto maybe_layer = maybeCurrentDynamicLayer();
|
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||||
@ -671,7 +671,7 @@ Tensor index_put_plumbing(const Tensor & self, const List<optional<Tensor>> & in
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template<typename Func, typename ...Args>
|
template<typename Func, typename ...Args>
|
||||||
std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> scatter_batch_rule(
|
||||||
Func f,
|
Func f,
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
@ -703,7 +703,7 @@ std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename Func, typename ...Args>
|
template <typename Func, typename ...Args>
|
||||||
inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
|
inline std::tuple<Tensor, std::optional<int64_t>> scatter_batch_rule(
|
||||||
Func f,
|
Func f,
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
@ -742,7 +742,7 @@ inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> scatter_value_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> scatter_value_batch_rule(
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||||
@ -751,7 +751,7 @@ std::tuple<Tensor,optional<int64_t>> scatter_value_batch_rule(
|
|||||||
self, self_bdim, dim, index, index_bdim, value);
|
self, self_bdim, dim, index, index_bdim, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> scatter_src_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> scatter_src_batch_rule(
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||||
@ -760,7 +760,7 @@ std::tuple<Tensor,optional<int64_t>> scatter_src_batch_rule(
|
|||||||
self, self_bdim, dim, index, index_bdim, src, src_bdim);
|
self, self_bdim, dim, index, index_bdim, src, src_bdim);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> scatter_add_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> scatter_add_batch_rule(
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||||
@ -769,7 +769,7 @@ std::tuple<Tensor,optional<int64_t>> scatter_add_batch_rule(
|
|||||||
self, self_bdim, dim, index, index_bdim, src, src_bdim);
|
self, self_bdim, dim, index, index_bdim, src, src_bdim);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> scatter_reduce_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> scatter_reduce_batch_rule(
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||||
@ -779,7 +779,7 @@ std::tuple<Tensor,optional<int64_t>> scatter_reduce_batch_rule(
|
|||||||
self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce);
|
self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> scatter_value_reduce_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> scatter_value_reduce_batch_rule(
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||||
@ -789,7 +789,7 @@ std::tuple<Tensor,optional<int64_t>> scatter_value_reduce_batch_rule(
|
|||||||
self, self_bdim, dim, index, index_bdim, src, reduce);
|
self, self_bdim, dim, index, index_bdim, src, reduce);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> gather_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> gather_batch_rule(
|
||||||
const Tensor& self, std::optional<int64_t> self_bdim,
|
const Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||||
@ -922,7 +922,7 @@ std::tuple<Tensor, std::optional<int64_t>> diagonal_scatter_batch_rule(
|
|||||||
return std::make_tuple(at::diagonal_scatter(self_, src_, offset, dim1, dim2), 0);
|
return std::make_tuple(at::diagonal_scatter(self_, src_, offset, dim1, dim2), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> index_add_batch_rule_impl(
|
std::tuple<Tensor, std::optional<int64_t>> index_add_batch_rule_impl(
|
||||||
Tensor& self, std::optional<int64_t> self_bdim,
|
Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||||
@ -1004,7 +1004,7 @@ void index_add__batch_rule(
|
|||||||
other_bdim, alpha, true);
|
other_bdim, alpha, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> index_add_batch_rule(
|
||||||
Tensor& self, std::optional<int64_t> self_bdim,
|
Tensor& self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor& index, std::optional<int64_t> index_bdim,
|
const Tensor& index, std::optional<int64_t> index_bdim,
|
||||||
@ -1038,7 +1038,7 @@ static std::tuple<Tensor,Tensor> binary_pointwise_align(
|
|||||||
return std::make_tuple(tensor_, other_);
|
return std::make_tuple(tensor_, other_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> masked_fill_scalar_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> masked_fill_scalar_batch_rule(
|
||||||
const Tensor & self,
|
const Tensor & self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
const Tensor & mask,
|
const Tensor & mask,
|
||||||
@ -1049,7 +1049,7 @@ std::tuple<Tensor,optional<int64_t>> masked_fill_scalar_batch_rule(
|
|||||||
return std::make_tuple(result, 0);
|
return std::make_tuple(result, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> index_fill_batch_rule_helper(
|
std::tuple<Tensor, std::optional<int64_t>> index_fill_batch_rule_helper(
|
||||||
int64_t batch_size,
|
int64_t batch_size,
|
||||||
int64_t self_logical_rank,
|
int64_t self_logical_rank,
|
||||||
int64_t index_logical_rank,
|
int64_t index_logical_rank,
|
||||||
@ -1085,7 +1085,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_batch_rule_helper(
|
|||||||
return std::make_tuple(self_, 0);
|
return std::make_tuple(self_, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
|
std::tuple<Tensor, std::optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
|
||||||
Tensor & self, std::optional<int64_t> self_bdim,
|
Tensor & self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor & index, std::optional<int64_t> index_bdim,
|
const Tensor & index, std::optional<int64_t> index_bdim,
|
||||||
@ -1136,7 +1136,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
|
|||||||
return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value);
|
return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
|
std::tuple<Tensor, std::optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
|
||||||
Tensor & self, std::optional<int64_t> self_bdim,
|
Tensor & self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor & index, std::optional<int64_t> index_bdim,
|
const Tensor & index, std::optional<int64_t> index_bdim,
|
||||||
@ -1207,7 +1207,7 @@ void index_fill__int_tensor_batch_rule(
|
|||||||
index_fill_int_tensor_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, value_bdim, true);
|
index_fill_int_tensor_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, value_bdim, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> index_fill_int_scalar_batch_rule(
|
||||||
const Tensor & self, std::optional<int64_t> self_bdim,
|
const Tensor & self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor & index, std::optional<int64_t> index_bdim,
|
const Tensor & index, std::optional<int64_t> index_bdim,
|
||||||
@ -1216,7 +1216,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule(
|
|||||||
return index_fill_int_scalar_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, false);
|
return index_fill_int_scalar_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> index_fill_int_tensor_batch_rule(
|
||||||
const Tensor & self, std::optional<int64_t> self_bdim,
|
const Tensor & self, std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const Tensor & index, std::optional<int64_t> index_bdim,
|
const Tensor & index, std::optional<int64_t> index_bdim,
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
namespace at::functorch {
|
namespace at::functorch {
|
||||||
|
|
||||||
namespace{
|
namespace{
|
||||||
std::tuple<Tensor,optional<int64_t>>
|
std::tuple<Tensor, std::optional<int64_t>>
|
||||||
clone_batch_rule(
|
clone_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
@ -48,7 +48,7 @@ clone_batch_rule(
|
|||||||
return std::make_tuple(result, self_bdim);
|
return std::make_tuple(result, self_bdim);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>>
|
std::tuple<Tensor, std::optional<int64_t>>
|
||||||
view_as_complex_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim) {
|
view_as_complex_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim) {
|
||||||
// guard against the user passing in a batch of scalar tensors with batch
|
// guard against the user passing in a batch of scalar tensors with batch
|
||||||
// size equal to 2.
|
// size equal to 2.
|
||||||
|
@ -36,7 +36,7 @@ namespace at::functorch {
|
|||||||
// `Tensor sum(const Tensor& self, int64_t dim)`. The signature of the
|
// `Tensor sum(const Tensor& self, int64_t dim)`. The signature of the
|
||||||
// batch rule has an additional std::optional<int64_t> argument after each
|
// batch rule has an additional std::optional<int64_t> argument after each
|
||||||
// Tensor argument and return. So, in this case, the batch rule has signature
|
// Tensor argument and return. So, in this case, the batch rule has signature
|
||||||
// tuple<Tensor,optional<int64_t>> sum_batch_rule(
|
// tuple<Tensor, std::optional<int64_t>> sum_batch_rule(
|
||||||
// const Tensor& self, std::optional<int64_t> self_bdim, int64_t dim);
|
// const Tensor& self, std::optional<int64_t> self_bdim, int64_t dim);
|
||||||
//
|
//
|
||||||
// The vmap call above invokes the batch rule with `self = tensor`,
|
// The vmap call above invokes the batch rule with `self = tensor`,
|
||||||
@ -90,7 +90,7 @@ namespace at::functorch {
|
|||||||
|
|
||||||
namespace{
|
namespace{
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> unsqueeze_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
int64_t dim) {
|
int64_t dim) {
|
||||||
@ -101,7 +101,7 @@ std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NB: repeat is not actually a view, but it is in this file
|
// NB: repeat is not actually a view, but it is in this file
|
||||||
std::tuple<Tensor,optional<int64_t>> repeat_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> repeat_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
c10::SymIntArrayRef sizes) {
|
c10::SymIntArrayRef sizes) {
|
||||||
@ -116,7 +116,7 @@ std::tuple<Tensor,optional<int64_t>> repeat_batch_rule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> _unsafe_view_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> _unsafe_view_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
c10::SymIntArrayRef size) {
|
c10::SymIntArrayRef size) {
|
||||||
@ -137,7 +137,7 @@ std::tuple<Tensor,optional<int64_t>> _unsafe_view_batch_rule(
|
|||||||
return std::make_tuple(at::_unsafe_view_symint(self_, view_size), 0);
|
return std::make_tuple(at::_unsafe_view_symint(self_, view_size), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> flip_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim, IntArrayRef dims) {
|
std::tuple<Tensor, std::optional<int64_t>> flip_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim, IntArrayRef dims) {
|
||||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||||
VmapDimVector new_dims;
|
VmapDimVector new_dims;
|
||||||
for (auto i: dims) {
|
for (auto i: dims) {
|
||||||
@ -317,7 +317,7 @@ std::tuple<Tensor, std::optional<int64_t>> diagonal_batching_rule(
|
|||||||
return std::make_tuple(std::move(result), 0);
|
return std::make_tuple(std::move(result), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> diagonal_backward_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> diagonal_backward_batch_rule(
|
||||||
const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
|
const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
|
||||||
c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
|
c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
|
||||||
auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
|
auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
|
||||||
@ -331,7 +331,7 @@ std::tuple<Tensor,optional<int64_t>> diagonal_backward_batch_rule(
|
|||||||
return std::make_tuple(std::move(result), 0);
|
return std::make_tuple(std::move(result), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> slice_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> slice_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
@ -349,7 +349,7 @@ static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
|
|||||||
return dim == 0 || dim == -1;
|
return dim == 0 || dim == -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>>
|
std::tuple<Tensor, std::optional<int64_t>>
|
||||||
transpose_int_batch_rule(
|
transpose_int_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
@ -389,7 +389,7 @@ std::tuple<Tensor, std::optional<int64_t>> permute_batching_rule(
|
|||||||
return std::make_tuple(self_.permute(dims_), 0);
|
return std::make_tuple(self_.permute(dims_), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> select_backward_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> select_backward_batch_rule(
|
||||||
const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
|
const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
|
||||||
c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) {
|
c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) {
|
||||||
auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
|
auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
|
||||||
@ -402,7 +402,7 @@ std::tuple<Tensor,optional<int64_t>> select_backward_batch_rule(
|
|||||||
return std::make_tuple(std::move(result), 0);
|
return std::make_tuple(std::move(result), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> slice_backward_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> slice_backward_batch_rule(
|
||||||
const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
|
const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
|
||||||
SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) {
|
SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) {
|
||||||
auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
|
auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
|
||||||
@ -427,7 +427,7 @@ std::tuple<Tensor, std::optional<int64_t>> view_batching_rule(
|
|||||||
return std::make_tuple(self_.view_symint(size_), 0);
|
return std::make_tuple(self_.view_symint(size_), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> view_copy_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> view_copy_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
c10::SymIntArrayRef size) {
|
c10::SymIntArrayRef size) {
|
||||||
@ -530,7 +530,7 @@ Tensor trace_decomp(const Tensor& tensor) {
|
|||||||
return tensor.diagonal().sum();
|
return tensor.diagonal().sum();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> tril_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> tril_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
int64_t diagonal = 0) {
|
int64_t diagonal = 0) {
|
||||||
@ -540,7 +540,7 @@ std::tuple<Tensor,optional<int64_t>> tril_batch_rule(
|
|||||||
return std::make_tuple(std::move(result), 0);
|
return std::make_tuple(std::move(result), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor,optional<int64_t>> triu_batch_rule(
|
std::tuple<Tensor, std::optional<int64_t>> triu_batch_rule(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
std::optional<int64_t> self_bdim,
|
std::optional<int64_t> self_bdim,
|
||||||
int64_t diagonal = 0) {
|
int64_t diagonal = 0) {
|
||||||
|
@ -53,7 +53,7 @@ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor&, std::optional
|
|||||||
using type = Tail;
|
using type = Tail;
|
||||||
};
|
};
|
||||||
template <class Next, class Tail>
|
template <class Next, class Tail>
|
||||||
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<optional<Tensor>, std::optional<int64_t>, Next, Tail> {
|
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::optional<Tensor>, std::optional<int64_t>, Next, Tail> {
|
||||||
using type = Tail;
|
using type = Tail;
|
||||||
};
|
};
|
||||||
template <class Next, class Tail>
|
template <class Next, class Tail>
|
||||||
@ -61,7 +61,7 @@ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const std::optional<Te
|
|||||||
using type = Tail;
|
using type = Tail;
|
||||||
};
|
};
|
||||||
template <class Next, class Tail>
|
template <class Next, class Tail>
|
||||||
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<optional<Tensor>&, std::optional<int64_t>, Next, Tail> {
|
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::optional<Tensor>&, std::optional<int64_t>, Next, Tail> {
|
||||||
using type = Tail;
|
using type = Tail;
|
||||||
};
|
};
|
||||||
template <class Next, class Tail>
|
template <class Next, class Tail>
|
||||||
|
@ -175,7 +175,7 @@ const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level) {
|
|||||||
return dynamic_layer.interpreter().is_alive_ptr();
|
return dynamic_layer.interpreter().is_alive_ptr();
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<DynamicLayer> maybeCurrentDynamicLayer() {
|
std::optional<DynamicLayer> maybeCurrentDynamicLayer() {
|
||||||
auto& dynamicLayerStack = dynamicLayerStackAccessor();
|
auto& dynamicLayerStack = dynamicLayerStackAccessor();
|
||||||
if (dynamicLayerStack.empty()) {
|
if (dynamicLayerStack.empty()) {
|
||||||
return {};
|
return {};
|
||||||
|
@ -82,7 +82,7 @@ bool isBatchedAtLevel(const c10::List<std::optional<Tensor>>& maybe_tensors, int
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool areAnyBatchedAtLevel(ArrayRef<optional<Tensor>> maybe_tensors, int64_t level) {
|
bool areAnyBatchedAtLevel(ArrayRef<std::optional<Tensor>> maybe_tensors, int64_t level) {
|
||||||
for (const auto& maybe_tensor : maybe_tensors) {
|
for (const auto& maybe_tensor : maybe_tensors) {
|
||||||
if (isBatchedAtLevel(maybe_tensor, level)) {
|
if (isBatchedAtLevel(maybe_tensor, level)) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -47,7 +47,7 @@ TORCH_API bool isBatchedAtLevel(const Tensor& tensor, int64_t level);
|
|||||||
TORCH_API bool isBatchedAtLevel(const std::optional<Tensor>& maybe_tensor, int64_t level);
|
TORCH_API bool isBatchedAtLevel(const std::optional<Tensor>& maybe_tensor, int64_t level);
|
||||||
|
|
||||||
// Convenience helper. Returns true if any tensor is batched at level
|
// Convenience helper. Returns true if any tensor is batched at level
|
||||||
TORCH_API bool areAnyBatchedAtLevel(ArrayRef<optional<Tensor>> maybe_tensors, int64_t level);
|
TORCH_API bool areAnyBatchedAtLevel(ArrayRef<std::optional<Tensor>> maybe_tensors, int64_t level);
|
||||||
|
|
||||||
inline bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) {
|
inline bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) {
|
||||||
if (ivalue.isTensor()) {
|
if (ivalue.isTensor()) {
|
||||||
|
@ -156,7 +156,7 @@
|
|||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
static void check_linalg_norm_dtype(optional<ScalarType> opt_dtype, ScalarType self_dtype, const char* const name) {
|
static void check_linalg_norm_dtype(std::optional<ScalarType> opt_dtype, ScalarType self_dtype, const char* const name) {
|
||||||
if (opt_dtype.has_value()) {
|
if (opt_dtype.has_value()) {
|
||||||
auto dtype = opt_dtype.value();
|
auto dtype = opt_dtype.value();
|
||||||
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype), name, ": dtype should"
|
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype), name, ": dtype should"
|
||||||
|
@ -220,7 +220,7 @@ static inline Device ensure_has_index(Device device) {
|
|||||||
return impl->getDevice();
|
return impl->getDevice();
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline std::optional<Device> ensure_has_index(optional<Device> device) {
|
static inline std::optional<Device> ensure_has_index(std::optional<Device> device) {
|
||||||
if (!device.has_value()) {
|
if (!device.has_value()) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
@ -904,7 +904,7 @@ _flash_attention_forward(
|
|||||||
std::optional<Tensor> out = std::nullopt;
|
std::optional<Tensor> out = std::nullopt;
|
||||||
|
|
||||||
std::optional<Tensor> seqused_k = _seqused_k;
|
std::optional<Tensor> seqused_k = _seqused_k;
|
||||||
c10::optional<at::Tensor> block_table = std::nullopt; // we are not using the block table yet
|
std::optional<at::Tensor> block_table = std::nullopt; // we are not using the block table yet
|
||||||
std::optional<Tensor> alibi_slopes = _alibi_slopes;
|
std::optional<Tensor> alibi_slopes = _alibi_slopes;
|
||||||
|
|
||||||
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
|
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
|
||||||
|
@ -547,7 +547,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|||||||
const at::Tensor &cu_seqlens_q, // b+1
|
const at::Tensor &cu_seqlens_q, // b+1
|
||||||
const at::Tensor &cu_seqlens_k, // b+1
|
const at::Tensor &cu_seqlens_k, // b+1
|
||||||
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||||
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
||||||
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||||
int max_seqlen_q,
|
int max_seqlen_q,
|
||||||
const int max_seqlen_k,
|
const int max_seqlen_k,
|
||||||
|
@ -29,7 +29,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|||||||
const at::Tensor &cu_seqlens_q, // b+1
|
const at::Tensor &cu_seqlens_q, // b+1
|
||||||
const at::Tensor &cu_seqlens_k, // b+1
|
const at::Tensor &cu_seqlens_k, // b+1
|
||||||
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||||
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
||||||
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||||
int max_seqlen_q,
|
int max_seqlen_q,
|
||||||
const int max_seqlen_k,
|
const int max_seqlen_k,
|
||||||
|
@ -135,7 +135,7 @@ class OptionalDeviceGuard {
|
|||||||
|
|
||||||
/// Initialize the guard if a Device is passed; otherwise leave the
|
/// Initialize the guard if a Device is passed; otherwise leave the
|
||||||
/// guard uninitialized.
|
/// guard uninitialized.
|
||||||
explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {}
|
explicit OptionalDeviceGuard(std::optional<Device> device) : guard_(device) {}
|
||||||
|
|
||||||
/// Constructor for testing only.
|
/// Constructor for testing only.
|
||||||
explicit OptionalDeviceGuard(
|
explicit OptionalDeviceGuard(
|
||||||
|
@ -99,7 +99,7 @@ struct OptionalStreamGuard {
|
|||||||
/// Set the current device to the device associated with the passed stream,
|
/// Set the current device to the device associated with the passed stream,
|
||||||
/// and set the current stream on that device to the passed stream,
|
/// and set the current stream on that device to the passed stream,
|
||||||
/// if the passed stream is not nullopt.
|
/// if the passed stream is not nullopt.
|
||||||
explicit OptionalStreamGuard(optional<Stream> stream_opt)
|
explicit OptionalStreamGuard(std::optional<Stream> stream_opt)
|
||||||
: guard_(stream_opt) {}
|
: guard_(stream_opt) {}
|
||||||
|
|
||||||
/// Copy is disallowed
|
/// Copy is disallowed
|
||||||
|
@ -223,7 +223,7 @@ class InlineOptionalDeviceGuard {
|
|||||||
{}
|
{}
|
||||||
|
|
||||||
/// Set the current device to the passed Device, if it is not nullopt.
|
/// Set the current device to the passed Device, if it is not nullopt.
|
||||||
explicit InlineOptionalDeviceGuard(optional<Device> device_opt)
|
explicit InlineOptionalDeviceGuard(std::optional<Device> device_opt)
|
||||||
: guard_() { // See Note [Explicit initialization of optional fields]
|
: guard_() { // See Note [Explicit initialization of optional fields]
|
||||||
if (device_opt.has_value()) {
|
if (device_opt.has_value()) {
|
||||||
guard_.emplace(device_opt.value());
|
guard_.emplace(device_opt.value());
|
||||||
@ -235,7 +235,8 @@ class InlineOptionalDeviceGuard {
|
|||||||
typename U = T,
|
typename U = T,
|
||||||
typename =
|
typename =
|
||||||
typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
|
typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
|
||||||
explicit InlineOptionalDeviceGuard(optional<DeviceIndex> device_index_opt)
|
explicit InlineOptionalDeviceGuard(
|
||||||
|
std::optional<DeviceIndex> device_index_opt)
|
||||||
: guard_() { // See Note [Explicit initialization of optional fields]
|
: guard_() { // See Note [Explicit initialization of optional fields]
|
||||||
if (device_index_opt.has_value()) {
|
if (device_index_opt.has_value()) {
|
||||||
guard_.emplace(device_index_opt.value());
|
guard_.emplace(device_index_opt.value());
|
||||||
|
@ -139,7 +139,8 @@ class InlineOptionalStreamGuard {
|
|||||||
/// Set the current device to the device associated with the passed stream,
|
/// Set the current device to the device associated with the passed stream,
|
||||||
/// and set the current stream on that device to the passed stream,
|
/// and set the current stream on that device to the passed stream,
|
||||||
/// if the passed stream is not nullopt.
|
/// if the passed stream is not nullopt.
|
||||||
explicit InlineOptionalStreamGuard(optional<Stream> stream_opt) : guard_() {
|
explicit InlineOptionalStreamGuard(std::optional<Stream> stream_opt)
|
||||||
|
: guard_() {
|
||||||
if (stream_opt.has_value()) {
|
if (stream_opt.has_value()) {
|
||||||
guard_.emplace(stream_opt.value());
|
guard_.emplace(stream_opt.value());
|
||||||
}
|
}
|
||||||
|
@ -76,12 +76,12 @@ struct OptionalCUDAGuard {
|
|||||||
explicit OptionalCUDAGuard() : guard_() {}
|
explicit OptionalCUDAGuard() : guard_() {}
|
||||||
|
|
||||||
/// Set the current CUDA device to the passed Device, if it is not nullopt.
|
/// Set the current CUDA device to the passed Device, if it is not nullopt.
|
||||||
explicit OptionalCUDAGuard(optional<Device> device_opt)
|
explicit OptionalCUDAGuard(std::optional<Device> device_opt)
|
||||||
: guard_(device_opt) {}
|
: guard_(device_opt) {}
|
||||||
|
|
||||||
/// Set the current CUDA device to the passed device index, if it is not
|
/// Set the current CUDA device to the passed device index, if it is not
|
||||||
/// nullopt
|
/// nullopt
|
||||||
explicit OptionalCUDAGuard(optional<DeviceIndex> device_index_opt)
|
explicit OptionalCUDAGuard(std::optional<DeviceIndex> device_index_opt)
|
||||||
: guard_(device_index_opt) {}
|
: guard_(device_index_opt) {}
|
||||||
|
|
||||||
// Copy is not allowed
|
// Copy is not allowed
|
||||||
@ -215,7 +215,7 @@ struct OptionalCUDAStreamGuard {
|
|||||||
/// Set the current device to the device associated with the passed stream,
|
/// Set the current device to the device associated with the passed stream,
|
||||||
/// and set the current stream on that device to the passed stream,
|
/// and set the current stream on that device to the passed stream,
|
||||||
/// if the passed stream is not nullopt.
|
/// if the passed stream is not nullopt.
|
||||||
explicit OptionalCUDAStreamGuard(optional<Stream> stream_opt)
|
explicit OptionalCUDAStreamGuard(std::optional<Stream> stream_opt)
|
||||||
: guard_(stream_opt) {}
|
: guard_(stream_opt) {}
|
||||||
|
|
||||||
/// Copy is disallowed
|
/// Copy is disallowed
|
||||||
|
@ -37,7 +37,7 @@ constexpr T value_or_else(const std::optional<T>& v, F&& func) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class T, class F>
|
template <class T, class F>
|
||||||
constexpr T value_or_else(optional<T>&& v, F&& func) {
|
constexpr T value_or_else(std::optional<T>&& v, F&& func) {
|
||||||
static_assert(
|
static_assert(
|
||||||
std::is_convertible_v<typename std::invoke_result_t<F>, T>,
|
std::is_convertible_v<typename std::invoke_result_t<F>, T>,
|
||||||
"func parameters must be a callable that returns a type convertible to the value stored in the optional");
|
"func parameters must be a callable that returns a type convertible to the value stored in the optional");
|
||||||
|
@ -121,7 +121,7 @@ class DataLoaderBase {
|
|||||||
/// The finished result of a job.
|
/// The finished result of a job.
|
||||||
struct Result : Sequenced {
|
struct Result : Sequenced {
|
||||||
Result() = default;
|
Result() = default;
|
||||||
Result(optional<Batch>&& b, size_t sqn)
|
Result(std::optional<Batch>&& b, size_t sqn)
|
||||||
: Sequenced(sqn), batch(std::move(b)) {}
|
: Sequenced(sqn), batch(std::move(b)) {}
|
||||||
Result(std::exception_ptr exception, size_t sqn)
|
Result(std::exception_ptr exception, size_t sqn)
|
||||||
: Sequenced(sqn), exception(std::move(exception)) {}
|
: Sequenced(sqn), exception(std::move(exception)) {}
|
||||||
@ -166,7 +166,7 @@ class DataLoaderBase {
|
|||||||
/// is still expected.
|
/// is still expected.
|
||||||
std::optional<BatchType> next() {
|
std::optional<BatchType> next() {
|
||||||
if (options_.workers > 0) {
|
if (options_.workers > 0) {
|
||||||
while (optional<Result> result = this->pop_result()) {
|
while (std::optional<Result> result = this->pop_result()) {
|
||||||
if (result->exception) {
|
if (result->exception) {
|
||||||
throw WorkerException(result->exception);
|
throw WorkerException(result->exception);
|
||||||
} else if (result->batch) {
|
} else if (result->batch) {
|
||||||
|
@ -24,10 +24,10 @@ struct DataLoaderOptions {
|
|||||||
|
|
||||||
/// The maximum number of jobs to enqueue for fetching by worker threads.
|
/// The maximum number of jobs to enqueue for fetching by worker threads.
|
||||||
/// Defaults to two times the number of worker threads.
|
/// Defaults to two times the number of worker threads.
|
||||||
TORCH_ARG(optional<size_t>, max_jobs);
|
TORCH_ARG(std::optional<size_t>, max_jobs);
|
||||||
|
|
||||||
/// An optional limit on the time to wait for the next batch.
|
/// An optional limit on the time to wait for the next batch.
|
||||||
TORCH_ARG(optional<std::chrono::milliseconds>, timeout);
|
TORCH_ARG(std::optional<std::chrono::milliseconds>, timeout);
|
||||||
|
|
||||||
/// Whether to enforce ordering of batches when multiple are loaded
|
/// Whether to enforce ordering of batches when multiple are loaded
|
||||||
/// asynchronously by worker threads. Set to `false` for better performance if
|
/// asynchronously by worker threads. Set to `false` for better performance if
|
||||||
|
@ -29,7 +29,7 @@ namespace detail {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
struct is_optional : std::false_type {};
|
struct is_optional : std::false_type {};
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct is_optional<optional<T>> : std::true_type {};
|
struct is_optional<std::optional<T>> : std::true_type {};
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
/// A dataset that can yield data only in batches.
|
/// A dataset that can yield data only in batches.
|
||||||
@ -49,7 +49,8 @@ class BatchDataset {
|
|||||||
/// Returns a batch of data given an index.
|
/// Returns a batch of data given an index.
|
||||||
virtual Batch get_batch(BatchRequest request) = 0;
|
virtual Batch get_batch(BatchRequest request) = 0;
|
||||||
|
|
||||||
/// Returns the size of the dataset, or an empty optional if it is unsized.
|
/// Returns the size of the dataset, or an empty std::optional if it is
|
||||||
|
/// unsized.
|
||||||
virtual std::optional<size_t> size() const = 0;
|
virtual std::optional<size_t> size() const = 0;
|
||||||
|
|
||||||
/// Creates a `MapDataset` that applies the given `transform` to this dataset.
|
/// Creates a `MapDataset` that applies the given `transform` to this dataset.
|
||||||
|
@ -40,7 +40,7 @@ class Queue {
|
|||||||
/// the queue. An optional `timeout` in seconds can be used to limit the time
|
/// the queue. An optional `timeout` in seconds can be used to limit the time
|
||||||
/// spent waiting for an element. If the wait times out, an exception is
|
/// spent waiting for an element. If the wait times out, an exception is
|
||||||
/// raised.
|
/// raised.
|
||||||
T pop(optional<std::chrono::milliseconds> timeout = std::nullopt) {
|
T pop(std::optional<std::chrono::milliseconds> timeout = std::nullopt) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
if (timeout) {
|
if (timeout) {
|
||||||
if (!cv_.wait_for(
|
if (!cv_.wait_for(
|
||||||
|
@ -12,7 +12,7 @@ namespace detail {
|
|||||||
namespace sequencers {
|
namespace sequencers {
|
||||||
namespace detail {
|
namespace detail {
|
||||||
template <typename Result>
|
template <typename Result>
|
||||||
bool buffer_contains_result(const std::vector<optional<Result>>& buffer) {
|
bool buffer_contains_result(const std::vector<std::optional<Result>>& buffer) {
|
||||||
return std::any_of(
|
return std::any_of(
|
||||||
buffer.begin(), buffer.end(), [](const std::optional<Result>& result) {
|
buffer.begin(), buffer.end(), [](const std::optional<Result>& result) {
|
||||||
return result.has_value();
|
return result.has_value();
|
||||||
@ -27,7 +27,7 @@ bool buffer_contains_result(const std::vector<optional<Result>>& buffer) {
|
|||||||
/// buffers results internally to return them in order of their sequence number.
|
/// buffers results internally to return them in order of their sequence number.
|
||||||
template <typename Result>
|
template <typename Result>
|
||||||
struct Sequencer {
|
struct Sequencer {
|
||||||
using ResultProducer = std::function<optional<Result>()>;
|
using ResultProducer = std::function<std::optional<Result>()>;
|
||||||
virtual ~Sequencer() = default;
|
virtual ~Sequencer() = default;
|
||||||
virtual std::optional<Result> next(ResultProducer next_result) = 0;
|
virtual std::optional<Result> next(ResultProducer next_result) = 0;
|
||||||
};
|
};
|
||||||
@ -105,7 +105,7 @@ struct OrderedSequencer : public Sequencer<Result> {
|
|||||||
size_t next_sequence_number_ = 0;
|
size_t next_sequence_number_ = 0;
|
||||||
|
|
||||||
/// A fixed-size buffer (after construction).
|
/// A fixed-size buffer (after construction).
|
||||||
std::vector<optional<Result>> buffer_;
|
std::vector<std::optional<Result>> buffer_;
|
||||||
};
|
};
|
||||||
} // namespace sequencers
|
} // namespace sequencers
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
@ -41,7 +41,7 @@ struct IteratorImpl {
|
|||||||
|
|
||||||
template <typename Batch>
|
template <typename Batch>
|
||||||
struct ValidIterator : public IteratorImpl<Batch> {
|
struct ValidIterator : public IteratorImpl<Batch> {
|
||||||
using BatchProducer = std::function<optional<Batch>()>;
|
using BatchProducer = std::function<std::optional<Batch>()>;
|
||||||
|
|
||||||
explicit ValidIterator(BatchProducer next_batch)
|
explicit ValidIterator(BatchProducer next_batch)
|
||||||
: next_batch_(std::move(next_batch)) {}
|
: next_batch_(std::move(next_batch)) {}
|
||||||
|
@ -29,7 +29,7 @@ class Sampler {
|
|||||||
/// Resets the `Sampler`'s internal state.
|
/// Resets the `Sampler`'s internal state.
|
||||||
/// Typically called before a new epoch.
|
/// Typically called before a new epoch.
|
||||||
/// Optionally, accepts a new size when reseting the sampler.
|
/// Optionally, accepts a new size when reseting the sampler.
|
||||||
virtual void reset(optional<size_t> new_size) = 0;
|
virtual void reset(std::optional<size_t> new_size) = 0;
|
||||||
|
|
||||||
/// Returns the next index if possible, or an empty optional if the
|
/// Returns the next index if possible, or an empty optional if the
|
||||||
/// sampler is exhausted for this epoch.
|
/// sampler is exhausted for this epoch.
|
||||||
|
@ -78,7 +78,7 @@ class TORCH_API DistributedRandomSampler : public DistributedSampler<> {
|
|||||||
bool allow_duplicates = true);
|
bool allow_duplicates = true);
|
||||||
|
|
||||||
/// Resets the `DistributedRandomSampler` to a new set of indices.
|
/// Resets the `DistributedRandomSampler` to a new set of indices.
|
||||||
void reset(optional<size_t> new_size = std::nullopt) override;
|
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
||||||
|
|
||||||
/// Returns the next batch of indices.
|
/// Returns the next batch of indices.
|
||||||
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||||
@ -111,7 +111,7 @@ class TORCH_API DistributedSequentialSampler : public DistributedSampler<> {
|
|||||||
bool allow_duplicates = true);
|
bool allow_duplicates = true);
|
||||||
|
|
||||||
/// Resets the `DistributedSequentialSampler` to a new set of indices.
|
/// Resets the `DistributedSequentialSampler` to a new set of indices.
|
||||||
void reset(optional<size_t> new_size = std::nullopt) override;
|
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
||||||
|
|
||||||
/// Returns the next batch of indices.
|
/// Returns the next batch of indices.
|
||||||
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||||
|
@ -31,7 +31,7 @@ class TORCH_API RandomSampler : public Sampler<> {
|
|||||||
~RandomSampler() override;
|
~RandomSampler() override;
|
||||||
|
|
||||||
/// Resets the `RandomSampler` to a new set of indices.
|
/// Resets the `RandomSampler` to a new set of indices.
|
||||||
void reset(optional<size_t> new_size = std::nullopt) override;
|
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
||||||
|
|
||||||
/// Returns the next batch of indices.
|
/// Returns the next batch of indices.
|
||||||
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||||
|
@ -26,7 +26,7 @@ class TORCH_API SequentialSampler : public Sampler<> {
|
|||||||
explicit SequentialSampler(size_t size);
|
explicit SequentialSampler(size_t size);
|
||||||
|
|
||||||
/// Resets the `SequentialSampler` to zero.
|
/// Resets the `SequentialSampler` to zero.
|
||||||
void reset(optional<size_t> new_size = std::nullopt) override;
|
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
||||||
|
|
||||||
/// Returns the next batch of indices.
|
/// Returns the next batch of indices.
|
||||||
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||||
|
@ -39,7 +39,7 @@ class TORCH_API StreamSampler : public Sampler<BatchSize> {
|
|||||||
explicit StreamSampler(size_t epoch_size);
|
explicit StreamSampler(size_t epoch_size);
|
||||||
|
|
||||||
/// Resets the internal state of the sampler.
|
/// Resets the internal state of the sampler.
|
||||||
void reset(optional<size_t> new_size = std::nullopt) override;
|
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
||||||
|
|
||||||
/// Returns a `BatchSize` object with the number of elements to fetch in the
|
/// Returns a `BatchSize` object with the number of elements to fetch in the
|
||||||
/// next batch. This number is the minimum of the supplied `batch_size` and
|
/// next batch. This number is the minimum of the supplied `batch_size` and
|
||||||
|
@ -136,7 +136,7 @@ class AnyModule {
|
|||||||
|
|
||||||
/// Creates a deep copy of an `AnyModule` if it contains a module, else an
|
/// Creates a deep copy of an `AnyModule` if it contains a module, else an
|
||||||
/// empty `AnyModule` if it is empty.
|
/// empty `AnyModule` if it is empty.
|
||||||
AnyModule clone(optional<Device> device = std::nullopt) const;
|
AnyModule clone(std::optional<Device> device = std::nullopt) const;
|
||||||
|
|
||||||
/// Assigns a module to the `AnyModule` (to circumvent the explicit
|
/// Assigns a module to the `AnyModule` (to circumvent the explicit
|
||||||
/// constructor).
|
/// constructor).
|
||||||
@ -253,7 +253,7 @@ inline AnyModule& AnyModule::operator=(const AnyModule& other) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline AnyModule AnyModule::clone(optional<Device> device) const {
|
inline AnyModule AnyModule::clone(std::optional<Device> device) const {
|
||||||
AnyModule clone;
|
AnyModule clone;
|
||||||
clone.content_ = content_ ? content_->clone_module(device) : nullptr;
|
clone.content_ = content_ ? content_->clone_module(device) : nullptr;
|
||||||
return clone;
|
return clone;
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -105,7 +104,7 @@ Example<> MNIST::get(size_t index) {
|
|||||||
return {images_[index], targets_[index]};
|
return {images_[index], targets_[index]};
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<size_t> MNIST::size() const {
|
std::optional<size_t> MNIST::size() const {
|
||||||
return images_.size(0);
|
return images_.size(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,11 +22,10 @@ DistributedRandomSampler::DistributedRandomSampler(
|
|||||||
end_index_(0),
|
end_index_(0),
|
||||||
sample_index_(0) {
|
sample_index_(0) {
|
||||||
// shuffle first time.
|
// shuffle first time.
|
||||||
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
||||||
reset(size_);
|
reset(size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<std::vector<size_t>> DistributedRandomSampler::next(
|
std::optional<std::vector<size_t>> DistributedRandomSampler::next(
|
||||||
size_t batch_size) {
|
size_t batch_size) {
|
||||||
if (sample_index_ == end_index_) {
|
if (sample_index_ == end_index_) {
|
||||||
return nullopt;
|
return nullopt;
|
||||||
@ -43,7 +42,7 @@ optional<std::vector<size_t>> DistributedRandomSampler::next(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DistributedRandomSampler::reset(optional<size_t> new_size) {
|
void DistributedRandomSampler::reset(std::optional<size_t> new_size) {
|
||||||
size_ = new_size.value_or(size_);
|
size_ = new_size.value_or(size_);
|
||||||
populate_indices();
|
populate_indices();
|
||||||
|
|
||||||
@ -107,7 +106,7 @@ DistributedSequentialSampler::DistributedSequentialSampler(
|
|||||||
populate_indices();
|
populate_indices();
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<std::vector<size_t>> DistributedSequentialSampler::next(
|
std::optional<std::vector<size_t>> DistributedSequentialSampler::next(
|
||||||
size_t batch_size) {
|
size_t batch_size) {
|
||||||
if (sample_index_ == end_index_) {
|
if (sample_index_ == end_index_) {
|
||||||
return nullopt;
|
return nullopt;
|
||||||
@ -129,7 +128,7 @@ optional<std::vector<size_t>> DistributedSequentialSampler::next(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DistributedSequentialSampler::reset(optional<size_t> new_size) {
|
void DistributedSequentialSampler::reset(std::optional<size_t> new_size) {
|
||||||
size_t size = new_size.value_or(size_);
|
size_t size = new_size.value_or(size_);
|
||||||
if (size != size_) {
|
if (size != size_) {
|
||||||
size_ = size;
|
size_ = size;
|
||||||
|
@ -14,7 +14,7 @@ RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
|
|||||||
|
|
||||||
RandomSampler::~RandomSampler() = default;
|
RandomSampler::~RandomSampler() = default;
|
||||||
|
|
||||||
void RandomSampler::reset(optional<size_t> new_size) {
|
void RandomSampler::reset(std::optional<size_t> new_size) {
|
||||||
// This allocates a new chunk of memory every time (just FYI). It should be
|
// This allocates a new chunk of memory every time (just FYI). It should be
|
||||||
// amortized over the entire epoch hopefully.
|
// amortized over the entire epoch hopefully.
|
||||||
const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
|
const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
|
||||||
|
@ -11,14 +11,14 @@ namespace data {
|
|||||||
namespace samplers {
|
namespace samplers {
|
||||||
SequentialSampler::SequentialSampler(size_t size) : size_(size) {}
|
SequentialSampler::SequentialSampler(size_t size) : size_(size) {}
|
||||||
|
|
||||||
void SequentialSampler::reset(optional<size_t> new_size) {
|
void SequentialSampler::reset(std::optional<size_t> new_size) {
|
||||||
if (new_size.has_value()) {
|
if (new_size.has_value()) {
|
||||||
size_ = *new_size;
|
size_ = *new_size;
|
||||||
}
|
}
|
||||||
index_ = 0;
|
index_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<std::vector<size_t>> SequentialSampler::next(size_t batch_size) {
|
std::optional<std::vector<size_t>> SequentialSampler::next(size_t batch_size) {
|
||||||
const auto remaining_indices = size_ - index_;
|
const auto remaining_indices = size_ - index_;
|
||||||
if (remaining_indices == 0) {
|
if (remaining_indices == 0) {
|
||||||
return nullopt;
|
return nullopt;
|
||||||
|
@ -20,14 +20,14 @@ BatchSize::operator size_t() const noexcept {
|
|||||||
|
|
||||||
StreamSampler::StreamSampler(size_t epoch_size) : epoch_size_(epoch_size) {}
|
StreamSampler::StreamSampler(size_t epoch_size) : epoch_size_(epoch_size) {}
|
||||||
|
|
||||||
void StreamSampler::reset(optional<size_t> new_size) {
|
void StreamSampler::reset(std::optional<size_t> new_size) {
|
||||||
if (new_size.has_value()) {
|
if (new_size.has_value()) {
|
||||||
epoch_size_ = *new_size;
|
epoch_size_ = *new_size;
|
||||||
}
|
}
|
||||||
examples_retrieved_so_far_ = 0;
|
examples_retrieved_so_far_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<BatchSize> StreamSampler::next(size_t batch_size) {
|
std::optional<BatchSize> StreamSampler::next(size_t batch_size) {
|
||||||
AT_ASSERT(examples_retrieved_so_far_ <= epoch_size_);
|
AT_ASSERT(examples_retrieved_so_far_ <= epoch_size_);
|
||||||
if (examples_retrieved_so_far_ == epoch_size_) {
|
if (examples_retrieved_so_far_ == epoch_size_) {
|
||||||
return nullopt;
|
return nullopt;
|
||||||
|
@ -258,7 +258,7 @@ PyObject* THPEngine_run_backward(
|
|||||||
for (const auto i : c10::irange(num_tensors)) {
|
for (const auto i : c10::irange(num_tensors)) {
|
||||||
PyObject* _tensor = PyTuple_GET_ITEM(tensors, i);
|
PyObject* _tensor = PyTuple_GET_ITEM(tensors, i);
|
||||||
Edge gradient_edge; // Temporary variable to hold the gradient edge
|
Edge gradient_edge; // Temporary variable to hold the gradient edge
|
||||||
c10::optional<at::Tensor> mb_output;
|
std::optional<at::Tensor> mb_output;
|
||||||
if (THPVariable_Check(_tensor)) {
|
if (THPVariable_Check(_tensor)) {
|
||||||
mb_output = THPVariable_Unpack(_tensor);
|
mb_output = THPVariable_Unpack(_tensor);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
namespace c10d {
|
namespace c10d {
|
||||||
|
|
||||||
// callback function will be given arguments (optional<string> oldValue,
|
// callback function will be given arguments (std::optional<string> oldValue,
|
||||||
// std::optional<string> newValue)
|
// std::optional<string> newValue)
|
||||||
using WatchKeyCallback =
|
using WatchKeyCallback =
|
||||||
std::function<void(std::optional<std::string>, std::optional<std::string>)>;
|
std::function<void(std::optional<std::string>, std::optional<std::string>)>;
|
||||||
|
@ -75,7 +75,7 @@ std::vector<at::Tensor> unpack_tensors(
|
|||||||
unpack_optional_tensor_list_ivalue(ivalue, device, inputs);
|
unpack_optional_tensor_list_ivalue(ivalue, device, inputs);
|
||||||
} else if (
|
} else if (
|
||||||
*ivalue_arg.real_type() ==
|
*ivalue_arg.real_type() ==
|
||||||
*c10::getTypePtr<c10::optional<at::Tensor>>()) {
|
*c10::getTypePtr<std::optional<at::Tensor>>()) {
|
||||||
// ivalue is c10::optional<at::Tensor>
|
// ivalue is c10::optional<at::Tensor>
|
||||||
unpack_optional_tensor_ivalue(ivalue, device, inputs);
|
unpack_optional_tensor_ivalue(ivalue, device, inputs);
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
|
|||||||
|
|
||||||
RandomSampler::~RandomSampler() = default;
|
RandomSampler::~RandomSampler() = default;
|
||||||
|
|
||||||
void RandomSampler::reset(optional<size_t> new_size) {
|
void RandomSampler::reset(std::optional<size_t> new_size) {
|
||||||
// This allocates a new chunk of memory every time (just FYI). It should be
|
// This allocates a new chunk of memory every time (just FYI). It should be
|
||||||
// amortized over the entire epoch hopefully.
|
// amortized over the entire epoch hopefully.
|
||||||
const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
|
const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
|
||||||
@ -22,7 +22,7 @@ void RandomSampler::reset(optional<size_t> new_size) {
|
|||||||
index_ = 0;
|
index_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<std::vector<size_t>> RandomSampler::next(size_t batch_size) {
|
std::optional<std::vector<size_t>> RandomSampler::next(size_t batch_size) {
|
||||||
AT_ASSERT(index_ <= indices_.numel());
|
AT_ASSERT(index_ <= indices_.numel());
|
||||||
const size_t remaining_indices = indices_.numel() - index_;
|
const size_t remaining_indices = indices_.numel() - index_;
|
||||||
if (remaining_indices == 0) {
|
if (remaining_indices == 0) {
|
||||||
|
@ -32,7 +32,7 @@ class TORCH_API RandomSampler : public torch::data::samplers::Sampler<> {
|
|||||||
~RandomSampler() override;
|
~RandomSampler() override;
|
||||||
|
|
||||||
/// Resets the `RandomSampler` to a new set of indices.
|
/// Resets the `RandomSampler` to a new set of indices.
|
||||||
void reset(optional<size_t> new_size = std::nullopt) override;
|
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
||||||
|
|
||||||
/// Returns the next batch of indices.
|
/// Returns the next batch of indices.
|
||||||
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||||
|
@ -10,7 +10,7 @@ namespace jit {
|
|||||||
namespace mobile {
|
namespace mobile {
|
||||||
SequentialSampler::SequentialSampler(size_t size) : size_(size) {}
|
SequentialSampler::SequentialSampler(size_t size) : size_(size) {}
|
||||||
|
|
||||||
void SequentialSampler::reset(optional<size_t> new_size) {
|
void SequentialSampler::reset(std::optional<size_t> new_size) {
|
||||||
if (new_size.has_value()) {
|
if (new_size.has_value()) {
|
||||||
size_ = *new_size;
|
size_ = *new_size;
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ class TORCH_API SequentialSampler : public torch::data::samplers::Sampler<> {
|
|||||||
explicit SequentialSampler(size_t size);
|
explicit SequentialSampler(size_t size);
|
||||||
|
|
||||||
/// Resets the `SequentialSampler` to zero.
|
/// Resets the `SequentialSampler` to zero.
|
||||||
void reset(optional<size_t> new_size = std::nullopt) override;
|
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
||||||
|
|
||||||
/// Returns the next batch of indices.
|
/// Returns the next batch of indices.
|
||||||
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
||||||
|
@ -16,7 +16,7 @@ struct DebugInfo {
|
|||||||
auto L = parseHeader(offset);
|
auto L = parseHeader(offset);
|
||||||
parseCompileUnit(L);
|
parseCompileUnit(L);
|
||||||
}
|
}
|
||||||
unwind::optional<uint64_t> lineNumberProgramOffset() {
|
std::optional<uint64_t> lineNumberProgramOffset() {
|
||||||
return line_number_program_offset_;
|
return line_number_program_offset_;
|
||||||
}
|
}
|
||||||
uint64_t nextOffset() {
|
uint64_t nextOffset() {
|
||||||
|
@ -88,9 +88,7 @@ enum {
|
|||||||
DW_RLE_start_length = 0x7
|
DW_RLE_start_length = 0x7
|
||||||
};
|
};
|
||||||
|
|
||||||
static torch::unwind::optional<size_t> formSize(
|
static std::optional<size_t> formSize(uint64_t form, uint8_t sec_offset_size) {
|
||||||
uint64_t form,
|
|
||||||
uint8_t sec_offset_size) {
|
|
||||||
switch (form) {
|
switch (form) {
|
||||||
case DW_FORM_addr:
|
case DW_FORM_addr:
|
||||||
return sizeof(void*);
|
return sizeof(void*);
|
||||||
|
@ -147,7 +147,7 @@ struct LineNumberProgram {
|
|||||||
uint32_t file = 1;
|
uint32_t file = 1;
|
||||||
int64_t line = 1;
|
int64_t line = 1;
|
||||||
};
|
};
|
||||||
unwind::optional<Entry> find(uint64_t address) {
|
std::optional<Entry> find(uint64_t address) {
|
||||||
auto e = program_index_.find(address);
|
auto e = program_index_.find(address);
|
||||||
if (!e) {
|
if (!e) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <unordered_map>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch::unwind {
|
namespace torch::unwind {
|
||||||
@ -14,7 +13,7 @@ struct RangeTable {
|
|||||||
addresses_.push_back(0);
|
addresses_.push_back(0);
|
||||||
payloads_.emplace_back(std::nullopt);
|
payloads_.emplace_back(std::nullopt);
|
||||||
}
|
}
|
||||||
void add(uint64_t address, unwind::optional<T> payload, bool sorted) {
|
void add(uint64_t address, std::optional<T> payload, bool sorted) {
|
||||||
if (addresses_.back() > address) {
|
if (addresses_.back() > address) {
|
||||||
UNWIND_CHECK(!sorted, "expected addresses to be sorted");
|
UNWIND_CHECK(!sorted, "expected addresses to be sorted");
|
||||||
sorted_ = false;
|
sorted_ = false;
|
||||||
@ -22,7 +21,7 @@ struct RangeTable {
|
|||||||
addresses_.push_back(address);
|
addresses_.push_back(address);
|
||||||
payloads_.emplace_back(std::move(payload));
|
payloads_.emplace_back(std::move(payload));
|
||||||
}
|
}
|
||||||
unwind::optional<T> find(uint64_t address) {
|
std::optional<T> find(uint64_t address) {
|
||||||
maybeSort();
|
maybeSort();
|
||||||
auto it = std::upper_bound(addresses_.begin(), addresses_.end(), address);
|
auto it = std::upper_bound(addresses_.begin(), addresses_.end(), address);
|
||||||
return payloads_.at(it - addresses_.begin() - 1);
|
return payloads_.at(it - addresses_.begin() - 1);
|
||||||
@ -56,7 +55,7 @@ struct RangeTable {
|
|||||||
bool(payloads_[a]) < bool(payloads_[b]));
|
bool(payloads_[a]) < bool(payloads_[b]));
|
||||||
});
|
});
|
||||||
std::vector<uint64_t> addresses;
|
std::vector<uint64_t> addresses;
|
||||||
std::vector<unwind::optional<T>> payloads;
|
std::vector<std::optional<T>> payloads;
|
||||||
addresses.reserve(addresses_.size());
|
addresses.reserve(addresses_.size());
|
||||||
payloads.reserve(addresses_.size());
|
payloads.reserve(addresses_.size());
|
||||||
for (auto i : indices) {
|
for (auto i : indices) {
|
||||||
@ -69,6 +68,6 @@ struct RangeTable {
|
|||||||
}
|
}
|
||||||
bool sorted_ = true;
|
bool sorted_ = true;
|
||||||
std::vector<uint64_t> addresses_;
|
std::vector<uint64_t> addresses_;
|
||||||
std::vector<unwind::optional<T>> payloads_;
|
std::vector<std::optional<T>> payloads_;
|
||||||
};
|
};
|
||||||
} // namespace torch::unwind
|
} // namespace torch::unwind
|
||||||
|
@ -77,7 +77,7 @@ struct Sections {
|
|||||||
return is_64bit ? data.read<uint64_t>() : data.read<uint32_t>();
|
return is_64bit ? data.read<uint64_t>() : data.read<uint32_t>();
|
||||||
}
|
}
|
||||||
|
|
||||||
unwind::optional<uint64_t> findDebugInfoOffset(uint64_t address) {
|
std::optional<uint64_t> findDebugInfoOffset(uint64_t address) {
|
||||||
return debug_info_offsets_.find(address);
|
return debug_info_offsets_.find(address);
|
||||||
}
|
}
|
||||||
size_t compilationUnitCount() {
|
size_t compilationUnitCount() {
|
||||||
|
@ -26,6 +26,4 @@ struct UnwindError : public std::runtime_error {
|
|||||||
// #define PRINT_LINE_TABLE(...) LOG_INFO(__VA_ARGS__)
|
// #define PRINT_LINE_TABLE(...) LOG_INFO(__VA_ARGS__)
|
||||||
#define PRINT_LINE_TABLE(...)
|
#define PRINT_LINE_TABLE(...)
|
||||||
|
|
||||||
using std::optional; // NOLINT
|
|
||||||
|
|
||||||
} // namespace torch::unwind
|
} // namespace torch::unwind
|
||||||
|
@ -126,7 +126,7 @@ class Unboxing:
|
|||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
f"""
|
f"""
|
||||||
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
|
auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
),
|
),
|
||||||
@ -146,7 +146,7 @@ class Unboxing:
|
|||||||
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
|
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
|
||||||
code.extend(
|
code.extend(
|
||||||
f"""
|
f"""
|
||||||
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList();
|
auto {out_name} = {arg_name}.toTensorList();
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
@ -156,7 +156,7 @@ class Unboxing:
|
|||||||
):
|
):
|
||||||
code.extend(
|
code.extend(
|
||||||
f"""
|
f"""
|
||||||
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList();
|
auto {out_name} = {arg_name}.toIntList();
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
@ -164,7 +164,7 @@ class Unboxing:
|
|||||||
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
|
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
|
||||||
code.extend(
|
code.extend(
|
||||||
f"""
|
f"""
|
||||||
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList();
|
auto {out_name} = {arg_name}.toDoubleList();
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
@ -173,7 +173,7 @@ class Unboxing:
|
|||||||
# handle list type with size, e.g., bool[4]
|
# handle list type with size, e.g., bool[4]
|
||||||
code.extend(
|
code.extend(
|
||||||
f"""
|
f"""
|
||||||
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList();
|
auto {out_name} = {arg_name}.toBoolList();
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
)
|
)
|
||||||
@ -194,7 +194,7 @@ for (auto {elem_name}: {in_name}) {{
|
|||||||
{out_name}.push_back({elem_name});
|
{out_name}.push_back({elem_name});
|
||||||
}}
|
}}
|
||||||
#else
|
#else
|
||||||
torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> {out_name} = {arg_name}.toListOptionalTensor();
|
auto {out_name} = {arg_name}.toListOptionalTensor();
|
||||||
#endif
|
#endif
|
||||||
""".split(
|
""".split(
|
||||||
"\n"
|
"\n"
|
||||||
|
@ -42,8 +42,8 @@ def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
|
|||||||
|
|
||||||
def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
|
def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
|
||||||
result = f"""\
|
result = f"""\
|
||||||
optional<Tensor> {name}_value;
|
std::optional<Tensor> {name}_value;
|
||||||
optional<int64_t> {name}_bdim;
|
std::optional<int64_t> {name}_bdim;
|
||||||
if ({name}) {{
|
if ({name}) {{
|
||||||
std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var});
|
std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var});
|
||||||
}}"""
|
}}"""
|
||||||
|
Reference in New Issue
Block a user