mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Move functorch/_src to torch/_functorch (#88756)"
This reverts commit 52bc5c1cfe098fd4b4b13902b4fea83b455b9773.
Reverted https://github.com/pytorch/pytorch/pull/88756 on behalf of https://github.com/clee2000 due to broke imports in tests 52bc5c1cfe
https://github.com/pytorch/pytorch/actions/runs/3574742513/jobs/6010814968 probably a landrace
This commit is contained in:
@ -101,16 +101,6 @@ exclude_patterns = [
|
||||
'torch/csrc/**',
|
||||
'torch/_dynamo/**/*.py',
|
||||
'torch/_inductor/**/*.py',
|
||||
'torch/_functorch/aot_autograd.py',
|
||||
'torch/_functorch/benchmark_utils.py',
|
||||
'torch/_functorch/compile_utils.py',
|
||||
'torch/_functorch/compilers.py',
|
||||
'torch/_functorch/eager_transforms.py',
|
||||
'torch/_functorch/fx_minifier.py',
|
||||
'torch/_functorch/partitioners.py',
|
||||
'torch/_functorch/make_functional.py',
|
||||
'torch/_functorch/top_operators_github_usage.py',
|
||||
'torch/_functorch/vmap.py',
|
||||
'torch/distributed/elastic/agent/server/api.py',
|
||||
'torch/testing/_internal/**',
|
||||
'torch/distributed/fsdp/fully_sharded_data_parallel.py',
|
||||
|
@ -23,13 +23,13 @@ import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.utils
|
||||
import torch.distributed
|
||||
from functorch._src.aot_autograd import set_model_name
|
||||
from scipy.stats import gmean, ttest_ind
|
||||
from torch._dynamo.optimizations import backends
|
||||
from torch._dynamo.optimizations.log_args import conv_args_analysis
|
||||
from torch._dynamo.profiler import fx_insert_profiling, Profiler
|
||||
from torch._dynamo.testing import dummy_fx_compile, format_speedup, same
|
||||
from torch._dynamo.utils import clone_inputs
|
||||
from torch._functorch.aot_autograd import set_model_name
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
@ -8,19 +8,19 @@ from . import _C
|
||||
|
||||
# Top-level APIs. Please think carefully before adding something to the
|
||||
# top-level namespace:
|
||||
# - private helper functions should go into torch._functorch
|
||||
# - private helper functions should go into functorch._src
|
||||
# - very experimental things should go into functorch.experimental
|
||||
# - compilation related things should go into functorch.compile
|
||||
|
||||
# functorch transforms
|
||||
from torch._functorch.vmap import vmap
|
||||
from torch._functorch.eager_transforms import (
|
||||
from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import (
|
||||
grad, grad_and_value, vjp, jacrev, jvp, jacfwd, hessian, functionalize
|
||||
)
|
||||
from torch._functorch.python_key import make_fx
|
||||
from ._src.python_key import make_fx
|
||||
|
||||
# utilities. Maybe these should go in their own namespace in the future?
|
||||
from torch._functorch.make_functional import (
|
||||
from ._src.make_functional import (
|
||||
make_functional_with_buffers,
|
||||
make_functional,
|
||||
combine_state_for_ensemble,
|
||||
|
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
@ -1,7 +0,0 @@
|
||||
# This file has moved. It is not public API. If you are not a PyTorch developer
|
||||
# and you are relying on the following imports, please file an issue.
|
||||
from torch._functorch.aot_autograd import (
|
||||
aot_autograd_decompositions,
|
||||
KNOWN_TYPES,
|
||||
PytreeThunk,
|
||||
)
|
@ -1,6 +0,0 @@
|
||||
# This file has moved. It is not public API. If you are not a PyTorch developer
|
||||
# and you are relying on the following imports, please file an issue.
|
||||
from torch._functorch.eager_transforms import (
|
||||
_unwrap_functional_tensor,
|
||||
_assert_wrapped_functional,
|
||||
)
|
@ -1,2 +0,0 @@
|
||||
# This file has moved. Please update your imports
|
||||
from torch._functorch.make_functional import _swap_state
|
@ -1,15 +0,0 @@
|
||||
# This file has moved. It is not public API. If you are not a PyTorch developer
|
||||
# and you are relying on the following imports, please file an issue.
|
||||
from torch._functorch.vmap import (
|
||||
_add_batch_dim,
|
||||
_broadcast_to_and_flatten,
|
||||
_get_name,
|
||||
_remove_batch_dim,
|
||||
_validate_and_get_batch_size,
|
||||
Tensor,
|
||||
tree_flatten,
|
||||
tree_unflatten,
|
||||
_process_batched_inputs,
|
||||
_create_batched_inputs,
|
||||
_unwrap_batched,
|
||||
)
|
@ -5,7 +5,7 @@ import os
|
||||
import logging
|
||||
import pandas as pd
|
||||
|
||||
from torch._functorch.benchmark_utils import compute_utilization
|
||||
from functorch._src.benchmark_utils import compute_utilization
|
||||
|
||||
# process the chrome traces output by the pytorch profiler
|
||||
# require the json input file's name to be in format {model_name}_chrome_trace_*.json
|
||||
|
@ -3,7 +3,7 @@ import torch.fx as fx
|
||||
from functorch import make_fx
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
from torch._functorch.compile_utils import fx_graph_cse
|
||||
from functorch._src.compile_utils import fx_graph_cse
|
||||
|
||||
def profile_it(f, inp):
|
||||
for _ in range(5):
|
||||
|
@ -1,6 +1,6 @@
|
||||
from torch._functorch.python_key import pythonkey_decompose
|
||||
from torch._functorch.fx_minifier import minifier
|
||||
from torch._functorch.aot_autograd import (
|
||||
from .._src.python_key import pythonkey_decompose
|
||||
from .._src.fx_minifier import minifier
|
||||
from .._src.aot_autograd import (
|
||||
aot_function,
|
||||
aot_module,
|
||||
compiled_function,
|
||||
@ -12,7 +12,7 @@ from torch._functorch.aot_autograd import (
|
||||
make_boxed_func,
|
||||
make_boxed_compiler
|
||||
)
|
||||
from torch._functorch.compilers import (
|
||||
from .._src.compilers import (
|
||||
ts_compile,
|
||||
draw_graph_compile,
|
||||
nop,
|
||||
@ -22,10 +22,10 @@ from torch._functorch.compilers import (
|
||||
print_compile,
|
||||
default_decompositions
|
||||
)
|
||||
from torch._functorch.partitioners import (
|
||||
from .._src.partitioners import (
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
draw_graph,
|
||||
draw_joint_graph,
|
||||
)
|
||||
from torch._functorch import config
|
||||
from .._src import config
|
||||
|
@ -1,5 +1,5 @@
|
||||
# PyTorch forward-mode is not mature yet
|
||||
from torch._functorch.eager_transforms import hessian, jacfwd, jvp
|
||||
from torch._functorch.vmap import chunk_vmap
|
||||
from .._src.eager_transforms import hessian, jacfwd, jvp
|
||||
from .._src.vmap import chunk_vmap
|
||||
from .batch_norm_replacement import replace_all_batch_norm_modules_
|
||||
from functorch import functionalize
|
||||
|
@ -104,7 +104,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
y = torch.randn((), device="cpu")
|
||||
fn(x, y)
|
||||
|
||||
@patch("torch._functorch.config.use_functionalize", True)
|
||||
@patch("functorch._src.config.use_functionalize", True)
|
||||
def test_mutate_input(self):
|
||||
def model(x, y):
|
||||
y.add_(3)
|
||||
@ -159,7 +159,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
y = torch.randn(3, device="cuda:0", requires_grad=True)
|
||||
fn(y)
|
||||
|
||||
@patch("torch._functorch.config.use_functionalize", True)
|
||||
@patch("functorch._src.config.use_functionalize", True)
|
||||
@patch_all()
|
||||
def test_mutated_metadata(self):
|
||||
# more tortured example at
|
||||
@ -180,7 +180,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
x = torch.empty(0, device="cuda:0")
|
||||
fn(x)
|
||||
|
||||
@patch("torch._functorch.config.use_functionalize", True)
|
||||
@patch("functorch._src.config.use_functionalize", True)
|
||||
@patch_all()
|
||||
def test_dead_fill(self):
|
||||
def model(x):
|
||||
|
@ -11,6 +11,8 @@ from copy import deepcopy
|
||||
from typing import List
|
||||
from unittest.mock import patch
|
||||
|
||||
import functorch._src.config
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -18,8 +20,6 @@ import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch._dynamo.utils
|
||||
|
||||
import torch._functorch.config
|
||||
|
||||
try:
|
||||
from test_minifier import requires_cuda
|
||||
except ImportError:
|
||||
@ -1681,7 +1681,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
opt_fn(x)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
@patch.object(torch._functorch.config, "use_dynamic_shapes", True)
|
||||
@patch.object(functorch._src.config, "use_dynamic_shapes", True)
|
||||
def test_bigbird_unsqueeze_inplace(self):
|
||||
def fn(reshape_2):
|
||||
view_2 = reshape_2.clone()
|
||||
|
@ -3,7 +3,7 @@ import copy
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from functorch_additional_op_db import additional_op_db
|
||||
from enum import Enum
|
||||
import torch._functorch.top_operators_github_usage as top_ops
|
||||
import functorch._src.top_operators_github_usage as top_ops
|
||||
import pprint
|
||||
import unittest
|
||||
import enum
|
||||
|
@ -22,7 +22,7 @@ from functorch import (
|
||||
grad, vjp, vmap, jacrev,
|
||||
make_fx
|
||||
)
|
||||
from torch._functorch.aot_autograd import aot_module_simplified
|
||||
from functorch._src.aot_autograd import aot_module_simplified
|
||||
from functorch.compile import (
|
||||
nnc_jit, compiled_function, compiled_module,
|
||||
min_cut_rematerialization_partition, aot_function, aot_module,
|
||||
@ -991,7 +991,7 @@ def forward(self, primals_1, primals_2):
|
||||
inp = [torch.randn(5, requires_grad=True) for _ in range(3)]
|
||||
f(*inp).sum().backward()
|
||||
|
||||
@patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count)
|
||||
@patch('functorch._src.aot_autograd.AOT_COUNTER', new_callable=itertools.count)
|
||||
def test_compilation_context(self, counter):
|
||||
def f(x):
|
||||
return x.sin().sin()
|
||||
@ -1016,8 +1016,8 @@ def forward(self, primals_1, primals_2):
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
self.verify_aot_autograd(f, [x, x])
|
||||
|
||||
@patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count)
|
||||
@patch("torch._functorch.config.debug_assert", True)
|
||||
@patch('functorch._src.aot_autograd.AOT_COUNTER', new_callable=itertools.count)
|
||||
@patch("functorch._src.config.debug_assert", True)
|
||||
def test_invalid_dupe(self, counter):
|
||||
class F(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -1037,8 +1037,8 @@ def forward(self, primals_1, primals_2):
|
||||
"""At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950
|
||||
)
|
||||
|
||||
@patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count)
|
||||
@patch("torch._functorch.config.debug_assert", True)
|
||||
@patch('functorch._src.aot_autograd.AOT_COUNTER', new_callable=itertools.count)
|
||||
@patch("functorch._src.config.debug_assert", True)
|
||||
def test_invalid_requires_grad(self, counter):
|
||||
class F(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
|
@ -32,10 +32,10 @@ from functorch import (
|
||||
jvp, make_functional, make_functional_with_buffers,
|
||||
combine_state_for_ensemble, make_fx
|
||||
)
|
||||
from torch._functorch.make_functional import (
|
||||
from functorch._src.make_functional import (
|
||||
functional_init, functional_init_with_buffers,
|
||||
)
|
||||
from torch._functorch.eager_transforms import enable_fwd_grad, _slice_argnums
|
||||
from functorch._src.eager_transforms import enable_fwd_grad, _slice_argnums
|
||||
from functorch.experimental import functionalize
|
||||
from torch._ops import PyOperator
|
||||
from torch._functorch.utils import enable_autograd_function
|
||||
|
@ -6,7 +6,7 @@ import torch.fx as fx
|
||||
from functorch import make_fx
|
||||
from torch.nn import functional as F
|
||||
from functorch.compile import memory_efficient_fusion
|
||||
from torch._functorch.compile_utils import fx_graph_cse
|
||||
from functorch._src.compile_utils import fx_graph_cse
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
import inspect
|
||||
import random
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
from functorch.compile import minifier
|
||||
from torch._functorch.compile_utils import get_placeholders, get_outputs
|
||||
from functorch._src.compile_utils import get_placeholders, get_outputs
|
||||
from functorch import make_fx
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
|
@ -41,7 +41,7 @@ from torch.testing._internal.opinfo.core import SampleInput
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
|
||||
from functorch import grad, vjp, vmap, jacrev, jacfwd
|
||||
import torch.autograd.forward_ad as fwAD
|
||||
from torch._functorch.eager_transforms import _as_tuple, jvp
|
||||
from functorch._src.eager_transforms import _as_tuple, jvp
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
@ -49,7 +49,7 @@ import functorch
|
||||
from functorch import vmap, grad, grad_and_value, jvp, vjp, jacfwd
|
||||
from functorch.experimental import chunk_vmap
|
||||
from torch._C._functorch import reshape_dim_into, reshape_dim_outof
|
||||
from torch._functorch.make_functional import functional_init_with_buffers
|
||||
from functorch._src.make_functional import functional_init_with_buffers
|
||||
|
||||
FALLBACK_REGEX = 'There is a performance drop'
|
||||
|
||||
|
@ -5527,7 +5527,7 @@ if HAS_CUDA:
|
||||
Instead, it transforms the fx graph so that its functions are
|
||||
aten operations. It then saves this graph.
|
||||
"""
|
||||
from torch._functorch.aot_autograd import Interpreter
|
||||
from functorch._src.aot_autograd import Interpreter
|
||||
from torch._inductor.decomposition import select_decomp_table
|
||||
from torch._subclasses import FakeTensorMode
|
||||
|
||||
|
@ -150,7 +150,7 @@ class TestFunctionalization(TestCase):
|
||||
|
||||
def g(x):
|
||||
loss = f(x).sum()
|
||||
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
|
||||
from functorch._src.aot_autograd import setup_stacktrace_preservation_hooks
|
||||
import torch.fx.traceback as fx_traceback
|
||||
setup_stacktrace_preservation_hooks([loss.grad_fn])
|
||||
with fx_traceback.override_stack_trace():
|
||||
|
@ -502,7 +502,7 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False):
|
||||
"""
|
||||
Runs a forward and possibly backward iteration for a given mod and args.
|
||||
"""
|
||||
from torch._functorch.aot_autograd import make_boxed_func
|
||||
from functorch._src.aot_autograd import make_boxed_func
|
||||
|
||||
from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
|
||||
|
||||
|
@ -120,7 +120,7 @@ def enable_dynamic(enable: bool = True):
|
||||
yield
|
||||
return
|
||||
with patch("torch._dynamo.config.dynamic_shapes", True), patch(
|
||||
"torch._functorch.config.use_dynamic_shapes", True
|
||||
"functorch._src.config.use_dynamic_shapes", True
|
||||
):
|
||||
yield
|
||||
|
||||
|
@ -6,6 +6,8 @@ from functools import partial
|
||||
from importlib import import_module
|
||||
from typing import Set
|
||||
|
||||
from functorch._src.compilers import debug_nop
|
||||
|
||||
from functorch.compile import (
|
||||
aot_module_simplified,
|
||||
min_cut_rematerialization_partition,
|
||||
@ -14,8 +16,6 @@ from functorch.compile import (
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
from torch._functorch.compilers import debug_nop
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
@ -1,5 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
@ -6,10 +6,10 @@ import sys
|
||||
from typing import List
|
||||
|
||||
import functorch
|
||||
from functorch._src.aot_autograd import make_boxed_func
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
|
||||
import torch.fx
|
||||
from torch._functorch.aot_autograd import make_boxed_func
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
from . import config, metrics, overrides
|
||||
@ -391,7 +391,7 @@ def compile_fx(
|
||||
with overrides.patch_functions():
|
||||
|
||||
# TODO: can add logging before/after the call to create_aot_dispatcher_function
|
||||
# in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
|
||||
# in functorch/_src/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
|
||||
# once torchdynamo is merged into pytorch
|
||||
return aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
|
Reference in New Issue
Block a user