mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 23:45:05 +08:00
Move functorch/functorch into `functorch` folder - Add functorch/CMakeLists.txt that adds `functorch` native python exension - Modify `setup.py` to package pytorch and functorch together into a single wheel - Modify `functorch.__version__` is not equal to that of `torch.__version__` - Add dummy `functorch/setup.py` file for the projects that still want to build it Differential Revision: [D39058811](https://our.internmc.facebook.com/intern/diff/D39058811) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83464 Approved by: https://github.com/zou3519
69 lines
2.9 KiB
C++
69 lines
2.9 KiB
C++
#include <functorch/csrc/FunctionalizeInterpreter.h>
|
|
#include <functorch/csrc/DynamicLayer.h>
|
|
#include <ATen/FunctionalTensorWrapper.h>
|
|
|
|
namespace at { namespace functorch {
|
|
|
|
static void sanityCheckNotFunctional(const c10::OperatorHandle& op, torch::jit::Stack* stack, size_t num_args) {
|
|
foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
|
|
[](const Tensor& tensor) {
|
|
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensor));
|
|
return tensor;
|
|
});
|
|
}
|
|
|
|
void FunctionalizeInterpreterPtr::processImpl(
|
|
const c10::OperatorHandle& op,
|
|
torch::jit::Stack* stack) {
|
|
DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer(TransformType::Functionalize);
|
|
|
|
// We always want to call the functionalization kernels if functionalize() is on the layer stack.
|
|
// It's the responsibility of the functionalization kernel to no-op and redispatch
|
|
// if none of the input tensors are functional.
|
|
setup_dispatch_key_tls(exclude, DispatchKeySet(DispatchKey::Functionalize));
|
|
auto functionalization_add_back_views = functionalizeAddBackViews();
|
|
// We have some side-car TLS that we can set to toggle the functionaliation behavior.
|
|
// If set, then we functionalization will only remove mutations, instead of
|
|
// removing both mutations AND view operators.
|
|
at::functionalization::impl::FunctionalizationReapplyViewsGuard functional_guard(functionalization_add_back_views);
|
|
|
|
op.callBoxed(stack);
|
|
|
|
auto ret_size = op.schema().returns().size();
|
|
foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(),
|
|
[&](const Tensor& tensor) {
|
|
if (at::functionalization::impl::isFunctionalTensor(tensor)) {
|
|
auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
|
|
// Functorch is responsible for setting the level on the wrapper, since we don't
|
|
// have that info available in core (for now).
|
|
// We could just "propagate" the level from the input tensors inside of the functionalize kernels,
|
|
// but unfortunately we can't do that for factory operators.
|
|
wrapper->set_level(level());
|
|
}
|
|
return tensor;
|
|
}
|
|
);
|
|
}
|
|
|
|
void FunctionalizeInterpreterPtr::sendToNextInterpreterImpl(
|
|
const c10::OperatorHandle& op,
|
|
torch::jit::Stack* stack) {
|
|
// For now, we don't support nested functionalization calls.
|
|
// This check just enforces that - after the functionalize kernel runs
|
|
// and we hit the BackModeFallback, we'll have unwrapped our FunctionalTensors
|
|
// so we can check that the unwrapped thing is not another (nested) FunctionalTensor.
|
|
auto args_size = op.schema().arguments().size();
|
|
sanityCheckNotFunctional(op, stack, args_size);
|
|
|
|
// Re-dispatch
|
|
if (getDynamicLayerStack().size() == 0) {
|
|
sanityCheckStack(op, stack);
|
|
}
|
|
op.callBoxed(stack);
|
|
|
|
auto ret_size = op.schema().returns().size();
|
|
sanityCheckNotFunctional(op, stack, ret_size);
|
|
}
|
|
|
|
}} // namespace at::functorch
|