mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Apply UFMT to all non test/torch files (#106205)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/106205 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
1163800d0f
commit
e6ec0efaf8
@ -908,62 +908,6 @@ exclude_patterns = [
|
||||
'third_party/**/*.pyi',
|
||||
# These files are all grandfathered in, feel free to remove from this list
|
||||
# as necessary
|
||||
'aten/src/ATen/function_wrapper.py',
|
||||
'aten/src/ATen/native/quantized/cpu/qnnpack/configure.py',
|
||||
'aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/configure.py',
|
||||
'aten/src/ATen/native/quantized/cpu/qnnpack/generate-wrapper.py',
|
||||
'aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py',
|
||||
'aten/src/ATen/nnapi/codegen.py',
|
||||
'functorch/__init__.py',
|
||||
'functorch/_src/__init__.py',
|
||||
'functorch/_src/aot_autograd/__init__.py',
|
||||
'functorch/_src/eager_transforms/__init__.py',
|
||||
'functorch/_src/make_functional/__init__.py',
|
||||
'functorch/_src/vmap/__init__.py',
|
||||
'functorch/benchmarks/chrome_trace_parser.py',
|
||||
'functorch/benchmarks/cse.py',
|
||||
'functorch/benchmarks/operator_authoring.py',
|
||||
'functorch/benchmarks/per_sample_grads.py',
|
||||
'functorch/benchmarks/pointwise_scorecard.py',
|
||||
'functorch/benchmarks/process_scorecard.py',
|
||||
'functorch/compile/__init__.py',
|
||||
'functorch/dim/__init__.py',
|
||||
'functorch/dim/batch_tensor.py',
|
||||
'functorch/dim/delayed_mul_tensor.py',
|
||||
'functorch/dim/dim.py',
|
||||
'functorch/dim/magic_trace.py',
|
||||
'functorch/dim/op_properties.py',
|
||||
'functorch/dim/reference.py',
|
||||
'functorch/dim/tree_map.py',
|
||||
'functorch/dim/wrap_type.py',
|
||||
'functorch/docs/source/conf.py',
|
||||
'functorch/einops/__init__.py',
|
||||
'functorch/einops/_parsing.py',
|
||||
'functorch/einops/rearrange.py',
|
||||
'functorch/examples/compilation/eager_fusion.py',
|
||||
'functorch/examples/compilation/fuse_module.py',
|
||||
'functorch/examples/compilation/linear_train.py',
|
||||
'functorch/examples/compilation/simple_function.py',
|
||||
'functorch/examples/dp_cifar10/cifar10_opacus.py',
|
||||
'functorch/examples/dp_cifar10/cifar10_transforms.py',
|
||||
'functorch/examples/ensembling/parallel_train.py',
|
||||
'functorch/examples/lennard_jones/lennard_jones.py',
|
||||
'functorch/examples/maml_omniglot/maml-omniglot-higher.py',
|
||||
'functorch/examples/maml_omniglot/maml-omniglot-ptonly.py',
|
||||
'functorch/examples/maml_omniglot/maml-omniglot-transforms.py',
|
||||
'functorch/examples/maml_omniglot/support/omniglot_loaders.py',
|
||||
'functorch/examples/maml_regression/evjang.py',
|
||||
'functorch/examples/maml_regression/evjang_transforms.py',
|
||||
'functorch/examples/maml_regression/evjang_transforms_module.py',
|
||||
'functorch/experimental/__init__.py',
|
||||
'functorch/experimental/_cond.py',
|
||||
'functorch/experimental/_map.py',
|
||||
'functorch/experimental/control_flow.py',
|
||||
'functorch/experimental/ops.py',
|
||||
'functorch/notebooks/_src/plot_ensembling.py',
|
||||
'functorch/notebooks/_src/plot_jacobians_and_hessians.py',
|
||||
'functorch/notebooks/_src/plot_per_sample_gradients.py',
|
||||
'functorch/op_analysis/gen_data.py',
|
||||
'test/_nvfuser/__init__.py',
|
||||
'test/_nvfuser/test_dynamo.py',
|
||||
'test/_nvfuser/test_python_frontend.py',
|
||||
|
@ -31,7 +31,6 @@ def main(args):
|
||||
],
|
||||
extra_include_dirs="src",
|
||||
):
|
||||
|
||||
requantization_objects = [
|
||||
build.cc("requantization/precise-scalar.c"),
|
||||
build.cc("requantization/fp32-scalar.c"),
|
||||
@ -192,7 +191,6 @@ def main(args):
|
||||
},
|
||||
extra_include_dirs=["src", "test"],
|
||||
):
|
||||
|
||||
build.unittest("hgemm-test", build.cxx("hgemm.cc"))
|
||||
build.unittest("q8avgpool-test", build.cxx("q8avgpool.cc"))
|
||||
build.unittest("q8conv-test", build.cxx("q8conv.cc"))
|
||||
@ -252,7 +250,6 @@ def main(args):
|
||||
isa=benchmark_isa,
|
||||
extra_include_dirs="src",
|
||||
):
|
||||
|
||||
build.benchmark("add-bench", build.cxx("add.cc"))
|
||||
build.benchmark("average-pooling-bench", build.cxx("average-pooling.cc"))
|
||||
build.benchmark("channel-shuffle-bench", build.cxx("channel-shuffle.cc"))
|
||||
|
@ -7,6 +7,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import confu
|
||||
|
||||
parser = confu.standard_parser("clog configuration script")
|
||||
|
||||
|
||||
@ -19,13 +20,16 @@ def main(args):
|
||||
with build.options(source_dir="src", extra_include_dirs="src"):
|
||||
build.static_library("clog", build.cc("clog.c"))
|
||||
|
||||
with build.options(source_dir="test", deps={
|
||||
(build, build.deps.googletest): all,
|
||||
"log": build.target.is_android}):
|
||||
with build.options(
|
||||
source_dir="test",
|
||||
deps={(build, build.deps.googletest): all, "log": build.target.is_android},
|
||||
):
|
||||
build.unittest("clog-test", build.cxx("clog.cc"))
|
||||
|
||||
return build
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
main(sys.argv[1:]).generate()
|
||||
|
@ -8,12 +8,12 @@
|
||||
# Kernels are ordered (see `sort_index`), and when dispatching,
|
||||
# we select the first kernel in the list that supports the inputs
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import itertools
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, TypeVar
|
||||
import argparse
|
||||
|
||||
DTYPES = {
|
||||
"f32": "float",
|
||||
@ -303,7 +303,11 @@ T = TypeVar("T", FwdKernel, BwdKernel)
|
||||
|
||||
|
||||
def write_decl_impl(
|
||||
kernels: List[T], family_name: str, impl_file: str, autogen_dir: Path, disable_def: str = None
|
||||
kernels: List[T],
|
||||
family_name: str,
|
||||
impl_file: str,
|
||||
autogen_dir: Path,
|
||||
disable_def: str = None,
|
||||
) -> None:
|
||||
cpp_file_header = """/*
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
@ -382,22 +386,28 @@ def main(output_dir: Optional[str]) -> None:
|
||||
FwdKernel.get_all(),
|
||||
"cutlassF",
|
||||
impl_file="<ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>",
|
||||
autogen_dir=output_dir
|
||||
autogen_dir=output_dir,
|
||||
)
|
||||
write_decl_impl(
|
||||
BwdKernel.get_all(),
|
||||
"cutlassB",
|
||||
impl_file="<ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>",
|
||||
autogen_dir=output_dir
|
||||
autogen_dir=output_dir,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='generate_kernels',
|
||||
description='Generate the mem-eff kernels template instantiations')
|
||||
prog="generate_kernels",
|
||||
description="Generate the mem-eff kernels template instantiations",
|
||||
)
|
||||
# Set an optional output directory
|
||||
parser.add_argument('-o', '--output_dir', required=False, help="Where to generate the kernels "
|
||||
" will default to <ATen/native/transformers/cuda/mem_eff_attention/kernels/> ")
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
required=False,
|
||||
help="Where to generate the kernels "
|
||||
" will default to <ATen/native/transformers/cuda/mem_eff_attention/kernels/> ",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.output_dir)
|
||||
|
@ -7,9 +7,9 @@ that opens libneuralnetworks.so with dlopen and finds the functions
|
||||
we need with dlsym. We also generate a "check" wrapper that checks
|
||||
return values and throws C++ exceptions on errors.
|
||||
"""
|
||||
import sys
|
||||
import re
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
|
||||
@ -36,39 +36,155 @@ PREFIX = """\
|
||||
|
||||
NNAPI_FUNCTIONS = [
|
||||
("int", "ANeuralNetworks_getDeviceCount", "uint32_t* numDevices"), # noqa: B950
|
||||
("int", "ANeuralNetworks_getDevice", "uint32_t devIndex, ANeuralNetworksDevice** device"), # noqa: B950
|
||||
("int", "ANeuralNetworksDevice_getName", "const ANeuralNetworksDevice* device, const char** name"), # noqa: B950
|
||||
("int", "ANeuralNetworksDevice_getVersion", "const ANeuralNetworksDevice* device, const char** version"), # noqa: B950
|
||||
("int", "ANeuralNetworksDevice_getFeatureLevel", "const ANeuralNetworksDevice* device, int64_t* featureLevel"), # noqa: B950
|
||||
("int", "ANeuralNetworksModel_getSupportedOperationsForDevices", " const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps"), # noqa: B950
|
||||
("int", "ANeuralNetworksCompilation_createForDevices", "ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation"), # noqa: B950
|
||||
("int", "ANeuralNetworksExecution_compute", "ANeuralNetworksExecution* execution"), # noqa: B950
|
||||
("int", "ANeuralNetworksMemory_createFromFd", "size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory"), # noqa: B950
|
||||
("void", "ANeuralNetworksMemory_free", "ANeuralNetworksMemory* memory"), # noqa: B950
|
||||
("int", "ANeuralNetworksModel_create", "ANeuralNetworksModel** model"), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworks_getDevice",
|
||||
"uint32_t devIndex, ANeuralNetworksDevice** device",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksDevice_getName",
|
||||
"const ANeuralNetworksDevice* device, const char** name",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksDevice_getVersion",
|
||||
"const ANeuralNetworksDevice* device, const char** version",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksDevice_getFeatureLevel",
|
||||
"const ANeuralNetworksDevice* device, int64_t* featureLevel",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksModel_getSupportedOperationsForDevices",
|
||||
" const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksCompilation_createForDevices",
|
||||
"ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation", # noqa: B950
|
||||
),
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksExecution_compute",
|
||||
"ANeuralNetworksExecution* execution",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksMemory_createFromFd",
|
||||
"size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory",
|
||||
), # noqa: B950
|
||||
(
|
||||
"void",
|
||||
"ANeuralNetworksMemory_free",
|
||||
"ANeuralNetworksMemory* memory",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksModel_create",
|
||||
"ANeuralNetworksModel** model",
|
||||
), # noqa: B950
|
||||
("void", "ANeuralNetworksModel_free", "ANeuralNetworksModel* model"), # noqa: B950
|
||||
("int", "ANeuralNetworksModel_finish", "ANeuralNetworksModel* model"), # noqa: B950
|
||||
("int", "ANeuralNetworksModel_addOperand", "ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type"), # noqa: B950
|
||||
("int", "ANeuralNetworksModel_setOperandValue", "ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length"), # noqa: B950
|
||||
("int", "ANeuralNetworksModel_setOperandValueFromMemory", "ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950
|
||||
("int", "ANeuralNetworksModel_addOperation", "ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), # noqa: B950
|
||||
("int", "ANeuralNetworksModel_identifyInputsAndOutputs", "ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), # noqa: B950
|
||||
("int", "ANeuralNetworksModel_relaxComputationFloat32toFloat16", "ANeuralNetworksModel* model, bool allow"), # noqa: B950
|
||||
("int", "ANeuralNetworksCompilation_create", "ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation"), # noqa: B950
|
||||
("void", "ANeuralNetworksCompilation_free", "ANeuralNetworksCompilation* compilation"), # noqa: B950
|
||||
("int", "ANeuralNetworksCompilation_setPreference", "ANeuralNetworksCompilation* compilation, int32_t preference"), # noqa: B950
|
||||
("int", "ANeuralNetworksCompilation_finish", "ANeuralNetworksCompilation* compilation"), # noqa: B950
|
||||
("int", "ANeuralNetworksExecution_create", "ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution"), # noqa: B950
|
||||
("void", "ANeuralNetworksExecution_free", "ANeuralNetworksExecution* execution"), # noqa: B950
|
||||
("int", "ANeuralNetworksExecution_setInput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length"), # noqa: B950
|
||||
("int", "ANeuralNetworksExecution_setInputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950
|
||||
("int", "ANeuralNetworksExecution_setOutput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length"), # noqa: B950
|
||||
("int", "ANeuralNetworksExecution_setOutputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950
|
||||
("int", "ANeuralNetworksExecution_startCompute", "ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event"), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksModel_addOperand",
|
||||
"ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksModel_setOperandValue",
|
||||
"ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksModel_setOperandValueFromMemory",
|
||||
"ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksModel_addOperation",
|
||||
"ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs", # noqa: B950
|
||||
),
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksModel_identifyInputsAndOutputs",
|
||||
"ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksModel_relaxComputationFloat32toFloat16",
|
||||
"ANeuralNetworksModel* model, bool allow",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksCompilation_create",
|
||||
"ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation",
|
||||
), # noqa: B950
|
||||
(
|
||||
"void",
|
||||
"ANeuralNetworksCompilation_free",
|
||||
"ANeuralNetworksCompilation* compilation",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksCompilation_setPreference",
|
||||
"ANeuralNetworksCompilation* compilation, int32_t preference",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksCompilation_finish",
|
||||
"ANeuralNetworksCompilation* compilation",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksExecution_create",
|
||||
"ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution",
|
||||
), # noqa: B950
|
||||
(
|
||||
"void",
|
||||
"ANeuralNetworksExecution_free",
|
||||
"ANeuralNetworksExecution* execution",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksExecution_setInput",
|
||||
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length", # noqa: B950
|
||||
),
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksExecution_setInputFromMemory",
|
||||
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", # noqa: B950
|
||||
),
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksExecution_setOutput",
|
||||
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksExecution_setOutputFromMemory",
|
||||
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", # noqa: B950
|
||||
),
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksExecution_startCompute",
|
||||
"ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event",
|
||||
), # noqa: B950
|
||||
("int", "ANeuralNetworksEvent_wait", "ANeuralNetworksEvent* event"), # noqa: B950
|
||||
("void", "ANeuralNetworksEvent_free", "ANeuralNetworksEvent* event"), # noqa: B950
|
||||
("int", "ANeuralNetworksExecution_getOutputOperandRank", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank"), # noqa: B950
|
||||
("int", "ANeuralNetworksExecution_getOutputOperandDimensions", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions"), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksExecution_getOutputOperandRank",
|
||||
"ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank",
|
||||
), # noqa: B950
|
||||
(
|
||||
"int",
|
||||
"ANeuralNetworksExecution_getOutputOperandDimensions",
|
||||
"ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions",
|
||||
), # noqa: B950
|
||||
]
|
||||
|
||||
|
||||
@ -82,18 +198,26 @@ def main(argv):
|
||||
|
||||
struct_members.append(f" {ret}(*{short_name})({args});")
|
||||
|
||||
load_functions.append(f' *(void**)&nnapi_.{short_name} = dlsym(handle, "{name}");')
|
||||
load_functions.append(f' check_nnapi_.{short_name} = check_{short_name};')
|
||||
load_functions.append(
|
||||
f' *(void**)&nnapi_.{short_name} = dlsym(handle, "{name}");'
|
||||
)
|
||||
load_functions.append(f" check_nnapi_.{short_name} = check_{short_name};")
|
||||
|
||||
call_args = "".join(re.findall(r"\w+(?:,|$)", args))
|
||||
if ret == "void":
|
||||
define_checks.append(textwrap.dedent(f"""\
|
||||
define_checks.append(
|
||||
textwrap.dedent(
|
||||
f"""\
|
||||
{ret} check_{short_name}({args}) {{
|
||||
CAFFE_ENFORCE(nnapi_.{short_name});
|
||||
nnapi_.{short_name}({call_args});
|
||||
}}"""))
|
||||
}}"""
|
||||
)
|
||||
)
|
||||
if ret == "int":
|
||||
define_checks.append(textwrap.dedent(f"""\
|
||||
define_checks.append(
|
||||
textwrap.dedent(
|
||||
f"""\
|
||||
{ret} check_{short_name}({args}) {{
|
||||
CAFFE_ENFORCE(nnapi_.{short_name});
|
||||
int ret = nnapi_.{short_name}({call_args});
|
||||
@ -103,13 +227,16 @@ def main(argv):
|
||||
"{short_name}", "failed with error ", ret
|
||||
);
|
||||
return ret;
|
||||
}}"""))
|
||||
}}"""
|
||||
)
|
||||
)
|
||||
|
||||
out_dir = pathlib.Path(__file__).parent
|
||||
|
||||
(out_dir / "nnapi_wrapper.h").write_text(
|
||||
PREFIX +
|
||||
textwrap.dedent("""\
|
||||
PREFIX
|
||||
+ textwrap.dedent(
|
||||
"""\
|
||||
#ifndef NNAPI_WRAPPER_H_
|
||||
#define NNAPI_WRAPPER_H_
|
||||
#include <stddef.h>
|
||||
@ -122,13 +249,14 @@ def main(argv):
|
||||
void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi);
|
||||
#endif
|
||||
#endif
|
||||
""")
|
||||
.replace("__STRUCT_MEMBERS__", "\n".join(struct_members))
|
||||
"""
|
||||
).replace("__STRUCT_MEMBERS__", "\n".join(struct_members))
|
||||
)
|
||||
|
||||
(out_dir / "nnapi_wrapper.cpp").write_text(
|
||||
PREFIX +
|
||||
textwrap.dedent("""\
|
||||
PREFIX
|
||||
+ textwrap.dedent(
|
||||
"""\
|
||||
#ifndef _WIN32
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
@ -157,7 +285,8 @@ def main(argv):
|
||||
*check_nnapi = &check_nnapi_;
|
||||
#endif
|
||||
}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
.replace("__DEFINE_CHECK_FUNCTIONS__", "\n".join(define_checks))
|
||||
.replace("__LOAD_FUNCTIONS__", "\n".join(load_functions))
|
||||
)
|
||||
|
@ -5,6 +5,27 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
|
||||
from torch._functorch.deprecated import (
|
||||
combine_state_for_ensemble,
|
||||
functionalize,
|
||||
grad,
|
||||
grad_and_value,
|
||||
hessian,
|
||||
jacfwd,
|
||||
jacrev,
|
||||
jvp,
|
||||
make_functional,
|
||||
make_functional_with_buffers,
|
||||
vjp,
|
||||
vmap,
|
||||
)
|
||||
|
||||
# utilities. Maybe these should go in their own namespace in the future?
|
||||
from torch._functorch.make_functional import (
|
||||
FunctionalModule,
|
||||
FunctionalModuleWithBuffers,
|
||||
)
|
||||
|
||||
# Top-level APIs. Please think carefully before adding something to the
|
||||
# top-level namespace:
|
||||
# - private helper functions should go into torch._functorch
|
||||
@ -14,15 +35,4 @@ import torch
|
||||
# Was never documented
|
||||
from torch._functorch.python_key import make_fx
|
||||
|
||||
from torch._functorch.deprecated import (
|
||||
vmap, grad, grad_and_value, vjp, jvp, jacrev, jacfwd, hessian, functionalize,
|
||||
make_functional, make_functional_with_buffers, combine_state_for_ensemble,
|
||||
)
|
||||
|
||||
# utilities. Maybe these should go in their own namespace in the future?
|
||||
from torch._functorch.make_functional import (
|
||||
FunctionalModule,
|
||||
FunctionalModuleWithBuffers,
|
||||
)
|
||||
|
||||
__version__ = torch.__version__
|
||||
|
@ -2,6 +2,6 @@
|
||||
# If you are not a PyTorch developer and you are relying on the following
|
||||
# imports, please file an issue.
|
||||
from torch._functorch.eager_transforms import (
|
||||
_unwrap_functional_tensor,
|
||||
_assert_wrapped_functional,
|
||||
_unwrap_functional_tensor,
|
||||
)
|
||||
|
@ -4,13 +4,13 @@
|
||||
from torch._functorch.vmap import (
|
||||
_add_batch_dim,
|
||||
_broadcast_to_and_flatten,
|
||||
_create_batched_inputs,
|
||||
_get_name,
|
||||
_process_batched_inputs,
|
||||
_remove_batch_dim,
|
||||
_unwrap_batched,
|
||||
_validate_and_get_batch_size,
|
||||
Tensor,
|
||||
tree_flatten,
|
||||
tree_unflatten,
|
||||
_process_batched_inputs,
|
||||
_create_batched_inputs,
|
||||
_unwrap_batched,
|
||||
)
|
||||
|
@ -1,8 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from torch._functorch.benchmark_utils import compute_utilization
|
||||
@ -17,9 +18,10 @@ def get_model_name(filename):
|
||||
Get model name from a file in format {model_name}_chrome_trace_*.json
|
||||
"""
|
||||
_, tail = os.path.split(filename)
|
||||
modelname = tail[:tail.find("_chrome_trace")]
|
||||
modelname = tail[: tail.find("_chrome_trace")]
|
||||
return modelname
|
||||
|
||||
|
||||
def get_total_length(run_times_df, modelname):
|
||||
return float(run_times_df[run_times_df["name"] == modelname]["runtime"])
|
||||
|
||||
@ -31,14 +33,14 @@ def main():
|
||||
"--runtime", "-runf", help="file name of the runtime file", required=True
|
||||
)
|
||||
group.add_argument(
|
||||
"--filename", "-f", action="append", help="a filename of the json file to process"
|
||||
)
|
||||
group.add_argument(
|
||||
"--folder", "-fd", help="a folder of the json files to process"
|
||||
"--filename",
|
||||
"-f",
|
||||
action="append",
|
||||
help="a filename of the json file to process",
|
||||
)
|
||||
group.add_argument("--folder", "-fd", help="a folder of the json files to process")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.filename:
|
||||
filenames = args.filename
|
||||
elif args.folder:
|
||||
@ -58,11 +60,14 @@ def main():
|
||||
try:
|
||||
modelname = get_model_name(filename)
|
||||
total_length = get_total_length(run_times_df, modelname) * 1e6
|
||||
utilization, mm_conv_utilization = compute_utilization(filenames, total_length)
|
||||
utilization, mm_conv_utilization = compute_utilization(
|
||||
filenames, total_length
|
||||
)
|
||||
print(f"{modelname}, {utilization}, {mm_conv_utilization}")
|
||||
except BaseException:
|
||||
logging.exception("%s, ERROR", filename)
|
||||
print(f"{filename}, ERROR")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from functorch import make_fx
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
from torch._functorch.compile_utils import fx_graph_cse
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
|
||||
def profile_it(f, inp):
|
||||
for _ in range(5):
|
||||
@ -20,6 +21,7 @@ def profile_it(f, inp):
|
||||
cuda_time_total = cuda_time_total + e.cuda_time_total
|
||||
return cuda_time_total / itr
|
||||
|
||||
|
||||
def profile_function(name, f, inp):
|
||||
fx_g = make_fx(f)(inp)
|
||||
|
||||
@ -34,17 +36,23 @@ def profile_function(name, f, inp):
|
||||
avg_cuda_time_g = profile_it(new_g, inp)
|
||||
num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)
|
||||
|
||||
print(f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}")
|
||||
print(
|
||||
f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}"
|
||||
)
|
||||
|
||||
g_gpu = torch.Generator(device='cuda')
|
||||
|
||||
g_gpu = torch.Generator(device="cuda")
|
||||
g_gpu.manual_seed(2147483647)
|
||||
inp = torch.randn(2**20, device='cuda', generator=g_gpu)
|
||||
inp = torch.randn(2**20, device="cuda", generator=g_gpu)
|
||||
|
||||
|
||||
def f1(x):
|
||||
return x.cos().cos()
|
||||
|
||||
|
||||
profile_function("f1", f1, inp)
|
||||
|
||||
|
||||
def fsum(x):
|
||||
a = x.sum()
|
||||
b = x.sum()
|
||||
@ -52,22 +60,29 @@ def fsum(x):
|
||||
d = x.sum()
|
||||
return a + b + c + d
|
||||
|
||||
|
||||
profile_function("fsum", fsum, inp)
|
||||
|
||||
|
||||
def fconcat(x):
|
||||
a = torch.cat((x, x))
|
||||
b = torch.cat((x, x))
|
||||
return a + b
|
||||
|
||||
|
||||
profile_function("fconcat", fconcat, inp)
|
||||
|
||||
|
||||
def fsum2(x):
|
||||
a = x.sum()
|
||||
for _ in range(30):
|
||||
a = a + x.sum()
|
||||
return a
|
||||
|
||||
|
||||
profile_function("fsum2", fsum2, inp)
|
||||
|
||||
|
||||
def fsummulti(x):
|
||||
a = 0
|
||||
for _ in range(3):
|
||||
@ -75,8 +90,10 @@ def fsummulti(x):
|
||||
a = a * x.sum()
|
||||
return a
|
||||
|
||||
|
||||
profile_function("fsummulti", fsummulti, inp)
|
||||
|
||||
|
||||
def fsummulti2(x):
|
||||
a = 0
|
||||
for _ in range(30):
|
||||
@ -84,20 +101,25 @@ def fsummulti2(x):
|
||||
a = a * x.sum()
|
||||
return a
|
||||
|
||||
|
||||
profile_function("fsummulti2", fsummulti2, inp)
|
||||
|
||||
|
||||
def fcos(x):
|
||||
a = 0
|
||||
for _ in range(3):
|
||||
a = a + x.cos()
|
||||
return a
|
||||
|
||||
|
||||
profile_function("fcos", fcos, inp)
|
||||
|
||||
|
||||
def fcos2(x):
|
||||
a = 0
|
||||
for _ in range(30):
|
||||
a = a + x.cos()
|
||||
return a
|
||||
|
||||
|
||||
profile_function("fcos2", fcos2, inp)
|
||||
|
@ -1,7 +1,8 @@
|
||||
import timeit
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import timeit
|
||||
import torch
|
||||
from functorch.compile import pointwise_operator
|
||||
|
||||
|
@ -1,14 +1,14 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
import time
|
||||
|
||||
from functorch import vmap, grad
|
||||
from functorch import make_functional
|
||||
from functorch import grad, make_functional, vmap
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
|
||||
device = 'cuda'
|
||||
device = "cuda"
|
||||
batch_size = 128
|
||||
torch.manual_seed(0)
|
||||
|
||||
@ -20,6 +20,7 @@ images = torch.randn(batch_size, 3, 32, 32, device=device)
|
||||
targets = torch.randint(0, 10, (batch_size,), device=device)
|
||||
func_model, weights = make_functional(model_functorch)
|
||||
|
||||
|
||||
def compute_loss(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
targets = target.unsqueeze(0)
|
||||
@ -27,11 +28,11 @@ def compute_loss(weights, image, target):
|
||||
loss = criterion(output, targets)
|
||||
return loss
|
||||
|
||||
|
||||
def functorch_per_sample_grad():
|
||||
compute_grad = grad(compute_loss)
|
||||
compute_per_sample_grad = vmap(compute_grad, (None, 0, 0))
|
||||
|
||||
|
||||
start = time.time()
|
||||
result = compute_per_sample_grad(weights, images, targets)
|
||||
torch.cuda.synchronize()
|
||||
@ -39,6 +40,7 @@ def functorch_per_sample_grad():
|
||||
|
||||
return result, end - start # end - start in seconds
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
model_opacus = convert_batchnorm_modules(models.resnet18(num_classes=10))
|
||||
model_opacus = model_opacus.to(device)
|
||||
@ -54,6 +56,7 @@ privacy_engine = PrivacyEngine(
|
||||
max_grad_norm=10000.0,
|
||||
)
|
||||
|
||||
|
||||
def opacus_per_sample_grad():
|
||||
start = time.time()
|
||||
output = model_opacus(images)
|
||||
@ -63,7 +66,7 @@ def opacus_per_sample_grad():
|
||||
end = time.time()
|
||||
expected = [p.grad_sample for p in model_opacus.parameters()]
|
||||
for p in model_opacus.parameters():
|
||||
delattr(p, 'grad_sample')
|
||||
delattr(p, "grad_sample")
|
||||
p.grad = None
|
||||
return expected, end - start
|
||||
|
||||
|
@ -1,14 +1,16 @@
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
import inspect
|
||||
import itertools
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from functorch import pointwise_operator
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch._C._debug_set_fusion_group_inlining(False)
|
||||
|
||||
|
||||
def rand(*shape):
|
||||
return torch.rand(*shape).mul(16).add(1)
|
||||
|
||||
@ -19,105 +21,139 @@ def rand(*shape):
|
||||
def scalar():
|
||||
return (rand(1), rand(1))
|
||||
|
||||
|
||||
def small():
|
||||
return (rand(32), rand(32))
|
||||
|
||||
|
||||
def small_2d():
|
||||
return (rand(1, 32), rand(1, 32))
|
||||
|
||||
|
||||
def small_broadcast():
|
||||
return (rand(4, 32), rand(32))
|
||||
|
||||
|
||||
def medium():
|
||||
return (rand(32, 12, 64, 64), rand(32, 12, 64, 64))
|
||||
|
||||
|
||||
def medium_sliced():
|
||||
return (rand(32, 12, 64, 64)[..., ::2],
|
||||
rand(32, 12, 64, 64)[..., ::2])
|
||||
return (rand(32, 12, 64, 64)[..., ::2], rand(32, 12, 64, 64)[..., ::2])
|
||||
|
||||
|
||||
def medium_transpose():
|
||||
return (rand(32, 12, 64, 64).transpose(-1, -2),
|
||||
rand(32, 12, 64, 64).transpose(-1, -2))
|
||||
return (
|
||||
rand(32, 12, 64, 64).transpose(-1, -2),
|
||||
rand(32, 12, 64, 64).transpose(-1, -2),
|
||||
)
|
||||
|
||||
|
||||
def medium2():
|
||||
return (rand(32, 3, 224, 224), rand(32, 3, 224, 224))
|
||||
|
||||
|
||||
def medium3d():
|
||||
return (rand(16, 32, 64), rand(16, 32, 64))
|
||||
|
||||
|
||||
def medium_channels_last():
|
||||
return (rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
|
||||
rand(32, 3, 224, 224).to(memory_format=torch.channels_last))
|
||||
return (
|
||||
rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
|
||||
rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
|
||||
)
|
||||
|
||||
|
||||
def medium_broadcast():
|
||||
return (rand(32, 12, 64, 64), rand(64))
|
||||
|
||||
|
||||
def medium_broadcast_channels_last():
|
||||
return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last),
|
||||
rand(3, 1, 1))
|
||||
return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), rand(3, 1, 1))
|
||||
|
||||
|
||||
def large():
|
||||
return (rand(8192, 8192), rand(8192, 8192))
|
||||
|
||||
|
||||
def large_transpose():
|
||||
return (rand(8192, 8192).transpose(0, 1),
|
||||
rand(8192, 8192).transpose(0, 1))
|
||||
return (rand(8192, 8192).transpose(0, 1), rand(8192, 8192).transpose(0, 1))
|
||||
|
||||
|
||||
def large_channels_last():
|
||||
return (rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
|
||||
rand(32, 32, 256, 256).to(memory_format=torch.channels_last))
|
||||
return (
|
||||
rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
|
||||
rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
|
||||
)
|
||||
|
||||
|
||||
def pathological_broadcast():
|
||||
return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2))
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Operator test cases
|
||||
# ------------------------------------------------------------------------------
|
||||
def add(a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
def sub(a, b):
|
||||
return a - b
|
||||
|
||||
|
||||
def mul(a, b):
|
||||
return a * b
|
||||
|
||||
|
||||
def div(a, b):
|
||||
return a / b
|
||||
|
||||
|
||||
def relu(a):
|
||||
return a.relu()
|
||||
|
||||
|
||||
def sigmoid(a):
|
||||
return a.sigmoid()
|
||||
|
||||
|
||||
def tanh(a):
|
||||
return a.tanh()
|
||||
|
||||
|
||||
def log(a):
|
||||
return a.log()
|
||||
|
||||
|
||||
def exp(a):
|
||||
return a.exp()
|
||||
|
||||
|
||||
def square(a):
|
||||
return a ** 2
|
||||
return a**2
|
||||
|
||||
|
||||
def fma(a, b):
|
||||
return a * b + b
|
||||
|
||||
|
||||
def hardswish(a):
|
||||
return a * (a + 3.0).clamp(0.0, 6.0) / 6.0
|
||||
|
||||
|
||||
def native_hardswish(a):
|
||||
return torch._C._nn.hardswish(a)
|
||||
|
||||
|
||||
def softplus(a):
|
||||
return (a * 1.0).exp().log1p() / 1.0
|
||||
|
||||
|
||||
def mish(a):
|
||||
return a * ((a * 1.0).exp().log1p() / 1.0).tanh()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------------------
|
||||
@ -128,6 +164,7 @@ def time_cpu(fn, args, iters):
|
||||
e = time.perf_counter()
|
||||
return e - s
|
||||
|
||||
|
||||
def time_cuda(fn, args, iters):
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
@ -138,19 +175,23 @@ def time_cuda(fn, args, iters):
|
||||
torch.cuda.synchronize()
|
||||
return start.elapsed_time(end) / 1e3
|
||||
|
||||
|
||||
def benchmark_with_timer(fn, args, timer):
|
||||
timer(fn, args, 3)
|
||||
calibration = timer(fn, args, 1)
|
||||
iters = int(1.0 / calibration)
|
||||
return timer(fn, args, iters) / iters
|
||||
|
||||
|
||||
def benchmark(fn, args):
|
||||
timer = time_cpu if args[0].device.type == "cpu" else time_cuda
|
||||
return benchmark_with_timer(fn, args, timer)
|
||||
|
||||
|
||||
def micros(s):
|
||||
return f"{s * 1e6:.1f}"
|
||||
|
||||
|
||||
shapes = [
|
||||
scalar,
|
||||
small,
|
||||
@ -211,7 +252,17 @@ for shape, operator in itertools.product(shapes, operators):
|
||||
args = shape()[:nargs]
|
||||
|
||||
result = benchmark(operator, args)
|
||||
print(",".join(["eager", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
|
||||
print(
|
||||
",".join(
|
||||
[
|
||||
"eager",
|
||||
args[0].device.type,
|
||||
operator.__name__,
|
||||
shape.__name__,
|
||||
micros(result),
|
||||
]
|
||||
)
|
||||
)
|
||||
try:
|
||||
if shape == medium_transpose:
|
||||
raise RuntimeError("pointwise_operator hangs on medium_transpose")
|
||||
@ -219,11 +270,41 @@ for shape, operator in itertools.product(shapes, operators):
|
||||
raise RuntimeError("pointwise_operator fails on medium_transpose")
|
||||
pw_op = pointwise_operator(operator)
|
||||
result = benchmark(pw_op, args)
|
||||
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
|
||||
print(
|
||||
",".join(
|
||||
[
|
||||
"pointwise",
|
||||
args[0].device.type,
|
||||
operator.__name__,
|
||||
shape.__name__,
|
||||
micros(result),
|
||||
]
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(float("nan"))]))
|
||||
print(
|
||||
",".join(
|
||||
[
|
||||
"pointwise",
|
||||
args[0].device.type,
|
||||
operator.__name__,
|
||||
shape.__name__,
|
||||
micros(float("nan")),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
ts_op = torch.jit.script(operator)
|
||||
result = benchmark(ts_op, args)
|
||||
print(",".join(["fuser", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
|
||||
print(
|
||||
",".join(
|
||||
[
|
||||
"fuser",
|
||||
args[0].device.type,
|
||||
operator.__name__,
|
||||
shape.__name__,
|
||||
micros(result),
|
||||
]
|
||||
)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
@ -1,11 +1,13 @@
|
||||
import pandas
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas
|
||||
|
||||
df = pandas.read_csv("perf.csv")
|
||||
|
||||
ops = pandas.unique(df["operator"])
|
||||
nops = len(ops)
|
||||
pivot_op_shape = df.pivot_table(values="time", index=["operator", "shape"], columns=["fuser"])
|
||||
pivot_op_shape = df.pivot_table(
|
||||
values="time", index=["operator", "shape"], columns=["fuser"]
|
||||
)
|
||||
pivot_speedups = (pivot_op_shape.T / pivot_op_shape["eager"]).T
|
||||
|
||||
plt.rcParams["figure.figsize"] = (20, 100)
|
||||
|
@ -1,31 +1,31 @@
|
||||
from torch._functorch.python_key import pythonkey_decompose
|
||||
from torch._functorch.fx_minifier import minifier
|
||||
from torch._functorch import config
|
||||
from torch._functorch.aot_autograd import (
|
||||
aot_function,
|
||||
aot_module,
|
||||
aot_module_simplified,
|
||||
compiled_function,
|
||||
compiled_module,
|
||||
aot_module_simplified,
|
||||
get_graph_being_compiled,
|
||||
get_aot_graph_name,
|
||||
get_aot_compilation_context,
|
||||
get_aot_graph_name,
|
||||
get_graph_being_compiled,
|
||||
make_boxed_compiler,
|
||||
make_boxed_func,
|
||||
make_boxed_compiler
|
||||
)
|
||||
from torch._functorch.compilers import (
|
||||
ts_compile,
|
||||
draw_graph_compile,
|
||||
nop,
|
||||
nnc_jit,
|
||||
memory_efficient_fusion,
|
||||
debug_compile,
|
||||
default_decompositions,
|
||||
draw_graph_compile,
|
||||
memory_efficient_fusion,
|
||||
nnc_jit,
|
||||
nop,
|
||||
print_compile,
|
||||
default_decompositions
|
||||
ts_compile,
|
||||
)
|
||||
from torch._functorch.fx_minifier import minifier
|
||||
from torch._functorch.partitioners import (
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
draw_graph,
|
||||
draw_joint_graph,
|
||||
min_cut_rematerialization_partition,
|
||||
)
|
||||
from torch._functorch import config
|
||||
from torch._functorch.python_key import pythonkey_decompose
|
||||
|
@ -1,20 +1,26 @@
|
||||
import torch
|
||||
from typing import Union, Sequence
|
||||
import inspect
|
||||
import dis
|
||||
from .tree_map import tree_flatten, tree_map
|
||||
from .wrap_type import wrap_type
|
||||
import inspect
|
||||
from typing import Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
import functorch._C
|
||||
from functorch._C import dim as _C
|
||||
from .tree_map import tree_flatten, tree_map
|
||||
from .wrap_type import wrap_type
|
||||
|
||||
_C._patch_tensor_class()
|
||||
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
|
||||
|
||||
|
||||
class DimensionMismatchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DimensionBindError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
from . import op_properties
|
||||
|
||||
# use dict to avoid writing C++ bindings for set
|
||||
@ -24,11 +30,11 @@ use_c = True
|
||||
if not use_c:
|
||||
from . import reference
|
||||
|
||||
|
||||
class _Tensor:
|
||||
# fast path around slow wrapping/unwrapping logic for simply queries used
|
||||
# by the implementation...
|
||||
|
||||
|
||||
@property
|
||||
def dims(self):
|
||||
return tuple(d for d in self._levels if isinstance(d, Dim))
|
||||
@ -47,11 +53,12 @@ class _Tensor:
|
||||
|
||||
def __repr__(self):
|
||||
tensor, levels, ndim = self._tensor, self._levels, self.ndim
|
||||
return f'{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}'
|
||||
return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
|
||||
|
||||
|
||||
TensorLike = (_Tensor, torch.Tensor)
|
||||
|
||||
|
||||
class Dim(_C.Dim, _Tensor):
|
||||
# note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
|
||||
# Tensor defines format, but we want to print Dims with special formatting
|
||||
@ -69,6 +76,7 @@ def cat(tensors, dim, new_dim):
|
||||
n = dims()
|
||||
return stack(tensors, n, dim).index([n, dim], new_dim)
|
||||
|
||||
|
||||
if use_c:
|
||||
_wrap = _C._wrap
|
||||
|
||||
@ -107,41 +115,41 @@ if use_c:
|
||||
else:
|
||||
_Tensor.order = reference.positional
|
||||
|
||||
_def('mean')
|
||||
_def('sum')
|
||||
_def('all')
|
||||
_def('amax')
|
||||
_def('amin')
|
||||
_def('aminmax')
|
||||
_def('any')
|
||||
_def('count_nonzero')
|
||||
_def('logsumexp')
|
||||
_def('nanmean')
|
||||
_def('nansum')
|
||||
_def('prod')
|
||||
_def('std', keepdim_offset=2)
|
||||
_def('var', keepdim_offset=2)
|
||||
_def('max', single_dim=True)
|
||||
_def('min', single_dim=True)
|
||||
_def('argmax', single_dim=True)
|
||||
_def('argmin', single_dim=True)
|
||||
_def('kthvalue', single_dim=True)
|
||||
_def('median', single_dim=True)
|
||||
_def('nanmedian', single_dim=True)
|
||||
_def('mode', single_dim=True)
|
||||
_def('sort', reduce=False)
|
||||
_def('argsort', reduce=False)
|
||||
_def('unbind', single_dim=True)
|
||||
_def('chunk', dim_offset=1, reduce=False)
|
||||
_def('cummax', single_dim=True, reduce=False)
|
||||
_def('cummin', single_dim=True, reduce=False)
|
||||
_def('cumprod', single_dim=True, reduce=False)
|
||||
_def('cumprod_', single_dim=True, reduce=False)
|
||||
_def('cumsum', single_dim=True, reduce=False)
|
||||
_def('cumsum_', single_dim=True, reduce=False)
|
||||
_def('logcumsumexp', single_dim=True, reduce=False)
|
||||
_def('renorm', dim_offset=1, single_dim=True, reduce=False)
|
||||
_def('softmax', single_dim=True, reduce=False)
|
||||
_def("mean")
|
||||
_def("sum")
|
||||
_def("all")
|
||||
_def("amax")
|
||||
_def("amin")
|
||||
_def("aminmax")
|
||||
_def("any")
|
||||
_def("count_nonzero")
|
||||
_def("logsumexp")
|
||||
_def("nanmean")
|
||||
_def("nansum")
|
||||
_def("prod")
|
||||
_def("std", keepdim_offset=2)
|
||||
_def("var", keepdim_offset=2)
|
||||
_def("max", single_dim=True)
|
||||
_def("min", single_dim=True)
|
||||
_def("argmax", single_dim=True)
|
||||
_def("argmin", single_dim=True)
|
||||
_def("kthvalue", single_dim=True)
|
||||
_def("median", single_dim=True)
|
||||
_def("nanmedian", single_dim=True)
|
||||
_def("mode", single_dim=True)
|
||||
_def("sort", reduce=False)
|
||||
_def("argsort", reduce=False)
|
||||
_def("unbind", single_dim=True)
|
||||
_def("chunk", dim_offset=1, reduce=False)
|
||||
_def("cummax", single_dim=True, reduce=False)
|
||||
_def("cummin", single_dim=True, reduce=False)
|
||||
_def("cumprod", single_dim=True, reduce=False)
|
||||
_def("cumprod_", single_dim=True, reduce=False)
|
||||
_def("cumsum", single_dim=True, reduce=False)
|
||||
_def("cumsum_", single_dim=True, reduce=False)
|
||||
_def("logcumsumexp", single_dim=True, reduce=False)
|
||||
_def("renorm", dim_offset=1, single_dim=True, reduce=False)
|
||||
_def("softmax", single_dim=True, reduce=False)
|
||||
softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
|
||||
|
||||
# stuff to handle in the future, because they require special
|
||||
|
@ -3,14 +3,13 @@
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from torch._C._functorch import (
|
||||
_vmap_add_layers,
|
||||
_vmap_remove_layers,
|
||||
)
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
|
||||
|
||||
_enabled = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _enable_layers(dims):
|
||||
global _enabled
|
||||
|
@ -4,9 +4,11 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
|
||||
from . import _Tensor, Tensor
|
||||
from .reference import _dims, _enable_layers, llist, ltuple
|
||||
|
||||
|
||||
class DelayedMulTensor(_Tensor):
|
||||
def __init__(self, lhs, rhs):
|
||||
self._lhs, self._rhs = lhs, rhs
|
||||
@ -37,7 +39,9 @@ class DelayedMulTensor(_Tensor):
|
||||
@property
|
||||
def _tensor(self):
|
||||
if self._tensor_data is None:
|
||||
self._tensor_data = Tensor.from_batched(self._batchtensor, self._has_device)._tensor
|
||||
self._tensor_data = Tensor.from_batched(
|
||||
self._batchtensor, self._has_device
|
||||
)._tensor
|
||||
return self._tensor_data
|
||||
|
||||
@property
|
||||
@ -48,20 +52,26 @@ class DelayedMulTensor(_Tensor):
|
||||
def dims(self):
|
||||
return ltuple(super().dims)
|
||||
|
||||
|
||||
def sum(self, dim):
|
||||
dims = _dims(dim, 0, False, False)
|
||||
n = ord('a')
|
||||
n = ord("a")
|
||||
all_levels = self._levels
|
||||
|
||||
def to_char(d):
|
||||
return chr(n + all_levels.index(d))
|
||||
|
||||
plhs, levelslhs = self._lhs._tensor, self._lhs._levels
|
||||
prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
|
||||
new_dims = tuple(d for d in self.dims if d not in dims)
|
||||
new_levels = [l for l in self._levels if l not in dims]
|
||||
fmt = ''.join([*(to_char(d) for d in levelslhs), ',',
|
||||
*(to_char(d) for d in levelsrhs), '->',
|
||||
*(to_char(d) for d in new_levels)])
|
||||
fmt = "".join(
|
||||
[
|
||||
*(to_char(d) for d in levelslhs),
|
||||
",",
|
||||
*(to_char(d) for d in levelsrhs),
|
||||
"->",
|
||||
*(to_char(d) for d in new_levels),
|
||||
]
|
||||
)
|
||||
result_data = torch.einsum(fmt, (plhs, prhs))
|
||||
return Tensor.from_positional(result_data, new_levels, True)
|
||||
|
@ -4,11 +4,14 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
_vmap_levels = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class LevelInfo:
|
||||
level: int
|
||||
alive: bool = True
|
||||
|
||||
|
||||
class Dim:
|
||||
def __init__(self, name: str, size: Union[None, int] = None):
|
||||
self.name = name
|
||||
@ -20,7 +23,9 @@ class Dim:
|
||||
def __del__(self):
|
||||
if self._vmap_level is not None:
|
||||
_vmap_active_levels[self._vmap_stack].alive = False
|
||||
while not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level:
|
||||
while (
|
||||
not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level
|
||||
):
|
||||
_vmap_decrement_nesting()
|
||||
_vmap_levels.pop()
|
||||
|
||||
@ -33,13 +38,14 @@ class Dim:
|
||||
def size(self, size: int):
|
||||
if self._size is None:
|
||||
self._size = size
|
||||
self._vmap_level = _vmap_increment_nesting(size, 'same')
|
||||
self._vmap_level = _vmap_increment_nesting(size, "same")
|
||||
self._vmap_stack = len(_vmap_levels)
|
||||
_vmap_levels.append(LevelInfo(self._vmap_level))
|
||||
|
||||
elif self._size != size:
|
||||
raise DimensionBindError(
|
||||
f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}")
|
||||
f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_bound(self):
|
||||
@ -50,10 +56,13 @@ class Dim:
|
||||
|
||||
|
||||
def extract_name(inst):
|
||||
assert inst.opname == 'STORE_FAST' or inst.opname == 'STORE_NAME'
|
||||
assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
|
||||
return inst.argval
|
||||
|
||||
|
||||
_cache = {}
|
||||
|
||||
|
||||
def dims(lists=0):
|
||||
frame = inspect.currentframe()
|
||||
assert frame is not None
|
||||
@ -66,17 +75,22 @@ def dims(lists=0):
|
||||
instructions = list(dis.get_instructions(calling_frame.f_code))
|
||||
unpack = instructions[first]
|
||||
|
||||
if unpack.opname == 'STORE_FAST' or unpack.opname == 'STORE_NAME':
|
||||
if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
|
||||
# just a single dim, not a list
|
||||
name = unpack.argval
|
||||
ctor = Dim if lists == 0 else DimList
|
||||
_cache[key] = lambda: ctor(name=name)
|
||||
else:
|
||||
assert unpack.opname == 'UNPACK_SEQUENCE'
|
||||
assert unpack.opname == "UNPACK_SEQUENCE"
|
||||
ndims = unpack.argval
|
||||
names = tuple(extract_name(instructions[first + 1 + i]) for i in range(ndims))
|
||||
names = tuple(
|
||||
extract_name(instructions[first + 1 + i]) for i in range(ndims)
|
||||
)
|
||||
first_list = len(names) - lists
|
||||
_cache[key] = lambda: tuple(Dim(n) if i < first_list else DimList(name=n) for i, n in enumerate(names))
|
||||
_cache[key] = lambda: tuple(
|
||||
Dim(n) if i < first_list else DimList(name=n)
|
||||
for i, n in enumerate(names)
|
||||
)
|
||||
return _cache[key]()
|
||||
|
||||
|
||||
@ -87,6 +101,7 @@ def _dim_set(positional, arg):
|
||||
else:
|
||||
assert isinstance(a, int)
|
||||
return positional[a]
|
||||
|
||||
if arg is None:
|
||||
return positional
|
||||
elif not isinstance(arg, (Dim, int)):
|
||||
|
@ -3,25 +3,33 @@
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
import subprocess
|
||||
import signal
|
||||
import subprocess
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def magic_trace(output='trace.fxt', magic_trace_cache='/tmp/magic-trace'):
|
||||
def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
|
||||
pid = os.getpid()
|
||||
if not os.path.exists(magic_trace_cache):
|
||||
print(f"Downloading magic_trace to: {magic_trace_cache}")
|
||||
subprocess.run(['wget', '-O', magic_trace_cache, '-q',
|
||||
'https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace'])
|
||||
subprocess.run(['chmod', '+x', magic_trace_cache])
|
||||
args = [magic_trace_cache, 'attach', '-pid', str(pid), '-o', output]
|
||||
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding='utf-8')
|
||||
subprocess.run(
|
||||
[
|
||||
"wget",
|
||||
"-O",
|
||||
magic_trace_cache,
|
||||
"-q",
|
||||
"https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace",
|
||||
]
|
||||
)
|
||||
subprocess.run(["chmod", "+x", magic_trace_cache])
|
||||
args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
|
||||
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
|
||||
while True:
|
||||
x = p.stderr.readline()
|
||||
print(x)
|
||||
if 'Attached' in x:
|
||||
if "Attached" in x:
|
||||
break
|
||||
try:
|
||||
yield
|
||||
@ -31,4 +39,4 @@ def magic_trace(output='trace.fxt', magic_trace_cache='/tmp/magic-trace'):
|
||||
print(p.stderr.read())
|
||||
p.stderr.close()
|
||||
if r != 0:
|
||||
raise ValueError(f'magic_trace exited abnormally: {r}')
|
||||
raise ValueError(f"magic_trace exited abnormally: {r}")
|
||||
|
@ -4,29 +4,58 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
|
||||
# pointwise operators can go through a faster pathway
|
||||
|
||||
tensor_magic_methods = [
|
||||
'add',
|
||||
''
|
||||
]
|
||||
tensor_magic_methods = ["add", ""]
|
||||
pointwise_magic_methods_with_reverse = (
|
||||
'add', 'sub', 'mul', 'floordiv', 'div', 'truediv', 'mod',
|
||||
'pow', 'lshift', 'rshift', 'and', 'or', 'xor'
|
||||
"add",
|
||||
"sub",
|
||||
"mul",
|
||||
"floordiv",
|
||||
"div",
|
||||
"truediv",
|
||||
"mod",
|
||||
"pow",
|
||||
"lshift",
|
||||
"rshift",
|
||||
"and",
|
||||
"or",
|
||||
"xor",
|
||||
)
|
||||
pointwise_magic_methods = (
|
||||
*(x for m in pointwise_magic_methods_with_reverse for x in (m, 'r' + m)),
|
||||
'eq', 'gt', 'le', 'lt', 'ge', 'gt', 'ne', 'neg', 'pos',
|
||||
'abs', 'invert',
|
||||
'iadd', 'isub', 'imul', 'ifloordiv', 'idiv',
|
||||
'itruediv', 'imod', 'ipow', 'ilshift', 'irshift', 'iand',
|
||||
'ior', 'ixor',
|
||||
'int', 'long', 'float', 'complex',
|
||||
*(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)),
|
||||
"eq",
|
||||
"gt",
|
||||
"le",
|
||||
"lt",
|
||||
"ge",
|
||||
"gt",
|
||||
"ne",
|
||||
"neg",
|
||||
"pos",
|
||||
"abs",
|
||||
"invert",
|
||||
"iadd",
|
||||
"isub",
|
||||
"imul",
|
||||
"ifloordiv",
|
||||
"idiv",
|
||||
"itruediv",
|
||||
"imod",
|
||||
"ipow",
|
||||
"ilshift",
|
||||
"irshift",
|
||||
"iand",
|
||||
"ior",
|
||||
"ixor",
|
||||
"int",
|
||||
"long",
|
||||
"float",
|
||||
"complex",
|
||||
)
|
||||
|
||||
pointwise_methods = (
|
||||
*(f'__{m}__' for m in pointwise_magic_methods),
|
||||
)
|
||||
pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),)
|
||||
|
||||
pointwise = (
|
||||
*(getattr(torch.Tensor, m) for m in pointwise_methods),
|
||||
|
@ -6,23 +6,28 @@
|
||||
|
||||
# reference python implementations for C ops
|
||||
import torch
|
||||
from .tree_map import tree_flatten, tree_map
|
||||
from .batch_tensor import _enable_layers
|
||||
from . import op_properties
|
||||
|
||||
from functorch._C import dim as _C
|
||||
from . import op_properties
|
||||
from .batch_tensor import _enable_layers
|
||||
from .tree_map import tree_flatten, tree_map
|
||||
|
||||
DimList = _C.DimList
|
||||
from functools import reduce
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
|
||||
# use dict to avoid writing C++ bindings for set
|
||||
pointwise = set(op_properties.pointwise)
|
||||
|
||||
|
||||
def prod(x):
|
||||
return reduce(operator.mul, x, 1)
|
||||
|
||||
|
||||
def _wrap_dim(d, N, keepdim):
|
||||
from . import Dim
|
||||
|
||||
if isinstance(d, Dim):
|
||||
assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
|
||||
return d
|
||||
@ -31,40 +36,52 @@ def _wrap_dim(d, N, keepdim):
|
||||
else:
|
||||
return d
|
||||
|
||||
|
||||
def _dims(d, N, keepdim, single_dim):
|
||||
from . import Dim
|
||||
|
||||
if isinstance(d, (Dim, int)):
|
||||
return ltuple((_wrap_dim(d, N, keepdim),))
|
||||
assert not single_dim, f"expected a single dimension or int but found: {d}"
|
||||
return ltuple(_wrap_dim(x, N, keepdim) for x in d)
|
||||
|
||||
|
||||
def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
|
||||
from . import DimensionMismatchError
|
||||
|
||||
not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
|
||||
if len(not_bound) == 1:
|
||||
idx, d = not_bound[0]
|
||||
rhs_so_far = prod(r.size for r in rhs if r.is_bound)
|
||||
if lhs_size % rhs_so_far != 0:
|
||||
rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
|
||||
raise DimensionMismatchError(f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}")
|
||||
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
|
||||
raise DimensionMismatchError(
|
||||
f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
|
||||
)
|
||||
new_size = lhs_size // rhs_so_far
|
||||
d.size = new_size
|
||||
elif len(not_bound) > 1:
|
||||
rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
|
||||
raise DimensionMismatchError(f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}")
|
||||
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
|
||||
raise DimensionMismatchError(
|
||||
f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
|
||||
)
|
||||
else:
|
||||
rhs_size = prod(r.size for r in rhs)
|
||||
if lhs_size != rhs_size:
|
||||
raise DimensionMismatchError(
|
||||
f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}")
|
||||
f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
|
||||
)
|
||||
|
||||
|
||||
def _tensor_levels(inp):
|
||||
from . import _Tensor
|
||||
|
||||
if isinstance(inp, _Tensor):
|
||||
return inp._tensor, llist(inp._levels), inp._has_device
|
||||
else:
|
||||
return inp, llist(range(-inp.ndim, 0)), True
|
||||
|
||||
|
||||
def _match_levels(v, from_levels, to_levels):
|
||||
view = []
|
||||
permute = []
|
||||
@ -90,6 +107,7 @@ def _match_levels(v, from_levels, to_levels):
|
||||
# should not physically move if possible
|
||||
def _positional_no_permute(self, dim, expand_dim=False):
|
||||
from . import Tensor
|
||||
|
||||
ptensor, levels = self._tensor, llist(self._levels)
|
||||
try:
|
||||
idx = levels.index(dim)
|
||||
@ -107,8 +125,10 @@ def _positional_no_permute(self, dim, expand_dim=False):
|
||||
levels[idx] = -idx_batched - 1
|
||||
return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
|
||||
|
||||
|
||||
def seq(a, b):
|
||||
from . import Dim
|
||||
|
||||
if isinstance(a, Dim) != isinstance(b, Dim):
|
||||
return False
|
||||
if isinstance(a, Dim):
|
||||
@ -116,6 +136,7 @@ def seq(a, b):
|
||||
else:
|
||||
return a == b
|
||||
|
||||
|
||||
class isin:
|
||||
def __contains__(self, item):
|
||||
for x in self:
|
||||
@ -133,18 +154,27 @@ class isin:
|
||||
class llist(isin, list):
|
||||
pass
|
||||
|
||||
|
||||
class ltuple(isin, tuple):
|
||||
pass
|
||||
|
||||
|
||||
empty_dict = {}
|
||||
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
|
||||
from . import _Tensor, TensorLike, Tensor
|
||||
from . import _Tensor, Tensor, TensorLike
|
||||
from .delayed_mul_tensor import DelayedMulTensor
|
||||
|
||||
if orig is torch.Tensor.__mul__:
|
||||
lhs, rhs = args
|
||||
if isinstance(lhs, _Tensor) and isinstance(rhs, _Tensor) and lhs.ndim == 0 and rhs.ndim == 0:
|
||||
if (
|
||||
isinstance(lhs, _Tensor)
|
||||
and isinstance(rhs, _Tensor)
|
||||
and lhs.ndim == 0
|
||||
and rhs.ndim == 0
|
||||
):
|
||||
return DelayedMulTensor(lhs, rhs)
|
||||
all_dims = llist()
|
||||
flat_args, unflatten = tree_flatten((args, kwargs))
|
||||
@ -172,7 +202,11 @@ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
|
||||
for i, f in enumerate(flat_args):
|
||||
if isinstance(f, TensorLike):
|
||||
ptensor, levels, _ = _tensor_levels(f)
|
||||
if isinstance(f, _Tensor) and not f._has_device and device_holding_tensor is not None:
|
||||
if (
|
||||
isinstance(f, _Tensor)
|
||||
and not f._has_device
|
||||
and device_holding_tensor is not None
|
||||
):
|
||||
ptensor = ptensor.to(device=device_holding_tensor.device)
|
||||
flat_args[i] = ptensor
|
||||
for l in levels:
|
||||
@ -187,14 +221,19 @@ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
|
||||
|
||||
def wrap(t):
|
||||
if isinstance(t, TensorLike):
|
||||
return Tensor.from_positional(t, result_levels, device_holding_tensor is not None)
|
||||
return Tensor.from_positional(
|
||||
t, result_levels, device_holding_tensor is not None
|
||||
)
|
||||
return t
|
||||
|
||||
return tree_map(wrap, result)
|
||||
else:
|
||||
|
||||
def wrap(t):
|
||||
if isinstance(t, TensorLike):
|
||||
return Tensor.from_batched(t, device_holding_tensor is not None)
|
||||
return t
|
||||
|
||||
with _enable_layers(all_dims):
|
||||
print(f"batch_tensor for {orig}")
|
||||
args, kwargs = unflatten(unwrap(f) for f in flat_args)
|
||||
@ -202,8 +241,10 @@ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
|
||||
# print("END", orig)
|
||||
return tree_map(wrap, result)
|
||||
|
||||
|
||||
def positional(self, *dims):
|
||||
from . import Dim, Tensor
|
||||
|
||||
ptensor, levels = self._tensor, llist(self._levels)
|
||||
flat_dims = llist()
|
||||
view = []
|
||||
@ -231,7 +272,9 @@ def positional(self, *dims):
|
||||
try:
|
||||
idx = levels.index(d)
|
||||
except ValueError as e:
|
||||
raise DimensionBindError(f'tensor of dimensions {self.dims} does not contain dim {d}') from e
|
||||
raise DimensionBindError(
|
||||
f"tensor of dimensions {self.dims} does not contain dim {d}"
|
||||
) from e
|
||||
p = permute[idx]
|
||||
del levels[idx]
|
||||
del permute[idx]
|
||||
@ -245,15 +288,18 @@ def positional(self, *dims):
|
||||
levels[i] = -seen
|
||||
result = Tensor.from_positional(ptensor, levels, self._has_device)
|
||||
if needs_view:
|
||||
result = result.reshape(*view, *result.size()[len(flat_dims):])
|
||||
result = result.reshape(*view, *result.size()[len(flat_dims) :])
|
||||
return result
|
||||
|
||||
|
||||
def _contains_dim(input):
|
||||
from . import Dim
|
||||
|
||||
for i in input:
|
||||
if isinstance(i, Dim):
|
||||
return True
|
||||
|
||||
|
||||
def expand(self, *sizes):
|
||||
if not _contains_dim(sizes):
|
||||
return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
|
||||
@ -265,27 +311,36 @@ def expand(self, *sizes):
|
||||
|
||||
_not_present = object()
|
||||
|
||||
|
||||
def _getarg(name, offset, args, kwargs, default):
|
||||
if len(args) > offset:
|
||||
return args[offset]
|
||||
return kwargs.get(name, default)
|
||||
|
||||
|
||||
def _patcharg(name, offset, args, kwargs, value):
|
||||
if len(args) > offset:
|
||||
args[offset] = value
|
||||
else:
|
||||
kwargs[name] = value
|
||||
|
||||
def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False, reduce=True):
|
||||
from . import TensorLike, Dim, Tensor
|
||||
|
||||
def _wrap(
|
||||
orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
|
||||
):
|
||||
from . import Dim, Tensor, TensorLike
|
||||
|
||||
def fn(self, *args, **kwargs):
|
||||
dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
|
||||
if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
|
||||
with _enable_layers(self.dims):
|
||||
print(f"dim fallback batch_tensor for {orig}")
|
||||
return Tensor.from_batched(orig(self._batchtensor, *args, **kwargs), self._has_device)
|
||||
keepdim = _getarg('keepdim', keepdim_offset, args, kwargs, False) if reduce else False
|
||||
return Tensor.from_batched(
|
||||
orig(self._batchtensor, *args, **kwargs), self._has_device
|
||||
)
|
||||
keepdim = (
|
||||
_getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
|
||||
)
|
||||
t, levels = self._tensor, llist(self._levels)
|
||||
dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
|
||||
dim_indices = tuple(levels.index(d) for d in dims)
|
||||
@ -295,7 +350,9 @@ def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False
|
||||
new_levels = levels
|
||||
|
||||
if len(dim_indices) == 1:
|
||||
dim_indices = dim_indices[0] # so that dims that really only take a single argument work...
|
||||
dim_indices = dim_indices[
|
||||
0
|
||||
] # so that dims that really only take a single argument work...
|
||||
args = list(args)
|
||||
_patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
|
||||
|
||||
@ -303,21 +360,27 @@ def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False
|
||||
if isinstance(t, TensorLike):
|
||||
return Tensor.from_positional(t, new_levels, self._has_device)
|
||||
return t
|
||||
|
||||
with _enable_layers(new_levels):
|
||||
print(f"dim used batch_tensor for {orig}")
|
||||
r = orig(t, *args, **kwargs)
|
||||
return tree_map(wrap, r)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def _def(name, *args, **kwargs):
|
||||
from . import _Tensor
|
||||
|
||||
orig = getattr(torch.Tensor, name)
|
||||
setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
|
||||
|
||||
|
||||
no_slice = slice(None)
|
||||
|
||||
_orig_getitem = torch.Tensor.__getitem__
|
||||
|
||||
|
||||
class dim_tracker:
|
||||
def __init__(self):
|
||||
self.dims = llist()
|
||||
@ -331,8 +394,10 @@ class dim_tracker:
|
||||
def __getitem__(self, d):
|
||||
return self.count[self.dims.index(d)]
|
||||
|
||||
|
||||
def t__getitem__(self, input):
|
||||
from . import Dim, DimensionBindError, _Tensor, TensorLike, DimList, Tensor
|
||||
from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
|
||||
|
||||
# * bail to original example if we have a single non-Dim tensor, or a non-tensor
|
||||
# * locate ... or an unbound tensor list, and determine its size, bind dim list
|
||||
# (remember that None does not count to the total dim count)
|
||||
@ -345,10 +410,13 @@ def t__getitem__(self, input):
|
||||
|
||||
# this handles bool indexing handling, as well as some other simple cases.
|
||||
|
||||
is_simple = (not isinstance(input, Dim) and
|
||||
not isinstance(input, (tuple, list)) and
|
||||
# WAR for functorch bug where zero time tensors in getitem are not handled correctly.
|
||||
not (isinstance(input, TensorLike) and input.ndim == 0))
|
||||
is_simple = (
|
||||
not isinstance(input, Dim)
|
||||
and not isinstance(input, (tuple, list))
|
||||
and
|
||||
# WAR for functorch bug where zero time tensors in getitem are not handled correctly.
|
||||
not (isinstance(input, TensorLike) and input.ndim == 0)
|
||||
)
|
||||
|
||||
if is_simple:
|
||||
if isinstance(self, _Tensor):
|
||||
@ -368,8 +436,10 @@ def t__getitem__(self, input):
|
||||
for i, s in enumerate(input):
|
||||
if s is ... or isinstance(s, DimList) and not s.is_bound:
|
||||
if expanding_object is not None:
|
||||
msg = 'at most one ... or unbound dimension list can exist in indexing list but' \
|
||||
f' found 2 at offsets {i} and {expanding_object}'
|
||||
msg = (
|
||||
"at most one ... or unbound dimension list can exist in indexing list but"
|
||||
f" found 2 at offsets {i} and {expanding_object}"
|
||||
)
|
||||
raise DimensionBindError(msg)
|
||||
expanding_object = i
|
||||
|
||||
@ -381,17 +451,21 @@ def t__getitem__(self, input):
|
||||
|
||||
ndim = self.ndim
|
||||
if dims_indexed > ndim:
|
||||
raise IndexError(f'at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions.')
|
||||
raise IndexError(
|
||||
f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
|
||||
)
|
||||
if expanding_object is not None:
|
||||
expanding_ndims = ndim - dims_indexed
|
||||
obj = input[expanding_object]
|
||||
if obj is ...:
|
||||
input[expanding_object:expanding_object + 1] = [no_slice] * expanding_ndims
|
||||
input[expanding_object : expanding_object + 1] = [
|
||||
no_slice
|
||||
] * expanding_ndims
|
||||
else:
|
||||
obj.bind_len(expanding_ndims)
|
||||
# flatten the dimslists into the indexing
|
||||
for i in reversed(dimlists):
|
||||
input[i:i + 1] = input[i]
|
||||
input[i : i + 1] = input[i]
|
||||
dims_indexed = 0
|
||||
requires_view = False
|
||||
size = self.size()
|
||||
@ -420,7 +494,7 @@ def t__getitem__(self, input):
|
||||
elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
|
||||
for d in idx:
|
||||
dims_seen.record(idx)
|
||||
_bind_dims_to_size(sz, idx, f'offset {i}')
|
||||
_bind_dims_to_size(sz, idx, f"offset {i}")
|
||||
view_sizes.extend(d.size for d in idx)
|
||||
requires_view = True
|
||||
dim_packs.append(i)
|
||||
@ -431,7 +505,7 @@ def t__getitem__(self, input):
|
||||
if requires_view:
|
||||
self = self.view(*view_sizes)
|
||||
for i in reversed(dim_packs):
|
||||
input[i:i + 1] = input[i]
|
||||
input[i : i + 1] = input[i]
|
||||
|
||||
# currenty:
|
||||
# input is flat, containing either Dim, or Tensor, or something valid for standard indexing
|
||||
@ -499,6 +573,7 @@ def t__getitem__(self, input):
|
||||
|
||||
return Tensor.from_positional(result, result_levels, has_device)
|
||||
|
||||
|
||||
# XXX - dim is optional and can be the outer-most dimension...
|
||||
def stack(tensors, new_dim, dim=0, out=None):
|
||||
if isinstance(dim, int):
|
||||
@ -517,12 +592,20 @@ def stack(tensors, new_dim, dim=0, out=None):
|
||||
pr = torch.stack(ptensors, index, out=out)
|
||||
return pr.index((index, index + 1), (new_dim, dim))
|
||||
|
||||
|
||||
_orig_split = torch.Tensor.split
|
||||
|
||||
|
||||
def split(self, split_size_or_sections, dim=0):
|
||||
from . import Dim, _Tensor
|
||||
if isinstance(split_size_or_sections, int) or any(isinstance(t, int) for t in split_size_or_sections):
|
||||
from . import _Tensor, Dim
|
||||
|
||||
if isinstance(split_size_or_sections, int) or any(
|
||||
isinstance(t, int) for t in split_size_or_sections
|
||||
):
|
||||
if isinstance(dim, Dim):
|
||||
raise ValueError('when dim is specified as a Dim object, split sizes must also be dimensions.')
|
||||
raise ValueError(
|
||||
"when dim is specified as a Dim object, split sizes must also be dimensions."
|
||||
)
|
||||
return _orig_split(self, split_size_or_sections, dim=dim)
|
||||
|
||||
if isinstance(dim, Dim):
|
||||
@ -542,8 +625,9 @@ def split(self, split_size_or_sections, dim=0):
|
||||
unbound.append(i)
|
||||
|
||||
if unbound:
|
||||
assert total_bound_size <= size, \
|
||||
f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
||||
assert (
|
||||
total_bound_size <= size
|
||||
), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
||||
remaining_size = size - total_bound_size
|
||||
chunk_size = -(-remaining_size // len(unbound))
|
||||
for u in unbound:
|
||||
@ -552,6 +636,10 @@ def split(self, split_size_or_sections, dim=0):
|
||||
sizes[u] = sz
|
||||
remaining_size -= sz
|
||||
else:
|
||||
assert total_bound_size == size, \
|
||||
f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
||||
return tuple(t.index(dim, d) for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)))
|
||||
assert (
|
||||
total_bound_size == size
|
||||
), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
||||
return tuple(
|
||||
t.index(dim, d)
|
||||
for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
|
||||
)
|
||||
|
@ -5,8 +5,10 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functorch._C import dim
|
||||
|
||||
tree_flatten = dim.tree_flatten
|
||||
|
||||
|
||||
def tree_map(fn, tree):
|
||||
vs, unflatten = tree_flatten(tree)
|
||||
return unflatten(fn(v) for v in vs)
|
||||
|
@ -4,22 +4,35 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from types import FunctionType, BuiltinMethodType, MethodDescriptorType, WrapperDescriptorType, GetSetDescriptorType
|
||||
from types import (
|
||||
BuiltinMethodType,
|
||||
FunctionType,
|
||||
GetSetDescriptorType,
|
||||
MethodDescriptorType,
|
||||
WrapperDescriptorType,
|
||||
)
|
||||
|
||||
from functorch._C import dim as _C
|
||||
|
||||
_wrap_method = _C._wrap_method
|
||||
|
||||
FUNC_TYPES = (FunctionType, MethodDescriptorType, BuiltinMethodType, WrapperDescriptorType)
|
||||
FUNC_TYPES = (
|
||||
FunctionType,
|
||||
MethodDescriptorType,
|
||||
BuiltinMethodType,
|
||||
WrapperDescriptorType,
|
||||
)
|
||||
PROPERTY_TYPES = (GetSetDescriptorType, property)
|
||||
|
||||
|
||||
def _py_wrap_method(orig, __torch_function__):
|
||||
def impl(*args, **kwargs):
|
||||
return __torch_function__(orig, None, args, kwargs)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
|
||||
def wrap_type(use_c, to_patch, pattern, __torch_function__):
|
||||
|
||||
if use_c:
|
||||
wrap_method = _wrap_method
|
||||
else:
|
||||
@ -29,18 +42,27 @@ def wrap_type(use_c, to_patch, pattern, __torch_function__):
|
||||
for t in reversed(pattern.mro()[:-1]): # skip object
|
||||
all.update(t.__dict__)
|
||||
|
||||
|
||||
def wrap_attr(orig):
|
||||
return property(wrap_method(orig.__get__, __torch_function__))
|
||||
|
||||
|
||||
for name, obj in all.items():
|
||||
if name in ('__dict__', '__new__', '__init__', '__repr__', '__weakref__', '__doc__', '__module__', '__dir__'):
|
||||
if name in (
|
||||
"__dict__",
|
||||
"__new__",
|
||||
"__init__",
|
||||
"__repr__",
|
||||
"__weakref__",
|
||||
"__doc__",
|
||||
"__module__",
|
||||
"__dir__",
|
||||
):
|
||||
continue
|
||||
|
||||
# skip things that have been overloaded
|
||||
# things that come from object like `__eq__` still need to be patched, however.
|
||||
if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(object, name, None):
|
||||
if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(
|
||||
object, name, None
|
||||
):
|
||||
continue
|
||||
|
||||
if isinstance(obj, FUNC_TYPES):
|
||||
|
@ -14,18 +14,21 @@
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
import os
|
||||
|
||||
import functorch
|
||||
|
||||
# import sys
|
||||
|
||||
# source code directory, relative to this file, for sphinx-autobuild
|
||||
# sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
import torch
|
||||
import functorch
|
||||
|
||||
RELEASE = os.environ.get('RELEASE', False)
|
||||
RELEASE = os.environ.get("RELEASE", False)
|
||||
|
||||
import sys
|
||||
|
||||
import pytorch_sphinx_theme
|
||||
import sys
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
@ -35,18 +38,18 @@ import sys
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.doctest',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.todo',
|
||||
'sphinx.ext.coverage',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.doctest",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.todo",
|
||||
"sphinx.ext.coverage",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.viewcode",
|
||||
# 'sphinxcontrib.katex',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'sphinx_copybutton',
|
||||
'myst_nb',
|
||||
"sphinx.ext.autosectionlabel",
|
||||
"sphinx_copybutton",
|
||||
"myst_nb",
|
||||
]
|
||||
|
||||
# sys.path.insert(0, os.path.abspath('./notebooks'))
|
||||
@ -75,21 +78,21 @@ napoleon_use_ivar = True
|
||||
autosummary_generate = True
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
templates_path = ["_templates"]
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
# source_suffix = ['.rst', '.md']
|
||||
source_suffix = '.rst'
|
||||
source_suffix = ".rst"
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
master_doc = "index"
|
||||
|
||||
# General information about the project.
|
||||
project = 'functorch'
|
||||
copyright = 'PyTorch Contributors'
|
||||
author = 'PyTorch Contributors'
|
||||
project = "functorch"
|
||||
copyright = "PyTorch Contributors"
|
||||
author = "PyTorch Contributors"
|
||||
functorch_version = str(functorch.__version__)
|
||||
|
||||
# The version info for the project you're documenting, acts as replacement for
|
||||
@ -98,16 +101,16 @@ functorch_version = str(functorch.__version__)
|
||||
#
|
||||
# The short X.Y version.
|
||||
# TODO: change to [:2] at v1.0
|
||||
version = 'nightly (' + functorch_version + ')'
|
||||
version = "nightly (" + functorch_version + ")"
|
||||
# The full version, including alpha/beta/rc tags.
|
||||
# TODO: verify this works as expected
|
||||
release = 'nightly'
|
||||
release = "nightly"
|
||||
|
||||
# Customized html_title here.
|
||||
# Default is " ".join(project, release, "documentation") if not set
|
||||
# TODO: I don't know if this flag works, please check before using it
|
||||
if RELEASE:
|
||||
raise RuntimeError('NYI')
|
||||
raise RuntimeError("NYI")
|
||||
# remove hash (start with 'a') from version number if any
|
||||
# version_end = functorch_version.find('a')
|
||||
# if version_end == -1:
|
||||
@ -128,10 +131,10 @@ language = "en"
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This patterns also effect to html_static_path and html_extra_path
|
||||
exclude_patterns = ['notebooks/colab**', 'notebooks/_src/**']
|
||||
exclude_patterns = ["notebooks/colab**", "notebooks/_src/**"]
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = 'sphinx'
|
||||
pygments_style = "sphinx"
|
||||
|
||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||
todo_include_todos = True
|
||||
@ -140,7 +143,7 @@ todo_include_todos = True
|
||||
autodoc_inherit_docstrings = False
|
||||
|
||||
# Disable displaying type annotations, these can be very verbose
|
||||
autodoc_typehints = 'none'
|
||||
autodoc_typehints = "none"
|
||||
|
||||
# Enable overriding of function signatures in the first line of the docstring.
|
||||
autodoc_docstring_signature = True
|
||||
@ -159,7 +162,7 @@ autodoc_docstring_signature = True
|
||||
#
|
||||
#
|
||||
|
||||
html_theme = 'pytorch_sphinx_theme'
|
||||
html_theme = "pytorch_sphinx_theme"
|
||||
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
@ -178,10 +181,10 @@ html_theme_options = {
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
html_static_path = ["_static"]
|
||||
|
||||
html_css_files = [
|
||||
'css/custom.css',
|
||||
"css/custom.css",
|
||||
]
|
||||
|
||||
|
||||
@ -191,19 +194,20 @@ def setup(app):
|
||||
# and can be moved outside of this function (and the setup(app) function
|
||||
# can be deleted).
|
||||
html_css_files = [
|
||||
'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css'
|
||||
"https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css"
|
||||
]
|
||||
|
||||
# In Sphinx 1.8 it was renamed to `add_css_file`, 1.7 and prior it is
|
||||
# `add_stylesheet` (deprecated in 1.8).
|
||||
add_css = getattr(app, 'add_css_file', app.add_stylesheet)
|
||||
add_css = getattr(app, "add_css_file", app.add_stylesheet)
|
||||
for css_file in html_css_files:
|
||||
add_css(css_file)
|
||||
|
||||
|
||||
# -- Options for HTMLHelp output ------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'PyTorchdoc'
|
||||
htmlhelp_basename = "PyTorchdoc"
|
||||
|
||||
|
||||
# -- Options for LaTeX output ---------------------------------------------
|
||||
@ -212,15 +216,12 @@ latex_elements = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#
|
||||
# 'pointsize': '10pt',
|
||||
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# 'preamble': '',
|
||||
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# 'figure_align': 'htbp',
|
||||
@ -230,8 +231,13 @@ latex_elements = {
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, 'pytorch.tex', 'PyTorch Documentation',
|
||||
'Torch Contributors', 'manual'),
|
||||
(
|
||||
master_doc,
|
||||
"pytorch.tex",
|
||||
"PyTorch Documentation",
|
||||
"Torch Contributors",
|
||||
"manual",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -239,10 +245,7 @@ latex_documents = [
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
(master_doc, 'functorch', 'functorch Documentation',
|
||||
[author], 1)
|
||||
]
|
||||
man_pages = [(master_doc, "functorch", "functorch Documentation", [author], 1)]
|
||||
|
||||
|
||||
# -- Options for Texinfo output -------------------------------------------
|
||||
@ -251,37 +254,44 @@ man_pages = [
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, 'functorch', 'functorch Documentation',
|
||||
author, 'functorch', 'One line description of project.',
|
||||
'Miscellaneous'),
|
||||
(
|
||||
master_doc,
|
||||
"functorch",
|
||||
"functorch Documentation",
|
||||
author,
|
||||
"functorch",
|
||||
"One line description of project.",
|
||||
"Miscellaneous",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# Example configuration for intersphinx: refer to the Python standard library.
|
||||
intersphinx_mapping = {
|
||||
'python': ('https://docs.python.org/3', None),
|
||||
'numpy': ('https://numpy.org/doc/stable', None),
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"numpy": ("https://numpy.org/doc/stable", None),
|
||||
"torch": ("https://pytorch.org/docs/stable/", None),
|
||||
}
|
||||
|
||||
import sphinx.ext.doctest
|
||||
|
||||
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
|
||||
# See http://stackoverflow.com/a/41184353/3343043
|
||||
|
||||
from docutils import nodes
|
||||
from sphinx.util.docfields import TypedField
|
||||
from sphinx import addnodes
|
||||
import sphinx.ext.doctest
|
||||
from sphinx.util.docfields import TypedField
|
||||
|
||||
# Without this, doctest adds any example with a `>>>` as a test
|
||||
doctest_test_doctest_blocks = ''
|
||||
doctest_test_doctest_blocks = ""
|
||||
doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS
|
||||
doctest_global_setup = '''
|
||||
doctest_global_setup = """
|
||||
import torch
|
||||
try:
|
||||
import torchvision
|
||||
except ImportError:
|
||||
torchvision = None
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def patched_make_field(self, types, domain, items, **kw):
|
||||
@ -291,43 +301,51 @@ def patched_make_field(self, types, domain, items, **kw):
|
||||
# (List, unicode, Tuple) -> nodes.field
|
||||
def handle_item(fieldarg, content):
|
||||
par = nodes.paragraph()
|
||||
par += addnodes.literal_strong('', fieldarg) # Patch: this line added
|
||||
par += addnodes.literal_strong("", fieldarg) # Patch: this line added
|
||||
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
|
||||
# addnodes.literal_strong))
|
||||
if fieldarg in types:
|
||||
par += nodes.Text(' (')
|
||||
par += nodes.Text(" (")
|
||||
# NOTE: using .pop() here to prevent a single type node to be
|
||||
# inserted twice into the doctree, which leads to
|
||||
# inconsistencies later when references are resolved
|
||||
fieldtype = types.pop(fieldarg)
|
||||
if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
|
||||
typename = u''.join(n.astext() for n in fieldtype)
|
||||
typename = typename.replace('int', 'python:int')
|
||||
typename = typename.replace('long', 'python:long')
|
||||
typename = typename.replace('float', 'python:float')
|
||||
typename = typename.replace('bool', 'python:bool')
|
||||
typename = typename.replace('type', 'python:type')
|
||||
par.extend(self.make_xrefs(self.typerolename, domain, typename,
|
||||
addnodes.literal_emphasis, **kw))
|
||||
typename = "".join(n.astext() for n in fieldtype)
|
||||
typename = typename.replace("int", "python:int")
|
||||
typename = typename.replace("long", "python:long")
|
||||
typename = typename.replace("float", "python:float")
|
||||
typename = typename.replace("bool", "python:bool")
|
||||
typename = typename.replace("type", "python:type")
|
||||
par.extend(
|
||||
self.make_xrefs(
|
||||
self.typerolename,
|
||||
domain,
|
||||
typename,
|
||||
addnodes.literal_emphasis,
|
||||
**kw,
|
||||
)
|
||||
)
|
||||
else:
|
||||
par += fieldtype
|
||||
par += nodes.Text(')')
|
||||
par += nodes.Text(' -- ')
|
||||
par += nodes.Text(")")
|
||||
par += nodes.Text(" -- ")
|
||||
par += content
|
||||
return par
|
||||
|
||||
fieldname = nodes.field_name('', self.label)
|
||||
fieldname = nodes.field_name("", self.label)
|
||||
if len(items) == 1 and self.can_collapse:
|
||||
fieldarg, content = items[0]
|
||||
bodynode = handle_item(fieldarg, content)
|
||||
else:
|
||||
bodynode = self.list_type()
|
||||
for fieldarg, content in items:
|
||||
bodynode += nodes.list_item('', handle_item(fieldarg, content))
|
||||
fieldbody = nodes.field_body('', bodynode)
|
||||
return nodes.field('', fieldname, fieldbody)
|
||||
bodynode += nodes.list_item("", handle_item(fieldarg, content))
|
||||
fieldbody = nodes.field_body("", bodynode)
|
||||
return nodes.field("", fieldname, fieldbody)
|
||||
|
||||
|
||||
TypedField.make_field = patched_make_field
|
||||
|
||||
copybutton_prompt_text = r'>>> |\.\.\. '
|
||||
copybutton_prompt_text = r">>> |\.\.\. "
|
||||
copybutton_prompt_is_regexp = True
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .rearrange import rearrange
|
||||
|
||||
__all__ = ['rearrange']
|
||||
__all__ = ["rearrange"]
|
||||
|
@ -40,7 +40,9 @@ class AnonymousAxis:
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = int(value)
|
||||
if self.value < 1:
|
||||
raise ValueError(f'Anonymous axis should have positive length, not {self.value}')
|
||||
raise ValueError(
|
||||
f"Anonymous axis should have positive length, not {self.value}"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.value}-axis"
|
||||
@ -49,7 +51,13 @@ class AnonymousAxis:
|
||||
class ParsedExpression:
|
||||
"""Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)')."""
|
||||
|
||||
def __init__(self, expression: str, *, allow_underscore: bool = False, allow_duplicates: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
expression: str,
|
||||
*,
|
||||
allow_underscore: bool = False,
|
||||
allow_duplicates: bool = False,
|
||||
) -> None:
|
||||
"""Parse the expression and store relevant metadata.
|
||||
|
||||
Args:
|
||||
@ -66,10 +74,13 @@ class ParsedExpression:
|
||||
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
|
||||
if "." in expression:
|
||||
if "..." not in expression:
|
||||
raise ValueError("Expression may contain dots only inside ellipsis (...)")
|
||||
raise ValueError(
|
||||
"Expression may contain dots only inside ellipsis (...)"
|
||||
)
|
||||
if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
|
||||
raise ValueError(
|
||||
"Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ")
|
||||
"Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
|
||||
)
|
||||
expression = expression.replace("...", _ellipsis)
|
||||
self.has_ellipsis = True
|
||||
|
||||
@ -78,7 +89,9 @@ class ParsedExpression:
|
||||
def add_axis_name(x: str) -> None:
|
||||
if x in self.identifiers:
|
||||
if not (allow_underscore and x == "_") and not allow_duplicates:
|
||||
raise ValueError(f"Indexing expression contains duplicate dimension '{x}'")
|
||||
raise ValueError(
|
||||
f"Indexing expression contains duplicate dimension '{x}'"
|
||||
)
|
||||
if x == _ellipsis:
|
||||
self.identifiers.add(_ellipsis)
|
||||
if bracket_group is None:
|
||||
@ -96,10 +109,14 @@ class ParsedExpression:
|
||||
else:
|
||||
pass # no need to think about 1s inside parenthesis
|
||||
return
|
||||
is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore)
|
||||
is_axis_name, reason = self.check_axis_name_return_reason(
|
||||
x, allow_underscore=allow_underscore
|
||||
)
|
||||
if not (is_number or is_axis_name):
|
||||
raise ValueError(f"Invalid axis identifier: {x}\n{reason}")
|
||||
axis_name: Union[str, AnonymousAxis] = AnonymousAxis(x) if is_number else x
|
||||
axis_name: Union[str, AnonymousAxis] = (
|
||||
AnonymousAxis(x) if is_number else x
|
||||
)
|
||||
self.identifiers.add(axis_name)
|
||||
if is_number:
|
||||
self.has_non_unitary_anonymous_axes = True
|
||||
@ -116,7 +133,9 @@ class ParsedExpression:
|
||||
current_identifier = None
|
||||
if char == "(":
|
||||
if bracket_group is not None:
|
||||
raise ValueError("Axis composition is one-level (brackets inside brackets not allowed)")
|
||||
raise ValueError(
|
||||
"Axis composition is one-level (brackets inside brackets not allowed)"
|
||||
)
|
||||
bracket_group = []
|
||||
elif char == ")":
|
||||
if bracket_group is None:
|
||||
@ -137,7 +156,9 @@ class ParsedExpression:
|
||||
add_axis_name(current_identifier)
|
||||
|
||||
@staticmethod
|
||||
def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]:
|
||||
def check_axis_name_return_reason(
|
||||
name: str, allow_underscore: bool = False
|
||||
) -> Tuple[bool, str]:
|
||||
"""Check if the given axis name is valid, and a message explaining why if not.
|
||||
|
||||
Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
|
||||
@ -157,10 +178,14 @@ class ParsedExpression:
|
||||
return False, "axis name should should not start or end with underscore"
|
||||
else:
|
||||
if keyword.iskeyword(name):
|
||||
warnings.warn(f"It is discouraged to use axes names that are keywords: {name}", RuntimeWarning)
|
||||
warnings.warn(
|
||||
f"It is discouraged to use axes names that are keywords: {name}",
|
||||
RuntimeWarning,
|
||||
)
|
||||
if name in ["axis"]:
|
||||
warnings.warn(
|
||||
"It is discouraged to use 'axis' as an axis name and will raise an error in future", FutureWarning
|
||||
"It is discouraged to use 'axis' as an axis name and will raise an error in future",
|
||||
FutureWarning,
|
||||
)
|
||||
return True, ""
|
||||
|
||||
@ -178,8 +203,9 @@ class ParsedExpression:
|
||||
return is_valid
|
||||
|
||||
|
||||
|
||||
def parse_pattern(pattern: str, axes_lengths: Mapping[str, int]) -> Tuple[ParsedExpression, ParsedExpression]:
|
||||
def parse_pattern(
|
||||
pattern: str, axes_lengths: Mapping[str, int]
|
||||
) -> Tuple[ParsedExpression, ParsedExpression]:
|
||||
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
|
||||
|
||||
Args:
|
||||
@ -203,9 +229,13 @@ def parse_pattern(pattern: str, axes_lengths: Mapping[str, int]) -> Tuple[Parsed
|
||||
right = ParsedExpression(right_str)
|
||||
|
||||
if not left.has_ellipsis and right.has_ellipsis:
|
||||
raise ValueError(f"Ellipsis found in right side, but not left side of a pattern {pattern}")
|
||||
raise ValueError(
|
||||
f"Ellipsis found in right side, but not left side of a pattern {pattern}"
|
||||
)
|
||||
if left.has_ellipsis and left.has_ellipsis_parenthesized:
|
||||
raise ValueError(f"Ellipsis is parenthesis in the left side is not allowed: {pattern}")
|
||||
raise ValueError(
|
||||
f"Ellipsis is parenthesis in the left side is not allowed: {pattern}"
|
||||
)
|
||||
|
||||
return left, right
|
||||
|
||||
@ -222,18 +252,24 @@ def validate_rearrange_expressions(
|
||||
"""
|
||||
for length in axes_lengths.values():
|
||||
if (length_type := type(length)) is not int:
|
||||
raise TypeError(f"rearrange axis lengths must be integers, got: {length_type}")
|
||||
raise TypeError(
|
||||
f"rearrange axis lengths must be integers, got: {length_type}"
|
||||
)
|
||||
|
||||
if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes:
|
||||
raise ValueError("rearrange only supports unnamed axes of size 1")
|
||||
|
||||
difference = set.symmetric_difference(left.identifiers, right.identifiers)
|
||||
if len(difference) > 0:
|
||||
raise ValueError(f"Identifiers only on one side of rearrange expression (should be on both): {difference}")
|
||||
raise ValueError(
|
||||
f"Identifiers only on one side of rearrange expression (should be on both): {difference}"
|
||||
)
|
||||
|
||||
unmatched_axes = axes_lengths.keys() - left.identifiers
|
||||
if len(unmatched_axes) > 0:
|
||||
raise ValueError(f"Identifiers not found in rearrange expression: {unmatched_axes}")
|
||||
raise ValueError(
|
||||
f"Identifiers not found in rearrange expression: {unmatched_axes}"
|
||||
)
|
||||
|
||||
|
||||
def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
|
||||
@ -259,6 +295,8 @@ def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
|
||||
'(d0,), (), (d1,), (d2,), (d3, d4)'
|
||||
"""
|
||||
return ", ".join(
|
||||
item if isinstance(item, str) else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
|
||||
item
|
||||
if isinstance(item, str)
|
||||
else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
|
||||
for item in collection
|
||||
)
|
||||
|
@ -4,8 +4,15 @@ import functools
|
||||
from typing import Callable, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from functorch._C import dim as _C
|
||||
from ._parsing import AnonymousAxis, _ellipsis, comma_separate, parse_pattern, validate_rearrange_expressions
|
||||
from ._parsing import (
|
||||
_ellipsis,
|
||||
AnonymousAxis,
|
||||
comma_separate,
|
||||
parse_pattern,
|
||||
validate_rearrange_expressions,
|
||||
)
|
||||
|
||||
__all__ = ["rearrange"]
|
||||
|
||||
@ -79,10 +86,12 @@ def _create_rearrange_callable(
|
||||
dims_i += 1
|
||||
elif dimension == _ellipsis:
|
||||
identifier = _ellipsis
|
||||
identifier_dim_map[identifier] = tuple(first_class_dims[dims_i + j] for j in range(n_ellipsis_dims))
|
||||
identifier_dim_map[identifier] = tuple(
|
||||
first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
|
||||
)
|
||||
dims_i += n_ellipsis_dims
|
||||
else:
|
||||
raise ValueError(f'Unexpected dimension: {dimension}')
|
||||
raise ValueError(f"Unexpected dimension: {dimension}")
|
||||
|
||||
def composition_to_dims(
|
||||
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
|
||||
@ -92,11 +101,17 @@ def _create_rearrange_callable(
|
||||
dim_composition: List[Union[str, Tuple[str, ...]]] = []
|
||||
for dimension in composition:
|
||||
if isinstance(dimension, list):
|
||||
dim_composition.append(tuple(dim for identifier in dimension for dim in identifier_dim_map[identifier]))
|
||||
dim_composition.append(
|
||||
tuple(
|
||||
dim
|
||||
for identifier in dimension
|
||||
for dim in identifier_dim_map[identifier]
|
||||
)
|
||||
)
|
||||
elif dimension == _ellipsis:
|
||||
dim_composition.extend(identifier_dim_map[_ellipsis])
|
||||
else:
|
||||
raise ValueError(f'Unexpected dimension: {dimension}')
|
||||
raise ValueError(f"Unexpected dimension: {dimension}")
|
||||
return dim_composition
|
||||
|
||||
left_dims = composition_to_dims(left.composition)
|
||||
@ -108,16 +123,22 @@ def _create_rearrange_callable(
|
||||
|
||||
custom_rearrange_callable_name = "do_rearrange"
|
||||
custom_rearrange_callable_code = (
|
||||
f"def {custom_rearrange_callable_name}(tensor):\n"
|
||||
f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
|
||||
(
|
||||
f"def {custom_rearrange_callable_name}(tensor):\n"
|
||||
f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
|
||||
)
|
||||
+ (
|
||||
"".join(f" {dim}.size = {length}\n" for (dim, length) in specified_lengths)
|
||||
if specified_lengths else ""
|
||||
"".join(
|
||||
f" {dim}.size = {length}\n" for (dim, length) in specified_lengths
|
||||
)
|
||||
if specified_lengths
|
||||
else ""
|
||||
)
|
||||
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
|
||||
+ (
|
||||
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
|
||||
if anon_dims else " return tensor\n"
|
||||
if anon_dims
|
||||
else " return tensor\n"
|
||||
)
|
||||
)
|
||||
|
||||
@ -126,7 +147,9 @@ def _create_rearrange_callable(
|
||||
|
||||
|
||||
def rearrange(
|
||||
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], pattern: str, **axes_lengths: int
|
||||
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||
pattern: str,
|
||||
**axes_lengths: int,
|
||||
) -> torch.Tensor:
|
||||
r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
|
||||
tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
|
||||
@ -177,6 +200,8 @@ def rearrange(
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
tensor = torch.stack(tensor)
|
||||
|
||||
rearrange_callable = _create_rearrange_callable(tensor.ndim, pattern, **axes_lengths)
|
||||
rearrange_callable = _create_rearrange_callable(
|
||||
tensor.ndim, pattern, **axes_lengths
|
||||
)
|
||||
|
||||
return rearrange_callable(tensor)
|
||||
|
@ -1,7 +1,8 @@
|
||||
from functorch.compile import aot_function, tvm_compile
|
||||
import torch
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.utils
|
||||
from functorch.compile import aot_function, tvm_compile
|
||||
|
||||
a = torch.randn(2000, 1, 4, requires_grad=True)
|
||||
b = torch.randn(1, 2000, 4)
|
||||
@ -11,8 +12,8 @@ def f(a):
|
||||
return (a * b).sum(dim=0)
|
||||
|
||||
|
||||
fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops')
|
||||
bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops')
|
||||
fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops")
|
||||
bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops")
|
||||
compiled_f = aot_function(f, fw_compiler, bw_compiler)
|
||||
|
||||
# fw_compiler = lambda x, _: x
|
||||
@ -32,13 +33,15 @@ def bench(func):
|
||||
|
||||
|
||||
def bench_jax():
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
jax_a = jnp.array(a.detach().numpy())
|
||||
jax_b = jnp.array(b.detach().numpy())
|
||||
|
||||
def f(a):
|
||||
return jnp.sin((a * jax_b).sum(axis=[0])).sum()
|
||||
|
||||
jit_f = jax.jit(jax.grad(f))
|
||||
jit_f(jax_a)
|
||||
begin = time.time()
|
||||
|
@ -1,15 +1,16 @@
|
||||
import timeit
|
||||
from functorch.compile import compiled_module, tvm_compile
|
||||
import torch.nn as nn
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functorch.compile import compiled_module, tvm_compile
|
||||
|
||||
|
||||
def nop(f, _):
|
||||
return f
|
||||
|
||||
|
||||
fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops')
|
||||
bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops')
|
||||
fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops")
|
||||
bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops")
|
||||
fw_compiler = nop
|
||||
bw_compiler = nop
|
||||
|
||||
|
@ -4,11 +4,13 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functorch import make_functional
|
||||
from functorch.compile import nnc_jit
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
from functorch import make_functional
|
||||
from functorch.compile import nnc_jit
|
||||
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
|
||||
|
||||
@ -30,7 +32,7 @@ class Foo(nn.Module):
|
||||
self.mod = nn.Sequential(*mods)
|
||||
|
||||
def forward(self, x):
|
||||
return (self.mod(x)**2).sum()
|
||||
return (self.mod(x) ** 2).sum()
|
||||
|
||||
|
||||
batch_size = 16
|
||||
@ -54,7 +56,9 @@ def functional_step(x, weights):
|
||||
return out, new_weights
|
||||
|
||||
|
||||
optim = torch.optim.SGD(jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0)
|
||||
optim = torch.optim.SGD(
|
||||
jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0
|
||||
)
|
||||
|
||||
|
||||
def jit_step(x, weights):
|
||||
|
@ -4,10 +4,11 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
from functorch import grad, make_fx
|
||||
from functorch.compile import nnc_jit
|
||||
import torch
|
||||
import time
|
||||
|
||||
|
||||
def f(x):
|
||||
|
@ -17,8 +17,8 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision import models
|
||||
from opacus import PrivacyEngine
|
||||
from torchvision import models
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -52,7 +52,6 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device):
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
@ -279,6 +278,7 @@ def main():
|
||||
)
|
||||
logger.info(metrics)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
@ -309,7 +309,7 @@ def parse_args():
|
||||
default=256,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size for test dataset (default: 256)"
|
||||
help="mini-batch size for test dataset (default: 256)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
|
@ -17,12 +17,12 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from torch.func import functional_call, grad_and_value, vmap
|
||||
from torchvision import models
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from torch.func import vmap, grad_and_value, functional_call
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s:%(levelname)s:%(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
@ -44,12 +44,16 @@ def accuracy(preds, labels):
|
||||
|
||||
def compute_norms(sample_grads):
|
||||
batch_size = sample_grads[0].shape[0]
|
||||
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
|
||||
norms = [
|
||||
sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads
|
||||
]
|
||||
norms = torch.stack(norms, dim=0).norm(2, dim=0)
|
||||
return norms, batch_size
|
||||
|
||||
|
||||
def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
|
||||
def clip_and_accumulate_and_add_noise(
|
||||
model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0
|
||||
):
|
||||
sample_grads = tuple(param.grad_sample for param in model.parameters())
|
||||
|
||||
# step 0: compute the norms
|
||||
@ -60,18 +64,21 @@ def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise
|
||||
clip_factor = clip_factor.clamp(max=1.0)
|
||||
|
||||
# step 2: clip
|
||||
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
|
||||
for sample_grad in sample_grads)
|
||||
grads = tuple(
|
||||
torch.einsum("i,i...", clip_factor, sample_grad) for sample_grad in sample_grads
|
||||
)
|
||||
|
||||
# step 3: add gaussian noise
|
||||
stddev = max_per_sample_grad_norm * noise_multiplier
|
||||
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
|
||||
for grad_param in grads)
|
||||
noises = tuple(
|
||||
torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
|
||||
for grad_param in grads
|
||||
)
|
||||
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
|
||||
|
||||
# step 4: assign the new grads, delete the sample grads
|
||||
for param, param_grad in zip(model.parameters(), grads):
|
||||
param.grad = param_grad/batch_size
|
||||
param.grad = param_grad / batch_size
|
||||
del param.grad_sample
|
||||
|
||||
|
||||
@ -84,7 +91,6 @@ def train(args, model, train_loader, optimizer, epoch, device):
|
||||
top1_acc = []
|
||||
|
||||
for i, (images, target) in enumerate(tqdm(train_loader)):
|
||||
|
||||
images = images.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
@ -120,8 +126,9 @@ def train(args, model, train_loader, optimizer, epoch, device):
|
||||
# detaching weights since we don't need to track gradients outside of transforms
|
||||
# and this is more performant
|
||||
detached_weights = {k: v.detach() for k, v in weights.items()}
|
||||
sample_grads, (sample_loss, output) = \
|
||||
vmap(grads_loss_output, (None, 0, 0))(detached_weights, images, target)
|
||||
sample_grads, (sample_loss, output) = vmap(grads_loss_output, (None, 0, 0))(
|
||||
detached_weights, images, target
|
||||
)
|
||||
loss = sample_loss.mean()
|
||||
|
||||
for name, grad_sample in sample_grads.items():
|
||||
@ -129,7 +136,8 @@ def train(args, model, train_loader, optimizer, epoch, device):
|
||||
|
||||
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
|
||||
clip_and_accumulate_and_add_noise(
|
||||
model, args.max_per_sample_grad_norm, args.sigma)
|
||||
model, args.max_per_sample_grad_norm, args.sigma
|
||||
)
|
||||
|
||||
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
|
||||
labels = target.detach().cpu().numpy()
|
||||
@ -270,9 +278,7 @@ def main():
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
train_duration = train(
|
||||
args, model, train_loader, optimizer, epoch, device
|
||||
)
|
||||
train_duration = train(args, model, train_loader, optimizer, epoch, device)
|
||||
top1_acc = test(args, model, test_loader, device)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
@ -308,6 +314,7 @@ def main():
|
||||
)
|
||||
logger.info(metrics)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
|
||||
parser.add_argument(
|
||||
@ -338,7 +345,7 @@ def parse_args():
|
||||
default=256,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="mini-batch size for test dataset (default: 256)"
|
||||
help="mini-batch size for test dataset (default: 256)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
|
@ -1,9 +1,10 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.func import functional_call, grad_and_value, vmap, stack_module_state
|
||||
from torch.func import functional_call, grad_and_value, stack_module_state, vmap
|
||||
|
||||
# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a
|
||||
# tutorial on Model Ensembling with JAX by Will Whitney.
|
||||
@ -33,15 +34,21 @@ DEVICE = args.device
|
||||
# Step 1: Make some spirals
|
||||
|
||||
|
||||
def make_spirals(n_samples, noise_std=0., rotations=1.):
|
||||
def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
|
||||
ts = torch.linspace(0, 1, n_samples, device=DEVICE)
|
||||
rs = ts ** 0.5
|
||||
rs = ts**0.5
|
||||
thetas = rs * rotations * 2 * math.pi
|
||||
signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1
|
||||
labels = (signs > 0).to(torch.long).to(DEVICE)
|
||||
|
||||
xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
|
||||
ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
|
||||
xs = (
|
||||
rs * signs * torch.cos(thetas)
|
||||
+ torch.randn(n_samples, device=DEVICE) * noise_std
|
||||
)
|
||||
ys = (
|
||||
rs * signs * torch.sin(thetas)
|
||||
+ torch.randn(n_samples, device=DEVICE) * noise_std
|
||||
)
|
||||
points = torch.stack([xs, ys], dim=1)
|
||||
return points, labels
|
||||
|
||||
@ -70,6 +77,7 @@ class MLPClassifier(nn.Module):
|
||||
loss_fn = nn.NLLLoss()
|
||||
model = MLPClassifier().to(DEVICE)
|
||||
|
||||
|
||||
def train_step_fn(weights, batch, targets, lr=0.2):
|
||||
def compute_loss(weights, batch, targets):
|
||||
output = functional_call(model, weights, batch)
|
||||
@ -109,6 +117,7 @@ def init_fn(num_models):
|
||||
params, _ = stack_module_state(models)
|
||||
return params
|
||||
|
||||
|
||||
# Step 6: Now, can we try multiple models at the same time?
|
||||
# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
|
||||
# on decreasing
|
||||
|
@ -4,15 +4,15 @@
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.functional import mse_loss
|
||||
from torch.func import jacrev, vmap
|
||||
from torch.nn.functional import mse_loss
|
||||
|
||||
sigma = 0.5
|
||||
epsilon = 4.
|
||||
epsilon = 4.0
|
||||
|
||||
|
||||
def lennard_jones(r):
|
||||
return epsilon * ((sigma / r)**12 - (sigma / r)**6)
|
||||
return epsilon * ((sigma / r) ** 12 - (sigma / r) ** 6)
|
||||
|
||||
|
||||
def lennard_jones_force(r):
|
||||
@ -29,7 +29,9 @@ norms = torch.norm(drs, dim=1).reshape(-1, 1)
|
||||
# Create training energies
|
||||
training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1)
|
||||
# Create forces with random direction vectors
|
||||
training_forces = torch.stack([force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)])
|
||||
training_forces = torch.stack(
|
||||
[force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)]
|
||||
)
|
||||
|
||||
model = nn.Sequential(
|
||||
nn.Linear(1, 16),
|
||||
@ -40,7 +42,7 @@ model = nn.Sequential(
|
||||
nn.Tanh(),
|
||||
nn.Linear(16, 16),
|
||||
nn.Tanh(),
|
||||
nn.Linear(16, 1)
|
||||
nn.Linear(16, 1),
|
||||
)
|
||||
|
||||
|
||||
@ -54,7 +56,10 @@ def make_prediction(model, drs):
|
||||
|
||||
|
||||
def loss_fn(energies, forces, predicted_energies, predicted_forces):
|
||||
return mse_loss(energies, predicted_energies) + 0.01 * mse_loss(forces, predicted_forces) / 3
|
||||
return (
|
||||
mse_loss(energies, predicted_energies)
|
||||
+ 0.01 * mse_loss(forces, predicted_forces) / 3
|
||||
)
|
||||
|
||||
|
||||
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
|
@ -27,38 +27,43 @@ Our MAML++ fork and experiments are available at:
|
||||
https://github.com/bamos/HowToTrainYourMAMLPytorch
|
||||
"""
|
||||
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
import higher
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import higher
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
plt.style.use('bmh')
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
from torch import nn
|
||||
|
||||
mpl.use("Agg")
|
||||
plt.style.use("bmh")
|
||||
|
||||
|
||||
def main():
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5)
|
||||
argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5)
|
||||
argparser.add_argument(
|
||||
'--k-spt', '--k_spt', type=int, help='k shot for support set', default=5)
|
||||
"--k-spt", "--k_spt", type=int, help="k shot for support set", default=5
|
||||
)
|
||||
argparser.add_argument(
|
||||
'--k-qry', '--k_qry', type=int, help='k shot for query set', default=15)
|
||||
"--k-qry", "--k_qry", type=int, help="k shot for query set", default=15
|
||||
)
|
||||
argparser.add_argument("--device", type=str, help="device", default="cuda")
|
||||
argparser.add_argument(
|
||||
'--device', type=str, help='device', default='cuda')
|
||||
argparser.add_argument(
|
||||
'--task-num', '--task_num',
|
||||
"--task-num",
|
||||
"--task_num",
|
||||
type=int,
|
||||
help='meta batch size, namely task num',
|
||||
default=32)
|
||||
argparser.add_argument('--seed', type=int, help='random seed', default=1)
|
||||
help="meta batch size, namely task num",
|
||||
default=32,
|
||||
)
|
||||
argparser.add_argument("--seed", type=int, help="random seed", default=1)
|
||||
args = argparser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
@ -69,7 +74,7 @@ def main():
|
||||
# Set up the Omniglot loader.
|
||||
device = args.device
|
||||
db = OmniglotNShot(
|
||||
'/tmp/omniglot-data',
|
||||
"/tmp/omniglot-data",
|
||||
batchsz=args.task_num,
|
||||
n_way=args.n_way,
|
||||
k_shot=args.k_spt,
|
||||
@ -97,7 +102,8 @@ def main():
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Flatten(),
|
||||
nn.Linear(64, args.n_way)).to(device)
|
||||
nn.Linear(64, args.n_way),
|
||||
).to(device)
|
||||
|
||||
# We will use Adam to (meta-)optimize the initial parameters
|
||||
# to be adapted.
|
||||
@ -134,9 +140,10 @@ def train(db, net, device, meta_opt, epoch, log):
|
||||
qry_accs = []
|
||||
meta_opt.zero_grad()
|
||||
for i in range(task_num):
|
||||
with higher.innerloop_ctx(
|
||||
net, inner_opt, copy_initial_weights=False
|
||||
) as (fnet, diffopt):
|
||||
with higher.innerloop_ctx(net, inner_opt, copy_initial_weights=False) as (
|
||||
fnet,
|
||||
diffopt,
|
||||
):
|
||||
# Optimize the likelihood of the support set by taking
|
||||
# gradient steps w.r.t. the model's parameters.
|
||||
# This adapts the model's meta-parameters to the task.
|
||||
@ -153,8 +160,7 @@ def train(db, net, device, meta_opt, epoch, log):
|
||||
qry_logits = fnet(x_qry[i])
|
||||
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
|
||||
qry_losses.append(qry_loss.detach())
|
||||
qry_acc = (qry_logits.argmax(
|
||||
dim=1) == y_qry[i]).sum().item() / querysz
|
||||
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
|
||||
qry_accs.append(qry_acc)
|
||||
|
||||
# print([b.shape for b in fnet[1].buffers()])
|
||||
@ -166,21 +172,23 @@ def train(db, net, device, meta_opt, epoch, log):
|
||||
|
||||
meta_opt.step()
|
||||
qry_losses = sum(qry_losses) / task_num
|
||||
qry_accs = 100. * sum(qry_accs) / task_num
|
||||
qry_accs = 100.0 * sum(qry_accs) / task_num
|
||||
i = epoch + float(batch_idx) / n_train_iter
|
||||
iter_time = time.time() - start_time
|
||||
if batch_idx % 4 == 0:
|
||||
print(
|
||||
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
|
||||
f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}"
|
||||
)
|
||||
|
||||
log.append({
|
||||
'epoch': i,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'train',
|
||||
'time': time.time(),
|
||||
})
|
||||
log.append(
|
||||
{
|
||||
"epoch": i,
|
||||
"loss": qry_losses,
|
||||
"acc": qry_accs,
|
||||
"mode": "train",
|
||||
"time": time.time(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test(db, net, device, epoch, log):
|
||||
@ -196,7 +204,7 @@ def test(db, net, device, epoch, log):
|
||||
qry_accs = []
|
||||
|
||||
for _ in range(n_test_iter):
|
||||
x_spt, y_spt, x_qry, y_qry = db.next('test')
|
||||
x_spt, y_spt, x_qry, y_qry = db.next("test")
|
||||
|
||||
task_num, setsz, c_, h, w = x_spt.size()
|
||||
|
||||
@ -206,7 +214,10 @@ def test(db, net, device, epoch, log):
|
||||
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
|
||||
|
||||
for i in range(task_num):
|
||||
with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
|
||||
with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (
|
||||
fnet,
|
||||
diffopt,
|
||||
):
|
||||
# Optimize the likelihood of the support set by taking
|
||||
# gradient steps w.r.t. the model's parameters.
|
||||
# This adapts the model's meta-parameters to the task.
|
||||
@ -217,24 +228,22 @@ def test(db, net, device, epoch, log):
|
||||
|
||||
# The query loss and acc induced by these parameters.
|
||||
qry_logits = fnet(x_qry[i]).detach()
|
||||
qry_loss = F.cross_entropy(
|
||||
qry_logits, y_qry[i], reduction='none')
|
||||
qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none")
|
||||
qry_losses.append(qry_loss.detach())
|
||||
qry_accs.append(
|
||||
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
|
||||
qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
|
||||
|
||||
qry_losses = torch.cat(qry_losses).mean().item()
|
||||
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
|
||||
print(
|
||||
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
|
||||
qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
|
||||
print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}")
|
||||
log.append(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"loss": qry_losses,
|
||||
"acc": qry_accs,
|
||||
"mode": "test",
|
||||
"time": time.time(),
|
||||
}
|
||||
)
|
||||
log.append({
|
||||
'epoch': epoch + 1,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'test',
|
||||
'time': time.time(),
|
||||
})
|
||||
|
||||
|
||||
def plot(log):
|
||||
@ -243,17 +252,17 @@ def plot(log):
|
||||
df = pd.DataFrame(log)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
train_df = df[df['mode'] == 'train']
|
||||
test_df = df[df['mode'] == 'test']
|
||||
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
|
||||
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
|
||||
ax.set_xlabel('Epoch')
|
||||
ax.set_ylabel('Accuracy')
|
||||
train_df = df[df["mode"] == "train"]
|
||||
test_df = df[df["mode"] == "test"]
|
||||
ax.plot(train_df["epoch"], train_df["acc"], label="Train")
|
||||
ax.plot(test_df["epoch"], test_df["acc"], label="Test")
|
||||
ax.set_xlabel("Epoch")
|
||||
ax.set_ylabel("Accuracy")
|
||||
ax.set_ylim(70, 100)
|
||||
fig.legend(ncol=2, loc='lower right')
|
||||
fig.legend(ncol=2, loc="lower right")
|
||||
fig.tight_layout()
|
||||
fname = 'maml-accs.png'
|
||||
print(f'--- Plotting accuracy to {fname}')
|
||||
fname = "maml-accs.png"
|
||||
print(f"--- Plotting accuracy to {fname}")
|
||||
fig.savefig(fname)
|
||||
plt.close(fig)
|
||||
|
||||
@ -265,5 +274,5 @@ class Flatten(nn.Module):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -27,38 +27,43 @@ Our MAML++ fork and experiments are available at:
|
||||
https://github.com/bamos/HowToTrainYourMAMLPytorch
|
||||
"""
|
||||
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
from functorch import make_functional_with_buffers
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
plt.style.use('bmh')
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from functorch import make_functional_with_buffers
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
from torch import nn
|
||||
|
||||
mpl.use("Agg")
|
||||
plt.style.use("bmh")
|
||||
|
||||
|
||||
def main():
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5)
|
||||
argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5)
|
||||
argparser.add_argument(
|
||||
'--k-spt', '--k_spt', type=int, help='k shot for support set', default=5)
|
||||
"--k-spt", "--k_spt", type=int, help="k shot for support set", default=5
|
||||
)
|
||||
argparser.add_argument(
|
||||
'--k-qry', '--k_qry', type=int, help='k shot for query set', default=15)
|
||||
"--k-qry", "--k_qry", type=int, help="k shot for query set", default=15
|
||||
)
|
||||
argparser.add_argument("--device", type=str, help="device", default="cuda")
|
||||
argparser.add_argument(
|
||||
'--device', type=str, help='device', default='cuda')
|
||||
argparser.add_argument(
|
||||
'--task-num', '--task_num',
|
||||
"--task-num",
|
||||
"--task_num",
|
||||
type=int,
|
||||
help='meta batch size, namely task num',
|
||||
default=32)
|
||||
argparser.add_argument('--seed', type=int, help='random seed', default=1)
|
||||
help="meta batch size, namely task num",
|
||||
default=32,
|
||||
)
|
||||
argparser.add_argument("--seed", type=int, help="random seed", default=1)
|
||||
args = argparser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
@ -69,7 +74,7 @@ def main():
|
||||
# Set up the Omniglot loader.
|
||||
device = args.device
|
||||
db = OmniglotNShot(
|
||||
'/tmp/omniglot-data',
|
||||
"/tmp/omniglot-data",
|
||||
batchsz=args.task_num,
|
||||
n_way=args.n_way,
|
||||
k_shot=args.k_spt,
|
||||
@ -97,7 +102,8 @@ def main():
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Flatten(),
|
||||
nn.Linear(64, args.n_way)).to(device)
|
||||
nn.Linear(64, args.n_way),
|
||||
).to(device)
|
||||
|
||||
net.train()
|
||||
fnet, params, buffers = make_functional_with_buffers(net)
|
||||
@ -153,8 +159,7 @@ def train(db, net, device, meta_opt, epoch, log):
|
||||
qry_logits = fnet(new_params, buffers, x_qry[i])
|
||||
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
|
||||
qry_losses.append(qry_loss.detach())
|
||||
qry_acc = (qry_logits.argmax(
|
||||
dim=1) == y_qry[i]).sum().item() / querysz
|
||||
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
|
||||
qry_accs.append(qry_acc)
|
||||
|
||||
# Update the model's meta-parameters to optimize the query
|
||||
@ -164,21 +169,23 @@ def train(db, net, device, meta_opt, epoch, log):
|
||||
|
||||
meta_opt.step()
|
||||
qry_losses = sum(qry_losses) / task_num
|
||||
qry_accs = 100. * sum(qry_accs) / task_num
|
||||
qry_accs = 100.0 * sum(qry_accs) / task_num
|
||||
i = epoch + float(batch_idx) / n_train_iter
|
||||
iter_time = time.time() - start_time
|
||||
if batch_idx % 4 == 0:
|
||||
print(
|
||||
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
|
||||
f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}"
|
||||
)
|
||||
|
||||
log.append({
|
||||
'epoch': i,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'train',
|
||||
'time': time.time(),
|
||||
})
|
||||
log.append(
|
||||
{
|
||||
"epoch": i,
|
||||
"loss": qry_losses,
|
||||
"acc": qry_accs,
|
||||
"mode": "train",
|
||||
"time": time.time(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test(db, net, device, epoch, log):
|
||||
@ -194,7 +201,7 @@ def test(db, net, device, epoch, log):
|
||||
qry_accs = []
|
||||
|
||||
for batch_idx in range(n_test_iter):
|
||||
x_spt, y_spt, x_qry, y_qry = db.next('test')
|
||||
x_spt, y_spt, x_qry, y_qry = db.next("test")
|
||||
task_num, setsz, c_, h, w = x_spt.size()
|
||||
|
||||
# TODO: Maybe pull this out into a separate module so it
|
||||
@ -211,24 +218,22 @@ def test(db, net, device, epoch, log):
|
||||
|
||||
# The query loss and acc induced by these parameters.
|
||||
qry_logits = fnet(new_params, buffers, x_qry[i]).detach()
|
||||
qry_loss = F.cross_entropy(
|
||||
qry_logits, y_qry[i], reduction='none')
|
||||
qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none")
|
||||
qry_losses.append(qry_loss.detach())
|
||||
qry_accs.append(
|
||||
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
|
||||
qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
|
||||
|
||||
qry_losses = torch.cat(qry_losses).mean().item()
|
||||
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
|
||||
print(
|
||||
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
|
||||
qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
|
||||
print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}")
|
||||
log.append(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"loss": qry_losses,
|
||||
"acc": qry_accs,
|
||||
"mode": "test",
|
||||
"time": time.time(),
|
||||
}
|
||||
)
|
||||
log.append({
|
||||
'epoch': epoch + 1,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'test',
|
||||
'time': time.time(),
|
||||
})
|
||||
|
||||
|
||||
def plot(log):
|
||||
@ -237,17 +242,17 @@ def plot(log):
|
||||
df = pd.DataFrame(log)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
train_df = df[df['mode'] == 'train']
|
||||
test_df = df[df['mode'] == 'test']
|
||||
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
|
||||
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
|
||||
ax.set_xlabel('Epoch')
|
||||
ax.set_ylabel('Accuracy')
|
||||
train_df = df[df["mode"] == "train"]
|
||||
test_df = df[df["mode"] == "test"]
|
||||
ax.plot(train_df["epoch"], train_df["acc"], label="Train")
|
||||
ax.plot(test_df["epoch"], test_df["acc"], label="Test")
|
||||
ax.set_xlabel("Epoch")
|
||||
ax.set_ylabel("Accuracy")
|
||||
ax.set_ylim(70, 100)
|
||||
fig.legend(ncol=2, loc='lower right')
|
||||
fig.legend(ncol=2, loc="lower right")
|
||||
fig.tight_layout()
|
||||
fname = 'maml-accs.png'
|
||||
print(f'--- Plotting accuracy to {fname}')
|
||||
fname = "maml-accs.png"
|
||||
print(f"--- Plotting accuracy to {fname}")
|
||||
fig.savefig(fname)
|
||||
plt.close(fig)
|
||||
|
||||
@ -259,5 +264,5 @@ class Flatten(nn.Module):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -27,39 +27,44 @@ Our MAML++ fork and experiments are available at:
|
||||
https://github.com/bamos/HowToTrainYourMAMLPytorch
|
||||
"""
|
||||
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
from torch.func import vmap, grad, functional_call
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
import time
|
||||
import functools
|
||||
import time
|
||||
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
plt.style.use('bmh')
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
from torch import nn
|
||||
from torch.func import functional_call, grad, vmap
|
||||
|
||||
mpl.use("Agg")
|
||||
plt.style.use("bmh")
|
||||
|
||||
|
||||
def main():
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5)
|
||||
argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5)
|
||||
argparser.add_argument(
|
||||
'--k-spt', '--k_spt', type=int, help='k shot for support set', default=5)
|
||||
"--k-spt", "--k_spt", type=int, help="k shot for support set", default=5
|
||||
)
|
||||
argparser.add_argument(
|
||||
'--k-qry', '--k_qry', type=int, help='k shot for query set', default=15)
|
||||
"--k-qry", "--k_qry", type=int, help="k shot for query set", default=15
|
||||
)
|
||||
argparser.add_argument("--device", type=str, help="device", default="cuda")
|
||||
argparser.add_argument(
|
||||
'--device', type=str, help='device', default='cuda')
|
||||
argparser.add_argument(
|
||||
'--task-num', '--task_num',
|
||||
"--task-num",
|
||||
"--task_num",
|
||||
type=int,
|
||||
help='meta batch size, namely task num',
|
||||
default=32)
|
||||
argparser.add_argument('--seed', type=int, help='random seed', default=1)
|
||||
help="meta batch size, namely task num",
|
||||
default=32,
|
||||
)
|
||||
argparser.add_argument("--seed", type=int, help="random seed", default=1)
|
||||
args = argparser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
@ -70,7 +75,7 @@ def main():
|
||||
# Set up the Omniglot loader.
|
||||
device = args.device
|
||||
db = OmniglotNShot(
|
||||
'/tmp/omniglot-data',
|
||||
"/tmp/omniglot-data",
|
||||
batchsz=args.task_num,
|
||||
n_way=args.n_way,
|
||||
k_shot=args.k_spt,
|
||||
@ -95,7 +100,8 @@ def main():
|
||||
nn.ReLU(inplace=inplace_relu),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.Flatten(),
|
||||
nn.Linear(64, args.n_way)).to(device)
|
||||
nn.Linear(64, args.n_way),
|
||||
).to(device)
|
||||
|
||||
net.train()
|
||||
|
||||
@ -132,8 +138,7 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
|
||||
# These will be used to update the model's meta-parameters.
|
||||
qry_logits = functional_call(net, (new_params, buffers), x_qry)
|
||||
qry_loss = F.cross_entropy(qry_logits, y_qry)
|
||||
qry_acc = (qry_logits.argmax(
|
||||
dim=1) == y_qry).sum() / querysz
|
||||
qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz
|
||||
|
||||
return qry_loss, qry_acc
|
||||
|
||||
@ -163,21 +168,23 @@ def train(db, net, device, meta_opt, epoch, log):
|
||||
|
||||
meta_opt.step()
|
||||
qry_losses = qry_losses.detach().sum() / task_num
|
||||
qry_accs = 100. * qry_accs.sum() / task_num
|
||||
qry_accs = 100.0 * qry_accs.sum() / task_num
|
||||
i = epoch + float(batch_idx) / n_train_iter
|
||||
iter_time = time.time() - start_time
|
||||
if batch_idx % 4 == 0:
|
||||
print(
|
||||
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
|
||||
f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}"
|
||||
)
|
||||
|
||||
log.append({
|
||||
'epoch': i,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'train',
|
||||
'time': time.time(),
|
||||
})
|
||||
log.append(
|
||||
{
|
||||
"epoch": i,
|
||||
"loss": qry_losses,
|
||||
"acc": qry_accs,
|
||||
"mode": "train",
|
||||
"time": time.time(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test(db, net, device, epoch, log):
|
||||
@ -194,7 +201,7 @@ def test(db, net, device, epoch, log):
|
||||
qry_accs = []
|
||||
|
||||
for batch_idx in range(n_test_iter):
|
||||
x_spt, y_spt, x_qry, y_qry = db.next('test')
|
||||
x_spt, y_spt, x_qry, y_qry = db.next("test")
|
||||
task_num, setsz, c_, h, w = x_spt.size()
|
||||
|
||||
# TODO: Maybe pull this out into a separate module so it
|
||||
@ -207,28 +214,28 @@ def test(db, net, device, epoch, log):
|
||||
spt_logits = functional_call(net, (new_params, buffers), x_spt[i])
|
||||
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
|
||||
grads = torch.autograd.grad(spt_loss, new_params.values())
|
||||
new_params = {k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads)}
|
||||
new_params = {
|
||||
k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads)
|
||||
}
|
||||
|
||||
# The query loss and acc induced by these parameters.
|
||||
qry_logits = functional_call(net, (new_params, buffers), x_qry[i]).detach()
|
||||
qry_loss = F.cross_entropy(
|
||||
qry_logits, y_qry[i], reduction='none')
|
||||
qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none")
|
||||
qry_losses.append(qry_loss.detach())
|
||||
qry_accs.append(
|
||||
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
|
||||
qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
|
||||
|
||||
qry_losses = torch.cat(qry_losses).mean().item()
|
||||
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
|
||||
print(
|
||||
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
|
||||
qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
|
||||
print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}")
|
||||
log.append(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"loss": qry_losses,
|
||||
"acc": qry_accs,
|
||||
"mode": "test",
|
||||
"time": time.time(),
|
||||
}
|
||||
)
|
||||
log.append({
|
||||
'epoch': epoch + 1,
|
||||
'loss': qry_losses,
|
||||
'acc': qry_accs,
|
||||
'mode': 'test',
|
||||
'time': time.time(),
|
||||
})
|
||||
|
||||
|
||||
def plot(log):
|
||||
@ -237,20 +244,20 @@ def plot(log):
|
||||
df = pd.DataFrame(log)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
train_df = df[df['mode'] == 'train']
|
||||
test_df = df[df['mode'] == 'test']
|
||||
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
|
||||
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
|
||||
ax.set_xlabel('Epoch')
|
||||
ax.set_ylabel('Accuracy')
|
||||
train_df = df[df["mode"] == "train"]
|
||||
test_df = df[df["mode"] == "test"]
|
||||
ax.plot(train_df["epoch"], train_df["acc"], label="Train")
|
||||
ax.plot(test_df["epoch"], test_df["acc"], label="Test")
|
||||
ax.set_xlabel("Epoch")
|
||||
ax.set_ylabel("Accuracy")
|
||||
ax.set_ylim(70, 100)
|
||||
fig.legend(ncol=2, loc='lower right')
|
||||
fig.legend(ncol=2, loc="lower right")
|
||||
fig.tight_layout()
|
||||
fname = 'maml-accs.png'
|
||||
print(f'--- Plotting accuracy to {fname}')
|
||||
fname = "maml-accs.png"
|
||||
print(f"--- Plotting accuracy to {fname}")
|
||||
fig.savefig(fname)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -17,38 +17,38 @@
|
||||
# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py
|
||||
# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
import errno
|
||||
import os
|
||||
import os.path
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import os
|
||||
import os.path
|
||||
import errno
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Omniglot(data.Dataset):
|
||||
urls = [
|
||||
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
|
||||
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
|
||||
"https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip",
|
||||
"https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip",
|
||||
]
|
||||
raw_folder = 'raw'
|
||||
processed_folder = 'processed'
|
||||
training_file = 'training.pt'
|
||||
test_file = 'test.pt'
|
||||
raw_folder = "raw"
|
||||
processed_folder = "processed"
|
||||
training_file = "training.pt"
|
||||
test_file = "test.pt"
|
||||
|
||||
'''
|
||||
"""
|
||||
The items are (filename,category). The index of all the categories can be found in self.idx_classes
|
||||
Args:
|
||||
- root: the directory where the dataset will be stored
|
||||
- transform: how to transform the input
|
||||
- target_transform: how to transform the target
|
||||
- download: need to download the dataset
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, root, transform=None, target_transform=None,
|
||||
download=False):
|
||||
def __init__(self, root, transform=None, target_transform=None, download=False):
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
@ -57,14 +57,16 @@ class Omniglot(data.Dataset):
|
||||
if download:
|
||||
self.download()
|
||||
else:
|
||||
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
|
||||
raise RuntimeError(
|
||||
"Dataset not found." + " You can use download=True to download it"
|
||||
)
|
||||
|
||||
self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
|
||||
self.idx_classes = index_classes(self.all_items)
|
||||
|
||||
def __getitem__(self, index):
|
||||
filename = self.all_items[index][0]
|
||||
img = str.join('/', [self.all_items[index][2], filename])
|
||||
img = str.join("/", [self.all_items[index][2], filename])
|
||||
|
||||
target = self.idx_classes[self.all_items[index][1]]
|
||||
if self.transform is not None:
|
||||
@ -78,8 +80,11 @@ class Omniglot(data.Dataset):
|
||||
return len(self.all_items)
|
||||
|
||||
def _check_exists(self):
|
||||
return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
|
||||
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
|
||||
return os.path.exists(
|
||||
os.path.join(self.root, self.processed_folder, "images_evaluation")
|
||||
) and os.path.exists(
|
||||
os.path.join(self.root, self.processed_folder, "images_background")
|
||||
)
|
||||
|
||||
def download(self):
|
||||
import urllib
|
||||
@ -99,15 +104,15 @@ class Omniglot(data.Dataset):
|
||||
raise
|
||||
|
||||
for url in self.urls:
|
||||
print('== Downloading ' + url)
|
||||
print("== Downloading " + url)
|
||||
data = urllib.request.urlopen(url)
|
||||
filename = url.rpartition('/')[2]
|
||||
filename = url.rpartition("/")[2]
|
||||
file_path = os.path.join(self.root, self.raw_folder, filename)
|
||||
with open(file_path, 'wb') as f:
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(data.read())
|
||||
file_processed = os.path.join(self.root, self.processed_folder)
|
||||
print("== Unzip from " + file_path + " to " + file_processed)
|
||||
zip_ref = zipfile.ZipFile(file_path, 'r')
|
||||
zip_ref = zipfile.ZipFile(file_path, "r")
|
||||
zip_ref.extractall(file_processed)
|
||||
zip_ref.close()
|
||||
print("Download finished.")
|
||||
@ -115,10 +120,10 @@ class Omniglot(data.Dataset):
|
||||
|
||||
def find_classes(root_dir):
|
||||
retour = []
|
||||
for (root, dirs, files) in os.walk(root_dir):
|
||||
for root, dirs, files in os.walk(root_dir):
|
||||
for f in files:
|
||||
if (f.endswith("png")):
|
||||
r = root.split('/')
|
||||
if f.endswith("png"):
|
||||
r = root.split("/")
|
||||
lr = len(r)
|
||||
retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
|
||||
print(f"== Found {len(retour)} items ")
|
||||
@ -135,7 +140,6 @@ def index_classes(items):
|
||||
|
||||
|
||||
class OmniglotNShot:
|
||||
|
||||
def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None):
|
||||
"""
|
||||
Different from mnistNShot, the
|
||||
@ -149,41 +153,52 @@ class OmniglotNShot:
|
||||
|
||||
self.resize = imgsz
|
||||
self.device = device
|
||||
if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
|
||||
if not os.path.isfile(os.path.join(root, "omniglot.npy")):
|
||||
# if root/data.npy does not exist, just download it
|
||||
self.x = Omniglot(
|
||||
root, download=True,
|
||||
root,
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[lambda x: Image.open(x).convert('L'),
|
||||
lambda x: x.resize((imgsz, imgsz)),
|
||||
lambda x: np.reshape(x, (imgsz, imgsz, 1)),
|
||||
lambda x: np.transpose(x, [2, 0, 1]),
|
||||
lambda x: x / 255.]),
|
||||
[
|
||||
lambda x: Image.open(x).convert("L"),
|
||||
lambda x: x.resize((imgsz, imgsz)),
|
||||
lambda x: np.reshape(x, (imgsz, imgsz, 1)),
|
||||
lambda x: np.transpose(x, [2, 0, 1]),
|
||||
lambda x: x / 255.0,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
temp = {} # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
|
||||
for (img, label) in self.x:
|
||||
temp = (
|
||||
{}
|
||||
) # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
|
||||
for img, label in self.x:
|
||||
if label in temp.keys():
|
||||
temp[label].append(img)
|
||||
else:
|
||||
temp[label] = [img]
|
||||
|
||||
self.x = []
|
||||
for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs
|
||||
for (
|
||||
label,
|
||||
imgs,
|
||||
) in temp.items(): # labels info deserted , each label contains 20imgs
|
||||
self.x.append(np.array(imgs))
|
||||
|
||||
# as different class may have different number of imgs
|
||||
self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total]
|
||||
self.x = np.array(self.x).astype(
|
||||
np.float
|
||||
) # [[20 imgs],..., 1623 classes in total]
|
||||
# each character contains 20 imgs
|
||||
print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1]
|
||||
print("data shape:", self.x.shape) # [1623, 20, 84, 84, 1]
|
||||
temp = [] # Free memory
|
||||
# save all dataset into npy file.
|
||||
np.save(os.path.join(root, 'omniglot.npy'), self.x)
|
||||
print('write into omniglot.npy.')
|
||||
np.save(os.path.join(root, "omniglot.npy"), self.x)
|
||||
print("write into omniglot.npy.")
|
||||
else:
|
||||
# if data.npy exists, just load it.
|
||||
self.x = np.load(os.path.join(root, 'omniglot.npy'))
|
||||
print('load from omniglot.npy.')
|
||||
self.x = np.load(os.path.join(root, "omniglot.npy"))
|
||||
print("load from omniglot.npy.")
|
||||
|
||||
# [1623, 20, 84, 84, 1]
|
||||
# TODO: can not shuffle here, we must keep training and test set distinct!
|
||||
@ -200,11 +215,18 @@ class OmniglotNShot:
|
||||
|
||||
# save pointer of current read batch in total cache
|
||||
self.indexes = {"train": 0, "test": 0}
|
||||
self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached
|
||||
self.datasets = {
|
||||
"train": self.x_train,
|
||||
"test": self.x_test,
|
||||
} # original data cached
|
||||
print("DB: train", self.x_train.shape, "test", self.x_test.shape)
|
||||
|
||||
self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached
|
||||
"test": self.load_data_cache(self.datasets["test"])}
|
||||
self.datasets_cache = {
|
||||
"train": self.load_data_cache(
|
||||
self.datasets["train"]
|
||||
), # current epoch data cached
|
||||
"test": self.load_data_cache(self.datasets["test"]),
|
||||
}
|
||||
|
||||
def normalization(self):
|
||||
"""
|
||||
@ -238,29 +260,32 @@ class OmniglotNShot:
|
||||
|
||||
# print('preload next 50 caches of batchsz of batch.')
|
||||
for sample in range(10): # num of episodes
|
||||
|
||||
x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
|
||||
for i in range(self.batchsz): # one batch means one set
|
||||
|
||||
x_spt, y_spt, x_qry, y_qry = [], [], [], []
|
||||
selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False)
|
||||
|
||||
for j, cur_class in enumerate(selected_cls):
|
||||
|
||||
selected_img = np.random.choice(20, self.k_shot + self.k_query, False)
|
||||
selected_img = np.random.choice(
|
||||
20, self.k_shot + self.k_query, False
|
||||
)
|
||||
|
||||
# meta-training and meta-test
|
||||
x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]])
|
||||
x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]])
|
||||
x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]])
|
||||
x_qry.append(data_pack[cur_class][selected_img[self.k_shot :]])
|
||||
y_spt.append([j for _ in range(self.k_shot)])
|
||||
y_qry.append([j for _ in range(self.k_query)])
|
||||
|
||||
# shuffle inside a batch
|
||||
perm = np.random.permutation(self.n_way * self.k_shot)
|
||||
x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm]
|
||||
x_spt = np.array(x_spt).reshape(
|
||||
self.n_way * self.k_shot, 1, self.resize, self.resize
|
||||
)[perm]
|
||||
y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
|
||||
perm = np.random.permutation(self.n_way * self.k_query)
|
||||
x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm]
|
||||
x_qry = np.array(x_qry).reshape(
|
||||
self.n_way * self.k_query, 1, self.resize, self.resize
|
||||
)[perm]
|
||||
y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
|
||||
|
||||
# append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
|
||||
@ -270,22 +295,30 @@ class OmniglotNShot:
|
||||
y_qrys.append(y_qry)
|
||||
|
||||
# [b, setsz, 1, 84, 84]
|
||||
x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize)
|
||||
x_spts = (
|
||||
np.array(x_spts)
|
||||
.astype(np.float32)
|
||||
.reshape(self.batchsz, setsz, 1, self.resize, self.resize)
|
||||
)
|
||||
y_spts = np.array(y_spts).astype(int).reshape(self.batchsz, setsz)
|
||||
# [b, qrysz, 1, 84, 84]
|
||||
x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
|
||||
x_qrys = (
|
||||
np.array(x_qrys)
|
||||
.astype(np.float32)
|
||||
.reshape(self.batchsz, querysz, 1, self.resize, self.resize)
|
||||
)
|
||||
y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz)
|
||||
|
||||
x_spts, y_spts, x_qrys, y_qrys = (
|
||||
torch.from_numpy(z).to(self.device) for z in
|
||||
[x_spts, y_spts, x_qrys, y_qrys]
|
||||
torch.from_numpy(z).to(self.device)
|
||||
for z in [x_spts, y_spts, x_qrys, y_qrys]
|
||||
)
|
||||
|
||||
data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
|
||||
|
||||
return data_cache
|
||||
|
||||
def next(self, mode='train'):
|
||||
def next(self, mode="train"):
|
||||
"""
|
||||
Gets next batch from the dataset with name.
|
||||
:param mode: The name of the splitting (one of "train", "val", "test")
|
||||
|
@ -2,13 +2,15 @@
|
||||
# (https://github.com/ericjang/maml-jax).
|
||||
# We translated his implementation from JAX to PyTorch.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.nn import functional as F
|
||||
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
|
||||
def net(x, params):
|
||||
@ -23,13 +25,15 @@ def net(x, params):
|
||||
|
||||
|
||||
params = [
|
||||
torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(),
|
||||
torch.Tensor(40, 1).uniform_(-1.0, 1.0).requires_grad_(),
|
||||
torch.Tensor(40).zero_().requires_grad_(),
|
||||
|
||||
torch.Tensor(40, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(),
|
||||
torch.Tensor(40, 40)
|
||||
.uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
|
||||
.requires_grad_(),
|
||||
torch.Tensor(40).zero_().requires_grad_(),
|
||||
|
||||
torch.Tensor(1, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(),
|
||||
torch.Tensor(1, 40)
|
||||
.uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
|
||||
.requires_grad_(),
|
||||
torch.Tensor(1).zero_().requires_grad_(),
|
||||
]
|
||||
|
||||
@ -46,17 +50,18 @@ def sample_tasks(outer_batch_size, inner_batch_size):
|
||||
As = []
|
||||
phases = []
|
||||
for _ in range(outer_batch_size):
|
||||
As.append(np.random.uniform(low=0.1, high=.5))
|
||||
phases.append(np.random.uniform(low=0., high=np.pi))
|
||||
As.append(np.random.uniform(low=0.1, high=0.5))
|
||||
phases.append(np.random.uniform(low=0.0, high=np.pi))
|
||||
|
||||
def get_batch():
|
||||
xs, ys = [], []
|
||||
for A, phase in zip(As, phases):
|
||||
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
|
||||
x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1))
|
||||
y = A * np.sin(x + phase)
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
|
||||
|
||||
x1, y1 = get_batch()
|
||||
x2, y2 = get_batch()
|
||||
return x1, y1, x2, y2
|
||||
@ -80,14 +85,17 @@ for it in range(20000):
|
||||
return F.mse_loss(v_f, y2)
|
||||
|
||||
task = sample_tasks(num_tasks, K)
|
||||
inner_losses = [get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks)]
|
||||
inner_losses = [
|
||||
get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i])
|
||||
for i in range(num_tasks)
|
||||
]
|
||||
loss2 = sum(inner_losses) / len(inner_losses)
|
||||
loss2.backward()
|
||||
|
||||
opt.step()
|
||||
|
||||
if it % 100 == 0:
|
||||
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
|
||||
print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
|
||||
losses.append(loss2.detach())
|
||||
|
||||
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
|
||||
@ -112,11 +120,11 @@ test_y = t_A * torch.sin(test_x + t_b)
|
||||
|
||||
test_f = net(test_x, t_params)
|
||||
|
||||
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
|
||||
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
|
||||
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
|
||||
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)")
|
||||
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)")
|
||||
plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples")
|
||||
plt.legend()
|
||||
plt.savefig('maml-sine.png')
|
||||
plt.savefig("maml-sine.png")
|
||||
plt.figure()
|
||||
plt.plot(np.convolve(losses, [.05] * 20))
|
||||
plt.savefig('losses.png')
|
||||
plt.plot(np.convolve(losses, [0.05] * 20))
|
||||
plt.savefig("losses.png")
|
||||
|
@ -2,14 +2,16 @@
|
||||
# (https://github.com/ericjang/maml-jax).
|
||||
# We translated his implementation from JAX to PyTorch.
|
||||
|
||||
from torch.func import grad, vmap
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.nn import functional as F
|
||||
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.func import grad, vmap
|
||||
from torch.nn import functional as F
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
|
||||
def net(params, x):
|
||||
@ -24,13 +26,15 @@ def net(params, x):
|
||||
|
||||
|
||||
params = [
|
||||
torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(),
|
||||
torch.Tensor(40, 1).uniform_(-1.0, 1.0).requires_grad_(),
|
||||
torch.Tensor(40).zero_().requires_grad_(),
|
||||
|
||||
torch.Tensor(40, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(),
|
||||
torch.Tensor(40, 40)
|
||||
.uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
|
||||
.requires_grad_(),
|
||||
torch.Tensor(40).zero_().requires_grad_(),
|
||||
|
||||
torch.Tensor(1, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(),
|
||||
torch.Tensor(1, 40)
|
||||
.uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
|
||||
.requires_grad_(),
|
||||
torch.Tensor(1).zero_().requires_grad_(),
|
||||
]
|
||||
|
||||
@ -54,17 +58,18 @@ def sample_tasks(outer_batch_size, inner_batch_size):
|
||||
As = []
|
||||
phases = []
|
||||
for _ in range(outer_batch_size):
|
||||
As.append(np.random.uniform(low=0.1, high=.5))
|
||||
phases.append(np.random.uniform(low=0., high=np.pi))
|
||||
As.append(np.random.uniform(low=0.1, high=0.5))
|
||||
phases.append(np.random.uniform(low=0.0, high=np.pi))
|
||||
|
||||
def get_batch():
|
||||
xs, ys = [], []
|
||||
for A, phase in zip(As, phases):
|
||||
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
|
||||
x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1))
|
||||
y = A * np.sin(x + phase)
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
|
||||
|
||||
x1, y1 = get_batch()
|
||||
x2, y2 = get_batch()
|
||||
return x1, y1, x2, y2
|
||||
@ -94,7 +99,7 @@ for it in range(20000):
|
||||
opt.step()
|
||||
|
||||
if it % 100 == 0:
|
||||
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
|
||||
print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
|
||||
losses.append(loss2.detach())
|
||||
|
||||
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
|
||||
@ -119,11 +124,11 @@ test_y = t_A * torch.sin(test_x + t_b)
|
||||
|
||||
test_f = net(t_params, test_x)
|
||||
|
||||
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
|
||||
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
|
||||
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
|
||||
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)")
|
||||
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)")
|
||||
plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples")
|
||||
plt.legend()
|
||||
plt.savefig('maml-sine.png')
|
||||
plt.savefig("maml-sine.png")
|
||||
plt.figure()
|
||||
plt.plot(np.convolve(losses, [.05] * 20))
|
||||
plt.savefig('losses.png')
|
||||
plt.plot(np.convolve(losses, [0.05] * 20))
|
||||
plt.savefig("losses.png")
|
||||
|
@ -2,15 +2,17 @@
|
||||
# (https://github.com/ericjang/maml-jax).
|
||||
# We translated his implementation from JAX to PyTorch.
|
||||
|
||||
from functorch import grad, vmap, make_functional
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
import torch
|
||||
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from functorch import grad, make_functional, vmap
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
|
||||
class ThreeLayerNet(nn.Module):
|
||||
@ -30,6 +32,7 @@ class ThreeLayerNet(nn.Module):
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
# TODO: Use F.mse_loss
|
||||
|
||||
|
||||
@ -51,17 +54,18 @@ def sample_tasks(outer_batch_size, inner_batch_size):
|
||||
As = []
|
||||
phases = []
|
||||
for _ in range(outer_batch_size):
|
||||
As.append(np.random.uniform(low=0.1, high=.5))
|
||||
phases.append(np.random.uniform(low=0., high=np.pi))
|
||||
As.append(np.random.uniform(low=0.1, high=0.5))
|
||||
phases.append(np.random.uniform(low=0.0, high=np.pi))
|
||||
|
||||
def get_batch():
|
||||
xs, ys = [], []
|
||||
for A, phase in zip(As, phases):
|
||||
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
|
||||
x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1))
|
||||
y = A * np.sin(x + phase)
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
|
||||
|
||||
x1, y1 = get_batch()
|
||||
x2, y2 = get_batch()
|
||||
return x1, y1, x2, y2
|
||||
@ -91,7 +95,7 @@ for it in range(20000):
|
||||
opt.step()
|
||||
|
||||
if it % 100 == 0:
|
||||
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
|
||||
print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
|
||||
losses.append(loss2.detach())
|
||||
|
||||
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
|
||||
@ -116,11 +120,11 @@ test_y = t_A * torch.sin(test_x + t_b)
|
||||
|
||||
test_f = net(t_params, test_x)
|
||||
|
||||
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
|
||||
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
|
||||
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
|
||||
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)")
|
||||
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)")
|
||||
plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples")
|
||||
plt.legend()
|
||||
plt.savefig('maml-sine.png')
|
||||
plt.savefig("maml-sine.png")
|
||||
plt.figure()
|
||||
plt.plot(np.convolve(losses, [.05] * 20))
|
||||
plt.savefig('losses.png')
|
||||
plt.plot(np.convolve(losses, [0.05] * 20))
|
||||
plt.savefig("losses.png")
|
||||
|
@ -1,5 +1,6 @@
|
||||
# PyTorch forward-mode is not mature yet
|
||||
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
|
||||
from torch._functorch.eager_transforms import hessian, jacfwd, jvp
|
||||
from torch._functorch.vmap import chunk_vmap
|
||||
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
|
||||
|
||||
from functorch import functionalize
|
||||
|
@ -1,26 +1,31 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard
|
||||
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
|
||||
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
|
||||
from torch._dynamo.exc import CondOpArgsMismatchError
|
||||
from torch._functorch.eager_transforms import (
|
||||
_unwrap_all_tensors_from_functional,
|
||||
_wrap_all_tensors_to_functional,
|
||||
functionalize,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
ProxyTorchDispatchMode,
|
||||
make_fx,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._python_dispatch import (
|
||||
_get_current_dispatch_mode,
|
||||
_pop_mode_temporarily,
|
||||
)
|
||||
from torch.utils._pytree import tree_flatten
|
||||
from torch._dynamo.exc import CondOpArgsMismatchError
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -34,9 +39,14 @@ In order to do this, we need implementations for each of the dispatch keys.
|
||||
"""
|
||||
cond = HigherOrderOperator("cond")
|
||||
|
||||
|
||||
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors"
|
||||
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"
|
||||
assert isinstance(
|
||||
operands, (list, tuple)
|
||||
), "Cond operands must be a list or tuple of tensors"
|
||||
assert all(
|
||||
isinstance(o, torch.Tensor) for o in operands
|
||||
), "Cond operands must be a list of tensors"
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
true_graph = make_fx(true_fn)(*operands)
|
||||
@ -45,11 +55,11 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
true_outs = []
|
||||
false_outs = []
|
||||
for node in true_graph.graph.nodes:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
true_outs.extend(node.args)
|
||||
|
||||
for node in false_graph.graph.nodes:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
false_outs.extend(node.args)
|
||||
|
||||
flat_true_outs, _ = pytree.tree_flatten(true_outs)
|
||||
@ -64,7 +74,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
for i in range(0, len(flat_true_outs)):
|
||||
true_out = flat_true_outs[i]
|
||||
false_out = flat_false_outs[i]
|
||||
if true_out.meta['tensor_meta'] != false_out.meta['tensor_meta']:
|
||||
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
|
||||
raise CondOpArgsMismatchError(
|
||||
f"Expected each tensor to have same metadata but got:"
|
||||
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
|
||||
@ -85,7 +95,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
|
||||
true_name = next_name
|
||||
false_name = f"false_graph_{i}"
|
||||
assert(not hasattr(proxy_mode.tracer.root, false_name))
|
||||
assert not hasattr(proxy_mode.tracer.root, false_name)
|
||||
|
||||
proxy_mode.tracer.root.register_module(true_name, true_graph)
|
||||
proxy_mode.tracer.root.register_module(false_name, false_graph)
|
||||
@ -94,8 +104,9 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
||||
|
||||
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
|
||||
name="conditional")
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", func_overload, proxy_args, {}, name="conditional"
|
||||
)
|
||||
|
||||
# At this point, we're *guaranteed* that whether an output came from the
|
||||
# true or false branch is indistinguishable. So, as this is just for tracing
|
||||
@ -112,7 +123,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
@cond.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def cond_dense(pred, true_fn, false_fn, operands):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
|
||||
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
||||
if pred:
|
||||
return true_fn(*operands)
|
||||
else:
|
||||
@ -125,8 +136,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
|
||||
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
|
||||
|
||||
requires_grad = any(
|
||||
isinstance(arg, torch.Tensor) and arg.requires_grad
|
||||
for arg in flat_operands
|
||||
isinstance(arg, torch.Tensor) and arg.requires_grad for arg in flat_operands
|
||||
)
|
||||
|
||||
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)):
|
||||
@ -148,6 +158,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
|
||||
var = var.detach()
|
||||
var.requires_grad = True
|
||||
return var
|
||||
|
||||
return err_fn(fake_requires_grad(result))
|
||||
|
||||
return result
|
||||
@ -156,7 +167,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
|
||||
@cond.py_impl(ProxyTorchDispatchMode)
|
||||
def inner(pred, true_fn, false_fn, operands):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert (mode is not None), "Mode should always be enabled for python fallback key"
|
||||
assert mode is not None, "Mode should always be enabled for python fallback key"
|
||||
with _pop_mode_temporarily() as mode:
|
||||
if mode.enable_tracing:
|
||||
return trace_cond(mode, cond, pred, true_fn, false_fn, operands)
|
||||
@ -177,7 +188,8 @@ def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
|
||||
false_meta = _extract_tensor_metadata(false_out)
|
||||
if true_meta != false_meta:
|
||||
raise RuntimeError(
|
||||
f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}")
|
||||
f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}"
|
||||
)
|
||||
return true_outs
|
||||
|
||||
|
||||
@ -203,7 +215,10 @@ def _has_potential_branch_input_mutation(branch, inputs):
|
||||
input_nodes.add(node)
|
||||
if node.op == "call_function":
|
||||
target = node.target
|
||||
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
|
||||
if (
|
||||
isinstance(target, torch._ops.OpOverload)
|
||||
and target._schema.is_mutable
|
||||
):
|
||||
for arg in node.args:
|
||||
if arg in input_nodes:
|
||||
return True
|
||||
@ -241,13 +256,15 @@ def _has_potential_branch_input_alias(branch, inputs):
|
||||
# for map operator, where num_mapped_args is a scalar
|
||||
# and doesn't have a "val" meta.
|
||||
if node.op == "placeholder" and "val" in node.meta:
|
||||
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
|
||||
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
|
||||
if node.op == "output":
|
||||
|
||||
def check_alias(out):
|
||||
if out is not None and "val" in out.meta:
|
||||
out_storage = StorageWeakRef(out.meta['val']._typed_storage())
|
||||
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
|
||||
return out_storage in input_storages
|
||||
return False
|
||||
|
||||
if any(pytree.tree_flatten(pytree.tree_map(check_alias, node.args))[0]):
|
||||
return True
|
||||
|
||||
@ -263,22 +280,30 @@ def _has_potential_branch_input_alias(branch, inputs):
|
||||
@cond.py_impl(DispatchKey.Functionalize)
|
||||
def cond_func(pred, true_fn, false_fn, inputs):
|
||||
reapply_views = torch._C._functionalization_reapply_views_tls()
|
||||
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
|
||||
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
|
||||
mode = 'mutations_and_views' if reapply_views else 'mutations'
|
||||
unwrapped_inputs = _unwrap_all_tensors_from_functional(
|
||||
inputs, reapply_views=reapply_views
|
||||
)
|
||||
unwrapped_pred = _unwrap_all_tensors_from_functional(
|
||||
pred, reapply_views=reapply_views
|
||||
)
|
||||
mode = "mutations_and_views" if reapply_views else "mutations"
|
||||
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
|
||||
functional_true = functionalize(true_fn, remove=mode)
|
||||
functional_false = functionalize(false_fn, remove=mode)
|
||||
for branch in [true_fn, false_fn]:
|
||||
if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be modifying the input!")
|
||||
raise UnsupportedAliasMutationException(
|
||||
"One of torch.cond branch " "might be modifying the input!"
|
||||
)
|
||||
|
||||
if _has_potential_branch_input_alias(branch, unwrapped_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be aliasing the input!")
|
||||
raise UnsupportedAliasMutationException(
|
||||
"One of torch.cond branch " "might be aliasing the input!"
|
||||
)
|
||||
|
||||
cond_return = cond(unwrapped_pred, functional_true, functional_false, unwrapped_inputs)
|
||||
cond_return = cond(
|
||||
unwrapped_pred, functional_true, functional_false, unwrapped_inputs
|
||||
)
|
||||
return _wrap_all_tensors_to_functional(cond_return, level=0)
|
||||
|
||||
|
||||
@ -290,10 +315,14 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
|
||||
2. Our check for above condition is not exhaustive
|
||||
"""
|
||||
reapply_views = interpreter.functionalize_add_back_views()
|
||||
mode = 'mutations_and_views' if reapply_views else 'mutations'
|
||||
mode = "mutations_and_views" if reapply_views else "mutations"
|
||||
# At this point, we will see functionalized tensors, so need to unwrap them first
|
||||
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
|
||||
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
|
||||
unwrapped_inputs = _unwrap_all_tensors_from_functional(
|
||||
inputs, reapply_views=reapply_views
|
||||
)
|
||||
unwrapped_pred = _unwrap_all_tensors_from_functional(
|
||||
pred, reapply_views=reapply_views
|
||||
)
|
||||
|
||||
functional_true_fn = functionalize(true_fn, remove=mode)
|
||||
functional_false_fn = functionalize(false_fn, remove=mode)
|
||||
@ -301,16 +330,21 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
|
||||
with interpreter.lower():
|
||||
for branch in [functional_true_fn, functional_false_fn]:
|
||||
if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be modifying the input!")
|
||||
raise UnsupportedAliasMutationException(
|
||||
"One of torch.cond branch " "might be modifying the input!"
|
||||
)
|
||||
for branch in [true_fn, false_fn]:
|
||||
if _has_potential_branch_input_alias(branch, unwrapped_inputs):
|
||||
raise UnsupportedAliasMutationException("One of torch.cond branch "
|
||||
"might be aliasing the input!")
|
||||
raise UnsupportedAliasMutationException(
|
||||
"One of torch.cond branch " "might be aliasing the input!"
|
||||
)
|
||||
|
||||
cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs)
|
||||
cond_return = cond(
|
||||
unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs
|
||||
)
|
||||
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())
|
||||
|
||||
|
||||
# TODO(voz): Make this automatic for keys, this is very ugly atm
|
||||
cond.fallthrough(DispatchKey.PythonDispatcher)
|
||||
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
|
||||
|
@ -1,23 +1,32 @@
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard
|
||||
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
|
||||
from torch._functorch.aot_autograd import create_joint, AOTConfig
|
||||
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
||||
from torch._functorch.eager_transforms import (
|
||||
_unwrap_all_tensors_from_functional,
|
||||
_wrap_all_tensors_to_functional,
|
||||
functionalize,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
make_fx,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._python_dispatch import (
|
||||
_get_current_dispatch_mode,
|
||||
_pop_mode_temporarily,
|
||||
)
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from ._cond import _has_potential_branch_input_alias, _has_potential_branch_input_mutation, UnsupportedAliasMutationException
|
||||
|
||||
from ._cond import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
|
||||
|
||||
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
|
||||
@ -26,16 +35,19 @@ class MapWrapper(HigherOrderOperator):
|
||||
def __call__(self, xs, *args):
|
||||
return map_wrapper(xs, *args)
|
||||
|
||||
|
||||
map = MapWrapper("map", _deprecated_global_ns=True)
|
||||
map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True)
|
||||
|
||||
dummy_aot_config = AOTConfig(fw_compiler=None,
|
||||
bw_compiler=None,
|
||||
partition_fn=None,
|
||||
decompositions={},
|
||||
num_params_buffers=0,
|
||||
aot_id=0,
|
||||
keep_inference_input_mutations=False)
|
||||
dummy_aot_config = AOTConfig(
|
||||
fw_compiler=None,
|
||||
bw_compiler=None,
|
||||
partition_fn=None,
|
||||
decompositions={},
|
||||
num_params_buffers=0,
|
||||
aot_id=0,
|
||||
keep_inference_input_mutations=False,
|
||||
)
|
||||
|
||||
|
||||
def create_fw_bw_graph(f, num_mapped_args, *args):
|
||||
@ -59,20 +71,33 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
|
||||
|
||||
with suspend_functionalization():
|
||||
with disable_proxy_modes_tracing():
|
||||
|
||||
def from_fun(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return torch.empty_strided(t.size(), t.stride(), requires_grad=t.requires_grad)
|
||||
return torch.empty_strided(
|
||||
t.size(), t.stride(), requires_grad=t.requires_grad
|
||||
)
|
||||
return t
|
||||
|
||||
example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]]
|
||||
example_pos_args = [from_fun(arg) if isinstance(arg, torch.Tensor) else arg for arg in pos_args]
|
||||
example_flat_out = pytree.tree_map(from_fun, f(*example_xs, *example_pos_args))
|
||||
if any(not isinstance(out, torch.Tensor) for out in example_flat_out if out is not None):
|
||||
raise RuntimeError("Expect outputs of map only contains tensors or None. "
|
||||
f"Got types {[type(out) for out in example_flat_out]}.")
|
||||
example_pos_args = [
|
||||
from_fun(arg) if isinstance(arg, torch.Tensor) else arg
|
||||
for arg in pos_args
|
||||
]
|
||||
example_flat_out = pytree.tree_map(
|
||||
from_fun, f(*example_xs, *example_pos_args)
|
||||
)
|
||||
if any(
|
||||
not isinstance(out, torch.Tensor)
|
||||
for out in example_flat_out
|
||||
if out is not None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Expect outputs of map only contains tensors or None. "
|
||||
f"Got types {[type(out) for out in example_flat_out]}."
|
||||
)
|
||||
example_grad = [from_fun(out) for out in example_flat_out]
|
||||
|
||||
|
||||
fw_graph = make_fx(f)(*example_xs, *example_pos_args)
|
||||
|
||||
def joint_f(*example_args):
|
||||
@ -84,20 +109,39 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
|
||||
|
||||
def fw_with_masks(*args):
|
||||
fw_out = f(*args)
|
||||
return fw_out, [True if isinstance(ret, torch.Tensor) and ret.requires_grad else False for ret in fw_out]
|
||||
return fw_out, [
|
||||
True
|
||||
if isinstance(ret, torch.Tensor) and ret.requires_grad
|
||||
else False
|
||||
for ret in fw_out
|
||||
]
|
||||
|
||||
joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
|
||||
_, grads = joint(list(mapped_input) + list(args),
|
||||
[grad for grad in mapped_grads if grad is not None and grad.requires_grad])
|
||||
_, grads = joint(
|
||||
list(mapped_input) + list(args),
|
||||
[
|
||||
grad
|
||||
for grad in mapped_grads
|
||||
if grad is not None and grad.requires_grad
|
||||
],
|
||||
)
|
||||
|
||||
# In order to keep map functional for backward graph,
|
||||
# we clone outputs that are aliasing inputs
|
||||
input_storage = {StorageWeakRef(arg._typed_storage()) for arg in example_args if isinstance(arg, torch.Tensor)}
|
||||
input_storage = {
|
||||
StorageWeakRef(arg._typed_storage())
|
||||
for arg in example_args
|
||||
if isinstance(arg, torch.Tensor)
|
||||
}
|
||||
|
||||
def maybe_clone(t):
|
||||
if isinstance(t, torch.Tensor) and StorageWeakRef(t._typed_storage()) in input_storage:
|
||||
if (
|
||||
isinstance(t, torch.Tensor)
|
||||
and StorageWeakRef(t._typed_storage()) in input_storage
|
||||
):
|
||||
return t.clone()
|
||||
return t
|
||||
|
||||
return pytree.tree_map(maybe_clone, grads)
|
||||
|
||||
joint_num_mapped = len(example_grad) + len(example_xs)
|
||||
@ -114,12 +158,12 @@ def map_wrapper(f, xs, *args):
|
||||
shapes = [xs.shape for xs in flat_xs]
|
||||
leading_dim_size = shapes[0][0]
|
||||
if leading_dim_size == 0:
|
||||
raise RuntimeError(
|
||||
"Leading dimensions of mapped xs cannot be 0.")
|
||||
raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
|
||||
|
||||
if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
|
||||
raise RuntimeError(
|
||||
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}.")
|
||||
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
|
||||
)
|
||||
|
||||
out_spec = None
|
||||
|
||||
@ -131,7 +175,11 @@ def map_wrapper(f, xs, *args):
|
||||
nonlocal out_spec
|
||||
out_spec = tmp_out_spec
|
||||
return flat_out
|
||||
return pytree.tree_unflatten(map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec)
|
||||
|
||||
return pytree.tree_unflatten(
|
||||
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec
|
||||
)
|
||||
|
||||
|
||||
class MapAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ -140,17 +188,24 @@ class MapAutogradOp(torch.autograd.Function):
|
||||
ctx._joint_graph = joint_graph
|
||||
ctx._num_mapped_args = num_mapped_args
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return (*map_impl(fw_graph, num_mapped_args, *flat_args), )
|
||||
return (*map_impl(fw_graph, num_mapped_args, *flat_args),)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *flat_grads):
|
||||
fw_args = ctx.saved_tensors
|
||||
fw_mapped_args = fw_args[:ctx._num_mapped_args]
|
||||
pos_args = fw_args[ctx._num_mapped_args:]
|
||||
fw_mapped_args = fw_args[: ctx._num_mapped_args]
|
||||
pos_args = fw_args[ctx._num_mapped_args :]
|
||||
|
||||
grads = map_impl(ctx._joint_graph, ctx._num_mapped_args + len(flat_grads), *fw_mapped_args, *flat_grads, *pos_args)
|
||||
grads = map_impl(
|
||||
ctx._joint_graph,
|
||||
ctx._num_mapped_args + len(flat_grads),
|
||||
*fw_mapped_args,
|
||||
*flat_grads,
|
||||
*pos_args,
|
||||
)
|
||||
return None, None, None, *grads
|
||||
|
||||
|
||||
def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
|
||||
xs = list(args[:num_mapped])
|
||||
pos_args = list(args[num_mapped:])
|
||||
@ -168,6 +223,7 @@ def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.expand(leading_dim_size, *t.shape)
|
||||
return t
|
||||
|
||||
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
|
||||
|
||||
next_name = None
|
||||
@ -182,9 +238,13 @@ def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
|
||||
proxy_mode.tracer.root.register_module(next_name, body_graph)
|
||||
node_args = (body_graph, num_mapped, *args)
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
|
||||
name="map_impl")
|
||||
return track_tensor_tree(expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", func_overload, proxy_args, {}, name="map_impl"
|
||||
)
|
||||
return track_tensor_tree(
|
||||
expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
|
||||
)
|
||||
|
||||
|
||||
def _unstack_pytree(xs):
|
||||
flat_xs, inspec = pytree.tree_flatten(xs)
|
||||
@ -192,7 +252,9 @@ def _unstack_pytree(xs):
|
||||
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
|
||||
|
||||
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
|
||||
raise RuntimeError(f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}")
|
||||
raise RuntimeError(
|
||||
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
|
||||
)
|
||||
|
||||
a = zip(*flat_xs)
|
||||
pytrees = []
|
||||
@ -200,6 +262,7 @@ def _unstack_pytree(xs):
|
||||
pytrees.append(pytree.tree_unflatten(tuple, inspec))
|
||||
return pytrees
|
||||
|
||||
|
||||
def _stack_pytree(pytrees):
|
||||
flat_out = []
|
||||
out_spec = None
|
||||
@ -220,6 +283,7 @@ def _stack_pytree(pytrees):
|
||||
raise RuntimeError(f"Cannot stack {leaves}.")
|
||||
return pytree.tree_unflatten(stacked_out, out_spec)
|
||||
|
||||
|
||||
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def map_dense(f, num_mapped_args, *args):
|
||||
xs = args[:num_mapped_args]
|
||||
@ -240,7 +304,7 @@ def map_autograd(f, num_mapped_args, *args):
|
||||
@map_impl.py_impl(ProxyTorchDispatchMode)
|
||||
def map_proxy_torch_dispatch_mode(f, num_mapped, *args):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert (mode is not None), "Mode should always be enabled for python fallback key"
|
||||
assert mode is not None, "Mode should always be enabled for python fallback key"
|
||||
with _pop_mode_temporarily() as mode:
|
||||
if mode.enable_tracing:
|
||||
return trace_map(mode, map_impl, f, num_mapped, *args)
|
||||
@ -259,8 +323,10 @@ def map_func(f, num_mapped, *args):
|
||||
xs = args[:num_mapped]
|
||||
pos_args = args[num_mapped:]
|
||||
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
|
||||
unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views)
|
||||
mode = 'mutations_and_views' if reapply_views else 'mutations'
|
||||
unwrapped_args = _unwrap_all_tensors_from_functional(
|
||||
pos_args, reapply_views=reapply_views
|
||||
)
|
||||
mode = "mutations_and_views" if reapply_views else "mutations"
|
||||
|
||||
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
|
||||
functional_map_fn = functionalize(f, remove=mode)
|
||||
@ -268,18 +334,17 @@ def map_func(f, num_mapped, *args):
|
||||
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
||||
|
||||
if _has_potential_branch_input_mutation(f, example_inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is mutating the input!"
|
||||
)
|
||||
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
|
||||
|
||||
if _has_potential_branch_input_alias(f, example_inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is aliasing the input!"
|
||||
)
|
||||
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
|
||||
|
||||
map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
|
||||
map_return = map_impl(
|
||||
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
|
||||
)
|
||||
return _wrap_all_tensors_to_functional(map_return, level=0)
|
||||
|
||||
|
||||
@map_impl.py_impl(torch._C._functorch.TransformType.Functionalize)
|
||||
def map_functionalize(interpreter, f, num_mapped, *args):
|
||||
"""
|
||||
@ -290,10 +355,12 @@ def map_functionalize(interpreter, f, num_mapped, *args):
|
||||
xs = args[:num_mapped]
|
||||
pos_args = args[num_mapped:]
|
||||
reapply_views = interpreter.functionalize_add_back_views()
|
||||
mode = 'mutations_and_views' if reapply_views else 'mutations'
|
||||
mode = "mutations_and_views" if reapply_views else "mutations"
|
||||
# At this point, we will see functionalized tensors, so need to unwrap them first
|
||||
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
|
||||
unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views)
|
||||
unwrapped_args = _unwrap_all_tensors_from_functional(
|
||||
pos_args, reapply_views=reapply_views
|
||||
)
|
||||
|
||||
functional_map_fn = functionalize(f, remove=mode)
|
||||
|
||||
@ -301,18 +368,17 @@ def map_functionalize(interpreter, f, num_mapped, *args):
|
||||
with disable_proxy_modes_tracing():
|
||||
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
||||
if _has_potential_branch_input_mutation(f, example_inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is mutating the input!"
|
||||
)
|
||||
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
|
||||
|
||||
if _has_potential_branch_input_alias(f, example_inputs):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"torch.map is aliasing the input!"
|
||||
)
|
||||
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
|
||||
|
||||
map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
|
||||
map_return = map_impl(
|
||||
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
|
||||
)
|
||||
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
|
||||
|
||||
|
||||
# TODO(voz) Make this automatic for keys, this is very ugly atm
|
||||
map_impl.fallthrough(DispatchKey.PythonDispatcher)
|
||||
map_impl.fallthrough(DispatchKey.PythonTLSSnapshot)
|
||||
|
@ -1,2 +1,2 @@
|
||||
from ._map import map # noqa: F401
|
||||
from ._cond import cond, UnsupportedAliasMutationException # noqa: F401
|
||||
from ._map import map # noqa: F401
|
||||
|
@ -19,8 +19,10 @@ Let's demonstrate how to do this using an ensemble of simple CNNs.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
# Here's a simple CNN
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self):
|
||||
@ -44,11 +46,12 @@ class SimpleCNN(nn.Module):
|
||||
output = x
|
||||
return output
|
||||
|
||||
|
||||
# Let's generate some dummy data. Pretend that we're working with an MNIST dataset
|
||||
# where the images are 28 by 28.
|
||||
# Furthermore, let's say we wish to combine the predictions from 10 different
|
||||
# models.
|
||||
device = 'cuda'
|
||||
device = "cuda"
|
||||
num_models = 10
|
||||
data = torch.randn(100, 64, 1, 28, 28, device=device)
|
||||
targets = torch.randint(10, (6400,), device=device)
|
||||
@ -81,6 +84,7 @@ predictions2 = [model(minibatch) for model in models]
|
||||
# functorch offers the following convenience function to do that. It returns a
|
||||
# stateless version of the model (fmodel) and stacked parameters and buffers.
|
||||
from functorch import combine_state_for_ensemble
|
||||
|
||||
fmodel, params, buffers = combine_state_for_ensemble(models)
|
||||
[p.requires_grad_() for p in params]
|
||||
|
||||
@ -92,15 +96,20 @@ fmodel, params, buffers = combine_state_for_ensemble(models)
|
||||
print([p.size(0) for p in params])
|
||||
assert minibatches.shape == (num_models, 64, 1, 28, 28)
|
||||
from functorch import vmap
|
||||
|
||||
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
|
||||
assert torch.allclose(predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6)
|
||||
assert torch.allclose(
|
||||
predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6
|
||||
)
|
||||
|
||||
# Option 2: get predictions using the same minibatch of data
|
||||
# vmap has an in_dims arg that specify which dimensions to map over.
|
||||
# Using ``None``, we tell vmap we want the same minibatch to apply for all of
|
||||
# the 10 models.
|
||||
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
|
||||
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6)
|
||||
assert torch.allclose(
|
||||
predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6
|
||||
)
|
||||
|
||||
# A quick note: there are limitations around what types of functions can be
|
||||
# transformed by vmap. The best functions to transform are ones that are
|
||||
|
@ -8,11 +8,14 @@ deep learning models. It is difficult (or annoying) to compute these quantities
|
||||
efficiently using a standard autodiff system like PyTorch Autograd; functorch
|
||||
provides ways of computing various higher-order autodiff quantities efficiently.
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
######################################################################
|
||||
# Setup: Comparing functorch vs the naive approach
|
||||
# --------------------------------------------------------------------
|
||||
@ -21,6 +24,7 @@ torch.manual_seed(0)
|
||||
def predict(weight, bias, x):
|
||||
return F.linear(x, weight, bias).tanh()
|
||||
|
||||
|
||||
# Here's some dummy data: a weight, a bias, and a feature vector.
|
||||
D = 16
|
||||
weight = torch.randn(D, D)
|
||||
@ -34,19 +38,24 @@ x = torch.randn(D)
|
||||
xp = x.clone().requires_grad_()
|
||||
unit_vectors = torch.eye(D)
|
||||
|
||||
|
||||
def compute_jac(xp):
|
||||
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
|
||||
for vec in unit_vectors]
|
||||
jacobian_rows = [
|
||||
torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
|
||||
for vec in unit_vectors
|
||||
]
|
||||
return torch.stack(jacobian_rows)
|
||||
|
||||
|
||||
jacobian = compute_jac(xp)
|
||||
|
||||
# Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid
|
||||
# of the for-loop and vectorize the computation. We can't directly apply vmap
|
||||
# to PyTorch Autograd; instead, functorch provides a ``vjp`` transform:
|
||||
from functorch import vmap, vjp
|
||||
from functorch import vjp, vmap
|
||||
|
||||
_, vjp_fn = vjp(partial(predict, weight, bias), x)
|
||||
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
|
||||
(ft_jacobian,) = vmap(vjp_fn)(unit_vectors)
|
||||
assert torch.allclose(ft_jacobian, jacobian)
|
||||
|
||||
# In another tutorial a composition of reverse-mode AD and vmap gave us
|
||||
@ -59,6 +68,7 @@ assert torch.allclose(ft_jacobian, jacobian)
|
||||
# argument that says which argument we would like to compute Jacobians with
|
||||
# respect to.
|
||||
from functorch import jacrev
|
||||
|
||||
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
|
||||
assert torch.allclose(ft_jacobian, jacobian)
|
||||
|
||||
@ -67,6 +77,7 @@ assert torch.allclose(ft_jacobian, jacobian)
|
||||
# there are). In general, we expect that vectorization via ``vmap`` can help
|
||||
# eliminate overhead and give better utilization of your hardware.
|
||||
from torch.utils.benchmark import Timer
|
||||
|
||||
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
|
||||
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
|
||||
print(without_vmap.timeit(500))
|
||||
@ -95,7 +106,7 @@ ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
|
||||
# In reverse-mode AD, we are computing the jacobian row-by-row, while in
|
||||
# forward-mode AD (which computes Jacobian-vector products), we are computing
|
||||
# it column-by-column. The Jacobian matrix has M rows and N columns.
|
||||
from functorch import jacrev, jacfwd
|
||||
from functorch import jacfwd, jacrev
|
||||
|
||||
# Benchmark with more inputs than outputs
|
||||
Din = 32
|
||||
@ -106,8 +117,8 @@ x = torch.randn(Din)
|
||||
|
||||
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
|
||||
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
|
||||
print(f'jacfwd time: {using_fwd.timeit(500)}')
|
||||
print(f'jacrev time: {using_bwd.timeit(500)}')
|
||||
print(f"jacfwd time: {using_fwd.timeit(500)}")
|
||||
print(f"jacrev time: {using_bwd.timeit(500)}")
|
||||
|
||||
# Benchmark with more outputs than inputs
|
||||
Din = 2048
|
||||
@ -118,8 +129,8 @@ x = torch.randn(Din)
|
||||
|
||||
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
|
||||
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
|
||||
print(f'jacfwd time: {using_fwd.timeit(500)}')
|
||||
print(f'jacrev time: {using_bwd.timeit(500)}')
|
||||
print(f"jacfwd time: {using_fwd.timeit(500)}")
|
||||
print(f"jacrev time: {using_bwd.timeit(500)}")
|
||||
|
||||
######################################################################
|
||||
# Hessian computation with functorch.hessian
|
||||
@ -132,6 +143,7 @@ print(f'jacrev time: {using_bwd.timeit(500)}')
|
||||
# Depending on your model, you may want to use ``jacfwd(jacfwd(f))`` or
|
||||
# ``jacrev(jacrev(f))`` instead to compute hessians.
|
||||
from functorch import hessian
|
||||
|
||||
# # TODO: make sure PyTorch has tanh_backward implemented for jvp!!
|
||||
# hess0 = hessian(predict, argnums=2)(weight, bias, x)
|
||||
# hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
|
||||
@ -148,9 +160,11 @@ hess2 = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
|
||||
# The easiest way to do this is to sum over the batch dimension and then
|
||||
# compute the Jacobian of that function:
|
||||
|
||||
|
||||
def predict_with_output_summed(weight, bias, x):
|
||||
return predict(weight, bias, x).sum(0)
|
||||
|
||||
|
||||
batch_size = 64
|
||||
Din = 31
|
||||
Dout = 33
|
||||
|
@ -12,8 +12,10 @@ and optimization research.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
# Here's a simple CNN
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self):
|
||||
@ -37,12 +39,14 @@ class SimpleCNN(nn.Module):
|
||||
output = x
|
||||
return output
|
||||
|
||||
|
||||
def loss_fn(predictions, targets):
|
||||
return F.nll_loss(predictions, targets)
|
||||
|
||||
|
||||
# Let's generate a batch of dummy data. Pretend that we're working with an
|
||||
# MNIST dataset where the images are 28 by 28 and we have a minibatch of size 64.
|
||||
device = 'cuda'
|
||||
device = "cuda"
|
||||
num_models = 10
|
||||
batch_size = 64
|
||||
data = torch.randn(batch_size, 1, 28, 28, device=device)
|
||||
@ -56,6 +60,7 @@ predictions = model(data)
|
||||
loss = loss_fn(predictions, targets)
|
||||
loss.backward()
|
||||
|
||||
|
||||
# Conceptually, per-sample-gradient computation is equivalent to: for each sample
|
||||
# of the data, perform a forward and a backward pass to get a gradient.
|
||||
def compute_grad(sample, target):
|
||||
@ -65,12 +70,14 @@ def compute_grad(sample, target):
|
||||
loss = loss_fn(prediction, target)
|
||||
return torch.autograd.grad(loss, list(model.parameters()))
|
||||
|
||||
|
||||
def compute_sample_grads(data, targets):
|
||||
sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
|
||||
sample_grads = zip(*sample_grads)
|
||||
sample_grads = [torch.stack(shards) for shards in sample_grads]
|
||||
return sample_grads
|
||||
|
||||
|
||||
per_sample_grads = compute_sample_grads(data, targets)
|
||||
|
||||
# sample_grads[0] is the per-sample-grad for model.conv1.weight
|
||||
@ -85,9 +92,11 @@ print(per_sample_grads[0].shape)
|
||||
# We can compute per-sample-gradients efficiently by using function transforms.
|
||||
# First, let's create a stateless functional version of ``model`` by using
|
||||
# ``functorch.make_functional_with_buffers``.
|
||||
from functorch import make_functional_with_buffers, vmap, grad
|
||||
from functorch import grad, make_functional_with_buffers, vmap
|
||||
|
||||
fmodel, params, buffers = make_functional_with_buffers(model)
|
||||
|
||||
|
||||
# Next, let's define a function to compute the loss of the model given a single
|
||||
# input rather than a batch of inputs. It is important that this function accepts the
|
||||
# parameters, the input, and the target, because we will be transforming over them.
|
||||
@ -100,6 +109,7 @@ def compute_loss(params, buffers, sample, target):
|
||||
loss = loss_fn(predictions, targets)
|
||||
return loss
|
||||
|
||||
|
||||
# Now, let's use ``grad`` to create a new function that computes the gradient
|
||||
# with respect to the first argument of compute_loss (i.e. the params).
|
||||
ft_compute_grad = grad(compute_loss)
|
||||
|
@ -1,8 +1,9 @@
|
||||
import yaml
|
||||
import csv
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
|
||||
def get_ops_for_key(key):
|
||||
# Needs modified PyTorch C++ code to work
|
||||
@ -12,7 +13,7 @@ def get_ops_for_key(key):
|
||||
ops = torch._C._dispatch_get_registrations_for_dispatch_key(key)
|
||||
cleaned_ops = []
|
||||
for i in ops:
|
||||
if 'aten::' not in i:
|
||||
if "aten::" not in i:
|
||||
continue
|
||||
cleaned_ops.append(i[6:].strip())
|
||||
return set(cleaned_ops)
|
||||
@ -20,12 +21,17 @@ def get_ops_for_key(key):
|
||||
|
||||
def gen_data(special_op_lists, analysis_name):
|
||||
all_ops = get_ops_for_key(None)
|
||||
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
|
||||
composite_ops = get_ops_for_key("CompositeImplicitAutograd")
|
||||
noncomposite_ops = all_ops - composite_ops
|
||||
|
||||
ops = yaml.load(open('../../aten/src/ATen/native/native_functions.yaml').read(), Loader=yaml.CLoader)
|
||||
ops = yaml.load(
|
||||
open("../../aten/src/ATen/native/native_functions.yaml").read(),
|
||||
Loader=yaml.CLoader,
|
||||
)
|
||||
|
||||
annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))}
|
||||
annotated_ops = {
|
||||
a.strip(): b.strip() for a, b in list(csv.reader(open("annotated_ops")))
|
||||
}
|
||||
from collections import defaultdict
|
||||
|
||||
uniq_ops = []
|
||||
@ -33,18 +39,18 @@ def gen_data(special_op_lists, analysis_name):
|
||||
overload_types = defaultdict(list)
|
||||
cnt = 0
|
||||
for op in ops:
|
||||
func_str = op['func']
|
||||
name = func_str[:func_str.index('(')]
|
||||
if '.' in name:
|
||||
uniq_name = name[:name.index('.')]
|
||||
overload_types[name[name.index('.') + 1:]].append(name)
|
||||
func_str = op["func"]
|
||||
name = func_str[: func_str.index("(")]
|
||||
if "." in name:
|
||||
uniq_name = name[: name.index(".")]
|
||||
overload_types[name[name.index(".") + 1 :]].append(name)
|
||||
else:
|
||||
uniq_name = name
|
||||
op['name'] = uniq_name
|
||||
full_name = func_str[:func_str.index('(')]
|
||||
op['full_name'] = full_name
|
||||
ret_type = func_str[func_str.index('->') + 3:]
|
||||
op['ret_type'] = ret_type
|
||||
op["name"] = uniq_name
|
||||
full_name = func_str[: func_str.index("(")]
|
||||
op["full_name"] = full_name
|
||||
ret_type = func_str[func_str.index("->") + 3 :]
|
||||
op["ret_type"] = ret_type
|
||||
cnt += 1
|
||||
if uniq_name in uniq_names:
|
||||
continue
|
||||
@ -54,104 +60,123 @@ def gen_data(special_op_lists, analysis_name):
|
||||
def annotate_ops(ops, is_unique):
|
||||
categorization = defaultdict(int)
|
||||
for op in ops:
|
||||
if op['name'][-1] == '_':
|
||||
categorization['inplace'] += 1
|
||||
op['meta'] = 'inplace'
|
||||
if op["name"][-1] == "_":
|
||||
categorization["inplace"] += 1
|
||||
op["meta"] = "inplace"
|
||||
continue
|
||||
if not is_unique and 'a!' in op['func'].lower():
|
||||
categorization['out'] += 1
|
||||
op['meta'] = 'out'
|
||||
if not is_unique and "a!" in op["func"].lower():
|
||||
categorization["out"] += 1
|
||||
op["meta"] = "out"
|
||||
continue
|
||||
if 'conv' in op['name']:
|
||||
categorization['conv'] += 1
|
||||
op['meta'] = 'conv'
|
||||
if "conv" in op["name"]:
|
||||
categorization["conv"] += 1
|
||||
op["meta"] = "conv"
|
||||
continue
|
||||
if 'pool' in op['name']:
|
||||
categorization['pool'] += 1
|
||||
op['meta'] = 'pool'
|
||||
if "pool" in op["name"]:
|
||||
categorization["pool"] += 1
|
||||
op["meta"] = "pool"
|
||||
continue
|
||||
if 'backward' in op['name']:
|
||||
categorization['backward'] += 1
|
||||
op['meta'] = 'backward'
|
||||
if "backward" in op["name"]:
|
||||
categorization["backward"] += 1
|
||||
op["meta"] = "backward"
|
||||
continue
|
||||
if op['name'][0] == '_' and op['name'][1] != '_':
|
||||
categorization['private'] += 1
|
||||
op['meta'] = 'private'
|
||||
if op["name"][0] == "_" and op["name"][1] != "_":
|
||||
categorization["private"] += 1
|
||||
op["meta"] = "private"
|
||||
continue
|
||||
if 'batch_norm' in op['name']:
|
||||
categorization['batch_norm'] += 1
|
||||
op['meta'] = 'batch_norm'
|
||||
if "batch_norm" in op["name"]:
|
||||
categorization["batch_norm"] += 1
|
||||
op["meta"] = "batch_norm"
|
||||
continue
|
||||
if 'Tensor' not in op['func'] or 'Tensor' not in op['ret_type']:
|
||||
categorization['non_tensor'] += 1
|
||||
op['meta'] = 'non_tensor'
|
||||
if "Tensor" not in op["func"] or "Tensor" not in op["ret_type"]:
|
||||
categorization["non_tensor"] += 1
|
||||
op["meta"] = "non_tensor"
|
||||
continue
|
||||
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or \
|
||||
'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']:
|
||||
categorization['backend'] += 1
|
||||
op['meta'] = 'backend'
|
||||
if (
|
||||
"cudnn" in op["name"]
|
||||
or "mkldnn" in op["name"]
|
||||
or "miopen" in op["name"]
|
||||
or "native" in op["name"]
|
||||
or "thnn" in op["name"]
|
||||
or "slow" in op["name"]
|
||||
):
|
||||
categorization["backend"] += 1
|
||||
op["meta"] = "backend"
|
||||
continue
|
||||
if op['name'] in annotated_ops:
|
||||
categorization['core'] += 1
|
||||
op['meta'] = 'core ' + annotated_ops[op['name']]
|
||||
if op["name"] in annotated_ops:
|
||||
categorization["core"] += 1
|
||||
op["meta"] = "core " + annotated_ops[op["name"]]
|
||||
continue
|
||||
categorization['core'] += 1
|
||||
op['meta'] = 'core unknown'
|
||||
categorization["core"] += 1
|
||||
op["meta"] = "core unknown"
|
||||
return categorization
|
||||
|
||||
annotate_ops(ops, is_unique=False)
|
||||
with open(f"{analysis_name}", 'w') as f:
|
||||
with open(f"{analysis_name}", "w") as f:
|
||||
for op in ops:
|
||||
info = [
|
||||
op['full_name'], op['meta'], op['full_name'] not in noncomposite_ops
|
||||
op["full_name"],
|
||||
op["meta"],
|
||||
op["full_name"] not in noncomposite_ops,
|
||||
] + [check(op) for check in special_op_lists]
|
||||
f.write(','.join([str(i) for i in info]) + '\n')
|
||||
f.write(",".join([str(i) for i in info]) + "\n")
|
||||
|
||||
|
||||
def name_check(lst):
|
||||
return lambda x: x['name'] in lst
|
||||
return lambda x: x["name"] in lst
|
||||
|
||||
|
||||
def full_name_check(lst):
|
||||
return lambda x: x['full_name'] in lst
|
||||
return lambda x: x["full_name"] in lst
|
||||
|
||||
|
||||
# Generates batching rule data
|
||||
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap.txt')
|
||||
gen_data([full_name_check(get_ops_for_key("FuncTorchBatched"))], "vmap.txt")
|
||||
|
||||
|
||||
def remove_suffix(input_string, suffix):
|
||||
if suffix and input_string.endswith(suffix):
|
||||
return input_string[:-len(suffix)]
|
||||
return input_string[: -len(suffix)]
|
||||
return input_string
|
||||
|
||||
|
||||
def remove_prefix(input_string, prefix):
|
||||
if prefix and input_string.startswith(prefix):
|
||||
return input_string[len(prefix):]
|
||||
return input_string[len(prefix) :]
|
||||
return input_string
|
||||
|
||||
|
||||
if True:
|
||||
with open('run_ops.txt') as f:
|
||||
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
|
||||
with open('count_ops.txt') as f:
|
||||
with open("run_ops.txt") as f:
|
||||
opinfo_ops = [remove_suffix(i.strip(), ".default") for i in f.readlines()]
|
||||
with open("count_ops.txt") as f:
|
||||
opinfo_counts = [i.strip() for i in f.readlines()]
|
||||
opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts)))
|
||||
|
||||
def count_fn(x):
|
||||
return opinfo_counts[x['full_name']]
|
||||
return opinfo_counts[x["full_name"]]
|
||||
|
||||
with open('run_decompositions.txt') as f:
|
||||
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
|
||||
with open("run_decompositions.txt") as f:
|
||||
decomposed_ops = [remove_suffix(i.strip(), ".default") for i in f.readlines()]
|
||||
|
||||
with open('public_api') as f:
|
||||
with open("public_api") as f:
|
||||
ref_api = [i.strip() for i in f.readlines()]
|
||||
|
||||
def has_ref_impl(x):
|
||||
name = x['name']
|
||||
name = x["name"]
|
||||
for prefix in ["linalg_", "special_"]:
|
||||
name = remove_prefix(name, prefix)
|
||||
prefixes = ['nn.functional', 'fft', 'special', 'linalg']
|
||||
return any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api
|
||||
prefixes = ["nn.functional", "fft", "special", "linalg"]
|
||||
return (
|
||||
any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api
|
||||
)
|
||||
|
||||
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops), count_fn, has_ref_impl], 'decompositions.txt')
|
||||
gen_data(
|
||||
[
|
||||
full_name_check(opinfo_ops),
|
||||
full_name_check(decomposed_ops),
|
||||
count_fn,
|
||||
has_ref_impl,
|
||||
],
|
||||
"decompositions.txt",
|
||||
)
|
||||
|
Reference in New Issue
Block a user