Compare commits

...

1 Commits

Author SHA1 Message Date
791087804d Some experiments with C++ FakeTensor
Timings on my machine:
- real Tensors: 561us
- FakeTensor: 7.31ms
- meta tensors: 8.8ms
- vmap: 1.2ms
- "C++ FakeTensor": 359us
2024-06-06 10:11:28 -07:00
2 changed files with 30 additions and 0 deletions

View File

@ -18,6 +18,7 @@
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/functorch/TensorWrapper.h>
#include <c10/core/AutogradState.h>
#include <torch/extension.h>
#include <iostream>
@ -598,4 +599,32 @@ void initFuncTorchBindings(PyObject* module) {
&FunctionalizeInterpreterPtr::functionalizeAddBackViews);
}
Tensor Fake_new(const Tensor& base, IntArrayRef sizes, IntArrayRef strides) {
auto keyset = base.unsafeGetTensorImpl()->key_set();
auto result = at::detail::make_tensor<TensorImpl>(keyset, base.dtype(), base.device());
result.unsafeGetTensorImpl()->set_sizes_and_strides(sizes, strides);
return result;
}
Tensor Fake_t(const Tensor& x) {
return Fake_new(x, {x.sizes()[1], x.sizes()[0]}, x.strides());
}
Tensor Fake_detach(const Tensor& x) {
return Fake_new(x, x.sizes(), x.strides());
}
Tensor Fake_relu(const Tensor& x) {
return Fake_new(x, x.sizes(), x.strides());
}
Tensor Fake_addmm(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
int64_t a = mat1.sizes()[0];
int64_t b = mat2.sizes()[1];
return Fake_new(self, {a, b}, {a * b, 1});
}
TORCH_LIBRARY_IMPL(aten, Fake, m) {
m.impl("t", &Fake_t);
m.impl("addmm", &Fake_addmm);
m.impl("relu", &Fake_relu);
m.impl("detach", &Fake_detach);
}
} // namespace torch::functorch::impl

View File

@ -668,6 +668,7 @@ void initDispatchBindings(PyObject* module) {
DEF_ONE(CompositeExplicitAutograd)
DEF_ONE(CompositeImplicitAutogradNestedTensor)
DEF_ONE(CompositeImplicitAutograd)
DEF_ONE(Fake)
// NestedTensor is not a backend key
DEF_ONE(AutogradNestedTensor)
DEF_ONE(AutogradOther)