mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] Added ones_like/zeros_like/... batching rules
This commit is contained in:
21
functorch/functorch/csrc/BatchRulesFactory.cpp
Normal file
21
functorch/functorch/csrc/BatchRulesFactory.cpp
Normal file
@ -0,0 +1,21 @@
|
||||
#include <functorch/csrc/BatchRulesHelper.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
using FactoryType = Tensor (*)(const Tensor&, optional<ScalarType>, optional<Layout>, optional<Device>, optional<bool>, optional<MemoryFormat>);
|
||||
using RandIntType = Tensor (*)(const Tensor&, optional<ScalarType>, optional<Layout>, optional<Device>, optional<bool>, optional<MemoryFormat>);
|
||||
|
||||
#define SINGLE_ARG(...) __VA_ARGS__
|
||||
#define FACTORY_BATCH_RULE(op) SINGLE_ARG(basic_unary_batch_rule<FactoryType, &op, optional<ScalarType>, optional<Layout>, optional<Device>, optional<bool>, optional<MemoryFormat>>)
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("ones_like", FACTORY_BATCH_RULE(at::ones_like));
|
||||
VMAP_SUPPORT("zeros_like", FACTORY_BATCH_RULE(at::zeros_like));
|
||||
VMAP_SUPPORT("empty_like", FACTORY_BATCH_RULE(at::empty_like));
|
||||
VMAP_SUPPORT("randn_like", FACTORY_BATCH_RULE(at::randn_like));
|
||||
VMAP_SUPPORT("rand_like", FACTORY_BATCH_RULE(at::rand_like));
|
||||
// Not sure how to add the ones with irregular args to the mix cleanly (i.e. randint takes an extra int parameter)
|
||||
}
|
||||
}}
|
||||
|
Reference in New Issue
Block a user