Compare commits

...

4 Commits

Author SHA1 Message Date
69e216306d Update
[ghstack-poisoned]
2025-10-30 19:56:13 -07:00
384c3d7f71 Update (base update)
[ghstack-poisoned]
2025-10-30 19:56:13 -07:00
5f909fca76 Update
[ghstack-poisoned]
2025-10-30 19:44:19 -07:00
26349a5ab3 Update (base update)
[ghstack-poisoned]
2025-10-30 19:44:19 -07:00
15 changed files with 78 additions and 25 deletions

View File

@ -139,7 +139,7 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
}
);
} else {
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
AT_DISPATCH_ALL_TYPES_AND(kHalf, dtype, "smooth_l1_backward_cpu_out", [&] {
auto norm_val = norm.to<scalar_t>();
scalar_t beta_val(beta);
auto norm_val_vec = Vectorized<scalar_t>(norm_val);

View File

@ -119,6 +119,10 @@ Tensor& relu_mps_(Tensor& self) {
TORCH_IMPL_FUNC(log_softmax_mps_out)
(const Tensor& self, const int64_t dim, const bool half_to_float, const Tensor& out) {
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()),
"log_softmax for complex is not supported for MPS");
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kBool, "log_softmax for bool is not supported for MPS");
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
@ -162,6 +166,10 @@ TORCH_IMPL_FUNC(log_softmax_mps_out)
TORCH_IMPL_FUNC(log_softmax_backward_mps_out)
(const Tensor& grad_output, const Tensor& output, int64_t dim, ScalarType input_dtype, const Tensor& out) {
TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(grad_output.scalar_type()),
"log_softmax for complex is not supported for MPS");
TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kBool, "log_softmax for bool is not supported for MPS");
using namespace mps;
using CachedGraph = MPSUnaryGradCachedGraph;
@ -202,6 +210,7 @@ TORCH_IMPL_FUNC(log_softmax_backward_mps_out)
}
std::tuple<Tensor&, Tensor&> log_sigmoid_forward_out_mps(const Tensor& self, Tensor& output, Tensor& buffer) {
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
// NOTE: buffer is only used by CPU dispatch, we just ignore it here
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
@ -704,6 +713,8 @@ static void elu_variants_out_mps(const Tensor& self,
const Scalar& input_scale,
const Tensor& result,
std::string func_name) {
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "ELU for complex is not supported for MPS");
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
@ -793,6 +804,9 @@ TORCH_IMPL_FUNC(elu_backward_out_mps)
bool is_result,
const Tensor& self_or_result,
const Tensor& grad_input) {
TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(grad_output.scalar_type()),
"ELU for complex is not supported for MPS");
using namespace mps;
using CachedGraph = MPSUnaryGradCachedGraph;
auto gradMemFormat = grad_input.suggest_memory_format();
@ -896,6 +910,7 @@ TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor
if (output.numel() == 0)
return;
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
// this can't pass anyway because a 0-dimensional tensor has "size" 1, which
// can't be evenly halved, but give a nicer error message here.
TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors");
@ -1009,6 +1024,7 @@ TORCH_IMPL_FUNC(softplus_out_mps)
(const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result) {
using namespace mps;
TORCH_CHECK(self.is_mps());
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Not implemented for long");
// Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
// \log(1 + \exp(\beta * x))` element-wise.
// For numerical stability the implementation reverts to the linear function
@ -1159,6 +1175,8 @@ TORCH_IMPL_FUNC(mish_out_mps)
(const Tensor& self, const Tensor& result) {
using namespace mps;
TORCH_CHECK(self.is_mps());
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS");
if (result.numel() == 0)
return;
@ -1207,6 +1225,8 @@ TORCH_IMPL_FUNC(mish_out_mps)
Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) {
using namespace mps;
TORCH_CHECK(self.is_mps());
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS");
Tensor grad_input = at::empty_like(self, self.suggest_memory_format());
if (grad_input.numel() == 0)
@ -1396,6 +1416,7 @@ TORCH_IMPL_FUNC(silu_out_mps)(const Tensor& self, const Tensor& result) {
using CachedGraph = MPSUnaryCachedGraph;
TORCH_CHECK(self.is_mps());
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
// Empty output
if (result.numel() == 0)

View File

@ -278,6 +278,9 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
}
TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()),
"logaddexp for complex is not supported for MPS");
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor =
@ -290,6 +293,9 @@ TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, cons
}
TORCH_IMPL_FUNC(logaddexp2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()),
"logaddexp2 for complex is not supported for MPS");
mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor =

View File

@ -80,6 +80,7 @@ static void grid_sampler_2d_mps_impl(Tensor& output,
MPSGraphTensor* outputTensor_ = nil;
};
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), "foobar");
@autoreleasepool {
std::string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" +
std::to_string(interpolation_mode) + ":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners);

View File

@ -232,7 +232,7 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
bool check_errors) {
using namespace mps;
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat,
"linalg.lu_factor(): MPS doesn't support complex types.");
TORCH_CHECK(pivot, "linalg.lu_factor(): MPS doesn't allow pivot == False.");
@ -356,8 +356,7 @@ static void linalg_solve_out_mps_impl(const Tensor& A,
const Tensor& info) {
using namespace mps;
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
"linalg.lu_factor(): MPS doesn't support complex types.");
TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat, "linalg.lu_factor(): MPS only supports floats.");
Tensor A_t, B_t;
// If 'left' is false, reinterpret the problem so that Ax = B becomes A^T ⋅ (x^T) = B^T
// Then we solve the normal "left" case on the transposed matrices and transpose x finally to get the output
@ -1050,7 +1049,8 @@ static Tensor& linalg_solve_triangular_mps_impl(const Tensor& A,
using namespace mps;
checkInputsSolver(A, B, left, "linalg.solve_triangular");
TORCH_CHECK(!A.is_complex() && !B.is_complex(), "linalg.solve.triangular(); Not supported for complex yet!");
TORCH_CHECK(A.scalar_type() == kFloat && B.scalar_type() == kFloat,
"linalg.solve.triangular(); Only float is supported!");
Tensor A_t, B_t;
std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/ nullptr);
at::native::resize_output(out, B_t.sizes());

View File

@ -370,7 +370,7 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
onValue:-1.0f
offValue:0.0f
name:nil];
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, inputTensor.dataType);
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, [inputTensor dataType]);
if (isWeightsArrayValid) {
oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
secondaryTensor:weightTensor
@ -421,6 +421,8 @@ static void nllnd_loss_forward_impl(Tensor& output,
int64_t reduction,
int64_t ignore_index,
bool is2D) {
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(output.scalar_type()),
"nlld_loss for complex is not supported for MPS");
std::vector<long long> reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end());
reshapedTarget.push_back(1);
@ -705,6 +707,7 @@ static void smooth_l1_loss_template(const Tensor& input,
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.");
TORCH_CHECK(input.is_mps());
TORCH_CHECK(target.is_mps());
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
if ((input.numel() == 0) || (target.numel() == 0)) {
reduction == Reduction::Mean ? output.fill_(std::numeric_limits<float>::quiet_NaN()) : output.zero_();
return;
@ -771,7 +774,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:MPSDataTypeFloat32];
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:[inputTensor dataType]];
// xn - yn
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
secondaryTensor:targetTensor
@ -797,7 +800,8 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
name:@"lossTensor"];
MPSGraphTensor* outputTensor = lossTensor;
if (reduction == Reduction::Mean) {
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() dataType:MPSDataTypeFloat32];
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel()
dataType:[lossTensor dataType]];
outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil];
}
MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor
@ -827,6 +831,9 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
Tensor& huber_loss_out_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& output) {
std::string op_name = __func__;
using namespace mps;
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()),
"huber_loss for complex is not supported for MPS");
TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.")
TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes")
TORCH_CHECK(output.is_mps());

View File

@ -84,6 +84,9 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out(const Tensor& self,
Tensor& output,
Tensor& save_mean,
Tensor& save_var) {
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Long batch norm is not supported with MPS");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()),
"Batch norm for complex is not supported for MPS");
using namespace at::native::mps;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
@ -918,6 +921,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(const Tensor& input,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const int axis = input_ndim - normalized_ndim;
MPSStream* stream = getCurrentMPSStream();
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "Not implemented for long on MPS");
@autoreleasepool {
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
// which kernel variant to use based on the normalized axis N size

View File

@ -595,6 +595,7 @@ static void avg_pool2d_template(const Tensor& input,
bool count_include_pad,
const std::optional<int64_t> divisor_override,
const std::string& op_name) {
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), "Not implemented for complex");
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
const bool is_backward_pass = grad_output.defined();
const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0;
@ -913,6 +914,8 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)
bool ceil_mode,
const Tensor& output,
const Tensor& indices) {
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()),
"Max pooling for complex is not supported for MPS");
bool use_graph = use_graph_for_max_pool2d(kernel_size, stride);
if (use_graph) {
auto indices_memory_format = indices.suggest_memory_format();
@ -965,6 +968,8 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)
bool ceil_mode,
const Tensor& indices,
const Tensor& grad_input) {
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()),
"Max pooling for complex is not supported for MPS");
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor

View File

@ -269,17 +269,22 @@ static void reduction_out_mps(const Tensor& input_t,
name:nil];
castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor axes:@[ @0, @1 ] name:nil];
} else if (reduction_type == MPSReductionType::NANSUM) {
// Create a 0 tensor of the same shape as inputTensor
MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType];
// Find NaNs
MPSGraphTensor* nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil];
// Replace NaNs with 0
MPSGraphTensor* nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask
truePredicateTensor:zeros
falsePredicateTensor:castInputTensor
name:nil];
// Sum
castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil];
// Integral types cannot contain NaN, so just do regular sum
if (([castInputTensor dataType] & MPSDataTypeFloatBit) == 0) {
castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor axes:wrappedAxes name:nil];
} else {
// Create a 0 tensor of the same shape as inputTensor
auto zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType];
// Find NaNs
auto nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil];
// Replace NaNs with 0
auto nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask
truePredicateTensor:zeros
falsePredicateTensor:castInputTensor
name:nil];
// Sum
castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil];
}
}
MPSGraphTensor* outputTensor = castOutputTensor;
@ -442,6 +447,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
const std::optional<Scalar>& correction,
bool keepdim,
StdVarType stdVarType) {
TORCH_CHECK_NOT_IMPLEMENTED(input_t.scalar_type() != kLong, "Not implemented for MPS");
using CachedGraph = MPSUnaryCachedGraph;
IntArrayRef input_shape = input_t.sizes();
@ -1028,15 +1034,18 @@ TORCH_IMPL_FUNC(prod_out_mps)
}
TORCH_IMPL_FUNC(amax_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) {
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amax is not defined for complex types");
reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps");
}
TORCH_IMPL_FUNC(amin_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) {
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amin is not defined for complex types");
reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps");
}
TORCH_IMPL_FUNC(aminmax_out_mps)
(const Tensor& input_t, std::optional<int64_t> dim_opt, bool keepdim, const Tensor& min_t, const Tensor& max_t) {
TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "aminmax is not defined for complex types");
reduction_out_mps(input_t,
dim_opt.has_value() ? OptionalIntArrayRef({*dim_opt}) : std::nullopt,
keepdim,

View File

@ -39,6 +39,7 @@ static void get_shapes(MPSShape* input_shape_readonly,
TORCH_IMPL_FUNC(softmax_mps_out)
(const Tensor& input_, const int64_t dim, const bool half_to_float, const Tensor& output) {
TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS");
TORCH_CHECK(c10::isFloatingType(input_.scalar_type()), "softmax only supported for floating types");
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
if (input_.numel() == 0) {

View File

@ -31,6 +31,7 @@ void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& v
indices.copy_(values.toType(at::ScalarType::Long));
return;
}
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "kthvalue is not implemented for complex types");
// issue #154890, raising error to prevent crash within MPSGraph until
// workaround is implemented.
TORCH_CHECK(self.dim() - dim <= 4, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890");

View File

@ -18,6 +18,8 @@ static Tensor& bincount_mps_impl(const Tensor& self, const Tensor& weights, Tens
MPSStream* stream = getCurrentMPSStream();
bool has_weights = weights.defined();
TORCH_CHECK(self.scalar_type() != kBool);
@autoreleasepool {
std::string key = "bincount_mps_impl" + getTensorsStringKey({self, weights});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {

View File

@ -2992,7 +2992,7 @@ class TestFakeTensor(TestCase):
self.assertEqual(strided_result.layout, torch.strided)
instantiate_device_type_tests(TestCommon, globals(), allow_xpu=True)
instantiate_device_type_tests(TestCommon, globals(), allow_xpu=True, allow_mps=True)
instantiate_device_type_tests(TestCompositeCompliance, globals())
instantiate_device_type_tests(TestMathBits, globals())
instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")

View File

@ -20340,9 +20340,7 @@ op_db: list[OpInfo] = [
ref=reference_smooth_l1_loss,
sample_inputs_func=sample_inputs_smooth_l1_loss,
dtypes=floating_types_and(torch.float16, torch.bfloat16),
backward_dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
backward_dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,

View File

@ -738,8 +738,6 @@ if torch.backends.mps.is_available():
"equal": [torch.float16, torch.float32],
# 'float' object is not iterable
"item": [torch.float16, torch.float32],
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
"nn.functional.smooth_l1_loss": [torch.float16],
# cpu error: grad requires non-empty inputs
"randn": [torch.float16, torch.float32],
"signal.windows.bartlett": [torch.float32],