[functorch] Fix clamp_min / clamp_max

This commit is contained in:
Richard Zou
2021-05-12 09:43:24 -04:00
committed by Jon Janzen
parent bb701ef563
commit 38eb311988

View File

@ -127,15 +127,14 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>));
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
// TODO: the following fails to compile using clang, figure out why...
// VMAP_SUPPORT("clamp_min.Tensor",
// SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, static_cast<TensorTensorType>(&at::clamp_min)>));
// VMAP_SUPPORT("clamp_min",
// SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, static_cast<TensorScalarType>(&at::clamp_min), const Scalar&>));
// VMAP_SUPPORT("clamp_max.Tensor",
// SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, static_cast<TensorTensorType>(&at::clamp_max)>));
// VMAP_SUPPORT("clamp_max",
// SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, static_cast<TensorScalarType>(&at::clamp_max), const Scalar&>));
VMAP_SUPPORT("clamp_min.Tensor",
SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, &at::clamp_min>));
VMAP_SUPPORT("clamp_min",
SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::clamp_min, const Scalar&>));
VMAP_SUPPORT("clamp_max.Tensor",
SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, &at::clamp_max>));
VMAP_SUPPORT("clamp_max",
SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::clamp_max, const Scalar&>));
#define COMPARISON_POINTWISE(op) \