[functorch] Roll back clamp_min / clamp_max change

It doesn't compile under clang. Going to investigate this later
This commit is contained in:
Richard Zou
2021-05-12 09:38:24 -04:00
committed by Jon Janzen
parent 256099c1a6
commit bb701ef563

View File

@ -127,14 +127,15 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>));
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
VMAP_SUPPORT("clamp_min.Tensor",
SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, static_cast<TensorTensorType>(&at::clamp_min)>));
VMAP_SUPPORT("clamp_min",
SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, static_cast<TensorScalarType>(&at::clamp_min), const Scalar&>));
VMAP_SUPPORT("clamp_max.Tensor",
SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, static_cast<TensorTensorType>(&at::clamp_max)>));
VMAP_SUPPORT("clamp_max",
SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, static_cast<TensorScalarType>(&at::clamp_max), const Scalar&>));
// TODO: the following fails to compile using clang, figure out why...
// VMAP_SUPPORT("clamp_min.Tensor",
// SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, static_cast<TensorTensorType>(&at::clamp_min)>));
// VMAP_SUPPORT("clamp_min",
// SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, static_cast<TensorScalarType>(&at::clamp_min), const Scalar&>));
// VMAP_SUPPORT("clamp_max.Tensor",
// SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, static_cast<TensorTensorType>(&at::clamp_max)>));
// VMAP_SUPPORT("clamp_max",
// SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, static_cast<TensorScalarType>(&at::clamp_max), const Scalar&>));
#define COMPARISON_POINTWISE(op) \