mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] ported random registrations to boxedfallback
This commit is contained in:
@ -16,17 +16,7 @@
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
void unsupportedRandomOp2(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
TORCH_CHECK(false, "vmap: We do not yet support calling random operations inside of vmap. ",
|
||||
"Please perform random operations outside of vmap as a workaround");
|
||||
}
|
||||
|
||||
template <typename... Args> Tensor unsupportedRandomOp(Args... args) {
|
||||
TORCH_CHECK(false, "vmap: We do not yet support calling random operations inside of vmap. ",
|
||||
"Please perform random operations outside of vmap as a workaround");
|
||||
}
|
||||
|
||||
template <typename... Args> Tensor& unsupportedRandomOp_(Args... args) {
|
||||
void unsupportedRandomOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
TORCH_CHECK(false, "vmap: We do not yet support calling random operations inside of vmap. ",
|
||||
"Please perform random operations outside of vmap as a workaround");
|
||||
}
|
||||
@ -35,14 +25,14 @@ TORCH_LIBRARY_IMPL(_, FuncTorchVmapMode, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
}
|
||||
|
||||
#define TENSOROPTIONSPARAMS c10::optional<c10::ScalarType> dtype, c10::optional<c10::Layout> layout, c10::optional<c10::Device> device, c10::optional<bool> pin_memory
|
||||
|
||||
#define UNSUPPORTED_RANDOM(op) \
|
||||
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp2>());
|
||||
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp>());
|
||||
|
||||
#define UNSUPPORTED_RANDOM2(op, overload) \
|
||||
m.impl(#op"."#overload, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp2>());
|
||||
m.impl(#op"."#overload, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp>());
|
||||
|
||||
#define TENSOROPTIONSPARAMS c10::optional<c10::ScalarType> dtype, c10::optional<c10::Layout> layout, c10::optional<c10::Device> device, c10::optional<bool> pin_memory
|
||||
#define TENSOROPTIONSARGS dtype, layout, device, pin_memory
|
||||
|
||||
Tensor randn_mbatching_rule(IntArrayRef shape, TENSOROPTIONSPARAMS) {
|
||||
@ -54,82 +44,78 @@ Tensor randn_mbatching_rule(IntArrayRef shape, TENSOROPTIONSPARAMS) {
|
||||
return makeBatched(at::randn(shapeVec, TENSOROPTIONSARGS), 0, maybe_layer->layerId());
|
||||
}
|
||||
|
||||
#undef TENSOROPTIONSARGS
|
||||
#undef TENSOROPTIONSPARAMS
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
|
||||
// NB: I'd really like to register a special kernel like
|
||||
// CppFunction::makeNamedNotSupported() to avoid listing out the types of everything.
|
||||
// However, registering e.g. CppFunction::makeNamedNotSupported() as an implementation
|
||||
// only works for operators that support boxing.
|
||||
#define TENSOROPTIONS c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>
|
||||
// random operations (out-of-place)
|
||||
UNSUPPORTED_RANDOM(bernoulli);
|
||||
UNSUPPORTED_RANDOM2(bernoulli, out);
|
||||
UNSUPPORTED_RANDOM2(bernoulli, p);
|
||||
UNSUPPORTED_RANDOM2(bernoulli_, Tensor);
|
||||
m.impl("bernoulli_.float", unsupportedRandomOp_<Tensor&, double, optional<Generator>>);
|
||||
UNSUPPORTED_RANDOM(bernoulli_.float);
|
||||
|
||||
m.impl("cauchy_", unsupportedRandomOp_<Tensor&, double, double, optional<Generator>>);
|
||||
m.impl("exponential_", unsupportedRandomOp_<Tensor&, double, optional<Generator>>);
|
||||
m.impl("geometric_", unsupportedRandomOp_<Tensor&, double, optional<Generator>>);
|
||||
m.impl("log_normal_", unsupportedRandomOp_<Tensor&, double, double, optional<Generator>>);
|
||||
m.impl("multinomial", unsupportedRandomOp<const Tensor&, int64_t, bool, optional<Generator>>);
|
||||
m.impl("multinomial.out", unsupportedRandomOp_<const Tensor&, int64_t, bool, optional<Generator>, Tensor&>);
|
||||
UNSUPPORTED_RANDOM(cauchy_);
|
||||
UNSUPPORTED_RANDOM(exponential_);
|
||||
UNSUPPORTED_RANDOM(geometric_);
|
||||
UNSUPPORTED_RANDOM(log_normal_);
|
||||
UNSUPPORTED_RANDOM(multinomial);
|
||||
UNSUPPORTED_RANDOM2(multinomial, out);
|
||||
|
||||
m.impl("normal.Tensor_float", unsupportedRandomOp<const Tensor&, double, optional<Generator>>);
|
||||
m.impl("normal.Tensor_float_out", unsupportedRandomOp_<const Tensor&, double, optional<Generator>, Tensor&>);
|
||||
m.impl("normal.float_Tensor_out", unsupportedRandomOp_<double, const Tensor&, optional<Generator>, Tensor&>);
|
||||
m.impl("normal.float_Tensor", unsupportedRandomOp<double, const Tensor&, optional<Generator>>);
|
||||
m.impl("normal.Tensor_Tensor", unsupportedRandomOp<const Tensor&, const Tensor&, optional<Generator>>);
|
||||
m.impl("normal.Tensor_Tensor_out", unsupportedRandomOp_<const Tensor&, const Tensor&, optional<Generator>, Tensor&>);
|
||||
m.impl("normal.float_float", unsupportedRandomOp<double, double, IntArrayRef, optional<Generator>, TENSOROPTIONS>);
|
||||
m.impl("normal.float_float_out", unsupportedRandomOp_<double, double, IntArrayRef, optional<Generator>, Tensor&>);
|
||||
m.impl("normal_", unsupportedRandomOp_<Tensor&, double, double, optional<Generator>>);
|
||||
UNSUPPORTED_RANDOM2(normal, Tensor_float);
|
||||
UNSUPPORTED_RANDOM2(normal, Tensor_float_out);
|
||||
UNSUPPORTED_RANDOM2(normal, float_Tensor_out);
|
||||
UNSUPPORTED_RANDOM2(normal, float_Tensor);
|
||||
UNSUPPORTED_RANDOM2(normal, Tensor_Tensor);
|
||||
UNSUPPORTED_RANDOM2(normal, Tensor_Tensor_out);
|
||||
UNSUPPORTED_RANDOM2(normal, float_float);
|
||||
UNSUPPORTED_RANDOM2(normal, float_float_out);
|
||||
UNSUPPORTED_RANDOM(normal_);
|
||||
|
||||
m.impl("poisson", unsupportedRandomOp<const Tensor&, optional<Generator>>);
|
||||
UNSUPPORTED_RANDOM(poisson);
|
||||
|
||||
m.impl("random_.from", unsupportedRandomOp_<Tensor&, int64_t, optional<int64_t>, optional<Generator>>);
|
||||
m.impl("random_.to", unsupportedRandomOp_<Tensor&, int64_t, optional<Generator>>);
|
||||
m.impl("random_", unsupportedRandomOp_<Tensor&, optional<Generator>>);
|
||||
UNSUPPORTED_RANDOM2(random_, from);
|
||||
UNSUPPORTED_RANDOM2(random_, to);
|
||||
UNSUPPORTED_RANDOM(random_);
|
||||
|
||||
// m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>);
|
||||
// m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>);
|
||||
// UNSUPPORTED_RANDOM(rand_like);
|
||||
// UNSUPPORTED_RANDOM(randn_like);
|
||||
|
||||
m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, optional<MemoryFormat>>);
|
||||
m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, optional<MemoryFormat>>);
|
||||
UNSUPPORTED_RANDOM(randint_like);
|
||||
UNSUPPORTED_RANDOM2(randint_like, low_dtype);
|
||||
|
||||
m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
|
||||
m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, optional<Generator>, TENSOROPTIONS>);
|
||||
m.impl("rand.names", unsupportedRandomOp<IntArrayRef, optional<DimnameList>, TENSOROPTIONS>);
|
||||
m.impl("rand.generator_with_names", unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, TENSOROPTIONS>);
|
||||
m.impl("rand.out", unsupportedRandomOp_<IntArrayRef, Tensor&>);
|
||||
m.impl("rand.generator_out", unsupportedRandomOp_<IntArrayRef, optional<Generator>, Tensor&>);
|
||||
UNSUPPORTED_RANDOM(rand);
|
||||
UNSUPPORTED_RANDOM2(rand, generator);
|
||||
UNSUPPORTED_RANDOM2(rand, names);
|
||||
UNSUPPORTED_RANDOM2(rand, generator_with_names);
|
||||
UNSUPPORTED_RANDOM2(rand, out);
|
||||
UNSUPPORTED_RANDOM2(rand, generator_out);
|
||||
|
||||
// m.impl("randn", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
|
||||
m.impl("randn.generator", unsupportedRandomOp<IntArrayRef, optional<Generator>, TENSOROPTIONS>);
|
||||
m.impl("randn.names", unsupportedRandomOp<IntArrayRef, optional<DimnameList>, TENSOROPTIONS>);
|
||||
m.impl("randn.generator_with_names", unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, TENSOROPTIONS>);
|
||||
m.impl("randn.out", unsupportedRandomOp_<IntArrayRef, Tensor&>);
|
||||
m.impl("randn.generator_out", unsupportedRandomOp_<IntArrayRef, optional<Generator>, Tensor&>);
|
||||
// UNSUPPORTED_RANDOM(randn);
|
||||
UNSUPPORTED_RANDOM2(randn, generator);
|
||||
UNSUPPORTED_RANDOM2(randn, names);
|
||||
UNSUPPORTED_RANDOM2(randn, generator_with_names);
|
||||
UNSUPPORTED_RANDOM2(randn, out);
|
||||
UNSUPPORTED_RANDOM2(randn, generator_out);
|
||||
|
||||
m.impl("randperm", unsupportedRandomOp<int64_t, TENSOROPTIONS>);
|
||||
m.impl("randperm.generator", unsupportedRandomOp<int64_t, optional<Generator>, TENSOROPTIONS>);
|
||||
m.impl("randperm.out", unsupportedRandomOp_<int64_t, Tensor&>);
|
||||
m.impl("randperm.generator_out", unsupportedRandomOp_<int64_t, optional<Generator>, Tensor&>);
|
||||
UNSUPPORTED_RANDOM(randperm);
|
||||
UNSUPPORTED_RANDOM2(randperm, generator);
|
||||
UNSUPPORTED_RANDOM2(randperm, out);
|
||||
UNSUPPORTED_RANDOM2(randperm, generator_out);
|
||||
|
||||
m.impl("randint", unsupportedRandomOp<int64_t, IntArrayRef, TENSOROPTIONS>);
|
||||
m.impl("randint.generator", unsupportedRandomOp<int64_t, IntArrayRef, optional<Generator>, TENSOROPTIONS>);
|
||||
m.impl("randint.low", unsupportedRandomOp<int64_t, int64_t, IntArrayRef, TENSOROPTIONS>);
|
||||
m.impl("randint.low_generator", unsupportedRandomOp<int64_t, int64_t, IntArrayRef, optional<Generator>, TENSOROPTIONS>);
|
||||
m.impl("randint.out", unsupportedRandomOp_<int64_t, IntArrayRef, Tensor&>);
|
||||
m.impl("randint.generator_out", unsupportedRandomOp_<int64_t, IntArrayRef, optional<Generator>, Tensor&>);
|
||||
m.impl("randint.low_out", unsupportedRandomOp_<int64_t, int64_t, IntArrayRef, Tensor&>);
|
||||
m.impl("randint.low_generator_out", unsupportedRandomOp_<int64_t, int64_t, IntArrayRef, optional<Generator>, Tensor&>);
|
||||
UNSUPPORTED_RANDOM(randint);
|
||||
UNSUPPORTED_RANDOM2(randint, generator);
|
||||
UNSUPPORTED_RANDOM2(randint, low);
|
||||
UNSUPPORTED_RANDOM2(randint, low_generator);
|
||||
UNSUPPORTED_RANDOM2(randint, out);
|
||||
UNSUPPORTED_RANDOM2(randint, generator_out);
|
||||
UNSUPPORTED_RANDOM2(randint, low_out);
|
||||
UNSUPPORTED_RANDOM2(randint, low_generator_out);
|
||||
|
||||
m.impl("uniform_", unsupportedRandomOp_<Tensor&, double, double, optional<Generator>>);
|
||||
UNSUPPORTED_RANDOM(uniform_);
|
||||
|
||||
#undef TENSOROPTIONS
|
||||
|
||||
m.impl("randn", randn_mbatching_rule);
|
||||
}
|
||||
|
@ -43,7 +43,9 @@ _functorch_lagging_meta = {
|
||||
('atan2', ''),
|
||||
('atanh', ''),
|
||||
('baddbmm', ''),
|
||||
('bitwise_left_shift', ''),
|
||||
('bitwise_not', ''),
|
||||
('bitwise_right_shift', ''),
|
||||
('bmm', ''),
|
||||
('broadcast_to', ''),
|
||||
('cdist', ''),
|
||||
@ -59,8 +61,10 @@ _functorch_lagging_meta = {
|
||||
('conj_physical', ''),
|
||||
('contiguous', ''),
|
||||
('copysign', ''),
|
||||
('corrcoef', ''),
|
||||
('cos', ''),
|
||||
('cosh', ''),
|
||||
('cov', ''),
|
||||
('cross', ''),
|
||||
('cummax', ''),
|
||||
('cummin', ''),
|
||||
@ -278,6 +282,8 @@ _functorch_lagging_meta = {
|
||||
('special.ndtr', ''),
|
||||
('special.ndtri', ''),
|
||||
('special.xlog1py', ''),
|
||||
('special.zeta', ''),
|
||||
('special.zeta', 'grad'),
|
||||
('split', ''),
|
||||
('split', 'list_args'),
|
||||
('split_with_sizes', ''),
|
||||
|
Reference in New Issue
Block a user