diff --git a/.clang-format b/.clang-format index 9f90836e1..38790238f 100755 --- a/.clang-format +++ b/.clang-format @@ -1,155 +1,155 @@ ---- -# Refer to the following link for the explanation of each params: -# http://releases.llvm.org/8.0.0/tools/clang/docs/ClangFormatStyleOptions.html -Language: Cpp -# BasedOnStyle: Google -AccessModifierOffset: -4 -AlignAfterOpenBracket: Align -AlignConsecutiveAssignments: false -AlignConsecutiveDeclarations: false -AlignEscapedNewlines: Left -AlignOperands: true -AlignTrailingComments: true -AllowAllParametersOfDeclarationOnNextLine: false -AllowShortBlocksOnASingleLine: true -AllowShortCaseLabelsOnASingleLine: true -AllowShortFunctionsOnASingleLine: All -AllowShortIfStatementsOnASingleLine: true -AllowShortLoopsOnASingleLine: true -# This is deprecated -AlwaysBreakAfterDefinitionReturnType: None -AlwaysBreakAfterReturnType: None -AlwaysBreakBeforeMultilineStrings: true -AlwaysBreakTemplateDeclarations: true -BinPackArguments: false -BinPackParameters: false -BraceWrapping: - AfterClass: false - AfterControlStatement: false - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - AfterExternBlock: false - BeforeCatch: false - BeforeElse: false - IndentBraces: false - # disabling the below splits, else, they'll just add to the vertical length of source files! - SplitEmptyFunction: false - SplitEmptyRecord: false - SplitEmptyNamespace: false -BreakBeforeBinaryOperators: None -BreakBeforeBraces: WebKit -BreakBeforeInheritanceComma: false -BreakInheritanceList: BeforeColon -BreakBeforeTernaryOperators: true -BreakConstructorInitializersBeforeComma: false -BreakConstructorInitializers: BeforeColon -BreakAfterJavaFieldAnnotations: false -BreakStringLiterals: true -ColumnLimit: 100 -CommentPragmas: '^ IWYU pragma:' -CompactNamespaces: false -ConstructorInitializerAllOnOneLineOrOnePerLine: true -# Kept the below 2 to be the same as `IndentWidth` to keep everything uniform -ConstructorInitializerIndentWidth: 4 -ContinuationIndentWidth: 4 -Cpp11BracedListStyle: true -DerivePointerAlignment: false -DisableFormat: false -ExperimentalAutoDetectBinPacking: false -FixNamespaceComments: true -ForEachMacros: - - foreach - - Q_FOREACH - - BOOST_FOREACH -IncludeBlocks: Preserve -IncludeCategories: - - Regex: '^' - Priority: 2 - - Regex: '^<.*\.h>' - Priority: 1 - - Regex: '^<.*' - Priority: 2 - - Regex: '.*' - Priority: 3 -IncludeIsMainRegex: '([-_](test|unittest))?$' -IndentCaseLabels: true -IndentPPDirectives: None -IndentWidth: 4 -IndentWrappedFunctionNames: false -JavaScriptQuotes: Leave -JavaScriptWrapImports: true -KeepEmptyLinesAtTheStartOfBlocks: false -MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None -ObjCBinPackProtocolList: Never -ObjCBlockIndentWidth: 4 -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: true -PenaltyBreakAssignment: 4 -PenaltyBreakBeforeFirstCallParameter: 1 -PenaltyBreakComment: 300 -PenaltyBreakFirstLessLess: 120 -PenaltyBreakString: 1000 -PenaltyBreakTemplateDeclaration: 10 -PenaltyExcessCharacter: 1000000 -PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left -RawStringFormats: - - Language: Cpp - Delimiters: - - cc - - CC - - cpp - - Cpp - - CPP - - 'c++' - - 'C++' - CanonicalDelimiter: '' - - Language: TextProto - Delimiters: - - pb - - PB - - proto - - PROTO - EnclosingFunctions: - - EqualsProto - - EquivToProto - - PARSE_PARTIAL_TEXT_PROTO - - PARSE_TEST_PROTO - - PARSE_TEXT_PROTO - - ParseTextOrDie - - ParseTextProtoOrDie - CanonicalDelimiter: '' - BasedOnStyle: google -# Enabling comment reflow causes doxygen comments to be messed up in their formats! -ReflowComments: true -SortIncludes: true -SortUsingDeclarations: true -SpaceAfterCStyleCast: false -SpaceAfterTemplateKeyword: true -SpaceBeforeAssignmentOperators: true -SpaceBeforeCpp11BracedList: false -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SpaceInEmptyParentheses: false -SpacesBeforeTrailingComments: 2 -SpacesInAngles: false -SpacesInContainerLiterals: true -SpacesInCStyleCastParentheses: false -SpacesInParentheses: false -SpacesInSquareBrackets: false -Standard: Cpp11 -StatementMacros: - - Q_UNUSED - - QT_REQUIRE_VERSION -# Be consistent with indent-width, even for people who use tab for indentation! -TabWidth: 4 -UseTab: Never +--- +# Refer to the following link for the explanation of each params: +# http://releases.llvm.org/8.0.0/tools/clang/docs/ClangFormatStyleOptions.html +Language: Cpp +# BasedOnStyle: Google +AccessModifierOffset: -4 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Left +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: true +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: true +AllowShortLoopsOnASingleLine: true +# This is deprecated +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false + # disabling the below splits, else, they'll just add to the vertical length of source files! + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: WebKit +BreakBeforeInheritanceComma: false +BreakInheritanceList: BeforeColon +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 100 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +# Kept the below 2 to be the same as `IndentWidth` to keep everything uniform +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^' + Priority: 2 + - Regex: '^<.*\.h>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IndentCaseLabels: true +IndentPPDirectives: None +IndentWidth: 4 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Never +ObjCBlockIndentWidth: 4 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakAssignment: 4 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +RawStringFormats: + - Language: Cpp + Delimiters: + - cc + - CC + - cpp + - Cpp + - CPP + - 'c++' + - 'C++' + CanonicalDelimiter: '' + - Language: TextProto + Delimiters: + - pb + - PB + - proto + - PROTO + EnclosingFunctions: + - EqualsProto + - EquivToProto + - PARSE_PARTIAL_TEXT_PROTO + - PARSE_TEST_PROTO + - PARSE_TEXT_PROTO + - ParseTextOrDie + - ParseTextProtoOrDie + CanonicalDelimiter: '' + BasedOnStyle: google +# Enabling comment reflow causes doxygen comments to be messed up in their formats! +ReflowComments: true +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +# Be consistent with indent-width, even for people who use tab for indentation! +TabWidth: 4 +UseTab: Never diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index b23e0910a..4d5628768 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -1,35 +1,35 @@ -name: Formatting - -on: - push: - branches: - - 'master' - - 'staging**' - pull_request: - branches: - '**' - -jobs: - - # formatting and basic install on cpu-only machine - formatting: - runs-on: ubuntu-20.04 - - steps: - - uses: actions/checkout@v2 - - - name: environment - run: | - which python - python --version - pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - python -c "import torch; print('torch:', torch.__version__, torch)" - - - name: Install deepspeed - run: | - pip install .[dev,autotuning] - ds_report - - - name: Formatting checks - run: | - pre-commit run --all-files +name: Formatting + +on: + push: + branches: + - 'master' + - 'staging**' + pull_request: + branches: + '**' + +jobs: + + # formatting and basic install on cpu-only machine + formatting: + runs-on: ubuntu-20.04 + + steps: + - uses: actions/checkout@v2 + + - name: environment + run: | + which python + python --version + pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + python -c "import torch; print('torch:', torch.__version__, torch)" + + - name: Install deepspeed + run: | + pip install .[dev,autotuning] + ds_report + + - name: Formatting checks + run: | + pre-commit run --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 571105d41..21be46d62 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,13 +4,15 @@ repos: rev: v1.2.3 hooks: - id: trailing-whitespace - exclude: "examples/" + exclude: "DeepSpeedExamples/" - id: check-yaml - exclude: "examples/" + exclude: "DeepSpeedExamples/" - id: end-of-file-fixer - exclude: "examples/" + exclude: "DeepSpeedExamples/" exclude: "docs/CNAME" - + - id: mixed-line-ending + exclude: "DeepSpeedExamples/" + args: [--fix=lf] - repo: https://github.com/pre-commit/mirrors-yapf rev: v0.29.0 diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index c72a5749c..f9ba8cf65 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,9 +1,9 @@ -# Microsoft Open Source Code of Conduct - -This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). - -Resources: - -- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) -- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) -- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/LICENSE b/LICENSE index 3d8b93bc7..9e841e7a2 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,21 @@ - MIT License - - Copyright (c) Microsoft Corporation. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/SECURITY.md b/SECURITY.md index 7ab49eb82..e0dfff56a 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,41 +1,41 @@ - - -## Security - -Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). - -If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. - -## Reporting Security Issues - -**Please do not report security vulnerabilities through public GitHub issues.** - -Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). - -If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). - -You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). - -Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: - - * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) - * Full paths of source file(s) related to the manifestation of the issue - * The location of the affected source code (tag/branch/commit or direct URL) - * Any special configuration required to reproduce the issue - * Step-by-step instructions to reproduce the issue - * Proof-of-concept or exploit code (if possible) - * Impact of the issue, including how an attacker might exploit the issue - -This information will help us triage your report more quickly. - -If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. - -## Preferred Languages - -We prefer all communications to be in English. - -## Policy - -Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). - - + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). + + diff --git a/csrc/adagrad/cpu_adagrad.cpp b/csrc/adagrad/cpu_adagrad.cpp index 607072dec..4f2a9b69e 100644 --- a/csrc/adagrad/cpu_adagrad.cpp +++ b/csrc/adagrad/cpu_adagrad.cpp @@ -1,227 +1,227 @@ -#include "cpu_adagrad.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" -#include "custom_cuda_layers.h" - -static std::unordered_map> s_optimizers; - -// C++ interface - -void Adagrad_Optimizer::Step_1(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); -#endif - if (_param_size > rounded_size) { - float step_size = -1 * _alpha; - __half* grads_cast_h; - __half* params_cast_h; - if (half_precision) { - grads_cast_h = reinterpret_cast<__half*>(grads); - params_cast_h = reinterpret_cast<__half*>(_params); - } - for (size_t t = rounded_size; t < _param_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > _param_size) copy_size = _param_size - t; - size_t offset = copy_size + t; - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#pragma omp parallel for - for (size_t k = t; k < offset; k++) { - float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; - float param = half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = grads[k]; - float variance = _exp_avg_sq[k]; - if (_weight_decay > 0) { grad = param * _weight_decay + grad; } - - variance += grad * grad; - - grad = sqrt(variance); - grad += _eps; - grad = momentum / grad; - param = grad * step_size + param; - if (dev_params) _doubled_buffer[_buf_index][k - t] = param; - - if (half_precision) - params_cast_h[k] = (__half)param; - else - _params[k] = param; - // STORE UPDATE TERM TO GRAD'S MEMORY - grads[k] = grad * step_size; - _exp_avg_sq[k] = variance; - } - if (dev_params) { - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); - _buf_index = !_buf_index; - } - } - } -} - -void Adagrad_Optimizer::Step_4(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); -#endif - if (_param_size > rounded_size) - Step_1((_params + rounded_size), - (grads + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int create_adagrad_optimizer(int optimizer_id, - float alpha = 1e-2, - float eps = 1e-8, - float weight_decay = 0, - bool should_log = false) -{ - auto opt = std::make_shared(alpha, eps, weight_decay); - - s_optimizers[optimizer_id] = opt; - - if (should_log) { - std::string avx_type = ""; -#if defined(__AVX512__) - avx_type = "AVX512"; -#else -#if defined(__AVX256__) - avx_type = "AVX2"; -#else - avx_type = "scalar"; -#endif -#endif - - printf("Adagrad Optimizer #%d is created with %s arithmetic capability.\n", - optimizer_id, - avx_type.c_str()); - printf("Config: alpha=%f, weight_decay=%f\n", alpha, weight_decay); - } - - return 0; -} - -void Adagrad_Optimizer::Step_8(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); -#endif - if (_param_size > rounded_size) - Step_4((_params + rounded_size), - (grads + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int ds_adagrad_step(int optimizer_id, - size_t step, - float lr, - float epsilon, - float weight_decay, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg_sq) -{ - auto params_c = params.contiguous(); - auto grads_c = grads.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step); - opt->update_state(lr, epsilon, weight_decay); - opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.size(0)); - - opt->SynchronizeStreams(); - return 0; -} - -int ds_adagrad_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float epsilon, - float weight_decay, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg_sq, - torch::Tensor& gpu_params) -{ - auto params_c = params.contiguous(); - auto gpu_params_c = gpu_params.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - auto grads_c = grads.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - __half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step); - opt->update_state(lr, epsilon, weight_decay); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_sq_ptr, - params_c.size(0), - gpu_params_ptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); - return 0; -} - -int destroy_adagrad_optimizer(int optimizer_id) -{ - s_optimizers.erase(optimizer_id); - - return 0; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)"); - m.def("adagrad_update_copy", - &ds_adagrad_step_plus_copy, - "DeepSpeed CPU Adagrad update and param copy (C++)"); - m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)"); - m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)"); -} +#include "cpu_adagrad.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" +#include "custom_cuda_layers.h" + +static std::unordered_map> s_optimizers; + +// C++ interface + +void Adagrad_Optimizer::Step_1(float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<1>( + &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) { + float step_size = -1 * _alpha; + __half* grads_cast_h; + __half* params_cast_h; + if (half_precision) { + grads_cast_h = reinterpret_cast<__half*>(grads); + params_cast_h = reinterpret_cast<__half*>(_params); + } + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; + float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = grads[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0) { grad = param * _weight_decay + grad; } + + variance += grad * grad; + + grad = sqrt(variance); + grad += _eps; + grad = momentum / grad; + param = grad * step_size + param; + if (dev_params) _doubled_buffer[_buf_index][k - t] = param; + + if (half_precision) + params_cast_h[k] = (__half)param; + else + _params[k] = param; + // STORE UPDATE TERM TO GRAD'S MEMORY + grads[k] = grad * step_size; + _exp_avg_sq[k] = variance; + } + if (dev_params) { + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); + _buf_index = !_buf_index; + } + } + } +} + +void Adagrad_Optimizer::Step_4(float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<4>( + &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) + Step_1((_params + rounded_size), + (grads + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int create_adagrad_optimizer(int optimizer_id, + float alpha = 1e-2, + float eps = 1e-8, + float weight_decay = 0, + bool should_log = false) +{ + auto opt = std::make_shared(alpha, eps, weight_decay); + + s_optimizers[optimizer_id] = opt; + + if (should_log) { + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + + printf("Adagrad Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, weight_decay=%f\n", alpha, weight_decay); + } + + return 0; +} + +void Adagrad_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<8>( + &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int ds_adagrad_step(int optimizer_id, + size_t step, + float lr, + float epsilon, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg_sq) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step); + opt->update_state(lr, epsilon, weight_decay); + opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.size(0)); + + opt->SynchronizeStreams(); + return 0; +} + +int ds_adagrad_step_plus_copy(int optimizer_id, + size_t step, + float lr, + float epsilon, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg_sq, + torch::Tensor& gpu_params) +{ + auto params_c = params.contiguous(); + auto gpu_params_c = gpu_params.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + auto grads_c = grads.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + __half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step); + opt->update_state(lr, epsilon, weight_decay); + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_sq_ptr, + params_c.size(0), + gpu_params_ptr, + (params.options().dtype() == at::kHalf)); + + opt->SynchronizeStreams(); + return 0; +} + +int destroy_adagrad_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + + return 0; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)"); + m.def("adagrad_update_copy", + &ds_adagrad_step_plus_copy, + "DeepSpeed CPU Adagrad update and param copy (C++)"); + m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)"); + m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)"); +} diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index b9d993148..727eec818 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -1,292 +1,292 @@ -#include "cpu_adam.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" -#include "custom_cuda_layers.h" - -static std::unordered_map> s_optimizers; - -// C++ interface - -void Adam_Optimizer::Step_1(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) { - float betta1_minus1 = 1 - _betta1; - float betta2_minus1 = 1 - _betta2; - - float step_size = -1 * _alpha / _bias_correction1; - float w_decay = -1 * _alpha * _weight_decay; - __half* grads_cast_h; - __half* params_cast_h; - if (half_precision) { - grads_cast_h = reinterpret_cast<__half*>(grads); - params_cast_h = reinterpret_cast<__half*>(_params); - } - - for (size_t t = rounded_size; t < _param_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > _param_size) copy_size = _param_size - t; - size_t offset = copy_size + t; - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } - -#pragma omp parallel for - for (size_t k = t; k < offset; k++) { - float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; - float param = half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = _exp_avg[k]; - float variance = _exp_avg_sq[k]; - if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } - momentum = momentum * _betta1; - momentum = grad * betta1_minus1 + momentum; - - variance = variance * _betta2; - grad = grad * grad; - variance = grad * betta2_minus1 + variance; - - grad = sqrt(variance); - grad = grad * _bias_correction2 + _eps; - grad = momentum / grad; - if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } - param = grad * step_size + param; - if (dev_params) _doubled_buffer[_buf_index][k - t] = param; - - if (half_precision) - params_cast_h[k] = (__half)param; - else - _params[k] = param; - _exp_avg[k] = momentum; - _exp_avg_sq[k] = variance; - } - if (dev_params) { - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); - - _buf_index = !_buf_index; - } - } - } -} - -void Adam_Optimizer::Step_4(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) - Step_1((_params + rounded_size), - (grads + rounded_size), - (_exp_avg + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int create_adam_optimizer(int optimizer_id, - float alpha = 1e-3, - float betta1 = 0.9, - float betta2 = 0.999, - float eps = 1e-8, - float weight_decay = 0, - bool adamw_mode = true, - bool should_log = false) -{ - auto opt = - std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); - - s_optimizers[optimizer_id] = opt; - - if (should_log) { - std::string avx_type = ""; -#if defined(__AVX512__) - avx_type = "AVX512"; -#else -#if defined(__AVX256__) - avx_type = "AVX2"; -#else - avx_type = "scalar"; -#endif -#endif - - printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", - optimizer_id, - avx_type.c_str()); - printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", - alpha, - betta1, - betta2, - weight_decay, - (int)adamw_mode); - } - - return 0; -} - -void Adam_Optimizer::Step_8(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) - Step_4((_params + rounded_size), - (grads + rounded_size), - (_exp_avg + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int ds_adam_step(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq) -{ - auto params_c = params.contiguous(); - auto grads_c = grads.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - - // assert(params.options().dtype() == grads.options().dtype()); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.size(0), - nullptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); - return 0; -} - -int ds_adam_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq, - torch::Tensor& gpu_params) -{ - auto params_c = params.contiguous(); - auto gpu_params_c = gpu_params.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - auto grads_c = grads.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - __half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.size(0), - gpu_params_ptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); - return 0; -} - -int destroy_adam_optimizer(int optimizer_id) -{ - s_optimizers.erase(optimizer_id); - - return 0; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); - m.def("adam_update_copy", - &ds_adam_step_plus_copy, - "DeepSpeed CPU Adam update and param copy (C++)"); - m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); - m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); -} +#include "cpu_adam.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" +#include "custom_cuda_layers.h" + +static std::unordered_map> s_optimizers; + +// C++ interface + +void Adam_Optimizer::Step_1(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<1>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) { + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + __half* grads_cast_h; + __half* params_cast_h; + if (half_precision) { + grads_cast_h = reinterpret_cast<__half*>(grads); + params_cast_h = reinterpret_cast<__half*>(_params); + } + + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } + +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; + float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = _exp_avg[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } + param = grad * step_size + param; + if (dev_params) _doubled_buffer[_buf_index][k - t] = param; + + if (half_precision) + params_cast_h[k] = (__half)param; + else + _params[k] = param; + _exp_avg[k] = momentum; + _exp_avg_sq[k] = variance; + } + if (dev_params) { + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); + + _buf_index = !_buf_index; + } + } + } +} + +void Adam_Optimizer::Step_4(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<4>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) + Step_1((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int create_adam_optimizer(int optimizer_id, + float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true, + bool should_log = false) +{ + auto opt = + std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); + + s_optimizers[optimizer_id] = opt; + + if (should_log) { + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + + printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", + alpha, + betta1, + betta2, + weight_decay, + (int)adamw_mode); + } + + return 0; +} + +void Adam_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<8>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int ds_adam_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + // assert(params.options().dtype() == grads.options().dtype()); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.size(0), + nullptr, + (params.options().dtype() == at::kHalf)); + + opt->SynchronizeStreams(); + return 0; +} + +int ds_adam_step_plus_copy(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq, + torch::Tensor& gpu_params) +{ + auto params_c = params.contiguous(); + auto gpu_params_c = gpu_params.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + auto grads_c = grads.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + __half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.size(0), + gpu_params_ptr, + (params.options().dtype() == at::kHalf)); + + opt->SynchronizeStreams(); + return 0; +} + +int destroy_adam_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + + return 0; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); + m.def("adam_update_copy", + &ds_adam_step_plus_copy, + "DeepSpeed CPU Adam update and param copy (C++)"); + m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); + m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); +} diff --git a/csrc/aio/common/deepspeed_aio_common.cpp b/csrc/aio/common/deepspeed_aio_common.cpp index 11927969c..9e405d8e7 100644 --- a/csrc/aio/common/deepspeed_aio_common.cpp +++ b/csrc/aio/common/deepspeed_aio_common.cpp @@ -1,333 +1,333 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "deepspeed_aio_common.h" - -using namespace std; -using namespace std::chrono; - -#define DEBUG_DS_AIO_PERF 0 -#define DEBUG_DS_AIO_SUBMIT_PERF 0 - -static const std::string c_library_name = "deepspeed_aio"; - -static void _report_aio_statistics(const char* tag, - const std::vector>& latencies) - __attribute__((unused)); - -static void _report_aio_statistics(const char* tag, - const std::vector>& latencies) -{ - std::vector lat_usec; - for (auto& lat : latencies) { lat_usec.push_back(lat.count() * 1e6); } - const auto min_lat = *(std::min_element(lat_usec.begin(), lat_usec.end())); - const auto max_lat = *(std::max_element(lat_usec.begin(), lat_usec.end())); - const auto avg_lat = std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); - - std::cout << c_library_name << ": latency statistics(usec) " << tag - << " min/max/avg = " << min_lat << " " << max_lat << " " << avg_lat << std::endl; -} - -static void _get_aio_latencies(std::vector>& raw_latencies, - struct deepspeed_aio_latency_t& summary_latencies) -{ - std::vector lat_usec; - for (auto& lat : raw_latencies) { lat_usec.push_back(lat.count() * 1e6); } - summary_latencies._min_usec = *(std::min_element(lat_usec.begin(), lat_usec.end())); - summary_latencies._max_usec = *(std::max_element(lat_usec.begin(), lat_usec.end())); - summary_latencies._avg_usec = - std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); -} - -static void _do_io_submit_singles(const long long int n_iocbs, - const long long int iocb_index, - std::unique_ptr& aio_ctxt, - std::vector>& submit_times) -{ - for (auto i = 0; i < n_iocbs; ++i) { - const auto st = std::chrono::high_resolution_clock::now(); - const auto submit_ret = io_submit(aio_ctxt->_io_ctxt, 1, aio_ctxt->_iocbs.data() + i); - submit_times.push_back(std::chrono::high_resolution_clock::now() - st); -#if DEBUG_DS_AIO_SUBMIT_PERF - printf("submit(usec) %f io_index=%lld buf=%p len=%lu off=%llu \n", - submit_times.back().count() * 1e6, - iocb_index, - aio_ctxt->_iocbs[i]->u.c.buf, - aio_ctxt->_iocbs[i]->u.c.nbytes, - aio_ctxt->_iocbs[i]->u.c.offset); -#endif - assert(submit_ret > 0); - } -} - -static void _do_io_submit_block(const long long int n_iocbs, - const long long int iocb_index, - std::unique_ptr& aio_ctxt, - std::vector>& submit_times) -{ - const auto st = std::chrono::high_resolution_clock::now(); - const auto submit_ret = io_submit(aio_ctxt->_io_ctxt, n_iocbs, aio_ctxt->_iocbs.data()); - submit_times.push_back(std::chrono::high_resolution_clock::now() - st); -#if DEBUG_DS_AIO_SUBMIT_PERF - printf("submit(usec) %f io_index=%lld nr=%lld buf=%p len=%lu off=%llu \n", - submit_times.back().count() * 1e6, - iocb_index, - n_iocbs, - aio_ctxt->_iocbs[0]->u.c.buf, - aio_ctxt->_iocbs[0]->u.c.nbytes, - aio_ctxt->_iocbs[0]->u.c.offset); -#endif - assert(submit_ret > 0); -} - -static int _do_io_complete(const long long int min_completes, - const long long int max_completes, - std::unique_ptr& aio_ctxt, - std::vector>& reap_times) -{ - const auto start_time = std::chrono::high_resolution_clock::now(); - const auto n_completes = io_getevents( - aio_ctxt->_io_ctxt, min_completes, max_completes, aio_ctxt->_io_events.data(), nullptr); - reap_times.push_back(std::chrono::high_resolution_clock::now() - start_time); - - assert(n_completes >= min_completes); - return n_completes; -} - -void do_aio_operation_sequential(const bool read_op, - std::unique_ptr& aio_ctxt, - std::unique_ptr& xfer_ctxt, - deepspeed_aio_config_t* config, - deepspeed_aio_perf_t* perf) -{ - struct io_prep_context prep_ctxt(read_op, xfer_ctxt, aio_ctxt->_block_size, &aio_ctxt->_iocbs); - - const auto num_io_blocks = static_cast( - ceil(static_cast(xfer_ctxt->_num_bytes) / aio_ctxt->_block_size)); -#if DEBUG_DS_AIO_PERF - const auto io_op_name = std::string(read_op ? "read" : "write"); - std::cout << c_library_name << ": start " << io_op_name << " " << xfer_ctxt->_num_bytes - << " bytes with " << num_io_blocks << " io blocks" << std::endl; -#endif - - std::vector> submit_times; - std::vector> reap_times; - const auto max_queue_bytes = - static_cast(aio_ctxt->_queue_depth * aio_ctxt->_block_size); - - auto start = std::chrono::high_resolution_clock::now(); - for (long long iocb_index = 0; iocb_index < num_io_blocks; - iocb_index += aio_ctxt->_queue_depth) { - const auto start_offset = iocb_index * aio_ctxt->_block_size; - const auto start_buffer = (char*)xfer_ctxt->_mem_buffer + start_offset; - const auto n_iocbs = - min(static_cast(aio_ctxt->_queue_depth), (num_io_blocks - iocb_index)); - const auto num_bytes = min(max_queue_bytes, (xfer_ctxt->_num_bytes - start_offset)); - prep_ctxt.prep_iocbs(n_iocbs, num_bytes, start_buffer, start_offset); - - if (config->_single_submit) { - _do_io_submit_singles(n_iocbs, iocb_index, aio_ctxt, submit_times); - } else { - _do_io_submit_block(n_iocbs, iocb_index, aio_ctxt, submit_times); - } - - _do_io_complete(n_iocbs, n_iocbs, aio_ctxt, reap_times); - } - const std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; - - if (perf) { - _get_aio_latencies(submit_times, perf->_submit); - _get_aio_latencies(reap_times, perf->_complete); - perf->_e2e_usec = elapsed.count() * 1e6; - perf->_e2e_rate_GB = (xfer_ctxt->_num_bytes / elapsed.count() / 1e9); - } - -#if DEBUG_DS_AIO_PERF - _report_aio_statistics("submit", submit_times); - _report_aio_statistics("complete", reap_times); -#endif - -#if DEBUG_DS_AIO_PERF - std::cout << c_library_name << ": runtime(usec) " << elapsed.count() * 1e6 - << " rate(GB/sec) = " << (xfer_ctxt->_num_bytes / elapsed.count() / 1e9) << std::endl; -#endif - -#if DEBUG_DS_AIO_PERF - std::cout << c_library_name << ": finish " << io_op_name << " " << xfer_ctxt->_num_bytes - << " bytes " << std::endl; -#endif -} - -void do_aio_operation_overlap(const bool read_op, - std::unique_ptr& aio_ctxt, - std::unique_ptr& xfer_ctxt, - deepspeed_aio_config_t* config, - deepspeed_aio_perf_t* perf) -{ - struct io_prep_generator io_gen(read_op, xfer_ctxt, aio_ctxt->_block_size); - -#if DEBUG_DS_AIO_PERF - const auto io_op_name = std::string(read_op ? "read" : "write"); - std::cout << c_library_name << ": start " << io_op_name << " " << xfer_ctxt->_num_bytes - << " bytes with " << io_gen._num_io_blocks << " io blocks" << std::endl; -#endif - - std::vector> submit_times; - std::vector> reap_times; - - auto request_iocbs = aio_ctxt->_queue_depth; - auto n_pending_iocbs = 0; - const auto min_completes = 1; - auto start = std::chrono::high_resolution_clock::now(); - while (true) { - const auto n_iocbs = io_gen.prep_iocbs(request_iocbs - n_pending_iocbs, &aio_ctxt->_iocbs); - if (n_iocbs > 0) { - if (config->_single_submit) { - _do_io_submit_singles( - n_iocbs, (io_gen._next_iocb_index - n_iocbs), aio_ctxt, submit_times); - } else { - _do_io_submit_block( - n_iocbs, (io_gen._next_iocb_index - n_iocbs), aio_ctxt, submit_times); - } - } - - n_pending_iocbs += n_iocbs; - assert(n_pending_iocbs <= aio_ctxt->_queue_depth); - - if (n_pending_iocbs == 0) { break; } - - const auto n_complete = - _do_io_complete(min_completes, n_pending_iocbs, aio_ctxt, reap_times); - n_pending_iocbs -= n_complete; - } - - const std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; - - if (perf) { - _get_aio_latencies(submit_times, perf->_submit); - _get_aio_latencies(reap_times, perf->_complete); - perf->_e2e_usec = elapsed.count() * 1e6; - perf->_e2e_rate_GB = (xfer_ctxt->_num_bytes / elapsed.count() / 1e9); - } - -#if DEBUG_DS_AIO_PERF - _report_aio_statistics("submit", submit_times); - _report_aio_statistics("complete", reap_times); -#endif - -#if DEBUG_DS_AIO_PERF - std::cout << c_library_name << ": runtime(usec) " << elapsed.count() * 1e6 - << " rate(GB/sec) = " << (xfer_ctxt->_num_bytes / elapsed.count() / 1e9) << std::endl; -#endif - -#if DEBUG_DS_AIO_PERF - std::cout << c_library_name << ": finish " << io_op_name << " " << xfer_ctxt->_num_bytes - << " bytes " << std::endl; -#endif -} - -void report_file_error(const char* filename, const std::string file_op, const int error_code) -{ - std::string err_msg = file_op + std::string(" failed on ") + std::string(filename) + - " error = " + std::to_string(error_code); - std::cerr << c_library_name << ": " << err_msg << std::endl; -} - -int open_file(const char* filename, const bool read_op) -{ - const int flags = read_op ? (O_RDONLY | __O_DIRECT) : (O_WRONLY | O_CREAT | __O_DIRECT); - const int mode = 0600; - const auto fd = open(filename, flags, mode); - if (fd == -1) { - const auto error_code = errno; - const auto error_msg = read_op ? " open for read " : " open for write "; - report_file_error(filename, error_msg, error_code); - return -1; - } - return fd; -} - -int regular_read(const char* filename, std::vector& buffer) -{ - long long int num_bytes; - const auto f_size = get_file_size(filename, num_bytes); - assert(f_size != -1); - buffer.resize(num_bytes); - const auto fd = open(filename, O_RDONLY, 0600); - assert(fd != -1); - long long int read_bytes = 0; - auto r = 0; - do { - const auto buffer_ptr = buffer.data() + read_bytes; - const auto bytes_to_read = num_bytes - read_bytes; - r = read(fd, buffer_ptr, bytes_to_read); - read_bytes += r; - } while (r > 0); - - if (read_bytes != num_bytes) { - std::cerr << "read error " - << " read_bytes (read) = " << read_bytes << " num_bytes (fstat) = " << num_bytes - << std::endl; - } - assert(read_bytes == num_bytes); - close(fd); - return 0; -} - -static bool _validate_buffer(const char* filename, void* aio_buffer, const long long int num_bytes) -{ - std::vector regular_buffer; - const auto reg_ret = regular_read(filename, regular_buffer); - assert(0 == reg_ret); - std::cout << "regular read of " << filename << " returned " << regular_buffer.size() << " bytes" - << std::endl; - - if (static_cast(regular_buffer.size()) != num_bytes) { return false; } - - return (0 == memcmp(aio_buffer, regular_buffer.data(), regular_buffer.size())); -} - -bool validate_aio_operation(const bool read_op, - const char* filename, - void* aio_buffer, - const long long int num_bytes) -{ - const auto msg_suffix = std::string("deepspeed_aio_") + - std::string(read_op ? "read()" : "write()") + - std::string("using read()"); - - if (false == _validate_buffer(filename, aio_buffer, num_bytes)) { - std::cout << "Fail: correctness of " << msg_suffix << std::endl; - return false; - } - - std::cout << "Pass: correctness of " << msg_suffix << std::endl; - return true; -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "deepspeed_aio_common.h" + +using namespace std; +using namespace std::chrono; + +#define DEBUG_DS_AIO_PERF 0 +#define DEBUG_DS_AIO_SUBMIT_PERF 0 + +static const std::string c_library_name = "deepspeed_aio"; + +static void _report_aio_statistics(const char* tag, + const std::vector>& latencies) + __attribute__((unused)); + +static void _report_aio_statistics(const char* tag, + const std::vector>& latencies) +{ + std::vector lat_usec; + for (auto& lat : latencies) { lat_usec.push_back(lat.count() * 1e6); } + const auto min_lat = *(std::min_element(lat_usec.begin(), lat_usec.end())); + const auto max_lat = *(std::max_element(lat_usec.begin(), lat_usec.end())); + const auto avg_lat = std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); + + std::cout << c_library_name << ": latency statistics(usec) " << tag + << " min/max/avg = " << min_lat << " " << max_lat << " " << avg_lat << std::endl; +} + +static void _get_aio_latencies(std::vector>& raw_latencies, + struct deepspeed_aio_latency_t& summary_latencies) +{ + std::vector lat_usec; + for (auto& lat : raw_latencies) { lat_usec.push_back(lat.count() * 1e6); } + summary_latencies._min_usec = *(std::min_element(lat_usec.begin(), lat_usec.end())); + summary_latencies._max_usec = *(std::max_element(lat_usec.begin(), lat_usec.end())); + summary_latencies._avg_usec = + std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); +} + +static void _do_io_submit_singles(const long long int n_iocbs, + const long long int iocb_index, + std::unique_ptr& aio_ctxt, + std::vector>& submit_times) +{ + for (auto i = 0; i < n_iocbs; ++i) { + const auto st = std::chrono::high_resolution_clock::now(); + const auto submit_ret = io_submit(aio_ctxt->_io_ctxt, 1, aio_ctxt->_iocbs.data() + i); + submit_times.push_back(std::chrono::high_resolution_clock::now() - st); +#if DEBUG_DS_AIO_SUBMIT_PERF + printf("submit(usec) %f io_index=%lld buf=%p len=%lu off=%llu \n", + submit_times.back().count() * 1e6, + iocb_index, + aio_ctxt->_iocbs[i]->u.c.buf, + aio_ctxt->_iocbs[i]->u.c.nbytes, + aio_ctxt->_iocbs[i]->u.c.offset); +#endif + assert(submit_ret > 0); + } +} + +static void _do_io_submit_block(const long long int n_iocbs, + const long long int iocb_index, + std::unique_ptr& aio_ctxt, + std::vector>& submit_times) +{ + const auto st = std::chrono::high_resolution_clock::now(); + const auto submit_ret = io_submit(aio_ctxt->_io_ctxt, n_iocbs, aio_ctxt->_iocbs.data()); + submit_times.push_back(std::chrono::high_resolution_clock::now() - st); +#if DEBUG_DS_AIO_SUBMIT_PERF + printf("submit(usec) %f io_index=%lld nr=%lld buf=%p len=%lu off=%llu \n", + submit_times.back().count() * 1e6, + iocb_index, + n_iocbs, + aio_ctxt->_iocbs[0]->u.c.buf, + aio_ctxt->_iocbs[0]->u.c.nbytes, + aio_ctxt->_iocbs[0]->u.c.offset); +#endif + assert(submit_ret > 0); +} + +static int _do_io_complete(const long long int min_completes, + const long long int max_completes, + std::unique_ptr& aio_ctxt, + std::vector>& reap_times) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + const auto n_completes = io_getevents( + aio_ctxt->_io_ctxt, min_completes, max_completes, aio_ctxt->_io_events.data(), nullptr); + reap_times.push_back(std::chrono::high_resolution_clock::now() - start_time); + + assert(n_completes >= min_completes); + return n_completes; +} + +void do_aio_operation_sequential(const bool read_op, + std::unique_ptr& aio_ctxt, + std::unique_ptr& xfer_ctxt, + deepspeed_aio_config_t* config, + deepspeed_aio_perf_t* perf) +{ + struct io_prep_context prep_ctxt(read_op, xfer_ctxt, aio_ctxt->_block_size, &aio_ctxt->_iocbs); + + const auto num_io_blocks = static_cast( + ceil(static_cast(xfer_ctxt->_num_bytes) / aio_ctxt->_block_size)); +#if DEBUG_DS_AIO_PERF + const auto io_op_name = std::string(read_op ? "read" : "write"); + std::cout << c_library_name << ": start " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes with " << num_io_blocks << " io blocks" << std::endl; +#endif + + std::vector> submit_times; + std::vector> reap_times; + const auto max_queue_bytes = + static_cast(aio_ctxt->_queue_depth * aio_ctxt->_block_size); + + auto start = std::chrono::high_resolution_clock::now(); + for (long long iocb_index = 0; iocb_index < num_io_blocks; + iocb_index += aio_ctxt->_queue_depth) { + const auto start_offset = iocb_index * aio_ctxt->_block_size; + const auto start_buffer = (char*)xfer_ctxt->_mem_buffer + start_offset; + const auto n_iocbs = + min(static_cast(aio_ctxt->_queue_depth), (num_io_blocks - iocb_index)); + const auto num_bytes = min(max_queue_bytes, (xfer_ctxt->_num_bytes - start_offset)); + prep_ctxt.prep_iocbs(n_iocbs, num_bytes, start_buffer, start_offset); + + if (config->_single_submit) { + _do_io_submit_singles(n_iocbs, iocb_index, aio_ctxt, submit_times); + } else { + _do_io_submit_block(n_iocbs, iocb_index, aio_ctxt, submit_times); + } + + _do_io_complete(n_iocbs, n_iocbs, aio_ctxt, reap_times); + } + const std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; + + if (perf) { + _get_aio_latencies(submit_times, perf->_submit); + _get_aio_latencies(reap_times, perf->_complete); + perf->_e2e_usec = elapsed.count() * 1e6; + perf->_e2e_rate_GB = (xfer_ctxt->_num_bytes / elapsed.count() / 1e9); + } + +#if DEBUG_DS_AIO_PERF + _report_aio_statistics("submit", submit_times); + _report_aio_statistics("complete", reap_times); +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": runtime(usec) " << elapsed.count() * 1e6 + << " rate(GB/sec) = " << (xfer_ctxt->_num_bytes / elapsed.count() / 1e9) << std::endl; +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": finish " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes " << std::endl; +#endif +} + +void do_aio_operation_overlap(const bool read_op, + std::unique_ptr& aio_ctxt, + std::unique_ptr& xfer_ctxt, + deepspeed_aio_config_t* config, + deepspeed_aio_perf_t* perf) +{ + struct io_prep_generator io_gen(read_op, xfer_ctxt, aio_ctxt->_block_size); + +#if DEBUG_DS_AIO_PERF + const auto io_op_name = std::string(read_op ? "read" : "write"); + std::cout << c_library_name << ": start " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes with " << io_gen._num_io_blocks << " io blocks" << std::endl; +#endif + + std::vector> submit_times; + std::vector> reap_times; + + auto request_iocbs = aio_ctxt->_queue_depth; + auto n_pending_iocbs = 0; + const auto min_completes = 1; + auto start = std::chrono::high_resolution_clock::now(); + while (true) { + const auto n_iocbs = io_gen.prep_iocbs(request_iocbs - n_pending_iocbs, &aio_ctxt->_iocbs); + if (n_iocbs > 0) { + if (config->_single_submit) { + _do_io_submit_singles( + n_iocbs, (io_gen._next_iocb_index - n_iocbs), aio_ctxt, submit_times); + } else { + _do_io_submit_block( + n_iocbs, (io_gen._next_iocb_index - n_iocbs), aio_ctxt, submit_times); + } + } + + n_pending_iocbs += n_iocbs; + assert(n_pending_iocbs <= aio_ctxt->_queue_depth); + + if (n_pending_iocbs == 0) { break; } + + const auto n_complete = + _do_io_complete(min_completes, n_pending_iocbs, aio_ctxt, reap_times); + n_pending_iocbs -= n_complete; + } + + const std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; + + if (perf) { + _get_aio_latencies(submit_times, perf->_submit); + _get_aio_latencies(reap_times, perf->_complete); + perf->_e2e_usec = elapsed.count() * 1e6; + perf->_e2e_rate_GB = (xfer_ctxt->_num_bytes / elapsed.count() / 1e9); + } + +#if DEBUG_DS_AIO_PERF + _report_aio_statistics("submit", submit_times); + _report_aio_statistics("complete", reap_times); +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": runtime(usec) " << elapsed.count() * 1e6 + << " rate(GB/sec) = " << (xfer_ctxt->_num_bytes / elapsed.count() / 1e9) << std::endl; +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": finish " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes " << std::endl; +#endif +} + +void report_file_error(const char* filename, const std::string file_op, const int error_code) +{ + std::string err_msg = file_op + std::string(" failed on ") + std::string(filename) + + " error = " + std::to_string(error_code); + std::cerr << c_library_name << ": " << err_msg << std::endl; +} + +int open_file(const char* filename, const bool read_op) +{ + const int flags = read_op ? (O_RDONLY | __O_DIRECT) : (O_WRONLY | O_CREAT | __O_DIRECT); + const int mode = 0600; + const auto fd = open(filename, flags, mode); + if (fd == -1) { + const auto error_code = errno; + const auto error_msg = read_op ? " open for read " : " open for write "; + report_file_error(filename, error_msg, error_code); + return -1; + } + return fd; +} + +int regular_read(const char* filename, std::vector& buffer) +{ + long long int num_bytes; + const auto f_size = get_file_size(filename, num_bytes); + assert(f_size != -1); + buffer.resize(num_bytes); + const auto fd = open(filename, O_RDONLY, 0600); + assert(fd != -1); + long long int read_bytes = 0; + auto r = 0; + do { + const auto buffer_ptr = buffer.data() + read_bytes; + const auto bytes_to_read = num_bytes - read_bytes; + r = read(fd, buffer_ptr, bytes_to_read); + read_bytes += r; + } while (r > 0); + + if (read_bytes != num_bytes) { + std::cerr << "read error " + << " read_bytes (read) = " << read_bytes << " num_bytes (fstat) = " << num_bytes + << std::endl; + } + assert(read_bytes == num_bytes); + close(fd); + return 0; +} + +static bool _validate_buffer(const char* filename, void* aio_buffer, const long long int num_bytes) +{ + std::vector regular_buffer; + const auto reg_ret = regular_read(filename, regular_buffer); + assert(0 == reg_ret); + std::cout << "regular read of " << filename << " returned " << regular_buffer.size() << " bytes" + << std::endl; + + if (static_cast(regular_buffer.size()) != num_bytes) { return false; } + + return (0 == memcmp(aio_buffer, regular_buffer.data(), regular_buffer.size())); +} + +bool validate_aio_operation(const bool read_op, + const char* filename, + void* aio_buffer, + const long long int num_bytes) +{ + const auto msg_suffix = std::string("deepspeed_aio_") + + std::string(read_op ? "read()" : "write()") + + std::string("using read()"); + + if (false == _validate_buffer(filename, aio_buffer, num_bytes)) { + std::cout << "Fail: correctness of " << msg_suffix << std::endl; + return false; + } + + std::cout << "Pass: correctness of " << msg_suffix << std::endl; + return true; +} diff --git a/csrc/aio/common/deepspeed_aio_common.h b/csrc/aio/common/deepspeed_aio_common.h index 1f32fc8f7..cc62d3376 100644 --- a/csrc/aio/common/deepspeed_aio_common.h +++ b/csrc/aio/common/deepspeed_aio_common.h @@ -1,36 +1,36 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include -#include - -using namespace std; - -void do_aio_operation_sequential(const bool read_op, - std::unique_ptr& aio_ctxt, - std::unique_ptr& xfer_ctxt, - deepspeed_aio_config_t* config, - deepspeed_aio_perf_t* perf); - -void do_aio_operation_overlap(const bool read_op, - std::unique_ptr& aio_ctxt, - std::unique_ptr& xfer_ctxt, - deepspeed_aio_config_t* config, - deepspeed_aio_perf_t* perf); - -int open_file(const char* filename, const bool read_op); - -void report_file_error(const char* filename, const std::string file_op, const int error_code); - -int regular_read(const char* filename, std::vector& buffer); - -bool validate_aio_operation(const bool read_op, - const char* filename, - void* aio_buffer, - const long long int num_bytes); +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include +#include + +using namespace std; + +void do_aio_operation_sequential(const bool read_op, + std::unique_ptr& aio_ctxt, + std::unique_ptr& xfer_ctxt, + deepspeed_aio_config_t* config, + deepspeed_aio_perf_t* perf); + +void do_aio_operation_overlap(const bool read_op, + std::unique_ptr& aio_ctxt, + std::unique_ptr& xfer_ctxt, + deepspeed_aio_config_t* config, + deepspeed_aio_perf_t* perf); + +int open_file(const char* filename, const bool read_op); + +void report_file_error(const char* filename, const std::string file_op, const int error_code); + +int regular_read(const char* filename, std::vector& buffer); + +bool validate_aio_operation(const bool read_op, + const char* filename, + void* aio_buffer, + const long long int num_bytes); diff --git a/csrc/aio/common/deepspeed_aio_types.cpp b/csrc/aio/common/deepspeed_aio_types.cpp index 5f717c3b5..e5811bb91 100644 --- a/csrc/aio/common/deepspeed_aio_types.cpp +++ b/csrc/aio/common/deepspeed_aio_types.cpp @@ -1,74 +1,74 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include - -#include "deepspeed_aio_utils.h" - -using namespace std; - -const int c_block_size = 128 * 1024; -const int c_io_queue_depth = 8; - -deepspeed_aio_config_t::deepspeed_aio_config_t() - : _block_size(c_block_size), - _queue_depth(c_io_queue_depth), - _single_submit(false), - _overlap_events(false), - _lock_memory(false) -{ -} - -deepspeed_aio_config_t::deepspeed_aio_config_t(const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool lock_memory) - : _block_size(block_size), - _queue_depth(queue_depth), - _single_submit(single_submit), - _overlap_events(overlap_events), - _lock_memory(lock_memory) -{ -} - -void deepspeed_aio_latency_t::dump(const std::string tag) -{ - std::cout << tag << _min_usec << " " << _max_usec << " " << _avg_usec << " " << std::endl; -} - -void deepspeed_aio_latency_t::accumulate(const struct deepspeed_aio_latency_t& other) -{ - _min_usec += other._min_usec; - _max_usec += other._max_usec; - _avg_usec += other._avg_usec; -} - -void deepspeed_aio_latency_t::scale(const float scaler) -{ - _min_usec *= scaler; - _max_usec *= scaler; - _avg_usec *= scaler; -} - -aio_context::aio_context(const int block_size, const int queue_depth) -{ - _block_size = block_size; - _queue_depth = queue_depth; - for (auto i = 0; i < queue_depth; ++i) { - _iocbs.push_back((struct iocb*)calloc(1, sizeof(struct iocb))); - } - _io_events.resize(queue_depth); - io_queue_init(queue_depth, &_io_ctxt); -} - -aio_context::~aio_context() -{ - for (auto& iocb : _iocbs) { free(iocb); } - _io_events.resize(0); - io_queue_release(_io_ctxt); -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include + +#include "deepspeed_aio_utils.h" + +using namespace std; + +const int c_block_size = 128 * 1024; +const int c_io_queue_depth = 8; + +deepspeed_aio_config_t::deepspeed_aio_config_t() + : _block_size(c_block_size), + _queue_depth(c_io_queue_depth), + _single_submit(false), + _overlap_events(false), + _lock_memory(false) +{ +} + +deepspeed_aio_config_t::deepspeed_aio_config_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool lock_memory) + : _block_size(block_size), + _queue_depth(queue_depth), + _single_submit(single_submit), + _overlap_events(overlap_events), + _lock_memory(lock_memory) +{ +} + +void deepspeed_aio_latency_t::dump(const std::string tag) +{ + std::cout << tag << _min_usec << " " << _max_usec << " " << _avg_usec << " " << std::endl; +} + +void deepspeed_aio_latency_t::accumulate(const struct deepspeed_aio_latency_t& other) +{ + _min_usec += other._min_usec; + _max_usec += other._max_usec; + _avg_usec += other._avg_usec; +} + +void deepspeed_aio_latency_t::scale(const float scaler) +{ + _min_usec *= scaler; + _max_usec *= scaler; + _avg_usec *= scaler; +} + +aio_context::aio_context(const int block_size, const int queue_depth) +{ + _block_size = block_size; + _queue_depth = queue_depth; + for (auto i = 0; i < queue_depth; ++i) { + _iocbs.push_back((struct iocb*)calloc(1, sizeof(struct iocb))); + } + _io_events.resize(queue_depth); + io_queue_init(queue_depth, &_io_ctxt); +} + +aio_context::~aio_context() +{ + for (auto& iocb : _iocbs) { free(iocb); } + _io_events.resize(0); + io_queue_release(_io_ctxt); +} diff --git a/csrc/aio/common/deepspeed_aio_types.h b/csrc/aio/common/deepspeed_aio_types.h index 5c5dcdf0b..be3b352d6 100644 --- a/csrc/aio/common/deepspeed_aio_types.h +++ b/csrc/aio/common/deepspeed_aio_types.h @@ -1,57 +1,57 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include - -#include -#include - -using namespace std; - -struct deepspeed_aio_latency_t { - double _min_usec; - double _max_usec; - double _avg_usec; - - void dump(const std::string tag); - void accumulate(const deepspeed_aio_latency_t&); - void scale(const float value); -}; - -struct deepspeed_aio_perf_t { - deepspeed_aio_latency_t _submit; - deepspeed_aio_latency_t _complete; - double _e2e_usec; - double _e2e_rate_GB; -}; - -struct deepspeed_aio_config_t { - const int _block_size; - const int _queue_depth; - const bool _single_submit; - const bool _overlap_events; - const bool _lock_memory; - - deepspeed_aio_config_t(); - deepspeed_aio_config_t(const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool lock_memory); -}; - -struct aio_context { - io_context_t _io_ctxt; - std::vector _io_events; - std::vector _iocbs; - int _block_size; - int _queue_depth; - - aio_context(const int block_size, const int queue_depth); - ~aio_context(); -}; +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include + +#include +#include + +using namespace std; + +struct deepspeed_aio_latency_t { + double _min_usec; + double _max_usec; + double _avg_usec; + + void dump(const std::string tag); + void accumulate(const deepspeed_aio_latency_t&); + void scale(const float value); +}; + +struct deepspeed_aio_perf_t { + deepspeed_aio_latency_t _submit; + deepspeed_aio_latency_t _complete; + double _e2e_usec; + double _e2e_rate_GB; +}; + +struct deepspeed_aio_config_t { + const int _block_size; + const int _queue_depth; + const bool _single_submit; + const bool _overlap_events; + const bool _lock_memory; + + deepspeed_aio_config_t(); + deepspeed_aio_config_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool lock_memory); +}; + +struct aio_context { + io_context_t _io_ctxt; + std::vector _io_events; + std::vector _iocbs; + int _block_size; + int _queue_depth; + + aio_context(const int block_size, const int queue_depth); + ~aio_context(); +}; diff --git a/csrc/aio/common/deepspeed_aio_utils.cpp b/csrc/aio/common/deepspeed_aio_utils.cpp index a3d89be5a..200c7030f 100644 --- a/csrc/aio/common/deepspeed_aio_utils.cpp +++ b/csrc/aio/common/deepspeed_aio_utils.cpp @@ -1,123 +1,123 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include - -#include "deepspeed_aio_utils.h" - -using namespace std; - -const int c_block_size = 128 * 1024; -const int c_io_queue_depth = 8; - -io_xfer_ctxt::io_xfer_ctxt(const int fd, - const long long int file_offset, - const long long int num_bytes, - const void* buffer) - : _fd(fd), _base_offset(file_offset), _mem_buffer(buffer), _num_bytes(num_bytes) -{ -} - -io_prep_context::io_prep_context(const bool read_op, - const std::unique_ptr& xfer_ctxt, - const size_t block_size, - const std::vector* iocbs) - : _read_op(read_op), _xfer_ctxt(xfer_ctxt), _block_size(block_size), _iocbs(iocbs) -{ -} - -void io_prep_context::prep_iocbs(const int n_iocbs, - const size_t num_bytes, - const void* start_buffer, - const long long int start_offset) -{ - assert(static_cast(n_iocbs) <= _iocbs->size()); - for (auto i = 0; i < n_iocbs; ++i) { - const auto shift = i * _block_size; - const auto xfer_buffer = (char*)start_buffer + _xfer_ctxt->_base_offset + shift; - const auto xfer_offset = _xfer_ctxt->_base_offset + start_offset + shift; - auto byte_count = _block_size; - if ((shift + _block_size) > num_bytes) { byte_count = num_bytes - shift; } - - if (_read_op) { - io_prep_pread(_iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, byte_count, xfer_offset); - } else { - io_prep_pwrite(_iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, byte_count, xfer_offset); - } - } -} - -io_prep_generator::io_prep_generator(const bool read_op, - const std::unique_ptr& xfer_ctxt, - const size_t block_size) - : _read_op(read_op), - _xfer_ctxt(xfer_ctxt), - _block_size(block_size), - _remaining_bytes(xfer_ctxt->_num_bytes), - _next_iocb_index(0) -{ - _num_io_blocks = - static_cast(ceil(static_cast(xfer_ctxt->_num_bytes) / block_size)); - _remaining_io_blocks = _num_io_blocks; -} - -int io_prep_generator::prep_iocbs(const int n_iocbs, std::vector* iocbs) -{ - if ((_remaining_bytes) == 0 || (_remaining_io_blocks == 0)) { - assert(static_cast(_remaining_bytes) == _remaining_io_blocks); - return 0; - } - - assert(static_cast(n_iocbs) <= iocbs->size()); - - auto actual_n_iocbs = min(static_cast(n_iocbs), _remaining_io_blocks); - for (auto i = 0; i < actual_n_iocbs; ++i, ++_next_iocb_index) { - const auto xfer_offset = _xfer_ctxt->_base_offset + (_next_iocb_index * _block_size); - const auto xfer_buffer = (char*)_xfer_ctxt->_mem_buffer + xfer_offset; - const auto num_bytes = min(static_cast(_block_size), _remaining_bytes); - - if (_read_op) { - io_prep_pread(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); - } else { - io_prep_pwrite(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); - } - _remaining_bytes -= num_bytes; - } - _remaining_io_blocks -= actual_n_iocbs; - - return actual_n_iocbs; -} - -int get_file_size(const char* filename, long long int& size) -{ - struct stat st; - if (stat(filename, &st) == -1) { return -1; } - size = st.st_size; - return 0; -} - -void* ds_page_aligned_alloc(const size_t size, const bool lock) -{ - void* ptr; - int retval; - - retval = posix_memalign(&ptr, (size_t)sysconf(_SC_PAGESIZE), size); - if (retval) { return nullptr; } - - if (lock == false) { return ptr; } - - auto mlock_ret = mlock(ptr, size); - if (mlock_ret != 0) { - auto mlock_error = errno; - printf("mlock failed with %d %s\n", mlock_error, strerror(mlock_error)); - - free(ptr); - return nullptr; - } - - return ptr; -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include + +#include "deepspeed_aio_utils.h" + +using namespace std; + +const int c_block_size = 128 * 1024; +const int c_io_queue_depth = 8; + +io_xfer_ctxt::io_xfer_ctxt(const int fd, + const long long int file_offset, + const long long int num_bytes, + const void* buffer) + : _fd(fd), _base_offset(file_offset), _mem_buffer(buffer), _num_bytes(num_bytes) +{ +} + +io_prep_context::io_prep_context(const bool read_op, + const std::unique_ptr& xfer_ctxt, + const size_t block_size, + const std::vector* iocbs) + : _read_op(read_op), _xfer_ctxt(xfer_ctxt), _block_size(block_size), _iocbs(iocbs) +{ +} + +void io_prep_context::prep_iocbs(const int n_iocbs, + const size_t num_bytes, + const void* start_buffer, + const long long int start_offset) +{ + assert(static_cast(n_iocbs) <= _iocbs->size()); + for (auto i = 0; i < n_iocbs; ++i) { + const auto shift = i * _block_size; + const auto xfer_buffer = (char*)start_buffer + _xfer_ctxt->_base_offset + shift; + const auto xfer_offset = _xfer_ctxt->_base_offset + start_offset + shift; + auto byte_count = _block_size; + if ((shift + _block_size) > num_bytes) { byte_count = num_bytes - shift; } + + if (_read_op) { + io_prep_pread(_iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, byte_count, xfer_offset); + } else { + io_prep_pwrite(_iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, byte_count, xfer_offset); + } + } +} + +io_prep_generator::io_prep_generator(const bool read_op, + const std::unique_ptr& xfer_ctxt, + const size_t block_size) + : _read_op(read_op), + _xfer_ctxt(xfer_ctxt), + _block_size(block_size), + _remaining_bytes(xfer_ctxt->_num_bytes), + _next_iocb_index(0) +{ + _num_io_blocks = + static_cast(ceil(static_cast(xfer_ctxt->_num_bytes) / block_size)); + _remaining_io_blocks = _num_io_blocks; +} + +int io_prep_generator::prep_iocbs(const int n_iocbs, std::vector* iocbs) +{ + if ((_remaining_bytes) == 0 || (_remaining_io_blocks == 0)) { + assert(static_cast(_remaining_bytes) == _remaining_io_blocks); + return 0; + } + + assert(static_cast(n_iocbs) <= iocbs->size()); + + auto actual_n_iocbs = min(static_cast(n_iocbs), _remaining_io_blocks); + for (auto i = 0; i < actual_n_iocbs; ++i, ++_next_iocb_index) { + const auto xfer_offset = _xfer_ctxt->_base_offset + (_next_iocb_index * _block_size); + const auto xfer_buffer = (char*)_xfer_ctxt->_mem_buffer + xfer_offset; + const auto num_bytes = min(static_cast(_block_size), _remaining_bytes); + + if (_read_op) { + io_prep_pread(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); + } else { + io_prep_pwrite(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); + } + _remaining_bytes -= num_bytes; + } + _remaining_io_blocks -= actual_n_iocbs; + + return actual_n_iocbs; +} + +int get_file_size(const char* filename, long long int& size) +{ + struct stat st; + if (stat(filename, &st) == -1) { return -1; } + size = st.st_size; + return 0; +} + +void* ds_page_aligned_alloc(const size_t size, const bool lock) +{ + void* ptr; + int retval; + + retval = posix_memalign(&ptr, (size_t)sysconf(_SC_PAGESIZE), size); + if (retval) { return nullptr; } + + if (lock == false) { return ptr; } + + auto mlock_ret = mlock(ptr, size); + if (mlock_ret != 0) { + auto mlock_error = errno; + printf("mlock failed with %d %s\n", mlock_error, strerror(mlock_error)); + + free(ptr); + return nullptr; + } + + return ptr; +} diff --git a/csrc/aio/common/deepspeed_aio_utils.h b/csrc/aio/common/deepspeed_aio_utils.h index f37a95c51..6c5952749 100644 --- a/csrc/aio/common/deepspeed_aio_utils.h +++ b/csrc/aio/common/deepspeed_aio_utils.h @@ -1,77 +1,77 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -struct io_xfer_ctxt { - const int _fd; - const long long int _base_offset; - const void* _mem_buffer; - const long long int _num_bytes; - - io_xfer_ctxt(const int fd, - const long long int file_offset, - const long long int num_bytes, - const void* buffer); -}; - -struct io_prep_context { - const bool _read_op; - const std::unique_ptr& _xfer_ctxt; - const size_t _block_size; - const std::vector* _iocbs; - - io_prep_context(const bool read_op, - const std::unique_ptr& xfer_ctxt, - const size_t block_size, - const std::vector* iocbs); - - void prep_iocbs(const int n_iocbs, - const size_t num_bytes, - const void* start_buffer, - const long long int start_offset); -}; - -struct io_prep_generator { - const bool _read_op; - const std::unique_ptr& _xfer_ctxt; - const size_t _block_size; - - long long int _remaining_bytes; - long long int _num_io_blocks; - long long int _remaining_io_blocks; - long long int _next_iocb_index; - - io_prep_generator(const bool read_op, - const std::unique_ptr& xfer_ctxt, - const size_t block_size); - - int prep_iocbs(const int n_iocbs, std::vector* iocbs); -}; - -void* ds_page_aligned_alloc(const size_t size, const bool lock = false); - -int get_file_size(const char* filename, long long int& size); +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +struct io_xfer_ctxt { + const int _fd; + const long long int _base_offset; + const void* _mem_buffer; + const long long int _num_bytes; + + io_xfer_ctxt(const int fd, + const long long int file_offset, + const long long int num_bytes, + const void* buffer); +}; + +struct io_prep_context { + const bool _read_op; + const std::unique_ptr& _xfer_ctxt; + const size_t _block_size; + const std::vector* _iocbs; + + io_prep_context(const bool read_op, + const std::unique_ptr& xfer_ctxt, + const size_t block_size, + const std::vector* iocbs); + + void prep_iocbs(const int n_iocbs, + const size_t num_bytes, + const void* start_buffer, + const long long int start_offset); +}; + +struct io_prep_generator { + const bool _read_op; + const std::unique_ptr& _xfer_ctxt; + const size_t _block_size; + + long long int _remaining_bytes; + long long int _num_io_blocks; + long long int _remaining_io_blocks; + long long int _next_iocb_index; + + io_prep_generator(const bool read_op, + const std::unique_ptr& xfer_ctxt, + const size_t block_size); + + int prep_iocbs(const int n_iocbs, std::vector* iocbs); +}; + +void* ds_page_aligned_alloc(const size_t size, const bool lock = false); + +int get_file_size(const char* filename, long long int& size); diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.cpp b/csrc/aio/py_lib/deepspeed_aio_thread.cpp index 2c7509cb3..a2670fb7b 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_thread.cpp @@ -1,84 +1,84 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include "deepspeed_aio_thread.h" - -using namespace std; - -io_op_desc_t::io_op_desc_t(const bool read_op, - const torch::Tensor& buffer, - const int fd, - const char* filename, - const long long int num_bytes, - const bool validate) - : _read_op(read_op), - _buffer(buffer), - _fd(fd), - _filename(filename), - _num_bytes(num_bytes), - _validate(validate) -{ - _cpu_buffer = _buffer.is_cuda() ? _buffer.to(torch::kCPU).pin_memory() : _buffer; - _contiguous_buffer = _cpu_buffer.contiguous(); -} - -char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } - -void io_op_desc_t::fini() -{ - if (_read_op && _buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); } -} - -deepspeed_aio_thread_t::deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config) - : _tid(tid), - _aio_config(aio_config), - _aio_ctxt(new aio_context(aio_config._block_size, aio_config._queue_depth)), - _time_to_exit(false) -{ -} - -deepspeed_aio_thread_t::~deepspeed_aio_thread_t() {} - -void deepspeed_aio_thread_t::run() -{ - while (true) { - std::shared_ptr next_io_op = nullptr; - - { - std::unique_lock lock(_work_sync._mutex); - _work_sync._cond_var.wait(lock, - [this] { return (!_work_queue.empty() || _time_to_exit); }); - if (!_work_queue.empty()) { - next_io_op = _work_queue.front(); - _work_queue.pop(); - } - } - - if (next_io_op) { - const auto base_offset = next_io_op->_num_bytes * _tid; - - std::unique_ptr xfer_ctxt(new io_xfer_ctxt( - next_io_op->_fd, base_offset, next_io_op->_num_bytes, next_io_op->data_ptr())); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap( - next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential( - next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } - - { - std::lock_guard lock(_complete_sync._mutex); - _complete_queue.push(next_io_op); - } - _complete_sync._cond_var.notify_one(); - } - - if (_time_to_exit) { break; } - } -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_aio_thread.h" + +using namespace std; + +io_op_desc_t::io_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int num_bytes, + const bool validate) + : _read_op(read_op), + _buffer(buffer), + _fd(fd), + _filename(filename), + _num_bytes(num_bytes), + _validate(validate) +{ + _cpu_buffer = _buffer.is_cuda() ? _buffer.to(torch::kCPU).pin_memory() : _buffer; + _contiguous_buffer = _cpu_buffer.contiguous(); +} + +char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } + +void io_op_desc_t::fini() +{ + if (_read_op && _buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); } +} + +deepspeed_aio_thread_t::deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config) + : _tid(tid), + _aio_config(aio_config), + _aio_ctxt(new aio_context(aio_config._block_size, aio_config._queue_depth)), + _time_to_exit(false) +{ +} + +deepspeed_aio_thread_t::~deepspeed_aio_thread_t() {} + +void deepspeed_aio_thread_t::run() +{ + while (true) { + std::shared_ptr next_io_op = nullptr; + + { + std::unique_lock lock(_work_sync._mutex); + _work_sync._cond_var.wait(lock, + [this] { return (!_work_queue.empty() || _time_to_exit); }); + if (!_work_queue.empty()) { + next_io_op = _work_queue.front(); + _work_queue.pop(); + } + } + + if (next_io_op) { + const auto base_offset = next_io_op->_num_bytes * _tid; + + std::unique_ptr xfer_ctxt(new io_xfer_ctxt( + next_io_op->_fd, base_offset, next_io_op->_num_bytes, next_io_op->data_ptr())); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap( + next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential( + next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + + { + std::lock_guard lock(_complete_sync._mutex); + _complete_queue.push(next_io_op); + } + _complete_sync._cond_var.notify_one(); + } + + if (_time_to_exit) { break; } + } +} diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.h b/csrc/aio/py_lib/deepspeed_aio_thread.h index ee099dd2d..d1cfcab8b 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.h +++ b/csrc/aio/py_lib/deepspeed_aio_thread.h @@ -1,57 +1,57 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include -#include "deepspeed_py_aio.h" - -struct io_op_desc_t { - const bool _read_op; - torch::Tensor _buffer; - int _fd; - const std::string _filename; - const long long int _num_bytes; - torch::Tensor _cpu_buffer; - torch::Tensor _contiguous_buffer; - const bool _validate; - - io_op_desc_t(const bool read_op, - const torch::Tensor& buffer, - const int fd, - const char* filename, - const long long int num_bytes, - const bool validate); - - char* data_ptr() const; - void fini(); -}; - -struct thread_sync_t { - std::mutex _mutex; - std::condition_variable _cond_var; -}; - -struct deepspeed_aio_thread_t { - const int _tid; - deepspeed_aio_config_t& _aio_config; - - std::unique_ptr _aio_ctxt; - std::queue> _work_queue; - std::queue> _complete_queue; - - bool _time_to_exit; - - struct thread_sync_t _work_sync; - struct thread_sync_t _complete_sync; - - deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config); - - ~deepspeed_aio_thread_t(); - - void run(); -}; +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include +#include "deepspeed_py_aio.h" + +struct io_op_desc_t { + const bool _read_op; + torch::Tensor _buffer; + int _fd; + const std::string _filename; + const long long int _num_bytes; + torch::Tensor _cpu_buffer; + torch::Tensor _contiguous_buffer; + const bool _validate; + + io_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int num_bytes, + const bool validate); + + char* data_ptr() const; + void fini(); +}; + +struct thread_sync_t { + std::mutex _mutex; + std::condition_variable _cond_var; +}; + +struct deepspeed_aio_thread_t { + const int _tid; + deepspeed_aio_config_t& _aio_config; + + std::unique_ptr _aio_ctxt; + std::queue> _work_queue; + std::queue> _complete_queue; + + bool _time_to_exit; + + struct thread_sync_t _work_sync; + struct thread_sync_t _complete_sync; + + deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config); + + ~deepspeed_aio_thread_t(); + + void run(); +}; diff --git a/csrc/aio/py_lib/deepspeed_py_aio.cpp b/csrc/aio/py_lib/deepspeed_py_aio.cpp index cc2895cc7..49ff1f240 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio.cpp @@ -1,121 +1,121 @@ - -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "deepspeed_py_aio.h" - -using namespace std; -using namespace std::chrono; - -#define DEBUG_DS_AIO_READ 0 -#define DEBUG_DS_AIO_WRITE 0 - -static const std::string c_library_name = "deepspeed_aio"; - -int deepspeed_py_aio_write(const torch::Tensor& buffer, - const char* filename, - const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool validate) -{ - const auto start_time = std::chrono::high_resolution_clock::now(); - deepspeed_aio_config_t config(block_size, queue_depth, single_submit, overlap_events, false); - - const auto fd = open_file(filename, false); - if (fd == -1) { return -1; } - - auto write_buffer = (char*)buffer.data_ptr(); - const auto num_write_bytes = static_cast(buffer.nbytes()); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); - std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); - - if (config._overlap_events) { - do_aio_operation_overlap(false, aio_ctxt, xfer_ctxt, &config, nullptr); - } else { - do_aio_operation_sequential(false, aio_ctxt, xfer_ctxt, &config, nullptr); - } - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - close(fd); - - if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } - - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; - return 0; -} - -int deepspeed_py_aio_read(torch::Tensor& buffer, - const char* filename, - const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool validate) -{ - const auto start_time = std::chrono::high_resolution_clock::now(); - long long num_file_bytes; - if (-1 == get_file_size(filename, num_file_bytes)) { - const auto error_code = errno; - report_file_error(filename, " fstat for read", error_code); - return -1; - } - - deepspeed_aio_config_t config(block_size, queue_depth, single_submit, overlap_events, false); - const auto fd = open_file(filename, true); - if (fd == -1) { return -1; } - - auto read_buffer = (char*)buffer.data_ptr(); - assert(static_cast(buffer.nbytes()) == num_file_bytes); - - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); - std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); - - if (config._overlap_events) { - do_aio_operation_overlap(true, aio_ctxt, xfer_ctxt, &config, nullptr); - } else { - do_aio_operation_sequential(true, aio_ctxt, xfer_ctxt, &config, nullptr); - } - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - close(fd); - - if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } - - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; - return 0; -} + +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "deepspeed_py_aio.h" + +using namespace std; +using namespace std::chrono; + +#define DEBUG_DS_AIO_READ 0 +#define DEBUG_DS_AIO_WRITE 0 + +static const std::string c_library_name = "deepspeed_aio"; + +int deepspeed_py_aio_write(const torch::Tensor& buffer, + const char* filename, + const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool validate) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + deepspeed_aio_config_t config(block_size, queue_depth, single_submit, overlap_events, false); + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto write_buffer = (char*)buffer.data_ptr(); + const auto num_write_bytes = static_cast(buffer.nbytes()); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); + std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); + + if (config._overlap_events) { + do_aio_operation_overlap(false, aio_ctxt, xfer_ctxt, &config, nullptr); + } else { + do_aio_operation_sequential(false, aio_ctxt, xfer_ctxt, &config, nullptr); + } + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + close(fd); + + if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } + + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; + return 0; +} + +int deepspeed_py_aio_read(torch::Tensor& buffer, + const char* filename, + const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool validate) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + long long num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + + deepspeed_aio_config_t config(block_size, queue_depth, single_submit, overlap_events, false); + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto read_buffer = (char*)buffer.data_ptr(); + assert(static_cast(buffer.nbytes()) == num_file_bytes); + + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); + std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); + + if (config._overlap_events) { + do_aio_operation_overlap(true, aio_ctxt, xfer_ctxt, &config, nullptr); + } else { + do_aio_operation_sequential(true, aio_ctxt, xfer_ctxt, &config, nullptr); + } + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + close(fd); + + if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } + + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; + return 0; +} diff --git a/csrc/aio/py_lib/deepspeed_py_aio.h b/csrc/aio/py_lib/deepspeed_py_aio.h index a78d57340..230d88da9 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.h +++ b/csrc/aio/py_lib/deepspeed_py_aio.h @@ -1,27 +1,27 @@ - -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include - -int deepspeed_py_aio_write(const torch::Tensor& buffer, - const char* filename, - const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool validate); - -int deepspeed_py_aio_read(torch::Tensor& buffer, - const char* filename, - const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool validate); + +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include + +int deepspeed_py_aio_write(const torch::Tensor& buffer, + const char* filename, + const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool validate); + +int deepspeed_py_aio_read(torch::Tensor& buffer, + const char* filename, + const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool validate); diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp index 4635e751d..417319f8a 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp @@ -1,282 +1,282 @@ - -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include "deepspeed_py_aio_handle.h" - -using namespace std; - -static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); } - -deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const int num_threads) - : _aio_ctxt(new aio_context(block_size, queue_depth)), - _single_submit(single_submit), - _overlap_events(overlap_events), - _num_threads(num_threads), - _aio_config(block_size, queue_depth, single_submit, overlap_events, false), - _num_pending_ops(0) -{ - for (auto i = 0; i < num_threads; ++i) { - _thread_contexts.push_back(std::make_shared(i, _aio_config)); - } - - for (auto& ctxt : _thread_contexts) { - _threads.push_back(std::thread(_start_aio_thread, ctxt)); - } -} - -deepspeed_aio_handle_t::~deepspeed_aio_handle_t() -{ - _stop_threads(); - for (auto& thr : _threads) { thr.join(); } -} - -const int deepspeed_aio_handle_t::get_block_size() const -{ - return _aio_ctxt ? _aio_ctxt->_block_size : -1; -} - -const int deepspeed_aio_handle_t::get_queue_depth() const -{ - return _aio_ctxt ? _aio_ctxt->_queue_depth : -1; -} - -const bool deepspeed_aio_handle_t::get_single_submit() const { return _single_submit; } - -const bool deepspeed_aio_handle_t::get_overlap_events() const { return _overlap_events; } - -const int deepspeed_aio_handle_t::get_thread_count() const { return _num_threads; } - -int deepspeed_aio_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate) -{ - const auto start_time = std::chrono::high_resolution_clock::now(); - - assert(_aio_ctxt); - - long long num_file_bytes; - if (-1 == get_file_size(filename, num_file_bytes)) { - const auto error_code = errno; - report_file_error(filename, " fstat for read", error_code); - return -1; - } - assert(static_cast(buffer.nbytes()) == num_file_bytes); - - const auto fd = open_file(filename, true); - if (fd == -1) { return -1; } - - auto read_buffer = (char*)buffer.data_ptr(); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } - - close(fd); - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; - return 0; -} - -int deepspeed_aio_handle_t::write(const torch::Tensor& buffer, - const char* filename, - const bool validate) -{ - assert(_aio_ctxt); - - const auto start_time = std::chrono::high_resolution_clock::now(); - - const auto fd = open_file(filename, false); - if (fd == -1) { return -1; } - - auto write_buffer = (char*)buffer.data_ptr(); - const auto num_write_bytes = static_cast(buffer.nbytes()); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - close(fd); - - if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } - - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; - return 0; -} - -void deepspeed_aio_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op) -{ - for (auto& ctxt : _thread_contexts) { - { - std::lock_guard lock(ctxt->_work_sync._mutex); - ctxt->_work_queue.push(scheduled_op); - } - ctxt->_work_sync._cond_var.notify_one(); - } - _num_pending_ops++; -} - -std::shared_ptr deepspeed_aio_handle_t::_wait_for_aio_work() -{ - std::shared_ptr completed_op = nullptr; - for (auto& ctxt : _thread_contexts) { - std::unique_lock lock(ctxt->_complete_sync._mutex); - ctxt->_complete_sync._cond_var.wait(lock, - [ctxt] { return !ctxt->_complete_queue.empty(); }); - completed_op = ctxt->_complete_queue.front(); - ctxt->_complete_queue.pop(); - } - return completed_op; -} - -void deepspeed_aio_handle_t::_stop_threads() -{ - assert(0 == _num_pending_ops); - for (auto& ctxt : _thread_contexts) { - { - std::lock_guard lock(ctxt->_work_sync._mutex); - ctxt->_time_to_exit = true; - } - ctxt->_work_sync._cond_var.notify_one(); - } -} - -int deepspeed_aio_handle_t::wait() -{ - assert(_num_pending_ops > 0); - auto num_completed_ops = 0; - - while (_num_pending_ops > 0) { - auto completed_op = _wait_for_aio_work(); - - completed_op->fini(); - - close(completed_op->_fd); - - if (completed_op->_validate) { - validate_aio_operation(completed_op->_read_op, - completed_op->_filename.c_str(), - completed_op->data_ptr(), - _num_threads * completed_op->_num_bytes); - } - --_num_pending_ops; - ++num_completed_ops; - } - - return num_completed_ops; -} - -bool deepspeed_aio_handle_t::_is_valid_parallel_aio_op(const bool read_op, - const long long int num_bytes) -{ - const auto op_string = read_op ? "Read" : "Write"; - if (num_bytes % get_thread_count()) { - std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes - << " not divisible by thread count = " << get_thread_count() << std::endl; - return false; - } - - return true; -} - -int deepspeed_aio_handle_t::pread(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async) -{ - long long num_file_bytes; - if (-1 == get_file_size(filename, num_file_bytes)) { - const auto error_code = errno; - report_file_error(filename, " fstat for read", error_code); - return -1; - } - const auto buffer_bytes = static_cast(buffer.nbytes()); - if (buffer_bytes != num_file_bytes) { - std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes - << " != " << num_file_bytes << std::endl; - } - assert(static_cast(buffer.nbytes()) == num_file_bytes); - assert((num_file_bytes % _num_threads) == 0); - - if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; } - - const auto fd = open_file(filename, true); - if (fd == -1) { return -1; } - - auto scheduled_op = std::make_shared( - true, buffer, fd, filename, (num_file_bytes / _num_threads), validate); - - _schedule_aio_work(scheduled_op); - - if (async) { return 0; } - - return wait(); -} - -int deepspeed_aio_handle_t::pwrite(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async) -{ - const auto num_write_bytes = static_cast(buffer.nbytes()); - assert((num_write_bytes % _num_threads) == 0); - - if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } - - const auto fd = open_file(filename, false); - if (fd == -1) { return -1; } - - auto scheduled_op = std::make_shared( - false, buffer, fd, filename, (num_write_bytes / _num_threads), validate); - - _schedule_aio_work(scheduled_op); - - if (async) { return 0; } - - return wait(); -} - -int deepspeed_aio_handle_t::sync_pread(torch::Tensor& buffer, const char* filename) -{ - return pread(buffer, filename, false, false); -} - -int deepspeed_aio_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename) -{ - return pwrite(buffer, filename, false, false); -} - -int deepspeed_aio_handle_t::async_pread(torch::Tensor& buffer, const char* filename) -{ - return pread(buffer, filename, false, true); -} - -int deepspeed_aio_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename) -{ - return pwrite(buffer, filename, false, true); -} + +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_py_aio_handle.h" + +using namespace std; + +static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); } + +deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int num_threads) + : _aio_ctxt(new aio_context(block_size, queue_depth)), + _single_submit(single_submit), + _overlap_events(overlap_events), + _num_threads(num_threads), + _aio_config(block_size, queue_depth, single_submit, overlap_events, false), + _num_pending_ops(0) +{ + for (auto i = 0; i < num_threads; ++i) { + _thread_contexts.push_back(std::make_shared(i, _aio_config)); + } + + for (auto& ctxt : _thread_contexts) { + _threads.push_back(std::thread(_start_aio_thread, ctxt)); + } +} + +deepspeed_aio_handle_t::~deepspeed_aio_handle_t() +{ + _stop_threads(); + for (auto& thr : _threads) { thr.join(); } +} + +const int deepspeed_aio_handle_t::get_block_size() const +{ + return _aio_ctxt ? _aio_ctxt->_block_size : -1; +} + +const int deepspeed_aio_handle_t::get_queue_depth() const +{ + return _aio_ctxt ? _aio_ctxt->_queue_depth : -1; +} + +const bool deepspeed_aio_handle_t::get_single_submit() const { return _single_submit; } + +const bool deepspeed_aio_handle_t::get_overlap_events() const { return _overlap_events; } + +const int deepspeed_aio_handle_t::get_thread_count() const { return _num_threads; } + +int deepspeed_aio_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + + assert(_aio_ctxt); + + long long num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + assert(static_cast(buffer.nbytes()) == num_file_bytes); + + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto read_buffer = (char*)buffer.data_ptr(); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + + close(fd); + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; + return 0; +} + +int deepspeed_aio_handle_t::write(const torch::Tensor& buffer, + const char* filename, + const bool validate) +{ + assert(_aio_ctxt); + + const auto start_time = std::chrono::high_resolution_clock::now(); + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto write_buffer = (char*)buffer.data_ptr(); + const auto num_write_bytes = static_cast(buffer.nbytes()); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + close(fd); + + if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } + + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; + return 0; +} + +void deepspeed_aio_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op) +{ + for (auto& ctxt : _thread_contexts) { + { + std::lock_guard lock(ctxt->_work_sync._mutex); + ctxt->_work_queue.push(scheduled_op); + } + ctxt->_work_sync._cond_var.notify_one(); + } + _num_pending_ops++; +} + +std::shared_ptr deepspeed_aio_handle_t::_wait_for_aio_work() +{ + std::shared_ptr completed_op = nullptr; + for (auto& ctxt : _thread_contexts) { + std::unique_lock lock(ctxt->_complete_sync._mutex); + ctxt->_complete_sync._cond_var.wait(lock, + [ctxt] { return !ctxt->_complete_queue.empty(); }); + completed_op = ctxt->_complete_queue.front(); + ctxt->_complete_queue.pop(); + } + return completed_op; +} + +void deepspeed_aio_handle_t::_stop_threads() +{ + assert(0 == _num_pending_ops); + for (auto& ctxt : _thread_contexts) { + { + std::lock_guard lock(ctxt->_work_sync._mutex); + ctxt->_time_to_exit = true; + } + ctxt->_work_sync._cond_var.notify_one(); + } +} + +int deepspeed_aio_handle_t::wait() +{ + assert(_num_pending_ops > 0); + auto num_completed_ops = 0; + + while (_num_pending_ops > 0) { + auto completed_op = _wait_for_aio_work(); + + completed_op->fini(); + + close(completed_op->_fd); + + if (completed_op->_validate) { + validate_aio_operation(completed_op->_read_op, + completed_op->_filename.c_str(), + completed_op->data_ptr(), + _num_threads * completed_op->_num_bytes); + } + --_num_pending_ops; + ++num_completed_ops; + } + + return num_completed_ops; +} + +bool deepspeed_aio_handle_t::_is_valid_parallel_aio_op(const bool read_op, + const long long int num_bytes) +{ + const auto op_string = read_op ? "Read" : "Write"; + if (num_bytes % get_thread_count()) { + std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes + << " not divisible by thread count = " << get_thread_count() << std::endl; + return false; + } + + return true; +} + +int deepspeed_aio_handle_t::pread(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async) +{ + long long num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + const auto buffer_bytes = static_cast(buffer.nbytes()); + if (buffer_bytes != num_file_bytes) { + std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes + << " != " << num_file_bytes << std::endl; + } + assert(static_cast(buffer.nbytes()) == num_file_bytes); + assert((num_file_bytes % _num_threads) == 0); + + if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; } + + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto scheduled_op = std::make_shared( + true, buffer, fd, filename, (num_file_bytes / _num_threads), validate); + + _schedule_aio_work(scheduled_op); + + if (async) { return 0; } + + return wait(); +} + +int deepspeed_aio_handle_t::pwrite(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async) +{ + const auto num_write_bytes = static_cast(buffer.nbytes()); + assert((num_write_bytes % _num_threads) == 0); + + if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto scheduled_op = std::make_shared( + false, buffer, fd, filename, (num_write_bytes / _num_threads), validate); + + _schedule_aio_work(scheduled_op); + + if (async) { return 0; } + + return wait(); +} + +int deepspeed_aio_handle_t::sync_pread(torch::Tensor& buffer, const char* filename) +{ + return pread(buffer, filename, false, false); +} + +int deepspeed_aio_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename) +{ + return pwrite(buffer, filename, false, false); +} + +int deepspeed_aio_handle_t::async_pread(torch::Tensor& buffer, const char* filename) +{ + return pread(buffer, filename, false, true); +} + +int deepspeed_aio_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename) +{ + return pwrite(buffer, filename, false, true); +} diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.h b/csrc/aio/py_lib/deepspeed_py_aio_handle.h index 09358f4d9..22de4c396 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.h @@ -1,68 +1,68 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include "deepspeed_aio_thread.h" - -struct deepspeed_aio_handle_t { - std::unique_ptr _aio_ctxt; - const bool _single_submit; - const bool _overlap_events; - const int _num_threads; - deepspeed_aio_config_t _aio_config; - - std::vector> _thread_contexts; - std::vector _threads; - int _num_pending_ops; - - deepspeed_aio_handle_t(const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const int num_threads); - - ~deepspeed_aio_handle_t(); - - const int get_block_size() const; - const int get_queue_depth() const; - const bool get_single_submit() const; - const bool get_overlap_events() const; - const int get_thread_count() const; - - int read(torch::Tensor& buffer, const char* filename, const bool validate); - - int write(const torch::Tensor& buffer, const char* filename, const bool validate); - - int pread(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async); - - int pwrite(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async); - - int sync_pread(torch::Tensor& buffer, const char* filename); - - int sync_pwrite(const torch::Tensor& buffer, const char* filename); - - int async_pread(torch::Tensor& buffer, const char* filename); - - int async_pwrite(const torch::Tensor& buffer, const char* filename); - - int wait(); - - void _stop_threads(); - - void _schedule_aio_work(std::shared_ptr scheduled_op); - - std::shared_ptr _wait_for_aio_work(); - - bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes); -}; +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include "deepspeed_aio_thread.h" + +struct deepspeed_aio_handle_t { + std::unique_ptr _aio_ctxt; + const bool _single_submit; + const bool _overlap_events; + const int _num_threads; + deepspeed_aio_config_t _aio_config; + + std::vector> _thread_contexts; + std::vector _threads; + int _num_pending_ops; + + deepspeed_aio_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int num_threads); + + ~deepspeed_aio_handle_t(); + + const int get_block_size() const; + const int get_queue_depth() const; + const bool get_single_submit() const; + const bool get_overlap_events() const; + const int get_thread_count() const; + + int read(torch::Tensor& buffer, const char* filename, const bool validate); + + int write(const torch::Tensor& buffer, const char* filename, const bool validate); + + int pread(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async); + + int pwrite(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async); + + int sync_pread(torch::Tensor& buffer, const char* filename); + + int sync_pwrite(const torch::Tensor& buffer, const char* filename); + + int async_pread(torch::Tensor& buffer, const char* filename); + + int async_pwrite(const torch::Tensor& buffer, const char* filename); + + int wait(); + + void _stop_threads(); + + void _schedule_aio_work(std::shared_ptr scheduled_op); + + std::shared_ptr _wait_for_aio_work(); + + bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes); +}; diff --git a/csrc/aio/py_lib/deepspeed_py_copy.cpp b/csrc/aio/py_lib/deepspeed_py_copy.cpp index 3cdb5ed34..ee51147f9 100644 --- a/csrc/aio/py_lib/deepspeed_py_copy.cpp +++ b/csrc/aio/py_lib/deepspeed_py_copy.cpp @@ -1,133 +1,133 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include "deepspeed_py_copy.h" -#include - -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) - -#if defined(__AVX512__) or defined(__AVX256__) -union AVX_Data { -#if defined(__AVX512__) - __m512 data; -#else - __m256 data; -#endif -}; -#endif - -static void helper_memcpy_1(float* dest, float* src, size_t param_size) -{ - size_t rounded_size = 0; - -#if defined(__AVX512__) or defined(__AVX256__) - - rounded_size = ROUND_DOWN(param_size, SIMD_WIDTH); - - for (size_t t = 0; t < rounded_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > rounded_size) copy_size = rounded_size - t; - size_t offset = copy_size + t; -#pragma omp parallel for - for (size_t i = t; i < offset; i += SIMD_WIDTH) { - AVX_Data src_4; - src_4.data = SIMD_LOAD(src + i); - - SIMD_STORE(dest + i, src_4.data); - } - } - -#endif - - if (param_size > rounded_size) { -#pragma omp parallel for - for (size_t k = rounded_size; k < param_size; k++) { dest[k] = src[k]; } - } -} - -static void helper_memcpy_4(float* dest, float* src, size_t param_size) -{ - size_t rounded_size = 0; - -#if defined(__AVX512__) or defined(__AVX256__) - - rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2)); - - for (size_t t = 0; t < rounded_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > rounded_size) copy_size = rounded_size - t; - size_t offset = copy_size + t; -#pragma omp parallel for - for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) { - AVX_Data src_4[4]; - src_4[0].data = SIMD_LOAD(src + i); - src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH); - src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1)); - src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3); - - SIMD_STORE(dest + i, src_4[0].data); - SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data); - SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data); - SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data); - } - } -#endif - if (param_size > rounded_size) - helper_memcpy_1((dest + rounded_size), (src + rounded_size), (param_size - rounded_size)); -} - -static void helper_mempcy_8(float* dest, float* src, size_t param_size) -{ - size_t rounded_size = 0; - -#if defined(__AVX512__) or defined(__AVX256__) - - rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2)); - - for (size_t t = 0; t < rounded_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > rounded_size) copy_size = rounded_size - t; - size_t offset = copy_size + t; -#pragma omp parallel for - for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) { - AVX_Data src_4[8]; - src_4[0].data = SIMD_LOAD(src + i); - src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH); - src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1)); - src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3); - src_4[4].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 2)); - src_4[5].data = SIMD_LOAD(src + i + SIMD_WIDTH * 5); - src_4[6].data = SIMD_LOAD(src + i + SIMD_WIDTH * 6); - src_4[7].data = SIMD_LOAD(src + i + SIMD_WIDTH * 7); - - SIMD_STORE(dest + i, src_4[0].data); - SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data); - SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data); - SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data); - SIMD_STORE(dest + i + (SIMD_WIDTH << 2), src_4[4].data); - SIMD_STORE(dest + i + SIMD_WIDTH * 5, src_4[5].data); - SIMD_STORE(dest + i + SIMD_WIDTH * 6, src_4[6].data); - SIMD_STORE(dest + i + SIMD_WIDTH * 7, src_4[7].data); - } - } -#endif - if (param_size > rounded_size) - helper_memcpy_4((dest + rounded_size), (src + rounded_size), (param_size - rounded_size)); -} - -int deepspeed_py_memcpy(torch::Tensor& dest, const torch::Tensor& src) -{ - auto dest_c = dest.contiguous(); - auto src_c = src.contiguous(); - - float* dest_ptr = (float*)dest_c.data_ptr(); - float* src_ptr = (float*)src_c.data_ptr(); - - helper_mempcy_8(dest_ptr, src_ptr, dest_c.size(0)); - - return 0; -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_py_copy.h" +#include + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) + +#if defined(__AVX512__) or defined(__AVX256__) +union AVX_Data { +#if defined(__AVX512__) + __m512 data; +#else + __m256 data; +#endif +}; +#endif + +static void helper_memcpy_1(float* dest, float* src, size_t param_size) +{ + size_t rounded_size = 0; + +#if defined(__AVX512__) or defined(__AVX256__) + + rounded_size = ROUND_DOWN(param_size, SIMD_WIDTH); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH) { + AVX_Data src_4; + src_4.data = SIMD_LOAD(src + i); + + SIMD_STORE(dest + i, src_4.data); + } + } + +#endif + + if (param_size > rounded_size) { +#pragma omp parallel for + for (size_t k = rounded_size; k < param_size; k++) { dest[k] = src[k]; } + } +} + +static void helper_memcpy_4(float* dest, float* src, size_t param_size) +{ + size_t rounded_size = 0; + +#if defined(__AVX512__) or defined(__AVX256__) + + rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2)); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) { + AVX_Data src_4[4]; + src_4[0].data = SIMD_LOAD(src + i); + src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH); + src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1)); + src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3); + + SIMD_STORE(dest + i, src_4[0].data); + SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data); + SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data); + SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data); + } + } +#endif + if (param_size > rounded_size) + helper_memcpy_1((dest + rounded_size), (src + rounded_size), (param_size - rounded_size)); +} + +static void helper_mempcy_8(float* dest, float* src, size_t param_size) +{ + size_t rounded_size = 0; + +#if defined(__AVX512__) or defined(__AVX256__) + + rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2)); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) { + AVX_Data src_4[8]; + src_4[0].data = SIMD_LOAD(src + i); + src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH); + src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1)); + src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3); + src_4[4].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 2)); + src_4[5].data = SIMD_LOAD(src + i + SIMD_WIDTH * 5); + src_4[6].data = SIMD_LOAD(src + i + SIMD_WIDTH * 6); + src_4[7].data = SIMD_LOAD(src + i + SIMD_WIDTH * 7); + + SIMD_STORE(dest + i, src_4[0].data); + SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data); + SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data); + SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data); + SIMD_STORE(dest + i + (SIMD_WIDTH << 2), src_4[4].data); + SIMD_STORE(dest + i + SIMD_WIDTH * 5, src_4[5].data); + SIMD_STORE(dest + i + SIMD_WIDTH * 6, src_4[6].data); + SIMD_STORE(dest + i + SIMD_WIDTH * 7, src_4[7].data); + } + } +#endif + if (param_size > rounded_size) + helper_memcpy_4((dest + rounded_size), (src + rounded_size), (param_size - rounded_size)); +} + +int deepspeed_py_memcpy(torch::Tensor& dest, const torch::Tensor& src) +{ + auto dest_c = dest.contiguous(); + auto src_c = src.contiguous(); + + float* dest_ptr = (float*)dest_c.data_ptr(); + float* src_ptr = (float*)src_c.data_ptr(); + + helper_mempcy_8(dest_ptr, src_ptr, dest_c.size(0)); + + return 0; +} diff --git a/csrc/aio/py_lib/deepspeed_py_copy.h b/csrc/aio/py_lib/deepspeed_py_copy.h index 819d568bb..69b044851 100644 --- a/csrc/aio/py_lib/deepspeed_py_copy.h +++ b/csrc/aio/py_lib/deepspeed_py_copy.h @@ -1,42 +1,42 @@ - -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#if (__x86_64__ || __i386__) -#include -#include -#endif - -#include -#include -#include - -#define TILE (1024 * 1024 * 1024) - -#if defined(__AVX512__) -#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) -#define SIMD_LOAD(x) _mm512_loadu_ps(x) -#define SIMD_SET(x) _mm512_set1_ps(x) -#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) -#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) -#define SIMD_SQRT(x) _mm512_sqrt_ps(x) -#define SIMD_DIV(x, y) _mm512_div_ps(x, y) -#define SIMD_WIDTH 16 -#else -#if defined(__AVX256__) -#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) -#define SIMD_LOAD(x) _mm256_loadu_ps(x) -#define SIMD_SET(x) _mm256_set1_ps(x) -#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) -#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) -#define SIMD_SQRT(x) _mm256_sqrt_ps(x) -#define SIMD_DIV(x, y) _mm256_div_ps(x, y) -#define SIMD_WIDTH 8 -#endif -#endif - -int deepspeed_py_memcpy(torch::Tensor& dest, const torch::Tensor& src); + +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#if (__x86_64__ || __i386__) +#include +#include +#endif + +#include +#include +#include + +#define TILE (1024 * 1024 * 1024) + +#if defined(__AVX512__) +#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm512_loadu_ps(x) +#define SIMD_SET(x) _mm512_set1_ps(x) +#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm512_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_WIDTH 16 +#else +#if defined(__AVX256__) +#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm256_loadu_ps(x) +#define SIMD_SET(x) _mm256_set1_ps(x) +#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm256_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_WIDTH 8 +#endif +#endif + +int deepspeed_py_memcpy(torch::Tensor& dest, const torch::Tensor& src); diff --git a/csrc/aio/py_lib/py_ds_aio.cpp b/csrc/aio/py_lib/py_ds_aio.cpp index eee2cba0a..61f95cd99 100755 --- a/csrc/aio/py_lib/py_ds_aio.cpp +++ b/csrc/aio/py_lib/py_ds_aio.cpp @@ -1,41 +1,41 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include "deepspeed_py_aio_handle.h" -#include "deepspeed_py_copy.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("aio_read", &deepspeed_py_aio_read, "DeepSpeed Asynchornous I/O Read"); - - m.def("aio_write", &deepspeed_py_aio_write, "DeepSpeed Asynchornous I/O Write"); - - m.def("deepspeed_memcpy", &deepspeed_py_memcpy, "DeepSpeed Memory Copy"); - - py::class_(m, "aio_handle") - .def(py::init()) - - .def("get_block_size", &deepspeed_aio_handle_t::get_block_size) - .def("get_queue_depth", &deepspeed_aio_handle_t::get_queue_depth) - .def("get_single_submit", &deepspeed_aio_handle_t::get_single_submit) - .def("get_overlap_events", &deepspeed_aio_handle_t::get_overlap_events) - .def("get_thread_count", &deepspeed_aio_handle_t::get_thread_count) - - .def("read", &deepspeed_aio_handle_t::read) - .def("write", &deepspeed_aio_handle_t::write) - - .def("pread", &deepspeed_aio_handle_t::pread) - .def("pwrite", &deepspeed_aio_handle_t::pwrite) - - .def("sync_pread", &deepspeed_aio_handle_t::sync_pread) - .def("sync_pwrite", &deepspeed_aio_handle_t::sync_pwrite) - .def("async_pread", &deepspeed_aio_handle_t::async_pread) - .def("async_pwrite", &deepspeed_aio_handle_t::async_pwrite) - - .def("wait", &deepspeed_aio_handle_t::wait); -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include "deepspeed_py_aio_handle.h" +#include "deepspeed_py_copy.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("aio_read", &deepspeed_py_aio_read, "DeepSpeed Asynchornous I/O Read"); + + m.def("aio_write", &deepspeed_py_aio_write, "DeepSpeed Asynchornous I/O Write"); + + m.def("deepspeed_memcpy", &deepspeed_py_memcpy, "DeepSpeed Memory Copy"); + + py::class_(m, "aio_handle") + .def(py::init()) + + .def("get_block_size", &deepspeed_aio_handle_t::get_block_size) + .def("get_queue_depth", &deepspeed_aio_handle_t::get_queue_depth) + .def("get_single_submit", &deepspeed_aio_handle_t::get_single_submit) + .def("get_overlap_events", &deepspeed_aio_handle_t::get_overlap_events) + .def("get_thread_count", &deepspeed_aio_handle_t::get_thread_count) + + .def("read", &deepspeed_aio_handle_t::read) + .def("write", &deepspeed_aio_handle_t::write) + + .def("pread", &deepspeed_aio_handle_t::pread) + .def("pwrite", &deepspeed_aio_handle_t::pwrite) + + .def("sync_pread", &deepspeed_aio_handle_t::sync_pread) + .def("sync_pwrite", &deepspeed_aio_handle_t::sync_pwrite) + .def("async_pread", &deepspeed_aio_handle_t::async_pread) + .def("async_pwrite", &deepspeed_aio_handle_t::async_pwrite) + + .def("wait", &deepspeed_aio_handle_t::wait); +} diff --git a/csrc/aio/py_test/ds_aio_basic.py b/csrc/aio/py_test/ds_aio_basic.py index e9f399d50..cf70b6655 100755 --- a/csrc/aio/py_test/ds_aio_basic.py +++ b/csrc/aio/py_test/ds_aio_basic.py @@ -1,144 +1,144 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality of swapping optimizer tensors to/from (NVMe) storage devices. -""" - -import torch -import os -import time -from deepspeed.ops.aio import AsyncIOBuilder -from multiprocessing import Pool, Barrier -from test_ds_aio_utils import report_results, task_log, task_barrier - - -def pre_basic(args, tid, read_op): - io_string = "Read" if read_op else "Write" - num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size - file = args.read_file if read_op else f'{args.write_file}.{tid}' - - task_log(tid, f'Allocate tensor of size {num_bytes} bytes') - buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory() - task_log( - tid, - f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}' - ) - - ctxt = {} - ctxt['file'] = file - ctxt['num_bytes'] = num_bytes - ctxt['buffer'] = buffer - ctxt['elapsed_sec'] = 0 - - return ctxt - - -def pre_basic_read(pool_params): - args, tid = pool_params - ctxt = pre_basic(args, tid, True) - return ctxt - - -def pre_basic_write(pool_params): - args, tid = pool_params - ctxt = pre_basic(args, tid, False) - return ctxt - - -def post_basic(pool_params): - _, _, ctxt = pool_params - ctxt["buffer"].detach() - ctxt["buffer"] = None - return ctxt - - -def main_basic_read(pool_params): - args, tid, ctxt = pool_params - start_time = time.time() - AsyncIOBuilder().load().aio_read(ctxt['buffer'], - ctxt['file'], - args.block_size, - args.queue_depth, - args.single_submit, - args.overlap_events, - args.validate) - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def main_basic_write(pool_params): - args, tid, ctxt = pool_params - start_time = time.time() - AsyncIOBuilder().load().aio_write(ctxt['buffer'], - ctxt['file'], - args.block_size, - args.queue_depth, - args.single_submit, - args.overlap_events, - args.validate) - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def get_schedule(args, read_op): - schedule = {} - if read_op: - schedule['pre'] = pre_basic_read - schedule['post'] = post_basic - schedule['main'] = main_basic_read - else: - schedule['pre'] = pre_basic_write - schedule['post'] = post_basic - schedule['main'] = main_basic_write - - return schedule - - -def _aio_handle_tasklet(pool_params): - args, tid, read_op = pool_params - - # Create schedule - schedule = get_schedule(args, read_op) - task_log(tid, f'schedule = {schedule}') - task_barrier(aio_barrier, args.threads) - - # Run pre task - task_log(tid, f'running pre-task') - ctxt = schedule["pre"]((args, tid)) - task_barrier(aio_barrier, args.threads) - - # Run main tasks in a loop - ctxt["main_task_sec"] = 0 - for i in range(args.loops): - task_log(tid, f'running main task {i}') - start_time = time.time() - ctxt = schedule["main"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) - stop_time = time.time() - ctxt["main_task_sec"] += stop_time - start_time - - # Run post task - task_log(tid, f'running post-task') - ctxt = schedule["post"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) - - return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops - - -def _init_tasklet(b): - global aio_barrier - aio_barrier = b - - -def aio_basic_multiprocessing(args, read_op): - b = Barrier(args.threads) - pool_params = [(args, p, read_op) for p in range(args.threads)] - with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: - pool_results = p.map(_aio_handle_tasklet, pool_params) - - report_results(args, read_op, pool_results) +""" +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import torch +import os +import time +from deepspeed.ops.aio import AsyncIOBuilder +from multiprocessing import Pool, Barrier +from test_ds_aio_utils import report_results, task_log, task_barrier + + +def pre_basic(args, tid, read_op): + io_string = "Read" if read_op else "Write" + num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size + file = args.read_file if read_op else f'{args.write_file}.{tid}' + + task_log(tid, f'Allocate tensor of size {num_bytes} bytes') + buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory() + task_log( + tid, + f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}' + ) + + ctxt = {} + ctxt['file'] = file + ctxt['num_bytes'] = num_bytes + ctxt['buffer'] = buffer + ctxt['elapsed_sec'] = 0 + + return ctxt + + +def pre_basic_read(pool_params): + args, tid = pool_params + ctxt = pre_basic(args, tid, True) + return ctxt + + +def pre_basic_write(pool_params): + args, tid = pool_params + ctxt = pre_basic(args, tid, False) + return ctxt + + +def post_basic(pool_params): + _, _, ctxt = pool_params + ctxt["buffer"].detach() + ctxt["buffer"] = None + return ctxt + + +def main_basic_read(pool_params): + args, tid, ctxt = pool_params + start_time = time.time() + AsyncIOBuilder().load().aio_read(ctxt['buffer'], + ctxt['file'], + args.block_size, + args.queue_depth, + args.single_submit, + args.overlap_events, + args.validate) + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def main_basic_write(pool_params): + args, tid, ctxt = pool_params + start_time = time.time() + AsyncIOBuilder().load().aio_write(ctxt['buffer'], + ctxt['file'], + args.block_size, + args.queue_depth, + args.single_submit, + args.overlap_events, + args.validate) + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def get_schedule(args, read_op): + schedule = {} + if read_op: + schedule['pre'] = pre_basic_read + schedule['post'] = post_basic + schedule['main'] = main_basic_read + else: + schedule['pre'] = pre_basic_write + schedule['post'] = post_basic + schedule['main'] = main_basic_write + + return schedule + + +def _aio_handle_tasklet(pool_params): + args, tid, read_op = pool_params + + # Create schedule + schedule = get_schedule(args, read_op) + task_log(tid, f'schedule = {schedule}') + task_barrier(aio_barrier, args.threads) + + # Run pre task + task_log(tid, f'running pre-task') + ctxt = schedule["pre"]((args, tid)) + task_barrier(aio_barrier, args.threads) + + # Run main tasks in a loop + ctxt["main_task_sec"] = 0 + for i in range(args.loops): + task_log(tid, f'running main task {i}') + start_time = time.time() + ctxt = schedule["main"]((args, tid, ctxt)) + task_barrier(aio_barrier, args.threads) + stop_time = time.time() + ctxt["main_task_sec"] += stop_time - start_time + + # Run post task + task_log(tid, f'running post-task') + ctxt = schedule["post"]((args, tid, ctxt)) + task_barrier(aio_barrier, args.threads) + + return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops + + +def _init_tasklet(b): + global aio_barrier + aio_barrier = b + + +def aio_basic_multiprocessing(args, read_op): + b = Barrier(args.threads) + pool_params = [(args, p, read_op) for p in range(args.threads)] + with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: + pool_results = p.map(_aio_handle_tasklet, pool_params) + + report_results(args, read_op, pool_results) diff --git a/csrc/aio/py_test/ds_aio_handle.py b/csrc/aio/py_test/ds_aio_handle.py index 68abbe802..947ee2e6c 100755 --- a/csrc/aio/py_test/ds_aio_handle.py +++ b/csrc/aio/py_test/ds_aio_handle.py @@ -1,176 +1,176 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality of swapping optimizer tensors to/from (NVMe) storage devices. -""" - -import torch -import os -import time -from multiprocessing import Pool, Barrier -from deepspeed.ops.aio import AsyncIOBuilder -from test_ds_aio_utils import report_results, task_log, task_barrier - - -def pre_handle(args, tid, read_op): - io_string = "Read" if read_op else "Write" - num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size - file = args.read_file if read_op else f'{args.write_file}.{tid}' - - task_log(tid, f'Allocate tensor of size {num_bytes} bytes') - if args.gpu: - buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cuda') - else: - buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory() - task_log( - tid, - f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}' - ) - - io_parallel = args.io_parallel if args.io_parallel else 1 - handle = AsyncIOBuilder().load().aio_handle(args.block_size, - args.queue_depth, - args.single_submit, - args.overlap_events, - io_parallel) - task_log(tid, f'created deepspeed aio handle') - - ctxt = {} - ctxt['file'] = file - ctxt['num_bytes'] = num_bytes - ctxt['handle'] = handle - ctxt['buffer'] = buffer - ctxt['elapsed_sec'] = 0 - - return ctxt - - -def pre_handle_read(pool_params): - args, tid = pool_params - ctxt = pre_handle(args, tid, True) - return ctxt - - -def pre_handle_write(pool_params): - args, tid = pool_params - ctxt = pre_handle(args, tid, False) - return ctxt - - -def post_handle(pool_params): - _, _, ctxt = pool_params - ctxt["buffer"].detach() - ctxt["buffer"] = None - return ctxt - - -def main_parallel_read(pool_params): - args, tid, ctxt = pool_params - handle = ctxt['handle'] - - start_time = time.time() - ret = handle.pread(ctxt['buffer'], ctxt['file'], args.validate, True) - assert ret != -1 - handle.wait() - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def main_parallel_write(pool_params): - args, tid, ctxt = pool_params - handle = ctxt['handle'] - start_time = time.time() - ret = handle.pwrite(ctxt['buffer'], ctxt['file'], args.validate, True) - assert ret != -1 - handle.wait() - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def main_handle_read(pool_parms): - args, tid, ctxt = pool_parms - handle = ctxt['handle'] - - start_time = time.time() - ret = handle.read(ctxt['buffer'], ctxt['file'], args.validate) - assert ret != -1 - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def main_handle_write(pool_parms): - args, tid, ctxt = pool_parms - handle = ctxt['handle'] - start_time = time.time() - ret = handle.write(ctxt['buffer'], ctxt['file'], args.validate) - assert ret != -1 - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def get_schedule(args, read_op): - schedule = {} - if read_op: - schedule['pre'] = pre_handle_read - schedule['post'] = post_handle - schedule['main'] = main_parallel_read if args.io_parallel else main_handle_read - else: - schedule['pre'] = pre_handle_write - schedule['post'] = post_handle - schedule['main'] = main_parallel_write if args.io_parallel else main_handle_write - - return schedule - - -def _aio_handle_tasklet(pool_params): - args, tid, read_op = pool_params - - # Create schedule - schedule = get_schedule(args, read_op) - task_log(tid, f'schedule = {schedule}') - task_barrier(aio_barrier, args.threads) - - # Run pre task - task_log(tid, f'running pre-task') - ctxt = schedule["pre"]((args, tid)) - task_barrier(aio_barrier, args.threads) - - # Run main tasks in a loop - ctxt["main_task_sec"] = 0 - for i in range(args.loops): - task_log(tid, f'running main task {i}') - start_time = time.time() - ctxt = schedule["main"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) - stop_time = time.time() - ctxt["main_task_sec"] += stop_time - start_time - - # Run post task - task_log(tid, f'running post-task') - ctxt = schedule["post"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) - - return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops - - -def _init_tasklet(b): - global aio_barrier - aio_barrier = b - - -def aio_handle_multiprocessing(args, read_op): - b = Barrier(args.threads) - pool_params = [(args, p, read_op) for p in range(args.threads)] - with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: - pool_results = p.map(_aio_handle_tasklet, pool_params) - - report_results(args, read_op, pool_results) +""" +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import torch +import os +import time +from multiprocessing import Pool, Barrier +from deepspeed.ops.aio import AsyncIOBuilder +from test_ds_aio_utils import report_results, task_log, task_barrier + + +def pre_handle(args, tid, read_op): + io_string = "Read" if read_op else "Write" + num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size + file = args.read_file if read_op else f'{args.write_file}.{tid}' + + task_log(tid, f'Allocate tensor of size {num_bytes} bytes') + if args.gpu: + buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cuda') + else: + buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory() + task_log( + tid, + f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}' + ) + + io_parallel = args.io_parallel if args.io_parallel else 1 + handle = AsyncIOBuilder().load().aio_handle(args.block_size, + args.queue_depth, + args.single_submit, + args.overlap_events, + io_parallel) + task_log(tid, f'created deepspeed aio handle') + + ctxt = {} + ctxt['file'] = file + ctxt['num_bytes'] = num_bytes + ctxt['handle'] = handle + ctxt['buffer'] = buffer + ctxt['elapsed_sec'] = 0 + + return ctxt + + +def pre_handle_read(pool_params): + args, tid = pool_params + ctxt = pre_handle(args, tid, True) + return ctxt + + +def pre_handle_write(pool_params): + args, tid = pool_params + ctxt = pre_handle(args, tid, False) + return ctxt + + +def post_handle(pool_params): + _, _, ctxt = pool_params + ctxt["buffer"].detach() + ctxt["buffer"] = None + return ctxt + + +def main_parallel_read(pool_params): + args, tid, ctxt = pool_params + handle = ctxt['handle'] + + start_time = time.time() + ret = handle.pread(ctxt['buffer'], ctxt['file'], args.validate, True) + assert ret != -1 + handle.wait() + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def main_parallel_write(pool_params): + args, tid, ctxt = pool_params + handle = ctxt['handle'] + start_time = time.time() + ret = handle.pwrite(ctxt['buffer'], ctxt['file'], args.validate, True) + assert ret != -1 + handle.wait() + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def main_handle_read(pool_parms): + args, tid, ctxt = pool_parms + handle = ctxt['handle'] + + start_time = time.time() + ret = handle.read(ctxt['buffer'], ctxt['file'], args.validate) + assert ret != -1 + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def main_handle_write(pool_parms): + args, tid, ctxt = pool_parms + handle = ctxt['handle'] + start_time = time.time() + ret = handle.write(ctxt['buffer'], ctxt['file'], args.validate) + assert ret != -1 + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def get_schedule(args, read_op): + schedule = {} + if read_op: + schedule['pre'] = pre_handle_read + schedule['post'] = post_handle + schedule['main'] = main_parallel_read if args.io_parallel else main_handle_read + else: + schedule['pre'] = pre_handle_write + schedule['post'] = post_handle + schedule['main'] = main_parallel_write if args.io_parallel else main_handle_write + + return schedule + + +def _aio_handle_tasklet(pool_params): + args, tid, read_op = pool_params + + # Create schedule + schedule = get_schedule(args, read_op) + task_log(tid, f'schedule = {schedule}') + task_barrier(aio_barrier, args.threads) + + # Run pre task + task_log(tid, f'running pre-task') + ctxt = schedule["pre"]((args, tid)) + task_barrier(aio_barrier, args.threads) + + # Run main tasks in a loop + ctxt["main_task_sec"] = 0 + for i in range(args.loops): + task_log(tid, f'running main task {i}') + start_time = time.time() + ctxt = schedule["main"]((args, tid, ctxt)) + task_barrier(aio_barrier, args.threads) + stop_time = time.time() + ctxt["main_task_sec"] += stop_time - start_time + + # Run post task + task_log(tid, f'running post-task') + ctxt = schedule["post"]((args, tid, ctxt)) + task_barrier(aio_barrier, args.threads) + + return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops + + +def _init_tasklet(b): + global aio_barrier + aio_barrier = b + + +def aio_handle_multiprocessing(args, read_op): + b = Barrier(args.threads) + pool_params = [(args, p, read_op) for p in range(args.threads)] + with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: + pool_results = p.map(_aio_handle_tasklet, pool_params) + + report_results(args, read_op, pool_results) diff --git a/csrc/aio/py_test/parse_aio_stats.py b/csrc/aio/py_test/parse_aio_stats.py index 3e4600a46..1921973e4 100755 --- a/csrc/aio/py_test/parse_aio_stats.py +++ b/csrc/aio/py_test/parse_aio_stats.py @@ -1,154 +1,154 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality of swapping optimizer tensors to/from (NVMe) storage devices. -""" - -import os -import argparse -import re - -READ_SPEED = 'read_speed' -WRITE_SPEED = 'write_speed' - -PERF_METRICS = [READ_SPEED, WRITE_SPEED] - -METRIC_SEARCH = {READ_SPEED: 'E2E Read Speed', WRITE_SPEED: 'E2E Write Speed'} - - -def parse_arguments(): - parser = argparse.ArgumentParser() - - parser.add_argument('--log_dir', - type=str, - required=True, - help='Folder of statistics logs') - - parser.add_argument('--metric', - type=str, - required=True, - help='Performance metric to report: [read_speed|write_speed]') - - args = parser.parse_args() - print(f'args = {args}') - - return args - - -def extract_value(key, file): - INVALID_PREFIXES = ["ds"] - for p in INVALID_PREFIXES: - if key.startswith(p): - return key - try: - if key[0] in ['t', 'd', 'p']: - return int(key[1:]) - if key.startswith("bs"): - if key.endswith('K'): - v = key[2:].split('K') - return int(v[0]) * 1024 - elif key.endswith('M'): - v = key[2:].split('M') - return int(v[0]) * 1024 * 1024 - else: - return int(key[2:]) - except: - print(f"{file}: extract_value fails on {key}") - return None - - return key - - -def get_file_key(file): - f, _ = os.path.splitext(os.path.basename(file)) - fields = f.split('_') - values = [extract_value(k, file) for k in fields] - return tuple(values) - - -def get_thread_count(file): - f, _ = os.path.splitext(os.path.basename(file)) - fields = f.split('_') - for key in fields: - if key[0] == 't': - return int(key[1:]) - return 1 - - -""" -Extract performance metric from log file. -Sample file lines are: -Task Read Latency = 0.031647682189941406 sec -Task Read Speed = 12.342926020792527 GB/sec -E2E Read Latency = 0.031697988510131836 sec -E2E Read Speed = 12.323337169333062 GB/sec - -For the above sample, -metric = "read_speed" corresponds to "E2E Read Speed", and 12.32 will be returned -""" - - -def get_metric(file, metric): - thread_count = get_thread_count(file) - with open(file) as f: - for line in f.readlines(): - if line.startswith(METRIC_SEARCH[metric]): - if metric in [READ_SPEED, WRITE_SPEED]: - fields = line.split() - return float(fields[-2]) - else: - fields = line.split('=') - return float(fields[-1]) - - return None - - -def validate_args(args): - if not args.metric in PERF_METRICS: - print(f'{args.metric} is not a valid performance metrics') - return False - - if not os.path.isdir(args.log_dir): - print(f'{args.log_dir} folder is not existent') - return False - - return True - - -def get_results(log_files, metric): - results = {} - for f in log_files: - file_key = get_file_key(f) - value = get_metric(f, metric) - results[file_key] = value - - return results - - -def get_sorted_results(log_dir, metric): - log_files = [ - f for f in os.listdir(log_dir) if os.path.isfile(os.path.join(log_dir, - f)) - ] - - log_files_path = [os.path.join(log_dir, f) for f in log_files] - results = get_results(log_files_path, metric) - result_keys = list(results.keys()) - sorted_keys = sorted(result_keys) - return sorted_keys, results - - -def main(): - print("Parsing aio statistics") - args = parse_arguments() - - if not validate_args(args): - quit() - - sorted_keys, results = get_sorted_results(args.log_dir, args.metric) - for k in sorted_keys: - print(f'{k} = {results[k]}') - - -if __name__ == "__main__": - main() +""" +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import os +import argparse +import re + +READ_SPEED = 'read_speed' +WRITE_SPEED = 'write_speed' + +PERF_METRICS = [READ_SPEED, WRITE_SPEED] + +METRIC_SEARCH = {READ_SPEED: 'E2E Read Speed', WRITE_SPEED: 'E2E Write Speed'} + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument('--log_dir', + type=str, + required=True, + help='Folder of statistics logs') + + parser.add_argument('--metric', + type=str, + required=True, + help='Performance metric to report: [read_speed|write_speed]') + + args = parser.parse_args() + print(f'args = {args}') + + return args + + +def extract_value(key, file): + INVALID_PREFIXES = ["ds"] + for p in INVALID_PREFIXES: + if key.startswith(p): + return key + try: + if key[0] in ['t', 'd', 'p']: + return int(key[1:]) + if key.startswith("bs"): + if key.endswith('K'): + v = key[2:].split('K') + return int(v[0]) * 1024 + elif key.endswith('M'): + v = key[2:].split('M') + return int(v[0]) * 1024 * 1024 + else: + return int(key[2:]) + except: + print(f"{file}: extract_value fails on {key}") + return None + + return key + + +def get_file_key(file): + f, _ = os.path.splitext(os.path.basename(file)) + fields = f.split('_') + values = [extract_value(k, file) for k in fields] + return tuple(values) + + +def get_thread_count(file): + f, _ = os.path.splitext(os.path.basename(file)) + fields = f.split('_') + for key in fields: + if key[0] == 't': + return int(key[1:]) + return 1 + + +""" +Extract performance metric from log file. +Sample file lines are: +Task Read Latency = 0.031647682189941406 sec +Task Read Speed = 12.342926020792527 GB/sec +E2E Read Latency = 0.031697988510131836 sec +E2E Read Speed = 12.323337169333062 GB/sec + +For the above sample, -metric = "read_speed" corresponds to "E2E Read Speed", and 12.32 will be returned +""" + + +def get_metric(file, metric): + thread_count = get_thread_count(file) + with open(file) as f: + for line in f.readlines(): + if line.startswith(METRIC_SEARCH[metric]): + if metric in [READ_SPEED, WRITE_SPEED]: + fields = line.split() + return float(fields[-2]) + else: + fields = line.split('=') + return float(fields[-1]) + + return None + + +def validate_args(args): + if not args.metric in PERF_METRICS: + print(f'{args.metric} is not a valid performance metrics') + return False + + if not os.path.isdir(args.log_dir): + print(f'{args.log_dir} folder is not existent') + return False + + return True + + +def get_results(log_files, metric): + results = {} + for f in log_files: + file_key = get_file_key(f) + value = get_metric(f, metric) + results[file_key] = value + + return results + + +def get_sorted_results(log_dir, metric): + log_files = [ + f for f in os.listdir(log_dir) if os.path.isfile(os.path.join(log_dir, + f)) + ] + + log_files_path = [os.path.join(log_dir, f) for f in log_files] + results = get_results(log_files_path, metric) + result_keys = list(results.keys()) + sorted_keys = sorted(result_keys) + return sorted_keys, results + + +def main(): + print("Parsing aio statistics") + args = parse_arguments() + + if not validate_args(args): + quit() + + sorted_keys, results = get_sorted_results(args.log_dir, args.metric) + for k in sorted_keys: + print(f'{k} = {results[k]}') + + +if __name__ == "__main__": + main() diff --git a/csrc/aio/py_test/test_ds_aio.py b/csrc/aio/py_test/test_ds_aio.py index db7c12d7f..f97d3e676 100755 --- a/csrc/aio/py_test/test_ds_aio.py +++ b/csrc/aio/py_test/test_ds_aio.py @@ -1,101 +1,101 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality of swapping optimizer tensors to/from (NVMe) storage devices. -""" - -import os -import torch -import argparse -import time -import sys -from multiprocessing import Pool -import multiprocessing as mp -from ds_aio_basic import aio_basic_multiprocessing -from ds_aio_handle import aio_handle_multiprocessing -from test_ds_aio_utils import refine_args - - -def parse_arguments(): - parser = argparse.ArgumentParser() - - parser.add_argument('--read_file', type=str, default=None, help='Read file.') - - parser.add_argument('--write_file', type=str, default=None, help='Write file.') - - parser.add_argument('--write_size', - type=str, - default=None, - help='Number of bytes to write.') - - parser.add_argument('--block_size', type=str, default='1M', help='I/O block size.') - - parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth.') - - parser.add_argument('--threads', - type=int, - default=1, - help='Thread parallelism count.') - - parser.add_argument( - '--single_submit', - action='store_true', - help= - 'Submit I/O requests in singles (default is submit queue_depth amount at once.).' - ) - - parser.add_argument('--overlap_events', - action='store_true', - help='Overlap I/O submission and completion requests.') - - parser.add_argument('--validate', - action='store_true', - help='Perform validation in library.') - - parser.add_argument('--handle', action='store_true', help='Use AIO handle.') - - parser.add_argument('--loops', - type=int, - default=1, - help='Count of operation repetitions') - - parser.add_argument('--io_parallel', - type=int, - default=None, - help='Per iop parallelism') - - parser.add_argument('--gpu', action='store_true', help='Use GPU memory') - - args = parser.parse_args() - print(f'args = {args}') - return args - - -def validate_args(args): - if args.read_file and not os.path.isfile(args.read_file): - print(f'args validation error: {args.read_file} not found') - return False - - return True - - -def main(): - print(f'Testing deepspeed_aio python frontend') - - args = parse_arguments() - refine_args(args) - if not validate_args(args): - quit() - - mp.set_start_method('spawn') - multiprocess_function = aio_handle_multiprocessing if args.handle else aio_basic_multiprocessing - if args.read_file: - multiprocess_function(args, True) - - if args.write_file: - multiprocess_function(args, False) - - -if __name__ == "__main__": - main() +""" +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import os +import torch +import argparse +import time +import sys +from multiprocessing import Pool +import multiprocessing as mp +from ds_aio_basic import aio_basic_multiprocessing +from ds_aio_handle import aio_handle_multiprocessing +from test_ds_aio_utils import refine_args + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument('--read_file', type=str, default=None, help='Read file.') + + parser.add_argument('--write_file', type=str, default=None, help='Write file.') + + parser.add_argument('--write_size', + type=str, + default=None, + help='Number of bytes to write.') + + parser.add_argument('--block_size', type=str, default='1M', help='I/O block size.') + + parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth.') + + parser.add_argument('--threads', + type=int, + default=1, + help='Thread parallelism count.') + + parser.add_argument( + '--single_submit', + action='store_true', + help= + 'Submit I/O requests in singles (default is submit queue_depth amount at once.).' + ) + + parser.add_argument('--overlap_events', + action='store_true', + help='Overlap I/O submission and completion requests.') + + parser.add_argument('--validate', + action='store_true', + help='Perform validation in library.') + + parser.add_argument('--handle', action='store_true', help='Use AIO handle.') + + parser.add_argument('--loops', + type=int, + default=1, + help='Count of operation repetitions') + + parser.add_argument('--io_parallel', + type=int, + default=None, + help='Per iop parallelism') + + parser.add_argument('--gpu', action='store_true', help='Use GPU memory') + + args = parser.parse_args() + print(f'args = {args}') + return args + + +def validate_args(args): + if args.read_file and not os.path.isfile(args.read_file): + print(f'args validation error: {args.read_file} not found') + return False + + return True + + +def main(): + print(f'Testing deepspeed_aio python frontend') + + args = parse_arguments() + refine_args(args) + if not validate_args(args): + quit() + + mp.set_start_method('spawn') + multiprocess_function = aio_handle_multiprocessing if args.handle else aio_basic_multiprocessing + if args.read_file: + multiprocess_function(args, True) + + if args.write_file: + multiprocess_function(args, False) + + +if __name__ == "__main__": + main() diff --git a/csrc/aio/py_test/test_ds_aio_utils.py b/csrc/aio/py_test/test_ds_aio_utils.py index fa0f0f6be..c68dfdddc 100755 --- a/csrc/aio/py_test/test_ds_aio_utils.py +++ b/csrc/aio/py_test/test_ds_aio_utils.py @@ -1,59 +1,59 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality of swapping optimizer tensors to/from (NVMe) storage devices. -""" - -import os - -BYTES_PER_GB = 1024**3 -LOG_TIDS = [0] - - -def task_log(tid, msg): - if tid in LOG_TIDS: - print(f'tid {tid}: {msg}') - - -def task_barrier(barrier, num_parties): - assert barrier.parties == num_parties - barrier.wait() - assert barrier.broken == False - - -def report_results(args, read_op, pool_results): - #print(f'pool_results = {pool_results}') - io_string = 'Read' if read_op else 'Write' - if None in pool_results: - print(f'Failure in one of {args.threads} {io_string} processes') - return - - total_bytes = sum([num_bytes for _, _, num_bytes in pool_results]) - - task_latency_sec = max([sec for _, sec, _ in pool_results]) - task_speed_GB = total_bytes / task_latency_sec / BYTES_PER_GB - print(f'Task {io_string} Latency = {task_latency_sec} sec') - print(f'Task {io_string} Speed = {task_speed_GB} GB/sec') - - e2e_latency_sec = max([sec for sec, _, _ in pool_results]) - e2e_speed_GB = total_bytes / e2e_latency_sec / BYTES_PER_GB - print(f'E2E {io_string} Latency = {e2e_latency_sec} sec') - print(f'E2E {io_string} Speed = {e2e_speed_GB} GB/sec') - - -def refine_integer_value(value): - unit_dict = {'K': 1024, 'M': 1024**2, 'G': 1024**3} - - if value[-1] in list(unit_dict.keys()): - int_value = int(value[:-1]) * unit_dict[value[-1]] - return int_value - return int(value) - - -def refine_args(args): - if args.write_size and type(args.write_size) == str: - args.write_size = refine_integer_value(args.write_size) - - if args.block_size and type(args.block_size) == str: - args.block_size = refine_integer_value(args.block_size) +""" +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import os + +BYTES_PER_GB = 1024**3 +LOG_TIDS = [0] + + +def task_log(tid, msg): + if tid in LOG_TIDS: + print(f'tid {tid}: {msg}') + + +def task_barrier(barrier, num_parties): + assert barrier.parties == num_parties + barrier.wait() + assert barrier.broken == False + + +def report_results(args, read_op, pool_results): + #print(f'pool_results = {pool_results}') + io_string = 'Read' if read_op else 'Write' + if None in pool_results: + print(f'Failure in one of {args.threads} {io_string} processes') + return + + total_bytes = sum([num_bytes for _, _, num_bytes in pool_results]) + + task_latency_sec = max([sec for _, sec, _ in pool_results]) + task_speed_GB = total_bytes / task_latency_sec / BYTES_PER_GB + print(f'Task {io_string} Latency = {task_latency_sec} sec') + print(f'Task {io_string} Speed = {task_speed_GB} GB/sec') + + e2e_latency_sec = max([sec for sec, _, _ in pool_results]) + e2e_speed_GB = total_bytes / e2e_latency_sec / BYTES_PER_GB + print(f'E2E {io_string} Latency = {e2e_latency_sec} sec') + print(f'E2E {io_string} Speed = {e2e_speed_GB} GB/sec') + + +def refine_integer_value(value): + unit_dict = {'K': 1024, 'M': 1024**2, 'G': 1024**3} + + if value[-1] in list(unit_dict.keys()): + int_value = int(value[:-1]) * unit_dict[value[-1]] + return int_value + return int(value) + + +def refine_args(args): + if args.write_size and type(args.write_size) == str: + args.write_size = refine_integer_value(args.write_size) + + if args.block_size and type(args.block_size) == str: + args.block_size = refine_integer_value(args.block_size) diff --git a/csrc/common/custom_cuda_kernel.cu b/csrc/common/custom_cuda_kernel.cu index dee09aac5..f7a2b5d48 100644 --- a/csrc/common/custom_cuda_kernel.cu +++ b/csrc/common/custom_cuda_kernel.cu @@ -1,39 +1,39 @@ -#include "custom_cuda_layers.h" - -__global__ void param_update_kernel(const float* input, __half* output, int size) -{ - int id = blockIdx.x * blockDim.x + threadIdx.x; - - if (id < size) { output[id] = (__half)input[id]; } -} - -void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream) -{ - int threads = 1024; - - dim3 grid_dim((size - 1) / threads + 1); - dim3 block_dim(threads); - - param_update_kernel<<>>(input, output, size); -} - -__global__ void param_update_kernel_half(const float* input, __half* output, int size) -{ - int id = blockIdx.x * blockDim.x + threadIdx.x; - __half2* output_cast = reinterpret_cast<__half2*>(output); - if (id < size) { - float input_f = input[id]; - __half2* input_h = reinterpret_cast<__half2*>(&input_f); - output_cast[id] = *input_h; - } -} - -void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream) -{ - int threads = 1024; - size /= 2; - dim3 grid_dim((size - 1) / threads + 1); - dim3 block_dim(threads); - - param_update_kernel_half<<>>(input, output, size); -} +#include "custom_cuda_layers.h" + +__global__ void param_update_kernel(const float* input, __half* output, int size) +{ + int id = blockIdx.x * blockDim.x + threadIdx.x; + + if (id < size) { output[id] = (__half)input[id]; } +} + +void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream) +{ + int threads = 1024; + + dim3 grid_dim((size - 1) / threads + 1); + dim3 block_dim(threads); + + param_update_kernel<<>>(input, output, size); +} + +__global__ void param_update_kernel_half(const float* input, __half* output, int size) +{ + int id = blockIdx.x * blockDim.x + threadIdx.x; + __half2* output_cast = reinterpret_cast<__half2*>(output); + if (id < size) { + float input_f = input[id]; + __half2* input_h = reinterpret_cast<__half2*>(&input_f); + output_cast[id] = *input_h; + } +} + +void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream) +{ + int threads = 1024; + size /= 2; + dim3 grid_dim((size - 1) / threads + 1); + dim3 block_dim(threads); + + param_update_kernel_half<<>>(input, output, size); +} diff --git a/csrc/includes/Timer.h b/csrc/includes/Timer.h index 7c20854a0..efc7fff84 100644 --- a/csrc/includes/Timer.h +++ b/csrc/includes/Timer.h @@ -1,47 +1,47 @@ - -#ifndef __TIMER_H__ -#define __TIMER_H__ - -#include -#include -#include "cuda.h" - -class GPUTimer { - cudaEvent_t start, stop; - -public: - GPUTimer() - { - cudaEventCreate(&start); - cudaEventCreate(&stop); - } - ~GPUTimer() - { - cudaEventDestroy(start); - cudaEventDestroy(stop); - } - inline void Record() { cudaEventRecord(start); } - inline void Elapsed(float& time_elapsed) - { - cudaEventRecord(stop); - cudaEventSynchronize(stop); - cudaEventElapsedTime(&time_elapsed, start, stop); - } -}; - -class CPUTimer { - std::chrono::high_resolution_clock::time_point start; - -public: - CPUTimer() : start(std::chrono::high_resolution_clock::now()) {} - inline void Reset() { start = std::chrono::high_resolution_clock::now(); } - inline float Elapsed() - { - auto temp = start; - start = std::chrono::high_resolution_clock::now(); - return (float)(std::chrono::duration_cast(start - temp).count() / - 1e3); - } -}; - -#endif + +#ifndef __TIMER_H__ +#define __TIMER_H__ + +#include +#include +#include "cuda.h" + +class GPUTimer { + cudaEvent_t start, stop; + +public: + GPUTimer() + { + cudaEventCreate(&start); + cudaEventCreate(&stop); + } + ~GPUTimer() + { + cudaEventDestroy(start); + cudaEventDestroy(stop); + } + inline void Record() { cudaEventRecord(start); } + inline void Elapsed(float& time_elapsed) + { + cudaEventRecord(stop); + cudaEventSynchronize(stop); + cudaEventElapsedTime(&time_elapsed, start, stop); + } +}; + +class CPUTimer { + std::chrono::high_resolution_clock::time_point start; + +public: + CPUTimer() : start(std::chrono::high_resolution_clock::now()) {} + inline void Reset() { start = std::chrono::high_resolution_clock::now(); } + inline float Elapsed() + { + auto temp = start; + start = std::chrono::high_resolution_clock::now(); + return (float)(std::chrono::duration_cast(start - temp).count() / + 1e3); + } +}; + +#endif diff --git a/csrc/includes/cpu_adagrad.h b/csrc/includes/cpu_adagrad.h index c398246f1..6bfd09bd2 100644 --- a/csrc/includes/cpu_adagrad.h +++ b/csrc/includes/cpu_adagrad.h @@ -1,147 +1,147 @@ -#pragma once - -#include -#include -#include -#include -#include "cuda.h" -#include "custom_cuda_layers.h" -#include "simd.h" - -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg_sq, \ - size_t _param_size, \ - __half* dev_param = nullptr, \ - bool half_precision = false); - -class Adagrad_Optimizer { -public: - Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0) - : _alpha(alpha), _eps(eps), _weight_decay(weight_decay), _buf_index(false) - { - cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _streams[0] = Context::Instance().GetCurrentStream(); - _streams[1] = Context::Instance().GetNewStream(); - } - ~Adagrad_Optimizer() - { - cudaFreeHost(_doubled_buffer[0]); - cudaFreeHost(_doubled_buffer[1]); - } -#if defined(__AVX512__) or defined(__AVX256__) - template - void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg_sq, - size_t param_size, - __half* dev_param = nullptr, - bool half_precision = false); -#endif - STEP(1) - STEP(4) - STEP(8) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); - } - inline void IncrementStep(size_t step) - { - _step++; - if (_step != step) { _step = step; } - } - inline void update_state(float lr, float epsilon, float weight_decay) - { - _alpha = lr; - _eps = epsilon; - _weight_decay = weight_decay; - } - -private: - float _alpha; - float _eps; - float _weight_decay; - - float _betta1_t; - float _betta2_t; - size_t _step; - - float* _doubled_buffer[2]; - bool _buf_index; - - cudaStream_t _streams[2]; -}; - -#if defined(__AVX512__) or defined(__AVX256__) -template -void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t new_rounded_size = 0; - AVX_Data eps_4; - eps_4.data = SIMD_SET(_eps); - - float step_size = -1 * _alpha; - AVX_Data step_size_4; - step_size_4.data = SIMD_SET(step_size); - - AVX_Data weight_decay4; - if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay); - new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); - for (size_t t = 0; t < new_rounded_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; - size_t offset = copy_size + t; - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#pragma omp parallel for - for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { - AVX_Data grad_4[span]; - simd_load(grad_4, grads + i, half_precision); - - AVX_Data momentum_4[span]; - simd_load(momentum_4, grads + i, false); - - AVX_Data variance_4[span]; - simd_load(variance_4, _exp_avg_sq + i, false); - - AVX_Data param_4[span]; - simd_load(param_4, _params + i, half_precision); - - if (_weight_decay > 0) { simd_fma(grad_4, param_4, weight_decay4, grad_4); } - - simd_fma(variance_4, grad_4, grad_4, variance_4); - simd_sqrt(grad_4, variance_4); - simd_add(grad_4, grad_4, eps_4); - simd_div(grad_4, momentum_4, grad_4); - simd_fma(param_4, grad_4, step_size_4, param_4); - - simd_store(_params + i, param_4, half_precision); - if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); - } - simd_store(_exp_avg_sq + i, variance_4, false); - } - - if (dev_params) { - if (half_precision) - launch_param_update_half( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - else - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - - _buf_index = !_buf_index; - } - } - *rounded_size = new_rounded_size; -} -#endif +#pragma once + +#include +#include +#include +#include +#include "cuda.h" +#include "custom_cuda_layers.h" +#include "simd.h" + +#define STEP(SPAN) \ + void Step_##SPAN(float* _params, \ + float* grads, \ + float* _exp_avg_sq, \ + size_t _param_size, \ + __half* dev_param = nullptr, \ + bool half_precision = false); + +class Adagrad_Optimizer { +public: + Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0) + : _alpha(alpha), _eps(eps), _weight_decay(weight_decay), _buf_index(false) + { + cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); + cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); + + _streams[0] = Context::Instance().GetCurrentStream(); + _streams[1] = Context::Instance().GetNewStream(); + } + ~Adagrad_Optimizer() + { + cudaFreeHost(_doubled_buffer[0]); + cudaFreeHost(_doubled_buffer[1]); + } +#if defined(__AVX512__) or defined(__AVX256__) + template + void Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg_sq, + size_t param_size, + __half* dev_param = nullptr, + bool half_precision = false); +#endif + STEP(1) + STEP(4) + STEP(8) + inline void SynchronizeStreams() + { + for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); + } + inline void IncrementStep(size_t step) + { + _step++; + if (_step != step) { _step = step; } + } + inline void update_state(float lr, float epsilon, float weight_decay) + { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + } + +private: + float _alpha; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float* _doubled_buffer[2]; + bool _buf_index; + + cudaStream_t _streams[2]; +}; + +#if defined(__AVX512__) or defined(__AVX256__) +template +void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t new_rounded_size = 0; + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + AVX_Data weight_decay4; + if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay); + new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); + for (size_t t = 0; t < new_rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { + AVX_Data grad_4[span]; + simd_load(grad_4, grads + i, half_precision); + + AVX_Data momentum_4[span]; + simd_load(momentum_4, grads + i, false); + + AVX_Data variance_4[span]; + simd_load(variance_4, _exp_avg_sq + i, false); + + AVX_Data param_4[span]; + simd_load(param_4, _params + i, half_precision); + + if (_weight_decay > 0) { simd_fma(grad_4, param_4, weight_decay4, grad_4); } + + simd_fma(variance_4, grad_4, grad_4, variance_4); + simd_sqrt(grad_4, variance_4); + simd_add(grad_4, grad_4, eps_4); + simd_div(grad_4, momentum_4, grad_4); + simd_fma(param_4, grad_4, step_size_4, param_4); + + simd_store(_params + i, param_4, half_precision); + if (dev_params) { + simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); + } + simd_store(_exp_avg_sq + i, variance_4, false); + } + + if (dev_params) { + if (half_precision) + launch_param_update_half( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + else + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + + _buf_index = !_buf_index; + } + } + *rounded_size = new_rounded_size; +} +#endif diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index 88779ef5f..9a4e80593 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -1,222 +1,222 @@ -#pragma once - -#include -#include -#include -#include -#include "cuda.h" -#include "custom_cuda_layers.h" -#include "simd.h" - -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg, \ - float* _exp_avg_sq, \ - size_t _param_size, \ - __half* dev_param = nullptr, \ - bool half_precision = false); - -class Adam_Optimizer { -public: - Adam_Optimizer(float alpha = 1e-3, - float betta1 = 0.9, - float betta2 = 0.999, - float eps = 1e-8, - float weight_decay = 0, - bool adamw_mode = true) - : _alpha(alpha), - _betta1(betta1), - _betta2(betta2), - _eps(eps), - _weight_decay(weight_decay), - _betta1_t(1.0), - _betta2_t(1.0), - _step(0), - _buf_index(false), - _adamw_mode(adamw_mode) - { - cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _streams[0] = Context::Instance().GetCurrentStream(); - _streams[1] = Context::Instance().GetNewStream(); - } - ~Adam_Optimizer() - { - cudaFreeHost(_doubled_buffer[0]); - cudaFreeHost(_doubled_buffer[1]); - } -#if defined(__AVX512__) or defined(__AVX256__) - template - void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t param_size, - __half* dev_param = nullptr, - bool half_precision = false); -#endif - STEP(1) - STEP(4) - STEP(8) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); - } - inline void IncrementStep(size_t step, float beta1, float beta2) - { - if (beta1 != _betta1 || beta2 != _betta2) { - _step = step; - _betta1 = beta1; - _betta2 = beta2; - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); - } else { - _step++; - if (_step != step) { - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); - _step = step; - } else { - _betta1_t *= _betta1; - _betta2_t *= _betta2; - } - } - } - inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) - { - _alpha = lr; - _eps = epsilon; - _weight_decay = weight_decay; - - _bias_correction1 = 1.0f; - _bias_correction2 = 1.0f; - if (bias_correction == 1) { - _bias_correction1 = 1 - _betta1_t; - _bias_correction2 = 1 / sqrt(1 - _betta2_t); - } - } - -private: - float _alpha; - float _betta1; - float _betta2; - float _eps; - float _weight_decay; - - float _betta1_t; - float _betta2_t; - size_t _step; - - float _bias_correction1; - float _bias_correction2; - - float* _doubled_buffer[2]; - bool _buf_index; - bool _adamw_mode; - - cudaStream_t _streams[2]; -}; - -#if defined(__AVX512__) or defined(__AVX256__) -template -void Adam_Optimizer::Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t new_rounded_size = 0; - - AVX_Data betta1_4; - betta1_4.data = SIMD_SET(_betta1); - AVX_Data betta2_4; - betta2_4.data = SIMD_SET(_betta2); - - float betta1_minus1 = 1 - _betta1; - float betta2_minus1 = 1 - _betta2; - AVX_Data betta1_minus1_4; - betta1_minus1_4.data = SIMD_SET(betta1_minus1); - AVX_Data betta2_minus1_4; - betta2_minus1_4.data = SIMD_SET(betta2_minus1); - - AVX_Data bias2_sqrt; - bias2_sqrt.data = SIMD_SET(_bias_correction2); - - AVX_Data eps_4; - eps_4.data = SIMD_SET(_eps); - - float step_size = -1 * _alpha / _bias_correction1; - AVX_Data step_size_4; - step_size_4.data = SIMD_SET(step_size); - - float w_decay = -1 * _alpha * _weight_decay; - AVX_Data weight_decay4; - if (_weight_decay > 0) - weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); - for (size_t t = 0; t < new_rounded_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; - size_t offset = copy_size + t; - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#pragma omp parallel for - for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { - AVX_Data grad_4[span]; - simd_load(grad_4, grads + i, half_precision); - - AVX_Data momentum_4[span]; - simd_load(momentum_4, _exp_avg + i, false); - - AVX_Data variance_4[span]; - simd_load(variance_4, _exp_avg_sq + i, false); - - AVX_Data param_4[span]; - simd_load(param_4, _params + i, half_precision); - - if (_weight_decay > 0 && !_adamw_mode) { - simd_fma(grad_4, param_4, weight_decay4, grad_4); - } - - simd_mul(momentum_4, momentum_4, betta1_4); - simd_fma(momentum_4, grad_4, betta1_minus1_4, momentum_4); - simd_mul(variance_4, variance_4, betta2_4); - simd_mul(grad_4, grad_4, grad_4); - simd_fma(variance_4, grad_4, betta2_minus1_4, variance_4); - simd_sqrt(grad_4, variance_4); - simd_fma(grad_4, grad_4, bias2_sqrt, eps_4); - simd_div(grad_4, momentum_4, grad_4); - - if (_weight_decay > 0 && _adamw_mode) { - simd_fma(param_4, param_4, weight_decay4, param_4); - } - - simd_fma(param_4, grad_4, step_size_4, param_4); - - simd_store(_params + i, param_4, half_precision); - if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); - } - simd_store(_exp_avg + i, momentum_4, false); - simd_store(_exp_avg_sq + i, variance_4, false); - } - - if (dev_params) { - if (half_precision) - launch_param_update_half( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - else - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - - _buf_index = !_buf_index; - } - } - *rounded_size = new_rounded_size; -} -#endif +#pragma once + +#include +#include +#include +#include +#include "cuda.h" +#include "custom_cuda_layers.h" +#include "simd.h" + +#define STEP(SPAN) \ + void Step_##SPAN(float* _params, \ + float* grads, \ + float* _exp_avg, \ + float* _exp_avg_sq, \ + size_t _param_size, \ + __half* dev_param = nullptr, \ + bool half_precision = false); + +class Adam_Optimizer { +public: + Adam_Optimizer(float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _buf_index(false), + _adamw_mode(adamw_mode) + { + cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); + cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); + + _streams[0] = Context::Instance().GetCurrentStream(); + _streams[1] = Context::Instance().GetNewStream(); + } + ~Adam_Optimizer() + { + cudaFreeHost(_doubled_buffer[0]); + cudaFreeHost(_doubled_buffer[1]); + } +#if defined(__AVX512__) or defined(__AVX256__) + template + void Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t param_size, + __half* dev_param = nullptr, + bool half_precision = false); +#endif + STEP(1) + STEP(4) + STEP(8) + inline void SynchronizeStreams() + { + for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); + } + inline void IncrementStep(size_t step, float beta1, float beta2) + { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) + { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + +private: + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + float* _doubled_buffer[2]; + bool _buf_index; + bool _adamw_mode; + + cudaStream_t _streams[2]; +}; + +#if defined(__AVX512__) or defined(__AVX256__) +template +void Adam_Optimizer::Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t new_rounded_size = 0; + + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay4; + if (_weight_decay > 0) + weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); + for (size_t t = 0; t < new_rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { + AVX_Data grad_4[span]; + simd_load(grad_4, grads + i, half_precision); + + AVX_Data momentum_4[span]; + simd_load(momentum_4, _exp_avg + i, false); + + AVX_Data variance_4[span]; + simd_load(variance_4, _exp_avg_sq + i, false); + + AVX_Data param_4[span]; + simd_load(param_4, _params + i, half_precision); + + if (_weight_decay > 0 && !_adamw_mode) { + simd_fma(grad_4, param_4, weight_decay4, grad_4); + } + + simd_mul(momentum_4, momentum_4, betta1_4); + simd_fma(momentum_4, grad_4, betta1_minus1_4, momentum_4); + simd_mul(variance_4, variance_4, betta2_4); + simd_mul(grad_4, grad_4, grad_4); + simd_fma(variance_4, grad_4, betta2_minus1_4, variance_4); + simd_sqrt(grad_4, variance_4); + simd_fma(grad_4, grad_4, bias2_sqrt, eps_4); + simd_div(grad_4, momentum_4, grad_4); + + if (_weight_decay > 0 && _adamw_mode) { + simd_fma(param_4, param_4, weight_decay4, param_4); + } + + simd_fma(param_4, grad_4, step_size_4, param_4); + + simd_store(_params + i, param_4, half_precision); + if (dev_params) { + simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); + } + simd_store(_exp_avg + i, momentum_4, false); + simd_store(_exp_avg_sq + i, variance_4, false); + } + + if (dev_params) { + if (half_precision) + launch_param_update_half( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + else + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + + _buf_index = !_buf_index; + } + } + *rounded_size = new_rounded_size; +} +#endif diff --git a/csrc/includes/dropout.h b/csrc/includes/dropout.h index f6e32af56..a72572d08 100644 --- a/csrc/includes/dropout.h +++ b/csrc/includes/dropout.h @@ -1,76 +1,76 @@ -#pragma once - -#include -#include -#include - -template -class Dropout { -public: - struct Config { - float ratio; - uint32_t dim; - bool training; - - Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {} - - float RATIO() const { return training ? ratio : 0.0; } - inline void SetDim(uint32_t d) { dim = d; } - }; - - Dropout(const Config& config) : _config(config), _mask(nullptr) {} - - virtual ~Dropout() {} - - void Forward(int bsz, T* out, const T* vals, cudaStream_t stream, bool bwd = false) - { - launch_dropout( - out, vals, _mask, bsz * _config.dim, _config.dim, _config.RATIO(), stream, bwd); - } - - void ForwardWithBias(int bsz, T* vals, const T* bias, cudaStream_t stream) - { - launch_dropout(vals, bias, _mask, bsz, _config.dim, _config.RATIO(), stream); - } - - void ForwardWithBias(int bsz, - T* out, - const T* vals, - const T* residual, - const T* bias, - cudaStream_t stream) - { - launch_dropout( - out, vals, residual, bias, _mask, bsz, _config.dim, _config.RATIO(), stream); - } - - void Backward(int bsz, T* d_vals, cudaStream_t stream) - { - launch_dropout_grad(d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream); - } - - void Backward(int bsz, T* d_vals_out, const T* d_vals, cudaStream_t stream) - { - launch_dropout_grad( - d_vals_out, d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream); - } - - bool HasDropout() const { return _config.RATIO() > 0.0; } - - void SetTrainingMode(bool training) { _config.training = training; } - - void SetMask(uint8_t* mask) - { - if (!mask) { throw std::runtime_error("Dropout mask is null."); } - - _mask = mask; - } - - Config GetConfig() const { return _config; } - - inline void SetDimension(uint32_t dim) { _config.SetDim(dim); } - -private: - uint8_t* _mask; - Config _config; -}; +#pragma once + +#include +#include +#include + +template +class Dropout { +public: + struct Config { + float ratio; + uint32_t dim; + bool training; + + Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {} + + float RATIO() const { return training ? ratio : 0.0; } + inline void SetDim(uint32_t d) { dim = d; } + }; + + Dropout(const Config& config) : _config(config), _mask(nullptr) {} + + virtual ~Dropout() {} + + void Forward(int bsz, T* out, const T* vals, cudaStream_t stream, bool bwd = false) + { + launch_dropout( + out, vals, _mask, bsz * _config.dim, _config.dim, _config.RATIO(), stream, bwd); + } + + void ForwardWithBias(int bsz, T* vals, const T* bias, cudaStream_t stream) + { + launch_dropout(vals, bias, _mask, bsz, _config.dim, _config.RATIO(), stream); + } + + void ForwardWithBias(int bsz, + T* out, + const T* vals, + const T* residual, + const T* bias, + cudaStream_t stream) + { + launch_dropout( + out, vals, residual, bias, _mask, bsz, _config.dim, _config.RATIO(), stream); + } + + void Backward(int bsz, T* d_vals, cudaStream_t stream) + { + launch_dropout_grad(d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream); + } + + void Backward(int bsz, T* d_vals_out, const T* d_vals, cudaStream_t stream) + { + launch_dropout_grad( + d_vals_out, d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream); + } + + bool HasDropout() const { return _config.RATIO() > 0.0; } + + void SetTrainingMode(bool training) { _config.training = training; } + + void SetMask(uint8_t* mask) + { + if (!mask) { throw std::runtime_error("Dropout mask is null."); } + + _mask = mask; + } + + Config GetConfig() const { return _config; } + + inline void SetDimension(uint32_t dim) { _config.SetDim(dim); } + +private: + uint8_t* _mask; + Config _config; +}; diff --git a/csrc/includes/feed_forward.h b/csrc/includes/feed_forward.h index 7b7379d9b..fc4d5f90a 100644 --- a/csrc/includes/feed_forward.h +++ b/csrc/includes/feed_forward.h @@ -1,93 +1,93 @@ -#ifndef __FEEDFORWARD_H__ -#define __FEEDFORWARD_H__ - -#include -#include -#include -#include "custom_cuda_layers.h" - -template -class FeedForward { -public: - struct Config { - int batchSize, outputSize; - int inputSize; - std::array gemm_algos; - Config(int batch, int outputs, int inputs, const std::array& algos) - : batchSize(batch), outputSize(outputs), inputSize(inputs), gemm_algos(algos) - { - } - }; - - FeedForward(Config config) : config_(config) {} - - ~FeedForward() {} - - void Forward(int bsz, - const T* input_ptr, - const T* weights, - T* out, - cublasHandle_t& _cublasHandle) - { - float alpha = T(1.); - float beta = T(0.); - - cublas_gemm_ex(_cublasHandle, - CUBLAS_OP_T, - CUBLAS_OP_N, - config_.outputSize, - bsz, - config_.inputSize, - &alpha, - &beta, - weights, - input_ptr, - out, - cublasGemmAlgo_t(config_.gemm_algos[0])); - } - void Backward(int bsz, - const T* out_grad, - const T* input_ptr, - const T* weights, - T* weights_grad, - T* bias_grad, - cublasHandle_t& _cublasHandle, - cudaStream_t& stream, - T* inp_grad_out = nullptr, - T* out_grad_trans_out = nullptr) - { - float alpha = (T)1.0, beta = (T)0.0; - cublas_gemm_ex(_cublasHandle, - CUBLAS_OP_N, - CUBLAS_OP_T, - config_.inputSize, - config_.outputSize, - bsz, - &alpha, - &beta, - input_ptr, - out_grad, - weights_grad, - cublasGemmAlgo_t(config_.gemm_algos[1])); - - cublas_gemm_ex(_cublasHandle, - CUBLAS_OP_N, - CUBLAS_OP_N, - config_.inputSize, - bsz, - config_.outputSize, - &alpha, - &beta, - weights, - out_grad, - inp_grad_out, - cublasGemmAlgo_t(config_.gemm_algos[2])); - - launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, config_.outputSize, stream); - } - -private: - Config config_; -}; - -#endif +#ifndef __FEEDFORWARD_H__ +#define __FEEDFORWARD_H__ + +#include +#include +#include +#include "custom_cuda_layers.h" + +template +class FeedForward { +public: + struct Config { + int batchSize, outputSize; + int inputSize; + std::array gemm_algos; + Config(int batch, int outputs, int inputs, const std::array& algos) + : batchSize(batch), outputSize(outputs), inputSize(inputs), gemm_algos(algos) + { + } + }; + + FeedForward(Config config) : config_(config) {} + + ~FeedForward() {} + + void Forward(int bsz, + const T* input_ptr, + const T* weights, + T* out, + cublasHandle_t& _cublasHandle) + { + float alpha = T(1.); + float beta = T(0.); + + cublas_gemm_ex(_cublasHandle, + CUBLAS_OP_T, + CUBLAS_OP_N, + config_.outputSize, + bsz, + config_.inputSize, + &alpha, + &beta, + weights, + input_ptr, + out, + cublasGemmAlgo_t(config_.gemm_algos[0])); + } + void Backward(int bsz, + const T* out_grad, + const T* input_ptr, + const T* weights, + T* weights_grad, + T* bias_grad, + cublasHandle_t& _cublasHandle, + cudaStream_t& stream, + T* inp_grad_out = nullptr, + T* out_grad_trans_out = nullptr) + { + float alpha = (T)1.0, beta = (T)0.0; + cublas_gemm_ex(_cublasHandle, + CUBLAS_OP_N, + CUBLAS_OP_T, + config_.inputSize, + config_.outputSize, + bsz, + &alpha, + &beta, + input_ptr, + out_grad, + weights_grad, + cublasGemmAlgo_t(config_.gemm_algos[1])); + + cublas_gemm_ex(_cublasHandle, + CUBLAS_OP_N, + CUBLAS_OP_N, + config_.inputSize, + bsz, + config_.outputSize, + &alpha, + &beta, + weights, + out_grad, + inp_grad_out, + cublasGemmAlgo_t(config_.gemm_algos[2])); + + launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, config_.outputSize, stream); + } + +private: + Config config_; +}; + +#endif diff --git a/csrc/includes/gelu.h b/csrc/includes/gelu.h index 41cf6f2a6..560f4140e 100644 --- a/csrc/includes/gelu.h +++ b/csrc/includes/gelu.h @@ -1,36 +1,36 @@ -#pragma once - -#include -#include -#include -#include "custom_cuda_layers.h" - -template -class Gelu { -public: - struct Config { - uint32_t intermediate_size; - Config(uint32_t inter_size) : intermediate_size(inter_size) {} - }; - - Gelu(const Config& config) : _config(config) {} - - virtual ~Gelu() {} - - void ForwardWithBiasAdd(int bsz, - const T* input_buf, - const T* bias, - T* output, - cudaStream_t stream) - { - launch_bias_gelu(input_buf, bias, output, _config.intermediate_size, bsz, stream); - } - - void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream) - { - launch_d_gelu(d_output, input_buf, bias, _config.intermediate_size, bsz, stream); - } - -private: - Config _config; -}; +#pragma once + +#include +#include +#include +#include "custom_cuda_layers.h" + +template +class Gelu { +public: + struct Config { + uint32_t intermediate_size; + Config(uint32_t inter_size) : intermediate_size(inter_size) {} + }; + + Gelu(const Config& config) : _config(config) {} + + virtual ~Gelu() {} + + void ForwardWithBiasAdd(int bsz, + const T* input_buf, + const T* bias, + T* output, + cudaStream_t stream) + { + launch_bias_gelu(input_buf, bias, output, _config.intermediate_size, bsz, stream); + } + + void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream) + { + launch_d_gelu(d_output, input_buf, bias, _config.intermediate_size, bsz, stream); + } + +private: + Config _config; +}; diff --git a/csrc/includes/gemm_test.h b/csrc/includes/gemm_test.h index b920896b4..3bfeee35d 100644 --- a/csrc/includes/gemm_test.h +++ b/csrc/includes/gemm_test.h @@ -1,293 +1,293 @@ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include "StopWatch.h" -#include "cublas_wrappers.h" - -template -void check(T result, char const* const func, const char* const file, int const line) -{ - if (result) { - std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) + - " \n"); - } -} - -#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) - -template -class GemmTest { -public: - GemmTest(int m, int n, int k, cublasOperation_t ta, cublasOperation_t tb, cublasHandle_t h) - : M(m), N(n), K(k), transa(ta), transb(tb), handle(h) - { - check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K)); - check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N)); - check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N)); - } - - ~GemmTest() - { - check_cuda_error(cudaFree(A)); - check_cuda_error(cudaFree(B)); - check_cuda_error(cudaFree(C)); - } - - std::array TestAlgo(int loops) - { - float alpha = (T)1.0f; - float beta = (T)0.0f; - - int algo_fw = Run(loops, [=](int algo) { - cublas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - N, - M, - K, - &alpha, - &beta, - B, - A, - C, - static_cast(algo)); - }); - - int algo_bw1 = Run(loops, [=](int algo) { - cublas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - K, - N, - M, - &alpha, - &beta, - A, - C, - B, - static_cast(algo)); - }); - - int algo_bw2 = Run(loops, [=](int algo) { - cublas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - K, - M, - N, - &alpha, - &beta, - B, - C, - A, - static_cast(algo)); - }); - - return std::array({algo_fw, algo_bw1, algo_bw2}); - } - - template - int Run(int loops, Func f) - { - float fast_latency = (std::numeric_limits::max)(); - int fast_algo = 0; - - for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; - algo++) { - int warm_up = 5; - for (int i = 0; i < warm_up; ++i) f(algo); - - cudaDeviceSynchronize(); - Stopwatch timer; - timer.Restart(); - - for (int i = 0; i < loops; ++i) f(algo); - - cudaDeviceSynchronize(); - timer.Stop(); - - float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; - - printf("algo-%d: %.3fms\n", algo, avg_latency); - - if (avg_latency < fast_latency) { - fast_latency = avg_latency; - fast_algo = algo; - } - } - - printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); - - return fast_algo; - } - -private: - int M, N, K; - cublasHandle_t handle; - cublasOperation_t transa, transb; - T *A, *B, *C; -}; - -template -class StridedGemmTest { -public: - StridedGemmTest(int b, - int m, - int n, - int k, - cublasOperation_t ta, - cublasOperation_t tb, - cublasHandle_t h) - : bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h) - { - check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K * bsz)); - check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N * bsz)); - check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N * bsz)); - } - - ~StridedGemmTest() - { - check_cuda_error(cudaFree(A)); - check_cuda_error(cudaFree(B)); - check_cuda_error(cudaFree(C)); - } - - std::array TestAlgo(int loops) - { - float alpha = (T)1.0f; - float beta = (T)0.0f; - - int algo_fw = Run(loops, [=](int algo) { - int stride_a = M * K; - int stride_b = N * K; - int stride_c = M * N; - - cublas_strided_batched_gemm(handle, - M, - N, - K, - &alpha, - &beta, - A, - B, - C, - transa, - transb, - stride_a, - stride_b, - stride_c, - bsz, - static_cast(algo)); - }); - - int algo_bw1 = Run(loops, [=](int algo) { - int mb = (transa == CUBLAS_OP_T ? K : M); - int kb = (transa == CUBLAS_OP_T ? M : K); - - int stride_a = mb * N; - int stride_b = N * kb; - int stride_c = M * K; - - // B need to transpose. - cublasOperation_t op_b = (transb == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - // Calculate d_A. - cublas_strided_batched_gemm(handle, - mb, - kb, - N, - &alpha, - &beta, - (transa == CUBLAS_OP_T ? B : C), - (transa == CUBLAS_OP_T ? C : B), - A, - CUBLAS_OP_N, - op_b, - stride_a, - stride_b, - stride_c, - bsz, - static_cast(algo)); - }); - - int algo_bw2 = Run(loops, [=](int algo) { - // A need to transpose. - cublasOperation_t op_a = (transa == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - int stride_a = M * K; - int stride_b = M * N; - int stride_c = N * K; - - // Calculate d_B. - cublas_strided_batched_gemm(handle, - K, - N, - M, - &alpha, - &beta, - A, - C, - B, - op_a, - CUBLAS_OP_N, - stride_a, - stride_b, - stride_c, - bsz, - static_cast(algo)); - }); - - return std::array({algo_fw, algo_bw1, algo_bw2}); - } - - template - int Run(int loops, Func f) - { - float fast_latency = (std::numeric_limits::max)(); - int fast_algo = 0; - - for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; - algo++) { - int warm_up = 5; - for (int i = 0; i < warm_up; ++i) f(algo); - - cudaDeviceSynchronize(); - Stopwatch timer; - timer.Restart(); - - for (int i = 0; i < loops; ++i) f(algo); - - cudaDeviceSynchronize(); - timer.Stop(); - - float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; - - printf("algo-%d: %.3fms\n", algo, avg_latency); - - if (avg_latency < fast_latency) { - fast_latency = avg_latency; - fast_algo = algo; - } - } - - printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); - - return fast_algo; - } - -private: - int bsz, M, N, K; - cublasHandle_t handle; - cublasOperation_t transa, transb; - T *A, *B, *C; -}; + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "StopWatch.h" +#include "cublas_wrappers.h" + +template +void check(T result, char const* const func, const char* const file, int const line) +{ + if (result) { + std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) + + " \n"); + } +} + +#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) + +template +class GemmTest { +public: + GemmTest(int m, int n, int k, cublasOperation_t ta, cublasOperation_t tb, cublasHandle_t h) + : M(m), N(n), K(k), transa(ta), transb(tb), handle(h) + { + check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K)); + check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N)); + check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N)); + } + + ~GemmTest() + { + check_cuda_error(cudaFree(A)); + check_cuda_error(cudaFree(B)); + check_cuda_error(cudaFree(C)); + } + + std::array TestAlgo(int loops) + { + float alpha = (T)1.0f; + float beta = (T)0.0f; + + int algo_fw = Run(loops, [=](int algo) { + cublas_gemm_ex(handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + N, + M, + K, + &alpha, + &beta, + B, + A, + C, + static_cast(algo)); + }); + + int algo_bw1 = Run(loops, [=](int algo) { + cublas_gemm_ex(handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + K, + N, + M, + &alpha, + &beta, + A, + C, + B, + static_cast(algo)); + }); + + int algo_bw2 = Run(loops, [=](int algo) { + cublas_gemm_ex(handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + K, + M, + N, + &alpha, + &beta, + B, + C, + A, + static_cast(algo)); + }); + + return std::array({algo_fw, algo_bw1, algo_bw2}); + } + + template + int Run(int loops, Func f) + { + float fast_latency = (std::numeric_limits::max)(); + int fast_algo = 0; + + for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + algo++) { + int warm_up = 5; + for (int i = 0; i < warm_up; ++i) f(algo); + + cudaDeviceSynchronize(); + Stopwatch timer; + timer.Restart(); + + for (int i = 0; i < loops; ++i) f(algo); + + cudaDeviceSynchronize(); + timer.Stop(); + + float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; + + printf("algo-%d: %.3fms\n", algo, avg_latency); + + if (avg_latency < fast_latency) { + fast_latency = avg_latency; + fast_algo = algo; + } + } + + printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); + + return fast_algo; + } + +private: + int M, N, K; + cublasHandle_t handle; + cublasOperation_t transa, transb; + T *A, *B, *C; +}; + +template +class StridedGemmTest { +public: + StridedGemmTest(int b, + int m, + int n, + int k, + cublasOperation_t ta, + cublasOperation_t tb, + cublasHandle_t h) + : bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h) + { + check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K * bsz)); + check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N * bsz)); + check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N * bsz)); + } + + ~StridedGemmTest() + { + check_cuda_error(cudaFree(A)); + check_cuda_error(cudaFree(B)); + check_cuda_error(cudaFree(C)); + } + + std::array TestAlgo(int loops) + { + float alpha = (T)1.0f; + float beta = (T)0.0f; + + int algo_fw = Run(loops, [=](int algo) { + int stride_a = M * K; + int stride_b = N * K; + int stride_c = M * N; + + cublas_strided_batched_gemm(handle, + M, + N, + K, + &alpha, + &beta, + A, + B, + C, + transa, + transb, + stride_a, + stride_b, + stride_c, + bsz, + static_cast(algo)); + }); + + int algo_bw1 = Run(loops, [=](int algo) { + int mb = (transa == CUBLAS_OP_T ? K : M); + int kb = (transa == CUBLAS_OP_T ? M : K); + + int stride_a = mb * N; + int stride_b = N * kb; + int stride_c = M * K; + + // B need to transpose. + cublasOperation_t op_b = (transb == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + // Calculate d_A. + cublas_strided_batched_gemm(handle, + mb, + kb, + N, + &alpha, + &beta, + (transa == CUBLAS_OP_T ? B : C), + (transa == CUBLAS_OP_T ? C : B), + A, + CUBLAS_OP_N, + op_b, + stride_a, + stride_b, + stride_c, + bsz, + static_cast(algo)); + }); + + int algo_bw2 = Run(loops, [=](int algo) { + // A need to transpose. + cublasOperation_t op_a = (transa == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + int stride_a = M * K; + int stride_b = M * N; + int stride_c = N * K; + + // Calculate d_B. + cublas_strided_batched_gemm(handle, + K, + N, + M, + &alpha, + &beta, + A, + C, + B, + op_a, + CUBLAS_OP_N, + stride_a, + stride_b, + stride_c, + bsz, + static_cast(algo)); + }); + + return std::array({algo_fw, algo_bw1, algo_bw2}); + } + + template + int Run(int loops, Func f) + { + float fast_latency = (std::numeric_limits::max)(); + int fast_algo = 0; + + for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + algo++) { + int warm_up = 5; + for (int i = 0; i < warm_up; ++i) f(algo); + + cudaDeviceSynchronize(); + Stopwatch timer; + timer.Restart(); + + for (int i = 0; i < loops; ++i) f(algo); + + cudaDeviceSynchronize(); + timer.Stop(); + + float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; + + printf("algo-%d: %.3fms\n", algo, avg_latency); + + if (avg_latency < fast_latency) { + fast_latency = avg_latency; + fast_algo = algo; + } + } + + printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); + + return fast_algo; + } + +private: + int bsz, M, N, K; + cublasHandle_t handle; + cublasOperation_t transa, transb; + T *A, *B, *C; +}; diff --git a/csrc/includes/general_kernels.h b/csrc/includes/general_kernels.h index 588cf2aaa..90e15b770 100644 --- a/csrc/includes/general_kernels.h +++ b/csrc/includes/general_kernels.h @@ -1,47 +1,47 @@ -#include -#include -#include -#include - -#include -#include - -#include "context.h" -#include "cublas_wrappers.h" - -#define THREADS 256 -#define TILE_DIM 32 - -#define minus_infinity -1 * std::numeric_limits::infinity() - -#define FINAL_MASK 0xffffffff - -template -void launch_fused_add2(T* out, - const T* inp1, - const T* inp2, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream); - -template -void launch_fused_add4(T* out, - const T* inp1, - const T* inp2, - const T* inp3, - const T* inp4, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream); - -template -void launch_fused_add3(T* out, - const T* inp1, - const T* inp2, - const T* inp3, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream); +#include +#include +#include +#include + +#include +#include + +#include "context.h" +#include "cublas_wrappers.h" + +#define THREADS 256 +#define TILE_DIM 32 + +#define minus_infinity -1 * std::numeric_limits::infinity() + +#define FINAL_MASK 0xffffffff + +template +void launch_fused_add2(T* out, + const T* inp1, + const T* inp2, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream); + +template +void launch_fused_add4(T* out, + const T* inp1, + const T* inp2, + const T* inp3, + const T* inp4, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream); + +template +void launch_fused_add3(T* out, + const T* inp1, + const T* inp2, + const T* inp3, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream); diff --git a/csrc/includes/normalize_layer.h b/csrc/includes/normalize_layer.h index e18e01a33..b4d135ec4 100644 --- a/csrc/includes/normalize_layer.h +++ b/csrc/includes/normalize_layer.h @@ -1,202 +1,202 @@ -#pragma once - -#include -#include -#include -#include -#include "custom_cuda_layers.h" - -using namespace std; - -template -class Normalize_Layer { -public: - struct Config { - uint32_t batchSize; - uint32_t seqLength; - uint32_t hiddenDim; - float epsilon; - bool training; - bool useMean; - Config(uint32_t batch, - uint32_t seq, - uint32_t h, - float epsilon = 1e-12, - bool training = true, - bool useMean = true) - : batchSize(batch), - seqLength(seq), - hiddenDim(h), - epsilon(epsilon), - training(training), - useMean(useMean) - { - } - }; - - Normalize_Layer(Config config) - : config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr) - { - } - - ~Normalize_Layer() {} - - void ForwardCheckpoint(int bsz, // batch * seq - T* vals, - const T* residual, - const T* gamma, - const T* betta, - cudaStream_t& stream, - bool preLayerNorm = false) - { - launch_bias_residual_layer_norm(vals, - residual, - gamma, - betta, - config_.epsilon, - bsz, - config_.hiddenDim, - stream, - preLayerNorm, - config_.training, - vars, - means); - } - - void Forward(int bsz, - T* vals, - const T* residual, - const T* gamma, - const T* betta, - cudaStream_t& stream, - bool preLayerNorm = false) - { - launch_bias_residual_layer_norm(vals, - residual, - gamma, - betta, - config_.epsilon, - bsz, - config_.hiddenDim, - stream, - preLayerNorm, - config_.training, - vars); - } - - void Backward(int bsz, - const T* out_grad, - const T* gamma, - T* gamma_grad, - T* betta_grad, - cudaStream_t stream[2], - T* inp_grad_out, - const T* norm_in = nullptr) - { - launch_layerNorm_backward(out_grad, - norm_in, - vars, - means, - gamma, - gamma_grad, - betta_grad, - inp_grad_out, - bsz, - config_.hiddenDim, - stream); - } - - void Backward(int bsz, - const T* out_grad, - const T* gamma, - const T* betta, - T* gamma_grad, - T* betta_grad, - cudaStream_t stream[2], - T* inp_grad_out, - const T* norm_out) - { - launch_layerNorm_backward(out_grad, - norm_out, - vars, - gamma, - gamma_grad, - betta_grad, - inp_grad_out, - bsz, - config_.hiddenDim, - stream, - !config_.useMean, - betta); - } - - void BackwardFusedAdd(int bsz, - const T* out_grad1, - const T* out_grad2, - const T* gamma, - T* gamma_grad, - T* betta_grad, - cudaStream_t stream[2], - T* inp_grad_out, - const T* norm_in = nullptr) - { - launch_layerNorm_backward_fused_add(out_grad1, - out_grad2, - norm_in, - vars, - means, - gamma, - gamma_grad, - betta_grad, - inp_grad_out, - bsz, - config_.hiddenDim, - stream); - } - - void BackwardFusedAdd(int bsz, - const T* out_grad1, - const T* out_grad2, - const T* gamma, - const T* betta, - T* gamma_grad, - T* betta_grad, - cudaStream_t stream[2], - T* inp_grad_out, - const T* norm_out) - { - launch_layerNorm_backward_fused_add(out_grad1, - out_grad2, - norm_out, - vars, - gamma, - gamma_grad, - betta_grad, - inp_grad_out, - bsz, - config_.hiddenDim, - stream, - !config_.useMean, - betta); - } - - inline bool UseMean() const { return config_.useMean; } - - inline void SetVar(T* variance) - { - if (!variance) { throw std::runtime_error("Normalize variance is null."); } - vars = variance; - } - - inline void SetMean(T* mean) - { - if (!mean) { throw std::runtime_error("Normalize mean is null."); } - means = mean; - } - -private: - Config config_; - T* vars; - T* means; - T* vals_hat; -}; +#pragma once + +#include +#include +#include +#include +#include "custom_cuda_layers.h" + +using namespace std; + +template +class Normalize_Layer { +public: + struct Config { + uint32_t batchSize; + uint32_t seqLength; + uint32_t hiddenDim; + float epsilon; + bool training; + bool useMean; + Config(uint32_t batch, + uint32_t seq, + uint32_t h, + float epsilon = 1e-12, + bool training = true, + bool useMean = true) + : batchSize(batch), + seqLength(seq), + hiddenDim(h), + epsilon(epsilon), + training(training), + useMean(useMean) + { + } + }; + + Normalize_Layer(Config config) + : config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr) + { + } + + ~Normalize_Layer() {} + + void ForwardCheckpoint(int bsz, // batch * seq + T* vals, + const T* residual, + const T* gamma, + const T* betta, + cudaStream_t& stream, + bool preLayerNorm = false) + { + launch_bias_residual_layer_norm(vals, + residual, + gamma, + betta, + config_.epsilon, + bsz, + config_.hiddenDim, + stream, + preLayerNorm, + config_.training, + vars, + means); + } + + void Forward(int bsz, + T* vals, + const T* residual, + const T* gamma, + const T* betta, + cudaStream_t& stream, + bool preLayerNorm = false) + { + launch_bias_residual_layer_norm(vals, + residual, + gamma, + betta, + config_.epsilon, + bsz, + config_.hiddenDim, + stream, + preLayerNorm, + config_.training, + vars); + } + + void Backward(int bsz, + const T* out_grad, + const T* gamma, + T* gamma_grad, + T* betta_grad, + cudaStream_t stream[2], + T* inp_grad_out, + const T* norm_in = nullptr) + { + launch_layerNorm_backward(out_grad, + norm_in, + vars, + means, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream); + } + + void Backward(int bsz, + const T* out_grad, + const T* gamma, + const T* betta, + T* gamma_grad, + T* betta_grad, + cudaStream_t stream[2], + T* inp_grad_out, + const T* norm_out) + { + launch_layerNorm_backward(out_grad, + norm_out, + vars, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream, + !config_.useMean, + betta); + } + + void BackwardFusedAdd(int bsz, + const T* out_grad1, + const T* out_grad2, + const T* gamma, + T* gamma_grad, + T* betta_grad, + cudaStream_t stream[2], + T* inp_grad_out, + const T* norm_in = nullptr) + { + launch_layerNorm_backward_fused_add(out_grad1, + out_grad2, + norm_in, + vars, + means, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream); + } + + void BackwardFusedAdd(int bsz, + const T* out_grad1, + const T* out_grad2, + const T* gamma, + const T* betta, + T* gamma_grad, + T* betta_grad, + cudaStream_t stream[2], + T* inp_grad_out, + const T* norm_out) + { + launch_layerNorm_backward_fused_add(out_grad1, + out_grad2, + norm_out, + vars, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream, + !config_.useMean, + betta); + } + + inline bool UseMean() const { return config_.useMean; } + + inline void SetVar(T* variance) + { + if (!variance) { throw std::runtime_error("Normalize variance is null."); } + vars = variance; + } + + inline void SetMean(T* mean) + { + if (!mean) { throw std::runtime_error("Normalize mean is null."); } + means = mean; + } + +private: + Config config_; + T* vars; + T* means; + T* vals_hat; +}; diff --git a/csrc/includes/simd.h b/csrc/includes/simd.h index 8c6533f3d..216d281a2 100644 --- a/csrc/includes/simd.h +++ b/csrc/includes/simd.h @@ -1,137 +1,137 @@ -#pragma once - -#if (__x86_64__ || __i386__) -#include -#include -#endif - -#define TILE (128 * 1024 * 1024) -#if defined(__AVX512__) or defined(__AVX256__) - -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) - -#if defined(__AVX512__) -#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) -#define SIMD_LOAD(x) _mm512_loadu_ps(x) -#define SIMD_SET(x) _mm512_set1_ps(x) -#define SIMD_ADD(x, y) _mm512_add_ps(x, y) -#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) -#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) -#define SIMD_SQRT(x) _mm512_sqrt_ps(x) -#define SIMD_DIV(x, y) _mm512_div_ps(x, y) -#define SIMD_WIDTH 16 - -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)x)) : _mm512_loadu_ps(x)) -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm512_storeu_ps(x, d)) - -#define INTV __m256i -#elif defined(__AVX256__) -#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) -#define SIMD_LOAD(x) _mm256_loadu_ps(x) -#define SIMD_SET(x) _mm256_set1_ps(x) -#define SIMD_ADD(x, y) _mm256_add_ps(x, y) -#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) -#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) -#define SIMD_SQRT(x) _mm256_sqrt_ps(x) -#define SIMD_DIV(x, y) _mm256_div_ps(x, y) -#define SIMD_WIDTH 8 -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x)) - -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm256_storeu_ps(x, d)) - -#define INTV __m128i -#endif - -union AVX_Data { -#if defined(__AVX512__) - __m512 data; -#elif defined(__AVX256__) - __m256 data; -#endif - // float data_f[16]; -}; - -template -inline void simd_store(float* dst, AVX_Data* src, bool half_precision) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { - SIMD_STORE2(dst + SIMD_WIDTH * i, src[i].data, half_precision); - } -} -template -inline void simd_load(AVX_Data* dst, float* src, bool half_precision) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { - dst[i].data = SIMD_LOAD2(src + SIMD_WIDTH * i, half_precision); - } -} -template -inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { - dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a[i].data); - } -} -template -inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data src_a) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { - dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a.data); - } -} -template -inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data* src_m_r, AVX_Data* src_a) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { - dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r[i].data, src_a[i].data); - } -} -template -inline void simd_sqrt(AVX_Data* dst, AVX_Data* src) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_SQRT(src[i].data); } -} -template -inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r.data); } -} -template -inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r[i].data); } -} -template -inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r.data); } -} -template -inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r[i].data); } -} -template -inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); } -} - -#endif +#pragma once + +#if (__x86_64__ || __i386__) +#include +#include +#endif + +#define TILE (128 * 1024 * 1024) +#if defined(__AVX512__) or defined(__AVX256__) + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) + +#if defined(__AVX512__) +#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm512_loadu_ps(x) +#define SIMD_SET(x) _mm512_set1_ps(x) +#define SIMD_ADD(x, y) _mm512_add_ps(x, y) +#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm512_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_WIDTH 16 + +#define SIMD_LOAD2(x, h) \ + ((h) ? _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)x)) : _mm512_loadu_ps(x)) +#define SIMD_STORE2(x, d, h) \ + ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ + : _mm512_storeu_ps(x, d)) + +#define INTV __m256i +#elif defined(__AVX256__) +#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm256_loadu_ps(x) +#define SIMD_SET(x) _mm256_set1_ps(x) +#define SIMD_ADD(x, y) _mm256_add_ps(x, y) +#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm256_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_WIDTH 8 +#define SIMD_LOAD2(x, h) \ + ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x)) + +#define SIMD_STORE2(x, d, h) \ + ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ + : _mm256_storeu_ps(x, d)) + +#define INTV __m128i +#endif + +union AVX_Data { +#if defined(__AVX512__) + __m512 data; +#elif defined(__AVX256__) + __m256 data; +#endif + // float data_f[16]; +}; + +template +inline void simd_store(float* dst, AVX_Data* src, bool half_precision) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { + SIMD_STORE2(dst + SIMD_WIDTH * i, src[i].data, half_precision); + } +} +template +inline void simd_load(AVX_Data* dst, float* src, bool half_precision) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_LOAD2(src + SIMD_WIDTH * i, half_precision); + } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a[i].data); + } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data src_a) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a.data); + } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data* src_m_r, AVX_Data* src_a) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r[i].data, src_a[i].data); + } +} +template +inline void simd_sqrt(AVX_Data* dst, AVX_Data* src) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_SQRT(src[i].data); } +} +template +inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); } +} + +#endif diff --git a/csrc/includes/softmax.h b/csrc/includes/softmax.h index ca3cf8a8e..8d541a6fe 100755 --- a/csrc/includes/softmax.h +++ b/csrc/includes/softmax.h @@ -1,60 +1,60 @@ -#pragma once - -#include -#include -#include -#include "custom_cuda_layers.h" - -#include - -using namespace std; - -template -class Softmax { -public: - struct Config { - size_t batchSize; - size_t heads; - size_t seq_length; - size_t prob_depth; - float temperature; - bool mem_alloc; - Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false) - : batchSize(batch), - heads(h), - seq_length(seq), - prob_depth(prob_size), - temperature(1.0), - mem_alloc(mem_alloc) - { - } - }; - - Softmax(Config config) : config_(config) {} - - ~Softmax() {} - - void Forward(int bsz, T* vals, const T* attn_mask, cudaStream_t& stream) - { - launch_attn_softmax(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream); - } - - void Backward(int bsz, T* out_grad, const T* soft_out, cudaStream_t stream) - { - launch_attn_softmax_backward_v2( - out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream); - } - - inline size_t GetProbDepth() const { return config_.prob_depth; } - - inline size_t GetBatchSize() const { return config_.batchSize; } - - inline size_t GetNumHeads() const { return config_.heads; } - - inline size_t GetSeqLength() const { return config_.seq_length; } - - inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; } - -private: - Config config_; -}; +#pragma once + +#include +#include +#include +#include "custom_cuda_layers.h" + +#include + +using namespace std; + +template +class Softmax { +public: + struct Config { + size_t batchSize; + size_t heads; + size_t seq_length; + size_t prob_depth; + float temperature; + bool mem_alloc; + Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false) + : batchSize(batch), + heads(h), + seq_length(seq), + prob_depth(prob_size), + temperature(1.0), + mem_alloc(mem_alloc) + { + } + }; + + Softmax(Config config) : config_(config) {} + + ~Softmax() {} + + void Forward(int bsz, T* vals, const T* attn_mask, cudaStream_t& stream) + { + launch_attn_softmax(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream); + } + + void Backward(int bsz, T* out_grad, const T* soft_out, cudaStream_t stream) + { + launch_attn_softmax_backward_v2( + out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream); + } + + inline size_t GetProbDepth() const { return config_.prob_depth; } + + inline size_t GetBatchSize() const { return config_.batchSize; } + + inline size_t GetNumHeads() const { return config_.heads; } + + inline size_t GetSeqLength() const { return config_.seq_length; } + + inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; } + +private: + Config config_; +}; diff --git a/csrc/includes/strided_batch_gemm.h b/csrc/includes/strided_batch_gemm.h index 44a1b313b..3a9ad65bc 100644 --- a/csrc/includes/strided_batch_gemm.h +++ b/csrc/includes/strided_batch_gemm.h @@ -1,179 +1,179 @@ -#pragma once - -#include -#include -#include -#include "context.h" - -template -class StridedBatchGemm { -public: - struct Config { - int batch_size; - int m; - int n; - int k; - float alpha; - float beta; - cublasOperation_t op_A; - cublasOperation_t op_B; - std::array gemm_algos; - - Config(int batch, - int mm, - int nn, - int kk, - float param_alpha, - float param_beta, - cublasOperation_t opA, - cublasOperation_t opB, - const std::array& algos) - : batch_size(batch), - m(mm), - n(nn), - k(kk), - alpha(param_alpha), - beta(param_beta), - op_A(opA), - op_B(opB), - gemm_algos(algos) - { - } - void SetConfig(int mm, int nn, int kk) - { - m = mm; - n = nn; - k = kk; - } - }; - - StridedBatchGemm(const Config& config) : _config(config) {} - - virtual ~StridedBatchGemm() {} - - void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle) - { - int stride_a = _config.m * _config.k; - int stride_b = _config.n * _config.k; - int stride_c = _config.m * _config.n; - - cublas_strided_batched_gemm(handle, - _config.m, - _config.n, - _config.k, - &_config.alpha, - &_config.beta, - _buffer_a, - _buffer_b, - output, - _config.op_A, - _config.op_B, - stride_a, - stride_b, - stride_c, - bsz, - cublasGemmAlgo_t(_config.gemm_algos[0])); - } - - void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle) - { - int stride_a = _config.m * _config.k; - int stride_b = _config.n * _config.k; - int stride_c = _config.m * _config.n; - - cublas_strided_batched_gemm(handle, - _config.m, - _config.n, - _config.k, - &_config.alpha, - &_config.beta, - _buffer_a, - _buffer_b, - output, - _config.op_A, - _config.op_B, - stride_a, - stride_b, - stride_c, - _config.batch_size, - cublasGemmAlgo_t(_config.gemm_algos[0])); - - k_buf = _buffer_a; - q_buf = _buffer_b; - } - - void Backward(int bsz, - const T* d_output, - const T* _buffer_a, - const T* _buffer_b, - cublasHandle_t handle, - T* inpGradA = nullptr, - T* inpGradB = nullptr) - { - int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); - int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); - - int stride_a = mb * _config.n; - int stride_b = _config.n * kb; - int stride_c = _config.m * _config.k; - - // B need to transpose. - cublasOperation_t op_b = (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - // Calculate d_A. - cublas_strided_batched_gemm(handle, - mb, - kb, - _config.n, - &_config.alpha, - &_config.beta, - (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), - (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), - inpGradA, - CUBLAS_OP_N, - op_b, - stride_a, - stride_b, - stride_c, - bsz, - cublasGemmAlgo_t(_config.gemm_algos[1])); - - // A need to transpose. - cublasOperation_t op_a = (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - stride_a = _config.m * _config.k; - stride_b = _config.m * _config.n; - stride_c = _config.n * _config.k; - - // Calculate d_B. - cublas_strided_batched_gemm(handle, - _config.k, - _config.n, - _config.m, - &_config.alpha, - &_config.beta, - _buffer_a, - d_output, - inpGradB, - op_a, - CUBLAS_OP_N, - stride_a, - stride_b, - stride_c, - bsz, - cublasGemmAlgo_t(_config.gemm_algos[2])); - } - - inline int GetN() const { return _config.k; } - - inline const T* GetBufferA() const { return k_buf; } - - inline const T* GetBufferB() const { return q_buf; } - - inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } - -private: - Config _config; - const T* q_buf; - const T* k_buf; -}; +#pragma once + +#include +#include +#include +#include "context.h" + +template +class StridedBatchGemm { +public: + struct Config { + int batch_size; + int m; + int n; + int k; + float alpha; + float beta; + cublasOperation_t op_A; + cublasOperation_t op_B; + std::array gemm_algos; + + Config(int batch, + int mm, + int nn, + int kk, + float param_alpha, + float param_beta, + cublasOperation_t opA, + cublasOperation_t opB, + const std::array& algos) + : batch_size(batch), + m(mm), + n(nn), + k(kk), + alpha(param_alpha), + beta(param_beta), + op_A(opA), + op_B(opB), + gemm_algos(algos) + { + } + void SetConfig(int mm, int nn, int kk) + { + m = mm; + n = nn; + k = kk; + } + }; + + StridedBatchGemm(const Config& config) : _config(config) {} + + virtual ~StridedBatchGemm() {} + + void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle) + { + int stride_a = _config.m * _config.k; + int stride_b = _config.n * _config.k; + int stride_c = _config.m * _config.n; + + cublas_strided_batched_gemm(handle, + _config.m, + _config.n, + _config.k, + &_config.alpha, + &_config.beta, + _buffer_a, + _buffer_b, + output, + _config.op_A, + _config.op_B, + stride_a, + stride_b, + stride_c, + bsz, + cublasGemmAlgo_t(_config.gemm_algos[0])); + } + + void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle) + { + int stride_a = _config.m * _config.k; + int stride_b = _config.n * _config.k; + int stride_c = _config.m * _config.n; + + cublas_strided_batched_gemm(handle, + _config.m, + _config.n, + _config.k, + &_config.alpha, + &_config.beta, + _buffer_a, + _buffer_b, + output, + _config.op_A, + _config.op_B, + stride_a, + stride_b, + stride_c, + _config.batch_size, + cublasGemmAlgo_t(_config.gemm_algos[0])); + + k_buf = _buffer_a; + q_buf = _buffer_b; + } + + void Backward(int bsz, + const T* d_output, + const T* _buffer_a, + const T* _buffer_b, + cublasHandle_t handle, + T* inpGradA = nullptr, + T* inpGradB = nullptr) + { + int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); + int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); + + int stride_a = mb * _config.n; + int stride_b = _config.n * kb; + int stride_c = _config.m * _config.k; + + // B need to transpose. + cublasOperation_t op_b = (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + // Calculate d_A. + cublas_strided_batched_gemm(handle, + mb, + kb, + _config.n, + &_config.alpha, + &_config.beta, + (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), + (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), + inpGradA, + CUBLAS_OP_N, + op_b, + stride_a, + stride_b, + stride_c, + bsz, + cublasGemmAlgo_t(_config.gemm_algos[1])); + + // A need to transpose. + cublasOperation_t op_a = (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + stride_a = _config.m * _config.k; + stride_b = _config.m * _config.n; + stride_c = _config.n * _config.k; + + // Calculate d_B. + cublas_strided_batched_gemm(handle, + _config.k, + _config.n, + _config.m, + &_config.alpha, + &_config.beta, + _buffer_a, + d_output, + inpGradB, + op_a, + CUBLAS_OP_N, + stride_a, + stride_b, + stride_c, + bsz, + cublasGemmAlgo_t(_config.gemm_algos[2])); + } + + inline int GetN() const { return _config.k; } + + inline const T* GetBufferA() const { return k_buf; } + + inline const T* GetBufferB() const { return q_buf; } + + inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } + +private: + Config _config; + const T* q_buf; + const T* k_buf; +}; diff --git a/csrc/transformer/dropout_kernels.cu b/csrc/transformer/dropout_kernels.cu index 98f2ac22f..d1ba135f4 100755 --- a/csrc/transformer/dropout_kernels.cu +++ b/csrc/transformer/dropout_kernels.cu @@ -1,868 +1,868 @@ -#include "custom_cuda_layers.h" - -const int unroll_factor = 4; - -__global__ void dropout_kernel(const int N, - const float ratio, - float* out, - const float* Xdata, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - int idx = blockIdx.x * blockDim.x + threadIdx.x; - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - float4 rand = curand_uniform4(&state); - uint8_t m[unroll_factor]; - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - int i = j * unroll_factor; - - mask[i] = (uint8_t)m[0]; - mask[i + 1] = (uint8_t)m[1]; - mask[i + 2] = (uint8_t)m[2]; - mask[i + 3] = (uint8_t)m[3]; - - out[i] = Xdata[i] * scale * m[0]; - out[i + 1] = Xdata[i + 1] * scale * m[1]; - out[i + 2] = Xdata[i + 2] * scale * m[2]; - out[i + 3] = Xdata[i + 3] * scale * m[3]; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - out[i] = Xdata[i] * scale * m; - mask[i] = m; - } - } -} - -__global__ void dropout_kernel(const int N, - const float ratio, - __half* out, - const __half* Xdata, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - - int idx = blockIdx.x * blockDim.x + threadIdx.x; - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - -#ifdef __STOCHASTIC_MODE__ - - const __half2 h_scale = __float2half2_rn(scale); - const float2* x_cast = reinterpret_cast(Xdata); - float2* out_cast = reinterpret_cast(out); - uint32_t* mask_cast = reinterpret_cast(mask); - - uint32_t m_32; - uint8_t* m = reinterpret_cast(&m_32); - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - __half2 mask_h[2]; - float2 mask_f[2]; - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - float2 x_f = x_cast[j]; - __half2* x_h = reinterpret_cast<__half2*>(&x_f); - - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - float* mask_f_data = &mask_f[0].x; -#pragma unroll - for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); - - mask_h[0] = __float22half2_rn(mask_f[0]); - mask_h[1] = __float22half2_rn(mask_f[1]); - - result_h[0] = x_h[0] * h_scale * mask_h[0]; - result_h[1] = x_h[1] * h_scale * mask_h[1]; - - out_cast[j] = result_f; - - mask_cast[j] = m_32; - } - -#else - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - int i = j * unroll_factor; - - const __half2* vals_half = reinterpret_cast(Xdata + i); - float2 vals_half_f[2]; - vals_half_f[0] = __half22float2(vals_half[0]); - vals_half_f[1] = __half22float2(vals_half[1]); - - uint8_t m[unroll_factor]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - out[i] = __float2half(vals_half_f[0].x * scale * m[0]); - out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]); - out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]); - out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]); - - mask[i] = m[0]; - mask[i + 1] = m[1]; - mask[i + 2] = m[2]; - mask[i + 3] = m[3]; - } - -#endif - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - out[i] = __float2half((float)Xdata[i] * scale * m); - mask[i] = m; - } - } -} - -__global__ void dropout_kernel_bwd(const int N, - const float ratio, - const float* Xdata, - float* out, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - int i = j * unroll_factor; - - out[i] = mask[i] ? Xdata[i] * scale : 0.0; - out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0; - out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0; - out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; } - } -} - -__global__ void dropout_kernel_bwd(const int N, - const float ratio, - const __half* Xdata, - __half* out, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - -#ifdef __STOCHASTIC_MODE__ - - const __half2 h_scale = __float2half2_rn(scale); - - const float2* x_cast = reinterpret_cast(Xdata); - float2* out_cast = reinterpret_cast(out); - uint32_t* mask_cast = reinterpret_cast(mask); - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - float2 x_f = x_cast[j]; - __half2* x_h = reinterpret_cast<__half2*>(&x_f); - - uint32_t m_32 = mask_cast[j]; - uint8_t* m = (uint8_t*)&m_32; - - __half2 mask_h[2]; - float2 mask_f[2]; - - float* mask_f_data = &mask_f[0].x; -#pragma unroll - for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); - -#pragma unroll - for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]); - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - - result_h[0] = x_h[0] * h_scale * mask_h[0]; - result_h[1] = x_h[1] * h_scale * mask_h[1]; - - out_cast[j] = result_f; - } - -#else - - const __half h_scale = __float2half(scale); - const __half h_zero = __float2half(0.0); - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - int i = j * unroll_factor; - - const __half2* vals_half = reinterpret_cast(Xdata + i); - - uint8_t* m = mask + i; - - float2 vals_half_f[2]; - - vals_half_f[0] = __half22float2(vals_half[0]); - vals_half_f[1] = __half22float2(vals_half[1]); - - out[i] = __float2half(vals_half_f[0].x * scale * m[0]); - out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]); - out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]); - out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]); - } - -#endif - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - for (int i = high_index; i < N; i++) { - out[i] = __float2half((float)Xdata[i] * scale * mask[i]); - } - } -} - -template -void launch_dropout(T* out, - const T* vals, - uint8_t* mask, - int total_count, - int dim, - float ratio, - cudaStream_t stream, - bool bwd) -{ - assert(unroll_factor == 4); - - dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor); - dim3 block_dim = DS_CUDA_NUM_THREADS; - - if (dim > 512) { - block_dim.x >>= 1; - grid_dim.x <<= 1; - } - uint64_t inc = total_count / grid_dim.x / block_dim.x; - std::pair seed = Context::Instance().IncrementOffset(inc); - if (bwd) - dropout_kernel_bwd<<>>( - total_count, ratio, vals, out, mask, seed); - else - dropout_kernel<<>>( - total_count, ratio, out, vals, mask, seed); -} - -template void launch_dropout(float* out, - const float* vals, - uint8_t* mask, - int total_count, - int dim, - float ratio, - cudaStream_t stream, - bool); -template void launch_dropout(__half* out, - const __half* vals, - uint8_t* mask, - int total_count, - int dim, - float ratio, - cudaStream_t stream, - bool); - -__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask) -{ - CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; } -} - -__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask) -{ - const __half2 h_scale = __float2half2_rn(scale); - float2* x_cast = reinterpret_cast(Xdata); - uint32_t* mask_cast = reinterpret_cast(mask); - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - float2 x_data = x_cast[j]; - uint32_t m_32 = mask_cast[j]; - uint8_t* m = (uint8_t*)&m_32; - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - -#ifdef __STOCHASTIC_MODE__ - - __half2* x_data_h = reinterpret_cast<__half2*>(&x_data); - __half2 mask_h[2]; - float2 mask_f[2]; - - float* mask_f_data = &mask_f[0].x; -#pragma unroll - for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]); - - mask_h[0] = __float22half2_rn(mask_f[0]); - mask_h[1] = __float22half2_rn(mask_f[1]); - - result_h[0] = x_data_h[0] * h_scale * mask_h[0]; - result_h[1] = x_data_h[1] * h_scale * mask_h[1]; - -#else - - __half* x_data_h = reinterpret_cast<__half*>(&x_data); - float2 result[2]; - - result[0].x = (float)x_data_h[0] * scale * m[0]; - result[0].y = (float)x_data_h[1] * scale * m[1]; - result[1].x = (float)x_data_h[2] * scale * m[2]; - result[1].y = (float)x_data_h[3] * scale * m[3]; - - result_h[0] = __float22half2_rn(result[0]); - result_h[1] = __float22half2_rn(result[1]); - -#endif - x_cast[j] = result_f; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - for (int i = high_index; i < N; i++) { - Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]); - } - } -} - -template -void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream) -{ - assert(unroll_factor == 4); - - const float scale = 1. / (1. - ratio); - dropout_grad_kernel<<>>(total_count, scale, vals, mask); -} - -template void launch_dropout_grad(float* vals, - uint8_t* mask, - int total_count, - float ratio, - cudaStream_t stream); -template void launch_dropout_grad(__half* vals, - uint8_t* mask, - int total_count, - float ratio, - cudaStream_t stream); - -__global__ void dropout_grad_kernel(const int N, - const float scale, - const float* Xdata, - float* out, - uint8_t* mask) -{ - CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; } -} - -__global__ void dropout_grad_kernel(const int N, - const float scale, - const __half* Xdata, - __half* out, - uint8_t* mask) -{ - const float2* x_cast = reinterpret_cast(Xdata); - float2* out_cast = reinterpret_cast(out); - const uint32_t* mask_cast = reinterpret_cast(mask); - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - float2 x_data = x_cast[j]; - uint32_t m_32 = mask_cast[j]; - uint8_t* m = (uint8_t*)&m_32; - - __half* x_data_h = reinterpret_cast<__half*>(&x_data); - float2 result[2]; - - result[0].x = (float)x_data_h[0] * scale * m[0]; - result[0].y = (float)x_data_h[1] * scale * m[1]; - result[1].x = (float)x_data_h[2] * scale * m[2]; - result[1].y = (float)x_data_h[3] * scale * m[3]; - - result_h[0] = __float22half2_rn(result[0]); - result_h[1] = __float22half2_rn(result[1]); - - out_cast[j] = result_f; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - for (int i = high_index; i < N; i++) { - out[i] = __float2half((float)Xdata[i] * scale * mask[i]); - } - } -} - -template -void launch_dropout_grad(T* vals_out, - const T* vals, - uint8_t* mask, - int total_count, - float ratio, - cudaStream_t stream) -{ - assert(unroll_factor == 4); - - const float scale = 1. / (1. - ratio); - dropout_grad_kernel<<>>(total_count, scale, vals, vals_out, mask); -} -template void launch_dropout_grad(float*, - const float* vals, - uint8_t* mask, - int total_count, - float ratio, - cudaStream_t stream); -template void launch_dropout_grad(__half*, - const __half* vals, - uint8_t* mask, - int total_count, - float ratio, - cudaStream_t stream); - -__global__ void dropout_kernel(const int N, - const int dim, - const float ratio, - const float* bias, - float* Xdata, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int tid = threadIdx.x % (dim / unroll_factor); - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - - float4* Xdata_cast = reinterpret_cast(Xdata); - uint32_t* mask_32 = reinterpret_cast(mask); - const float4* bias_cast = reinterpret_cast(bias); - - CUDA_1D_KERNEL_LOOP(j, N) - { - float4 rand = curand_uniform4(&state); - uint32_t m_32; - uint8_t* m = (uint8_t*)&m_32; - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - float4 x_data = Xdata_cast[j]; - float4 b_data = bias_cast[j % (dim / unroll_factor)]; - - x_data.x += b_data.x; - x_data.y += b_data.y; - x_data.z += b_data.z; - x_data.w += b_data.w; - - x_data.x = x_data.x * scale * m[0]; - x_data.y = x_data.y * scale * m[1]; - x_data.z = x_data.z * scale * m[2]; - x_data.w = x_data.w * scale * m[3]; - - mask_32[j] = m_32; - Xdata_cast[j] = x_data; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - float x_data = Xdata[i] + bias[i % dim]; - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - Xdata[i] = x_data * scale * m; - mask[i] = m; - } - } -} - -__global__ void dropout_kernel(const int N, - const int dim, - const float ratio, - const __half* bias, - __half* Xdata, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int tid = threadIdx.x % (dim / unroll_factor); - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - - float2* Xdata_cast = reinterpret_cast(Xdata); - uint32_t* mask_32 = reinterpret_cast(mask); - const float2* bias_cast = reinterpret_cast(bias); - - CUDA_1D_KERNEL_LOOP(j, N) - { - float4 rand = curand_uniform4(&state); - - float2 data_f; - __half2* data_h = reinterpret_cast<__half2*>(&data_f); - - float2 bias_f; - __half2* bias_h = reinterpret_cast<__half2*>(&bias_f); - - data_f = Xdata_cast[j]; - bias_f = bias_cast[j % (dim / unroll_factor)]; - - float2 data_h_0 = __half22float2(data_h[0]); - float2 data_h_1 = __half22float2(data_h[1]); - - float2 bias_h_0 = __half22float2(bias_h[0]); - float2 bias_h_1 = __half22float2(bias_h[1]); - - data_h_0.x += bias_h_0.x; - data_h_0.y += bias_h_0.y; - data_h_1.x += bias_h_1.x; - data_h_1.y += bias_h_1.y; - - uint32_t m_32; - uint8_t* m = (uint8_t*)&m_32; - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - data_h_0.x = __float2half(data_h_0.x * scale * m[0]); - data_h_0.y = __float2half(data_h_0.y * scale * m[1]); - data_h_1.x = __float2half(data_h_1.x * scale * m[2]); - data_h_1.y = __float2half(data_h_1.y * scale * m[3]); - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - - result_h[0] = __float22half2_rn(data_h_0); - result_h[1] = __float22half2_rn(data_h_1); - - Xdata_cast[j] = result_f; - mask_32[j] = m_32; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - float x_data = (float)Xdata[i] + (float)bias[i % dim]; - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - Xdata[i] = __float2half(x_data * scale * m); - mask[i] = m; - } - } -} - -template -void launch_dropout(T* out, - const T* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream) -{ - assert(unroll_factor == 4); - - int total_count = batch * dim / unroll_factor; - - dim3 grid_dim = DS_GET_BLOCKS(total_count); - dim3 block_dim = DS_CUDA_NUM_THREADS; - - uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; - std::pair seed = Context::Instance().IncrementOffset(inc); - - dropout_kernel<<>>( - total_count, dim, ratio, bias, out, mask, seed); -} - -template void launch_dropout(float*, - const float* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream); -template void launch_dropout(__half*, - const __half* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream); - -__global__ void dropout_kernel(const int N, - const int dim, - const float ratio, - const float* input, - const float* residual, - const float* bias, - float* out, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int tid = threadIdx.x % (dim / unroll_factor); - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - - float4* out_cast = reinterpret_cast(out); - uint32_t* mask_32 = reinterpret_cast(mask); - - const float4* bias_cast = reinterpret_cast(bias); - const float4* residual_cast = reinterpret_cast(residual); - const float4* input_cast = reinterpret_cast(input); - - CUDA_1D_KERNEL_LOOP(j, N) - { - float4 rand = curand_uniform4(&state); - - uint32_t m_32; - uint8_t* m = (uint8_t*)&m_32; - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - float4 out_data; - float4 b_data = bias_cast[j % (dim / unroll_factor)]; - float4 res_data = residual_cast[j]; - float4 inp_data = input_cast[j]; - - out_data.x = (b_data.x + inp_data.x); - out_data.y = (b_data.y + inp_data.y); - out_data.z = (b_data.z + inp_data.z); - out_data.w = (b_data.w + inp_data.w); - - out_data.x = out_data.x * scale * m[0]; - out_data.y = out_data.y * scale * m[1]; - out_data.z = out_data.z * scale * m[2]; - out_data.w = out_data.w * scale * m[3]; - - out_data.x += res_data.x; - out_data.y += res_data.y; - out_data.z += res_data.z; - out_data.w += res_data.w; - - mask_32[j] = m_32; - out_cast[j] = out_data; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - float x_data = input[i] + bias[i % dim]; - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - x_data = x_data * scale * m; - x_data += residual[i]; - - out[i] = x_data; - mask[i] = m; - } - } -} - -__global__ void dropout_kernel(const int N, - const int dim, - const float ratio, - const __half* input, - const __half* residual, - const __half* bias, - __half* out, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int tid = threadIdx.x % (dim / unroll_factor); - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - - float2* out_cast = reinterpret_cast(out); - uint32_t* mask_32 = reinterpret_cast(mask); - - const float2* bias_cast = reinterpret_cast(bias); - const float2* residual_cast = reinterpret_cast(residual); - const float2* input_cast = reinterpret_cast(input); - - CUDA_1D_KERNEL_LOOP(j, N) - { - float4 rand = curand_uniform4(&state); - - float2 data_f; - __half2* data_h = reinterpret_cast<__half2*>(&data_f); - - float2 bias_f; - __half2* bias_h = reinterpret_cast<__half2*>(&bias_f); - - float2 residual_f; - __half2* residual_h = reinterpret_cast<__half2*>(&residual_f); - - float2 input_f; - __half2* input_h = reinterpret_cast<__half2*>(&input_f); - - bias_f = bias_cast[j % (dim / unroll_factor)]; - residual_f = residual_cast[j]; - input_f = input_cast[j]; - - float2 data_h_0 = __half22float2(data_h[0]); - float2 data_h_1 = __half22float2(data_h[1]); - - float2 bias_h_0 = __half22float2(bias_h[0]); - float2 bias_h_1 = __half22float2(bias_h[1]); - - float2 residual_h_0 = __half22float2(residual_h[0]); - float2 residual_h_1 = __half22float2(residual_h[1]); - - float2 input_h_0 = __half22float2(input_h[0]); - float2 input_h_1 = __half22float2(input_h[1]); - - data_h_0.x = (bias_h_0.x + input_h_0.x); - data_h_0.y = (bias_h_0.y + input_h_0.y); - data_h_1.x = (bias_h_1.x + input_h_1.x); - data_h_1.y = (bias_h_1.y + input_h_1.y); - - uint32_t m_32; - uint8_t* m = (uint8_t*)&m_32; - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - data_h_0.x = __float2half(data_h_0.x * scale * m[0]); - data_h_0.y = __float2half(data_h_0.y * scale * m[1]); - data_h_1.x = __float2half(data_h_1.x * scale * m[2]); - data_h_1.y = __float2half(data_h_1.y * scale * m[3]); - - data_h_0.x += residual_h_0.x; - data_h_0.y += residual_h_0.y; - data_h_1.x += residual_h_1.x; - data_h_1.y += residual_h_1.y; - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - - result_h[0] = __float22half2_rn(data_h_0); - result_h[1] = __float22half2_rn(data_h_1); - - out_cast[j] = result_f; - mask_32[j] = m_32; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - float x_data = (float)input[i] + (float)bias[i % dim]; - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - x_data = x_data * scale * m; - x_data += (float)residual[i]; - - out[i] = __float2half(x_data); - mask[i] = m; - } - } -} - -template -void launch_dropout(T* out, - const T* input, - const T* residual, - const T* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream) -{ - assert(unroll_factor == 4); - - int total_count = batch * dim / unroll_factor; - dim3 grid_dim = DS_GET_BLOCKS(total_count); - dim3 block_dim = DS_CUDA_NUM_THREADS; - - uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; - std::pair seed = Context::Instance().IncrementOffset(inc); - - dropout_kernel<<>>( - total_count, dim, ratio, input, residual, bias, out, mask, seed); -} - -template void launch_dropout(float*, - const float*, - const float* residual, - const float* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream); -template void launch_dropout(__half*, - const __half*, - const __half* residual, - const __half* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream); +#include "custom_cuda_layers.h" + +const int unroll_factor = 4; + +__global__ void dropout_kernel(const int N, + const float ratio, + float* out, + const float* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float4 rand = curand_uniform4(&state); + uint8_t m[unroll_factor]; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + int i = j * unroll_factor; + + mask[i] = (uint8_t)m[0]; + mask[i + 1] = (uint8_t)m[1]; + mask[i + 2] = (uint8_t)m[2]; + mask[i + 3] = (uint8_t)m[3]; + + out[i] = Xdata[i] * scale * m[0]; + out[i + 1] = Xdata[i + 1] * scale * m[1]; + out[i + 2] = Xdata[i + 2] * scale * m[2]; + out[i + 3] = Xdata[i + 3] * scale * m[3]; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + out[i] = Xdata[i] * scale * m; + mask[i] = m; + } + } +} + +__global__ void dropout_kernel(const int N, + const float ratio, + __half* out, + const __half* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + +#ifdef __STOCHASTIC_MODE__ + + const __half2 h_scale = __float2half2_rn(scale); + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_cast = reinterpret_cast(mask); + + uint32_t m_32; + uint8_t* m = reinterpret_cast(&m_32); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + __half2 mask_h[2]; + float2 mask_f[2]; + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_f = x_cast[j]; + __half2* x_h = reinterpret_cast<__half2*>(&x_f); + + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); + + mask_h[0] = __float22half2_rn(mask_f[0]); + mask_h[1] = __float22half2_rn(mask_f[1]); + + result_h[0] = x_h[0] * h_scale * mask_h[0]; + result_h[1] = x_h[1] * h_scale * mask_h[1]; + + out_cast[j] = result_f; + + mask_cast[j] = m_32; + } + +#else + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + const __half2* vals_half = reinterpret_cast(Xdata + i); + float2 vals_half_f[2]; + vals_half_f[0] = __half22float2(vals_half[0]); + vals_half_f[1] = __half22float2(vals_half[1]); + + uint8_t m[unroll_factor]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + out[i] = __float2half(vals_half_f[0].x * scale * m[0]); + out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]); + out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]); + out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]); + + mask[i] = m[0]; + mask[i + 1] = m[1]; + mask[i + 2] = m[2]; + mask[i + 3] = m[3]; + } + +#endif + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + out[i] = __float2half((float)Xdata[i] * scale * m); + mask[i] = m; + } + } +} + +__global__ void dropout_kernel_bwd(const int N, + const float ratio, + const float* Xdata, + float* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + out[i] = mask[i] ? Xdata[i] * scale : 0.0; + out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0; + out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0; + out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; } + } +} + +__global__ void dropout_kernel_bwd(const int N, + const float ratio, + const __half* Xdata, + __half* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + +#ifdef __STOCHASTIC_MODE__ + + const __half2 h_scale = __float2half2_rn(scale); + + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_cast = reinterpret_cast(mask); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_f = x_cast[j]; + __half2* x_h = reinterpret_cast<__half2*>(&x_f); + + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + __half2 mask_h[2]; + float2 mask_f[2]; + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); + +#pragma unroll + for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + result_h[0] = x_h[0] * h_scale * mask_h[0]; + result_h[1] = x_h[1] * h_scale * mask_h[1]; + + out_cast[j] = result_f; + } + +#else + + const __half h_scale = __float2half(scale); + const __half h_zero = __float2half(0.0); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + const __half2* vals_half = reinterpret_cast(Xdata + i); + + uint8_t* m = mask + i; + + float2 vals_half_f[2]; + + vals_half_f[0] = __half22float2(vals_half[0]); + vals_half_f[1] = __half22float2(vals_half[1]); + + out[i] = __float2half(vals_half_f[0].x * scale * m[0]); + out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]); + out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]); + out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]); + } + +#endif + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { + out[i] = __float2half((float)Xdata[i] * scale * mask[i]); + } + } +} + +template +void launch_dropout(T* out, + const T* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + cudaStream_t stream, + bool bwd) +{ + assert(unroll_factor == 4); + + dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor); + dim3 block_dim = DS_CUDA_NUM_THREADS; + + if (dim > 512) { + block_dim.x >>= 1; + grid_dim.x <<= 1; + } + uint64_t inc = total_count / grid_dim.x / block_dim.x; + std::pair seed = Context::Instance().IncrementOffset(inc); + if (bwd) + dropout_kernel_bwd<<>>( + total_count, ratio, vals, out, mask, seed); + else + dropout_kernel<<>>( + total_count, ratio, out, vals, mask, seed); +} + +template void launch_dropout(float* out, + const float* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + cudaStream_t stream, + bool); +template void launch_dropout(__half* out, + const __half* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + cudaStream_t stream, + bool); + +__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask) +{ + CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; } +} + +__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask) +{ + const __half2 h_scale = __float2half2_rn(scale); + float2* x_cast = reinterpret_cast(Xdata); + uint32_t* mask_cast = reinterpret_cast(mask); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_data = x_cast[j]; + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + +#ifdef __STOCHASTIC_MODE__ + + __half2* x_data_h = reinterpret_cast<__half2*>(&x_data); + __half2 mask_h[2]; + float2 mask_f[2]; + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]); + + mask_h[0] = __float22half2_rn(mask_f[0]); + mask_h[1] = __float22half2_rn(mask_f[1]); + + result_h[0] = x_data_h[0] * h_scale * mask_h[0]; + result_h[1] = x_data_h[1] * h_scale * mask_h[1]; + +#else + + __half* x_data_h = reinterpret_cast<__half*>(&x_data); + float2 result[2]; + + result[0].x = (float)x_data_h[0] * scale * m[0]; + result[0].y = (float)x_data_h[1] * scale * m[1]; + result[1].x = (float)x_data_h[2] * scale * m[2]; + result[1].y = (float)x_data_h[3] * scale * m[3]; + + result_h[0] = __float22half2_rn(result[0]); + result_h[1] = __float22half2_rn(result[1]); + +#endif + x_cast[j] = result_f; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { + Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]); + } + } +} + +template +void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream) +{ + assert(unroll_factor == 4); + + const float scale = 1. / (1. - ratio); + dropout_grad_kernel<<>>(total_count, scale, vals, mask); +} + +template void launch_dropout_grad(float* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); +template void launch_dropout_grad(__half* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); + +__global__ void dropout_grad_kernel(const int N, + const float scale, + const float* Xdata, + float* out, + uint8_t* mask) +{ + CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; } +} + +__global__ void dropout_grad_kernel(const int N, + const float scale, + const __half* Xdata, + __half* out, + uint8_t* mask) +{ + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + const uint32_t* mask_cast = reinterpret_cast(mask); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_data = x_cast[j]; + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + __half* x_data_h = reinterpret_cast<__half*>(&x_data); + float2 result[2]; + + result[0].x = (float)x_data_h[0] * scale * m[0]; + result[0].y = (float)x_data_h[1] * scale * m[1]; + result[1].x = (float)x_data_h[2] * scale * m[2]; + result[1].y = (float)x_data_h[3] * scale * m[3]; + + result_h[0] = __float22half2_rn(result[0]); + result_h[1] = __float22half2_rn(result[1]); + + out_cast[j] = result_f; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { + out[i] = __float2half((float)Xdata[i] * scale * mask[i]); + } + } +} + +template +void launch_dropout_grad(T* vals_out, + const T* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream) +{ + assert(unroll_factor == 4); + + const float scale = 1. / (1. - ratio); + dropout_grad_kernel<<>>(total_count, scale, vals, vals_out, mask); +} +template void launch_dropout_grad(float*, + const float* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); +template void launch_dropout_grad(__half*, + const __half* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const float* bias, + float* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float4* Xdata_cast = reinterpret_cast(Xdata); + uint32_t* mask_32 = reinterpret_cast(mask); + const float4* bias_cast = reinterpret_cast(bias); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + float4 x_data = Xdata_cast[j]; + float4 b_data = bias_cast[j % (dim / unroll_factor)]; + + x_data.x += b_data.x; + x_data.y += b_data.y; + x_data.z += b_data.z; + x_data.w += b_data.w; + + x_data.x = x_data.x * scale * m[0]; + x_data.y = x_data.y * scale * m[1]; + x_data.z = x_data.z * scale * m[2]; + x_data.w = x_data.w * scale * m[3]; + + mask_32[j] = m_32; + Xdata_cast[j] = x_data; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = Xdata[i] + bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + Xdata[i] = x_data * scale * m; + mask[i] = m; + } + } +} + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const __half* bias, + __half* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float2* Xdata_cast = reinterpret_cast(Xdata); + uint32_t* mask_32 = reinterpret_cast(mask); + const float2* bias_cast = reinterpret_cast(bias); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + + float2 data_f; + __half2* data_h = reinterpret_cast<__half2*>(&data_f); + + float2 bias_f; + __half2* bias_h = reinterpret_cast<__half2*>(&bias_f); + + data_f = Xdata_cast[j]; + bias_f = bias_cast[j % (dim / unroll_factor)]; + + float2 data_h_0 = __half22float2(data_h[0]); + float2 data_h_1 = __half22float2(data_h[1]); + + float2 bias_h_0 = __half22float2(bias_h[0]); + float2 bias_h_1 = __half22float2(bias_h[1]); + + data_h_0.x += bias_h_0.x; + data_h_0.y += bias_h_0.y; + data_h_1.x += bias_h_1.x; + data_h_1.y += bias_h_1.y; + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + data_h_0.x = __float2half(data_h_0.x * scale * m[0]); + data_h_0.y = __float2half(data_h_0.y * scale * m[1]); + data_h_1.x = __float2half(data_h_1.x * scale * m[2]); + data_h_1.y = __float2half(data_h_1.y * scale * m[3]); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + result_h[0] = __float22half2_rn(data_h_0); + result_h[1] = __float22half2_rn(data_h_1); + + Xdata_cast[j] = result_f; + mask_32[j] = m_32; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = (float)Xdata[i] + (float)bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + Xdata[i] = __float2half(x_data * scale * m); + mask[i] = m; + } + } +} + +template +void launch_dropout(T* out, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream) +{ + assert(unroll_factor == 4); + + int total_count = batch * dim / unroll_factor; + + dim3 grid_dim = DS_GET_BLOCKS(total_count); + dim3 block_dim = DS_CUDA_NUM_THREADS; + + uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; + std::pair seed = Context::Instance().IncrementOffset(inc); + + dropout_kernel<<>>( + total_count, dim, ratio, bias, out, mask, seed); +} + +template void launch_dropout(float*, + const float* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); +template void launch_dropout(__half*, + const __half* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const float* input, + const float* residual, + const float* bias, + float* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float4* out_cast = reinterpret_cast(out); + uint32_t* mask_32 = reinterpret_cast(mask); + + const float4* bias_cast = reinterpret_cast(bias); + const float4* residual_cast = reinterpret_cast(residual); + const float4* input_cast = reinterpret_cast(input); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + float4 out_data; + float4 b_data = bias_cast[j % (dim / unroll_factor)]; + float4 res_data = residual_cast[j]; + float4 inp_data = input_cast[j]; + + out_data.x = (b_data.x + inp_data.x); + out_data.y = (b_data.y + inp_data.y); + out_data.z = (b_data.z + inp_data.z); + out_data.w = (b_data.w + inp_data.w); + + out_data.x = out_data.x * scale * m[0]; + out_data.y = out_data.y * scale * m[1]; + out_data.z = out_data.z * scale * m[2]; + out_data.w = out_data.w * scale * m[3]; + + out_data.x += res_data.x; + out_data.y += res_data.y; + out_data.z += res_data.z; + out_data.w += res_data.w; + + mask_32[j] = m_32; + out_cast[j] = out_data; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = input[i] + bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + x_data = x_data * scale * m; + x_data += residual[i]; + + out[i] = x_data; + mask[i] = m; + } + } +} + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const __half* input, + const __half* residual, + const __half* bias, + __half* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_32 = reinterpret_cast(mask); + + const float2* bias_cast = reinterpret_cast(bias); + const float2* residual_cast = reinterpret_cast(residual); + const float2* input_cast = reinterpret_cast(input); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + + float2 data_f; + __half2* data_h = reinterpret_cast<__half2*>(&data_f); + + float2 bias_f; + __half2* bias_h = reinterpret_cast<__half2*>(&bias_f); + + float2 residual_f; + __half2* residual_h = reinterpret_cast<__half2*>(&residual_f); + + float2 input_f; + __half2* input_h = reinterpret_cast<__half2*>(&input_f); + + bias_f = bias_cast[j % (dim / unroll_factor)]; + residual_f = residual_cast[j]; + input_f = input_cast[j]; + + float2 data_h_0 = __half22float2(data_h[0]); + float2 data_h_1 = __half22float2(data_h[1]); + + float2 bias_h_0 = __half22float2(bias_h[0]); + float2 bias_h_1 = __half22float2(bias_h[1]); + + float2 residual_h_0 = __half22float2(residual_h[0]); + float2 residual_h_1 = __half22float2(residual_h[1]); + + float2 input_h_0 = __half22float2(input_h[0]); + float2 input_h_1 = __half22float2(input_h[1]); + + data_h_0.x = (bias_h_0.x + input_h_0.x); + data_h_0.y = (bias_h_0.y + input_h_0.y); + data_h_1.x = (bias_h_1.x + input_h_1.x); + data_h_1.y = (bias_h_1.y + input_h_1.y); + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + data_h_0.x = __float2half(data_h_0.x * scale * m[0]); + data_h_0.y = __float2half(data_h_0.y * scale * m[1]); + data_h_1.x = __float2half(data_h_1.x * scale * m[2]); + data_h_1.y = __float2half(data_h_1.y * scale * m[3]); + + data_h_0.x += residual_h_0.x; + data_h_0.y += residual_h_0.y; + data_h_1.x += residual_h_1.x; + data_h_1.y += residual_h_1.y; + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + result_h[0] = __float22half2_rn(data_h_0); + result_h[1] = __float22half2_rn(data_h_1); + + out_cast[j] = result_f; + mask_32[j] = m_32; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = (float)input[i] + (float)bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + x_data = x_data * scale * m; + x_data += (float)residual[i]; + + out[i] = __float2half(x_data); + mask[i] = m; + } + } +} + +template +void launch_dropout(T* out, + const T* input, + const T* residual, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream) +{ + assert(unroll_factor == 4); + + int total_count = batch * dim / unroll_factor; + dim3 grid_dim = DS_GET_BLOCKS(total_count); + dim3 block_dim = DS_CUDA_NUM_THREADS; + + uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; + std::pair seed = Context::Instance().IncrementOffset(inc); + + dropout_kernel<<>>( + total_count, dim, ratio, input, residual, bias, out, mask, seed); +} + +template void launch_dropout(float*, + const float*, + const float* residual, + const float* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); +template void launch_dropout(__half*, + const __half*, + const __half* residual, + const __half* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); diff --git a/csrc/transformer/gelu_kernels.cu b/csrc/transformer/gelu_kernels.cu index 120480062..cea337b06 100644 --- a/csrc/transformer/gelu_kernels.cu +++ b/csrc/transformer/gelu_kernels.cu @@ -1,330 +1,330 @@ -#include "custom_cuda_layers.h" - -inline __device__ float gelu(const float x) -{ - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); -} - -inline __device__ float d_gelu(const float x) -{ - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return (dg1 + dg2 + dg3); -} - -/* -Fused bias add with GELU - -Loads a vector of 4 elements each iteration, for stride -iterations. It was written with the intention to launch 256 thread -threadblocks, so to launch for bert-large, we would set ITERATIONS -to 4. This is currently done automatically as a heuristic, setting -the number of iterations as blocks of 1024. - -For FP16, the values are loaded from memory as __half, but converted -to FP32 for the arithmetic itself, to prevent numerous overflow on -the intermediate hyperbolic tangent, since there's no intrinsic -that computes it directly. -*/ - -__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations) -{ - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - const float4* input_cast = reinterpret_cast(input); - float4* vals_cast = reinterpret_cast(vals); - - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float4 data = input_cast[row * row_stride + i * loop_stride + id]; - - data.x = gelu(data.x); - data.y = gelu(data.y); - data.z = gelu(data.z); - data.w = gelu(data.w); - - vals_cast[row * row_stride + i * loop_stride + id] = data; - } - } -} - -__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations) -{ -#if __CUDA_ARCH__ >= 700 - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - const float2* input_cast = reinterpret_cast(input); - float2* vals_cast = reinterpret_cast(vals); - - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - low_data.x = gelu(low_data.x); - low_data.y = gelu(low_data.y); - high_data.x = gelu(high_data.x); - high_data.y = gelu(high_data.y); - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; - } - } -#endif -} - -__global__ void fused_bias_gelu(const float* input, - const float* bias, - float* vals, - int row_stride, - int iterations) -{ - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - const float4* input_cast = reinterpret_cast(input); - float4* vals_cast = reinterpret_cast(vals); - const float4* bias_cast = reinterpret_cast(bias); - - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float4 data = input_cast[row * row_stride + i * loop_stride + id]; - float4 bias_data = bias_cast[i * loop_stride + id]; - - data.x += bias_data.x; - data.y += bias_data.y; - data.z += bias_data.z; - data.w += bias_data.w; - - data.x = gelu(data.x); - data.y = gelu(data.y); - data.z = gelu(data.z); - data.w = gelu(data.w); - - vals_cast[row * row_stride + i * loop_stride + id] = data; - } - } -} - -__global__ void fused_bias_gelu(const __half* input, - const __half* bias, - __half* vals, - int row_stride, - int iterations) -{ -#if __CUDA_ARCH__ >= 700 - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - const float2* input_cast = reinterpret_cast(input); - float2* vals_cast = reinterpret_cast(vals); - const float2* bias_cast = reinterpret_cast(bias); - - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; - float2 bias_vec = bias_cast[i * loop_stride + id]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - - low_data.x += low_bias.x; - low_data.y += low_bias.y; - high_data.x += high_bias.x; - high_data.y += high_bias.y; - - low_data.x = gelu(low_data.x); - low_data.y = gelu(low_data.y); - high_data.x = gelu(high_data.x); - high_data.y = gelu(high_data.y); - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; - } - } -#endif -} - -__global__ void d_gelu_func(float* d_output, - const float* gelu_input, - const float* bias, - int row_stride, - int iterations) -{ - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - float4* d_output_cast = reinterpret_cast(d_output); - const float4* gelu_input_cast = reinterpret_cast(gelu_input); - const float4* bias_cast = reinterpret_cast(bias); - - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; - float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; - float4 bias_data = bias_cast[i * loop_stride + id]; - - gelu_input_data.x += bias_data.x; - gelu_input_data.y += bias_data.y; - gelu_input_data.z += bias_data.z; - gelu_input_data.w += bias_data.w; - - output_data.x *= d_gelu(gelu_input_data.x); - output_data.y *= d_gelu(gelu_input_data.y); - output_data.z *= d_gelu(gelu_input_data.z); - output_data.w *= d_gelu(gelu_input_data.w); - - d_output_cast[row * row_stride + i * loop_stride + id] = output_data; - } - } -} - -__global__ void d_gelu_func(__half* d_output, - const __half* gelu_input, - const __half* bias, - int row_stride, - int iterations) -{ -#if __CUDA_ARCH__ >= 700 - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - float2* d_output_cast = reinterpret_cast(d_output); - const float2* gelu_input_cast = reinterpret_cast(gelu_input); - const float2* bias_cast = reinterpret_cast(bias); - -#pragma unroll - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; - float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; - float2 bias_vec = bias_cast[i * loop_stride + id]; - - __half2* output_data_half = reinterpret_cast<__half2*>(&output_data); - __half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 output_half_0 = __half22float2(output_data_half[0]); - float2 output_half_1 = __half22float2(output_data_half[1]); - - float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]); - float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]); - - float2 bias_half_0 = __half22float2(bias_half[0]); - float2 bias_half_1 = __half22float2(bias_half[1]); - - gelu_input_half_0.x += bias_half_0.x; - gelu_input_half_0.y += bias_half_0.y; - gelu_input_half_1.x += bias_half_1.x; - gelu_input_half_1.y += bias_half_1.y; - - output_half_0.x *= d_gelu(gelu_input_half_0.x); - output_half_0.y *= d_gelu(gelu_input_half_0.y); - output_half_1.x *= d_gelu(gelu_input_half_1.x); - output_half_1.y *= d_gelu(gelu_input_half_1.y); - - float2 result; - __half2* result_half2 = reinterpret_cast<__half2*>(&result); - - result_half2[0] = __float22half2_rn(output_half_0); - result_half2[1] = __float22half2_rn(output_half_1); - - d_output_cast[row * row_stride + i * loop_stride + id] = result; - } - } -#endif -} - -template -void launch_bias_gelu(const T* input, - const T* bias, - T* output, - int intermediate_size, - int batch_size, - cudaStream_t stream) -{ - int iterations = (intermediate_size + 1023) / 1024; - int threads = (intermediate_size - 1) / (iterations * 4) + 1; - dim3 block_dims(threads); - dim3 grid_dims(batch_size); - - fused_bias_gelu<<>>( - input, bias, output, intermediate_size / 4, iterations); -} - -template -void launch_gelu(const T* input, - T* output, - int intermediate_size, - int batch_size, - cudaStream_t stream) -{ - int iterations = (intermediate_size + 1023) / 1024; - int threads = (intermediate_size - 1) / (iterations * 4) + 1; - dim3 block_dims(threads); - dim3 grid_dims(batch_size); - - gelu_kernel<<>>( - input, output, intermediate_size / 4, iterations); -} - -template void launch_bias_gelu(const float*, const float*, float*, int, int, cudaStream_t); -template void launch_bias_gelu<__half>(const __half*, - const __half*, - __half*, - int, - int, - cudaStream_t); - -template void launch_gelu(const float*, float*, int, int, cudaStream_t); -template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t); - -template -void launch_d_gelu(T* d_output, - const T* input, - const T* bias, - int intermediate_size, - int batch_size, - cudaStream_t stream) -{ - int iterations = (intermediate_size + 1023) / 1024; - int threads = (intermediate_size - 1) / (iterations * 4) + 1; - dim3 block_dims(threads); - dim3 grid_dims(batch_size); - - d_gelu_func<<>>( - d_output, input, bias, intermediate_size / 4, iterations); -} - -template void launch_d_gelu(float*, const float*, const float*, int, int, cudaStream_t); -template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t); +#include "custom_cuda_layers.h" + +inline __device__ float gelu(const float x) +{ + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); +} + +inline __device__ float d_gelu(const float x) +{ + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return (dg1 + dg2 + dg3); +} + +/* +Fused bias add with GELU + +Loads a vector of 4 elements each iteration, for stride +iterations. It was written with the intention to launch 256 thread +threadblocks, so to launch for bert-large, we would set ITERATIONS +to 4. This is currently done automatically as a heuristic, setting +the number of iterations as blocks of 1024. + +For FP16, the values are loaded from memory as __half, but converted +to FP32 for the arithmetic itself, to prevent numerous overflow on +the intermediate hyperbolic tangent, since there's no intrinsic +that computes it directly. +*/ + +__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations) +{ + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float4* input_cast = reinterpret_cast(input); + float4* vals_cast = reinterpret_cast(vals); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 data = input_cast[row * row_stride + i * loop_stride + id]; + + data.x = gelu(data.x); + data.y = gelu(data.y); + data.z = gelu(data.z); + data.w = gelu(data.w); + + vals_cast[row * row_stride + i * loop_stride + id] = data; + } + } +} + +__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations) +{ +#if __CUDA_ARCH__ >= 700 + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float2* input_cast = reinterpret_cast(input); + float2* vals_cast = reinterpret_cast(vals); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + low_data.x = gelu(low_data.x); + low_data.y = gelu(low_data.y); + high_data.x = gelu(high_data.x); + high_data.y = gelu(high_data.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; + } + } +#endif +} + +__global__ void fused_bias_gelu(const float* input, + const float* bias, + float* vals, + int row_stride, + int iterations) +{ + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float4* input_cast = reinterpret_cast(input); + float4* vals_cast = reinterpret_cast(vals); + const float4* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 data = input_cast[row * row_stride + i * loop_stride + id]; + float4 bias_data = bias_cast[i * loop_stride + id]; + + data.x += bias_data.x; + data.y += bias_data.y; + data.z += bias_data.z; + data.w += bias_data.w; + + data.x = gelu(data.x); + data.y = gelu(data.y); + data.z = gelu(data.z); + data.w = gelu(data.w); + + vals_cast[row * row_stride + i * loop_stride + id] = data; + } + } +} + +__global__ void fused_bias_gelu(const __half* input, + const __half* bias, + __half* vals, + int row_stride, + int iterations) +{ +#if __CUDA_ARCH__ >= 700 + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float2* input_cast = reinterpret_cast(input); + float2* vals_cast = reinterpret_cast(vals); + const float2* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; + float2 bias_vec = bias_cast[i * loop_stride + id]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += low_bias.x; + low_data.y += low_bias.y; + high_data.x += high_bias.x; + high_data.y += high_bias.y; + + low_data.x = gelu(low_data.x); + low_data.y = gelu(low_data.y); + high_data.x = gelu(high_data.x); + high_data.y = gelu(high_data.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; + } + } +#endif +} + +__global__ void d_gelu_func(float* d_output, + const float* gelu_input, + const float* bias, + int row_stride, + int iterations) +{ + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + float4* d_output_cast = reinterpret_cast(d_output); + const float4* gelu_input_cast = reinterpret_cast(gelu_input); + const float4* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; + float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; + float4 bias_data = bias_cast[i * loop_stride + id]; + + gelu_input_data.x += bias_data.x; + gelu_input_data.y += bias_data.y; + gelu_input_data.z += bias_data.z; + gelu_input_data.w += bias_data.w; + + output_data.x *= d_gelu(gelu_input_data.x); + output_data.y *= d_gelu(gelu_input_data.y); + output_data.z *= d_gelu(gelu_input_data.z); + output_data.w *= d_gelu(gelu_input_data.w); + + d_output_cast[row * row_stride + i * loop_stride + id] = output_data; + } + } +} + +__global__ void d_gelu_func(__half* d_output, + const __half* gelu_input, + const __half* bias, + int row_stride, + int iterations) +{ +#if __CUDA_ARCH__ >= 700 + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + float2* d_output_cast = reinterpret_cast(d_output); + const float2* gelu_input_cast = reinterpret_cast(gelu_input); + const float2* bias_cast = reinterpret_cast(bias); + +#pragma unroll + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; + float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; + float2 bias_vec = bias_cast[i * loop_stride + id]; + + __half2* output_data_half = reinterpret_cast<__half2*>(&output_data); + __half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 output_half_0 = __half22float2(output_data_half[0]); + float2 output_half_1 = __half22float2(output_data_half[1]); + + float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]); + float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]); + + float2 bias_half_0 = __half22float2(bias_half[0]); + float2 bias_half_1 = __half22float2(bias_half[1]); + + gelu_input_half_0.x += bias_half_0.x; + gelu_input_half_0.y += bias_half_0.y; + gelu_input_half_1.x += bias_half_1.x; + gelu_input_half_1.y += bias_half_1.y; + + output_half_0.x *= d_gelu(gelu_input_half_0.x); + output_half_0.y *= d_gelu(gelu_input_half_0.y); + output_half_1.x *= d_gelu(gelu_input_half_1.x); + output_half_1.y *= d_gelu(gelu_input_half_1.y); + + float2 result; + __half2* result_half2 = reinterpret_cast<__half2*>(&result); + + result_half2[0] = __float22half2_rn(output_half_0); + result_half2[1] = __float22half2_rn(output_half_1); + + d_output_cast[row * row_stride + i * loop_stride + id] = result; + } + } +#endif +} + +template +void launch_bias_gelu(const T* input, + const T* bias, + T* output, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + dim3 block_dims(threads); + dim3 grid_dims(batch_size); + + fused_bias_gelu<<>>( + input, bias, output, intermediate_size / 4, iterations); +} + +template +void launch_gelu(const T* input, + T* output, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + dim3 block_dims(threads); + dim3 grid_dims(batch_size); + + gelu_kernel<<>>( + input, output, intermediate_size / 4, iterations); +} + +template void launch_bias_gelu(const float*, const float*, float*, int, int, cudaStream_t); +template void launch_bias_gelu<__half>(const __half*, + const __half*, + __half*, + int, + int, + cudaStream_t); + +template void launch_gelu(const float*, float*, int, int, cudaStream_t); +template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t); + +template +void launch_d_gelu(T* d_output, + const T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + dim3 block_dims(threads); + dim3 grid_dims(batch_size); + + d_gelu_func<<>>( + d_output, input, bias, intermediate_size / 4, iterations); +} + +template void launch_d_gelu(float*, const float*, const float*, int, int, cudaStream_t); +template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t); diff --git a/csrc/transformer/general_kernels.cu b/csrc/transformer/general_kernels.cu index 7d318773f..1eaa94e1e 100644 --- a/csrc/transformer/general_kernels.cu +++ b/csrc/transformer/general_kernels.cu @@ -1,411 +1,411 @@ -#include "general_kernels.h" - -namespace cg = cooperative_groups; - -template -__global__ void column_sum_reduce(const T* __restrict__ inp, - T* __restrict__ out, - int rows, - int width) -{ - __shared__ float tile[TILE_DIM][TILE_DIM + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - - int y_stride = width * TILE_DIM; - - float localSum = 0; - - // Loop across matrix height - if (idx < width) { - int offset = threadIdx.y * width + idx; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - localSum += (float)inp[offset]; - offset += y_stride; - } - } - - tile[threadIdx.x][threadIdx.y] = localSum; - - __syncthreads(); - - // Sum the shared buffer. - float sum = tile[threadIdx.y][threadIdx.x]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) { - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (pos < width) out[pos] = sum; - } -} - -template -void launch_fuse_transpose_bias_kernel(const T* inp, - T* out, - int rows, - int cols, - cudaStream_t stream); - -template <> -void launch_fuse_transpose_bias_kernel(const float* inp, - float* out, - int rows, - int cols, - cudaStream_t stream) -{ - // assert(rows % TILE_DIM == 0); - // assert(cols % TILE_DIM == 0); - - dim3 grid_dim((cols - 1) / TILE_DIM + 1); - dim3 block_dim(TILE_DIM, TILE_DIM); - - column_sum_reduce<<>>(inp, out, rows, cols); -} - -template <> -void launch_fuse_transpose_bias_kernel<__half>(const __half* inp, - __half* out, - int rows, - int cols, - cudaStream_t stream) -{ - // assert(rows % TILE_DIM == 0); - // assert(cols % TILE_DIM == 0); - - dim3 grid_dim((cols - 1) / TILE_DIM + 1); - dim3 block_dim(TILE_DIM, TILE_DIM); - - column_sum_reduce<__half><<>>(inp, out, rows, cols); -} - -__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2) -{ - const float4* inp1_4 = reinterpret_cast(inp1); - const float4* inp2_4 = reinterpret_cast(inp2); - float4* out_4 = reinterpret_cast(out); - - CUDA_1D_KERNEL_LOOP(j, N) - { - float4 val; - float4 inp1_reg = inp1_4[j]; - float4 inp2_reg = inp2_4[j]; - - val.x = inp1_reg.x + inp2_reg.x; - val.y = inp1_reg.y + inp2_reg.y; - val.z = inp1_reg.z + inp2_reg.z; - val.w = inp1_reg.w + inp2_reg.w; - - out_4[j] = val; - } -} - -__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2) -{ - float2 inp1_4; - float2 inp2_4; - - __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); - __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); - - const float2* inp1_arr = reinterpret_cast(inp1); - const float2* inp2_arr = reinterpret_cast(inp2); - - CUDA_1D_KERNEL_LOOP(j, N) - { - inp1_4 = inp1_arr[j]; - inp2_4 = inp2_arr[j]; - - float2 inp1_h_f_0 = __half22float2(inp1_h[0]); - float2 inp1_h_f_1 = __half22float2(inp1_h[1]); - - float2 inp2_h_f_0 = __half22float2(inp2_h[0]); - float2 inp2_h_f_1 = __half22float2(inp2_h[1]); - - inp1_h_f_0.x += inp2_h_f_0.x; - inp1_h_f_0.y += inp2_h_f_0.y; - inp1_h_f_1.x += inp2_h_f_1.x; - inp1_h_f_1.y += inp2_h_f_1.y; - - float2 val_f; - __half2* val_h = reinterpret_cast<__half2*>(&val_f); - - val_h[0] = __float22half2_rn(inp1_h_f_0); - val_h[1] = __float22half2_rn(inp1_h_f_1); - - float2* out_4 = reinterpret_cast(out); - out_4[j] = val_f; - } -} - -template <> -void launch_fused_add2(float* out, - const float* inp1, - const float* inp2, - int batch_size, - int seq_length, - int hidden_dim, - cudaStream_t& stream) -{ - int total_count = batch_size * seq_length * hidden_dim / 4; - dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); - - fused_add2_kernel<<>>(total_count, out, inp1, inp2); -} - -template <> -void launch_fused_add2<__half>(__half* out, - const __half* inp1, - const __half* inp2, - int batch_size, - int seq_length, - int hidden_dim, - cudaStream_t& stream) -{ - int total_count = batch_size * seq_length * hidden_dim / 4; - dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); - - fused_add2_kernel<<>>(total_count, out, inp1, inp2); -} - -__global__ void fused_add3_kernel(float* out, - const float* inp1, - const float* inp2, - const float* inp3, - int size, - int row_stride) -{ - int row = blockIdx.x; - int id = threadIdx.x; - - const float4* inp1_4 = reinterpret_cast(inp1); - const float4* inp2_4 = reinterpret_cast(inp2); - const float4* inp3_4 = reinterpret_cast(inp3); - - float4* out_4 = reinterpret_cast(out); - - float4 val; - float4 inp1_reg = inp1_4[row * row_stride + id]; - float4 inp2_reg = inp2_4[row * row_stride + id]; - float4 inp3_reg = inp3_4[row * row_stride + id]; - - val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x; - val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y; - val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z; - val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w; - - out_4[row * row_stride + id] = val; -} - -__global__ void fused_add3_kernel(__half* out, - const __half* inp1, - const __half* inp2, - const __half* inp3, - int size, - int row_stride) -{ - int row = blockIdx.x; - int id = threadIdx.x; - const float2* inp1_arr = reinterpret_cast(inp1); - const float2* inp2_arr = reinterpret_cast(inp2); - const float2* inp3_arr = reinterpret_cast(inp3); - - float2 inp1_4 = inp1_arr[row * row_stride + id]; - float2 inp2_4 = inp2_arr[row * row_stride + id]; - float2 inp3_4 = inp3_arr[row * row_stride + id]; - - __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); - __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); - __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4); - - float2 inp1_h_f_0 = __half22float2(inp1_h[0]); - float2 inp1_h_f_1 = __half22float2(inp1_h[1]); - - float2 inp2_h_f_0 = __half22float2(inp2_h[0]); - float2 inp2_h_f_1 = __half22float2(inp2_h[1]); - - float2 inp3_h_f_0 = __half22float2(inp3_h[0]); - float2 inp3_h_f_1 = __half22float2(inp3_h[1]); - - inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x); - inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y); - inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x); - inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y); - - float2 val_f; - __half2* val_h = reinterpret_cast<__half2*>(&val_f); - - val_h[0] = __float22half2_rn(inp1_h_f_0); - val_h[1] = __float22half2_rn(inp1_h_f_1); - - float2* out_4 = reinterpret_cast(out); - out_4[row * row_stride + id] = val_f; -} - -template <> -void launch_fused_add3(float* out, - const float* inp1, - const float* inp2, - const float* inp3, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream) -{ - dim3 grid_dim(batch_size * seq_length); - - dim3 block_dim(hidden_size / 4); - - fused_add3_kernel<<>>( - out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4); -} - -template <> -void launch_fused_add3<__half>(__half* out, - const __half* inp1, - const __half* inp2, - const __half* inp3, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream) -{ - dim3 grid_dim(batch_size * seq_length); - - dim3 block_dim(hidden_size / 4); - - fused_add3_kernel<<>>( - out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4); -} - -__global__ void fused_add4_kernel(float* out, - const float* inp1, - const float* inp2, - const float* inp3, - const float* inp4, - int size, - int row_stride) -{ - int row = blockIdx.x; - int id = threadIdx.x; - - const float4* inp1_4 = reinterpret_cast(inp1); - const float4* inp2_4 = reinterpret_cast(inp2); - const float4* inp3_4 = reinterpret_cast(inp3); - const float4* inp4_4 = reinterpret_cast(inp4); - float4* out_4 = reinterpret_cast(out); - - float4 val; - float4 inp1_reg = inp1_4[row * row_stride + id]; - float4 inp2_reg = inp2_4[row * row_stride + id]; - float4 inp3_reg = inp3_4[row * row_stride + id]; - float4 inp4_reg = inp4_4[row * row_stride + id]; - - val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x; - val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y; - val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z; - val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w; - - out_4[row * row_stride + id] = val; -} - -__global__ void fused_add4_kernel(__half* out, - const __half* inp1, - const __half* inp2, - const __half* inp3, - const __half* inp4, - int size, - int row_stride) -{ - int row = blockIdx.x; - int id = threadIdx.x; - const float2* inp1_arr = reinterpret_cast(inp1); - const float2* inp2_arr = reinterpret_cast(inp2); - const float2* inp3_arr = reinterpret_cast(inp3); - const float2* inp4_arr = reinterpret_cast(inp4); - - float2 inp1_4 = inp1_arr[row * row_stride + id]; - float2 inp2_4 = inp2_arr[row * row_stride + id]; - float2 inp3_4 = inp3_arr[row * row_stride + id]; - float2 inp4_4 = inp4_arr[row * row_stride + id]; - - __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); - __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); - __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4); - __half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4); - - float2 inp1_h_f_0 = __half22float2(inp1_h[0]); - float2 inp1_h_f_1 = __half22float2(inp1_h[1]); - - float2 inp2_h_f_0 = __half22float2(inp2_h[0]); - float2 inp2_h_f_1 = __half22float2(inp2_h[1]); - - float2 inp3_h_f_0 = __half22float2(inp3_h[0]); - float2 inp3_h_f_1 = __half22float2(inp3_h[1]); - - float2 inp4_h_f_0 = __half22float2(inp4_h[0]); - float2 inp4_h_f_1 = __half22float2(inp4_h[1]); - - inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x); - inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y); - inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x); - inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y); - - float2 val_f; - __half2* val_h = reinterpret_cast<__half2*>(&val_f); - - val_h[0] = __float22half2_rn(inp1_h_f_0); - val_h[1] = __float22half2_rn(inp1_h_f_1); - - float2* out_4 = reinterpret_cast(out); - out_4[row * row_stride + id] = val_f; -} - -template <> -void launch_fused_add4(float* out, - const float* inp1, - const float* inp2, - const float* inp3, - const float* inp4, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream) -{ - dim3 grid_dim(batch_size * seq_length); - - dim3 block_dim(hidden_size / 4); - - fused_add4_kernel<<>>( - out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4); -} - -template <> -void launch_fused_add4<__half>(__half* out, - const __half* inp1, - const __half* inp2, - const __half* inp3, - const __half* inp4, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream) -{ - dim3 grid_dim(batch_size * seq_length); - - dim3 block_dim(hidden_size / 4); - - fused_add4_kernel<<>>( - out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4); -} +#include "general_kernels.h" + +namespace cg = cooperative_groups; + +template +__global__ void column_sum_reduce(const T* __restrict__ inp, + T* __restrict__ out, + int rows, + int width) +{ + __shared__ float tile[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + + int y_stride = width * TILE_DIM; + + float localSum = 0; + + // Loop across matrix height + if (idx < width) { + int offset = threadIdx.y * width + idx; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + localSum += (float)inp[offset]; + offset += y_stride; + } + } + + tile[threadIdx.x][threadIdx.y] = localSum; + + __syncthreads(); + + // Sum the shared buffer. + float sum = tile[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) { + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + if (pos < width) out[pos] = sum; + } +} + +template +void launch_fuse_transpose_bias_kernel(const T* inp, + T* out, + int rows, + int cols, + cudaStream_t stream); + +template <> +void launch_fuse_transpose_bias_kernel(const float* inp, + float* out, + int rows, + int cols, + cudaStream_t stream) +{ + // assert(rows % TILE_DIM == 0); + // assert(cols % TILE_DIM == 0); + + dim3 grid_dim((cols - 1) / TILE_DIM + 1); + dim3 block_dim(TILE_DIM, TILE_DIM); + + column_sum_reduce<<>>(inp, out, rows, cols); +} + +template <> +void launch_fuse_transpose_bias_kernel<__half>(const __half* inp, + __half* out, + int rows, + int cols, + cudaStream_t stream) +{ + // assert(rows % TILE_DIM == 0); + // assert(cols % TILE_DIM == 0); + + dim3 grid_dim((cols - 1) / TILE_DIM + 1); + dim3 block_dim(TILE_DIM, TILE_DIM); + + column_sum_reduce<__half><<>>(inp, out, rows, cols); +} + +__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2) +{ + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + float4* out_4 = reinterpret_cast(out); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 val; + float4 inp1_reg = inp1_4[j]; + float4 inp2_reg = inp2_4[j]; + + val.x = inp1_reg.x + inp2_reg.x; + val.y = inp1_reg.y + inp2_reg.y; + val.z = inp1_reg.z + inp2_reg.z; + val.w = inp1_reg.w + inp2_reg.w; + + out_4[j] = val; + } +} + +__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2) +{ + float2 inp1_4; + float2 inp2_4; + + __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); + __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); + + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + + CUDA_1D_KERNEL_LOOP(j, N) + { + inp1_4 = inp1_arr[j]; + inp2_4 = inp2_arr[j]; + + float2 inp1_h_f_0 = __half22float2(inp1_h[0]); + float2 inp1_h_f_1 = __half22float2(inp1_h[1]); + + float2 inp2_h_f_0 = __half22float2(inp2_h[0]); + float2 inp2_h_f_1 = __half22float2(inp2_h[1]); + + inp1_h_f_0.x += inp2_h_f_0.x; + inp1_h_f_0.y += inp2_h_f_0.y; + inp1_h_f_1.x += inp2_h_f_1.x; + inp1_h_f_1.y += inp2_h_f_1.y; + + float2 val_f; + __half2* val_h = reinterpret_cast<__half2*>(&val_f); + + val_h[0] = __float22half2_rn(inp1_h_f_0); + val_h[1] = __float22half2_rn(inp1_h_f_1); + + float2* out_4 = reinterpret_cast(out); + out_4[j] = val_f; + } +} + +template <> +void launch_fused_add2(float* out, + const float* inp1, + const float* inp2, + int batch_size, + int seq_length, + int hidden_dim, + cudaStream_t& stream) +{ + int total_count = batch_size * seq_length * hidden_dim / 4; + dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); + + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + + fused_add2_kernel<<>>(total_count, out, inp1, inp2); +} + +template <> +void launch_fused_add2<__half>(__half* out, + const __half* inp1, + const __half* inp2, + int batch_size, + int seq_length, + int hidden_dim, + cudaStream_t& stream) +{ + int total_count = batch_size * seq_length * hidden_dim / 4; + dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); + + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + + fused_add2_kernel<<>>(total_count, out, inp1, inp2); +} + +__global__ void fused_add3_kernel(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + const float4* inp3_4 = reinterpret_cast(inp3); + + float4* out_4 = reinterpret_cast(out); + + float4 val; + float4 inp1_reg = inp1_4[row * row_stride + id]; + float4 inp2_reg = inp2_4[row * row_stride + id]; + float4 inp3_reg = inp3_4[row * row_stride + id]; + + val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x; + val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y; + val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z; + val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w; + + out_4[row * row_stride + id] = val; +} + +__global__ void fused_add3_kernel(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + const float2* inp3_arr = reinterpret_cast(inp3); + + float2 inp1_4 = inp1_arr[row * row_stride + id]; + float2 inp2_4 = inp2_arr[row * row_stride + id]; + float2 inp3_4 = inp3_arr[row * row_stride + id]; + + __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); + __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); + __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4); + + float2 inp1_h_f_0 = __half22float2(inp1_h[0]); + float2 inp1_h_f_1 = __half22float2(inp1_h[1]); + + float2 inp2_h_f_0 = __half22float2(inp2_h[0]); + float2 inp2_h_f_1 = __half22float2(inp2_h[1]); + + float2 inp3_h_f_0 = __half22float2(inp3_h[0]); + float2 inp3_h_f_1 = __half22float2(inp3_h[1]); + + inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x); + inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y); + inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x); + inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y); + + float2 val_f; + __half2* val_h = reinterpret_cast<__half2*>(&val_f); + + val_h[0] = __float22half2_rn(inp1_h_f_0); + val_h[1] = __float22half2_rn(inp1_h_f_1); + + float2* out_4 = reinterpret_cast(out); + out_4[row * row_stride + id] = val_f; +} + +template <> +void launch_fused_add3(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add3_kernel<<>>( + out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4); +} + +template <> +void launch_fused_add3<__half>(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add3_kernel<<>>( + out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4); +} + +__global__ void fused_add4_kernel(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + const float* inp4, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + const float4* inp3_4 = reinterpret_cast(inp3); + const float4* inp4_4 = reinterpret_cast(inp4); + float4* out_4 = reinterpret_cast(out); + + float4 val; + float4 inp1_reg = inp1_4[row * row_stride + id]; + float4 inp2_reg = inp2_4[row * row_stride + id]; + float4 inp3_reg = inp3_4[row * row_stride + id]; + float4 inp4_reg = inp4_4[row * row_stride + id]; + + val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x; + val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y; + val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z; + val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w; + + out_4[row * row_stride + id] = val; +} + +__global__ void fused_add4_kernel(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + const __half* inp4, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + const float2* inp3_arr = reinterpret_cast(inp3); + const float2* inp4_arr = reinterpret_cast(inp4); + + float2 inp1_4 = inp1_arr[row * row_stride + id]; + float2 inp2_4 = inp2_arr[row * row_stride + id]; + float2 inp3_4 = inp3_arr[row * row_stride + id]; + float2 inp4_4 = inp4_arr[row * row_stride + id]; + + __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); + __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); + __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4); + __half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4); + + float2 inp1_h_f_0 = __half22float2(inp1_h[0]); + float2 inp1_h_f_1 = __half22float2(inp1_h[1]); + + float2 inp2_h_f_0 = __half22float2(inp2_h[0]); + float2 inp2_h_f_1 = __half22float2(inp2_h[1]); + + float2 inp3_h_f_0 = __half22float2(inp3_h[0]); + float2 inp3_h_f_1 = __half22float2(inp3_h[1]); + + float2 inp4_h_f_0 = __half22float2(inp4_h[0]); + float2 inp4_h_f_1 = __half22float2(inp4_h[1]); + + inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x); + inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y); + inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x); + inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y); + + float2 val_f; + __half2* val_h = reinterpret_cast<__half2*>(&val_f); + + val_h[0] = __float22half2_rn(inp1_h_f_0); + val_h[1] = __float22half2_rn(inp1_h_f_1); + + float2* out_4 = reinterpret_cast(out); + out_4[row * row_stride + id] = val_f; +} + +template <> +void launch_fused_add4(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + const float* inp4, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add4_kernel<<>>( + out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4); +} + +template <> +void launch_fused_add4<__half>(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + const __half* inp4, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add4_kernel<<>>( + out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4); +} diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu index ddf7a9588..0fc15d0fb 100644 --- a/csrc/transformer/inference/csrc/dequantize.cu +++ b/csrc/transformer/inference/csrc/dequantize.cu @@ -1,110 +1,110 @@ -#include "custom_cuda_layers.h" - -#define MAX_QUANTIZE_GROUPING 1024 - -#define loop_unroll 1 -#define loop_unroll_bits 1 - -__global__ void dequantize_kernel(float* output, - const int8_t* input, - const float* qscale, - int output_size, - int hidden_dim, - int groups, - int merge_count) -{ - unsigned merge_hidden = hidden_dim >> merge_count; - unsigned quantization_stride = (merge_hidden * output_size) / groups; - - unsigned bid = blockIdx.x; - unsigned tid = threadIdx.x; - - while (tid < output_size) { - unsigned w_index = bid / merge_hidden; - unsigned q_index = tid + bid * output_size; - - auto q = input[q_index]; - - unsigned merge_hidden_total = w_index * merge_hidden; - unsigned scale_index = - ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) - << merge_count) + - w_index; - - float scale_data = qscale[scale_index]; - - output[q_index] = (scale_data * (float)q); - tid += blockDim.x; - } -} - -__global__ void dequantize_kernel(__half* output, - const int8_t* input, - const float* qscale, - unsigned output_size, - unsigned hidden_dim, - unsigned groups, - unsigned merge_count) -{ -#if __CUDA_ARCH__ >= 700 - - unsigned merge_hidden = hidden_dim >> merge_count; - unsigned quantization_stride = (merge_hidden * output_size) / groups; - - unsigned bid = blockIdx.x; - unsigned tid = threadIdx.x; - - while (tid < output_size) { - unsigned w_index = bid / merge_hidden; - unsigned q_index = tid + bid * output_size; - - auto q = input[q_index]; - - unsigned merge_hidden_total = w_index * merge_hidden; - unsigned scale_index = - ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) - << merge_count) + - w_index; - - float scale_data = qscale[scale_index]; - - output[q_index] = __float2half(scale_data * (float)q); - tid += blockDim.x; - } -#endif -} - -template -void launch_dequantize(T* output, - const int8_t* input, - const float* qscale, - unsigned output_size, - unsigned hidden_dim, - unsigned groups, - unsigned merge_count, - cudaStream_t stream) -{ - unsigned threads = 1024; - dim3 block_dims(threads); - dim3 grid_dims(hidden_dim); - - dequantize_kernel<<>>( - output, input, qscale, output_size, hidden_dim, groups, merge_count); -} - -template void launch_dequantize(float*, - const int8_t*, - const float*, - unsigned, - unsigned, - unsigned, - unsigned, - cudaStream_t); -template void launch_dequantize<__half>(__half*, - const int8_t*, - const float*, - unsigned, - unsigned, - unsigned, - unsigned, - cudaStream_t); +#include "custom_cuda_layers.h" + +#define MAX_QUANTIZE_GROUPING 1024 + +#define loop_unroll 1 +#define loop_unroll_bits 1 + +__global__ void dequantize_kernel(float* output, + const int8_t* input, + const float* qscale, + int output_size, + int hidden_dim, + int groups, + int merge_count) +{ + unsigned merge_hidden = hidden_dim >> merge_count; + unsigned quantization_stride = (merge_hidden * output_size) / groups; + + unsigned bid = blockIdx.x; + unsigned tid = threadIdx.x; + + while (tid < output_size) { + unsigned w_index = bid / merge_hidden; + unsigned q_index = tid + bid * output_size; + + auto q = input[q_index]; + + unsigned merge_hidden_total = w_index * merge_hidden; + unsigned scale_index = + ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) + << merge_count) + + w_index; + + float scale_data = qscale[scale_index]; + + output[q_index] = (scale_data * (float)q); + tid += blockDim.x; + } +} + +__global__ void dequantize_kernel(__half* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + unsigned merge_count) +{ +#if __CUDA_ARCH__ >= 700 + + unsigned merge_hidden = hidden_dim >> merge_count; + unsigned quantization_stride = (merge_hidden * output_size) / groups; + + unsigned bid = blockIdx.x; + unsigned tid = threadIdx.x; + + while (tid < output_size) { + unsigned w_index = bid / merge_hidden; + unsigned q_index = tid + bid * output_size; + + auto q = input[q_index]; + + unsigned merge_hidden_total = w_index * merge_hidden; + unsigned scale_index = + ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) + << merge_count) + + w_index; + + float scale_data = qscale[scale_index]; + + output[q_index] = __float2half(scale_data * (float)q); + tid += blockDim.x; + } +#endif +} + +template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + unsigned merge_count, + cudaStream_t stream) +{ + unsigned threads = 1024; + dim3 block_dims(threads); + dim3 grid_dims(hidden_dim); + + dequantize_kernel<<>>( + output, input, qscale, output_size, hidden_dim, groups, merge_count); +} + +template void launch_dequantize(float*, + const int8_t*, + const float*, + unsigned, + unsigned, + unsigned, + unsigned, + cudaStream_t); +template void launch_dequantize<__half>(__half*, + const int8_t*, + const float*, + unsigned, + unsigned, + unsigned, + unsigned, + cudaStream_t); diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index fc3faacc5..10adaa6fe 100755 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -1,266 +1,266 @@ -#include "custom_cuda_layers.h" - -inline __device__ float gelu(const float x) -{ - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); -} - -__global__ void fused_bias_gelu(float* input, - const float* bias, - int total_count, - int intermediate_size) -{ - float4* input_cast = reinterpret_cast(input); - const float4* bias_cast = reinterpret_cast(bias); - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float4 data = input_cast[offset]; - float4 bias_data = bias_cast[offset % intermediate_size]; - - data.x += bias_data.x; - data.y += bias_data.y; - data.z += bias_data.z; - data.w += bias_data.w; - - data.x = gelu(data.x); - data.y = gelu(data.y); - data.z = gelu(data.z); - data.w = gelu(data.w); - - input_cast[offset] = data; - } -} - -__global__ void fused_bias_gelu(__half* input, - const __half* bias, - int total_count, - int intermediate_size) -{ -#if __CUDA_ARCH__ >= 700 - - float2* input_cast = reinterpret_cast(input); - const float2* bias_cast = reinterpret_cast(bias); - - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float2 vals_vec = input_cast[offset]; - float2 bias_vec = bias_cast[offset % intermediate_size]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - - low_data.x += low_bias.x; - low_data.y += low_bias.y; - high_data.x += high_bias.x; - high_data.y += high_bias.y; - - low_data.x = gelu(low_data.x); - low_data.y = gelu(low_data.y); - high_data.x = gelu(high_data.x); - high_data.y = gelu(high_data.y); - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - input_cast[offset] = vals_vec; - } -#endif -} - -template -void launch_bias_gelu(T* input, - const T* bias, - int intermediate_size, - int batch_size, - cudaStream_t stream) -{ - int total_count = batch_size * (intermediate_size / 4); - int threads = 1024; // intermediate_size / iterations / 4; - dim3 block_dims(threads); - dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size); - - fused_bias_gelu<<>>( - input, bias, total_count, intermediate_size / 4); -} - -template void launch_bias_gelu(float*, const float*, int, int, cudaStream_t); -template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t); - -__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size) -{ - float4* input_cast = reinterpret_cast(input); - const float4* bias_cast = reinterpret_cast(bias); - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float4 data = input_cast[offset]; - float4 bias_data = bias_cast[offset % hidden_size]; - - data.x += bias_data.x; - data.y += bias_data.y; - data.z += bias_data.z; - data.w += bias_data.w; - - input_cast[offset] = data; - } -} - -__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size) -{ -#if __CUDA_ARCH__ >= 700 - - float2* input_cast = reinterpret_cast(input); - const float2* bias_cast = reinterpret_cast(bias); - - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float2 vals_vec = input_cast[offset]; - float2 bias_vec = bias_cast[offset % hidden_size]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - - low_data.x += low_bias.x; - low_data.y += low_bias.y; - high_data.x += high_bias.x; - high_data.y += high_bias.y; - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - input_cast[offset] = vals_vec; - } -#endif -} - -template -void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream) -{ - int total_count = batch_size * (hidden_size / 4); - int threads = 1024; // hidden_size / iterations / 4; - dim3 block_dims(threads); - dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size); - - fused_bias_add<<>>(input, bias, total_count, hidden_size / 4); -} - -template void launch_bias_add(float*, const float*, int, int, cudaStream_t); -template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t); - -__global__ void fused_bias_residual(float* input, - const float* residual, - const float* bias, - int total_count, - int intermediate_size) -{ - float4* input_cast = reinterpret_cast(input); - const float4* residual_cast = reinterpret_cast(residual); - const float4* bias_cast = reinterpret_cast(bias); - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float4 data = input_cast[offset]; - float4 res_vec = residual_cast[offset]; - float4 bias_data = bias_cast[offset % intermediate_size]; - - data.x += (res_vec.x + bias_data.x); - data.y += (res_vec.y + bias_data.y); - data.z += (res_vec.z + bias_data.z); - data.w += (res_vec.w + bias_data.w); - - input_cast[offset] = data; - } -} - -__global__ void fused_bias_residual(__half* input, - const __half* residual, - const __half* bias, - int total_count, - int intermediate_size) -{ -#if __CUDA_ARCH__ >= 700 - - float2* input_cast = reinterpret_cast(input); - const float2* residual_cast = reinterpret_cast(residual); - - const float2* bias_cast = reinterpret_cast(bias); - - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float2 vals_vec = input_cast[offset]; - float2 res_vec = residual_cast[offset]; - - float2 bias_vec = bias_cast[offset % intermediate_size]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - __half2* res_half = reinterpret_cast<__half2*>(&res_vec); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - float2 low_res = __half22float2(res_half[0]); - float2 high_res = __half22float2(res_half[1]); - - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - - low_data.x += (low_res.x + low_bias.x); - low_data.y += (low_res.y + low_bias.y); - high_data.x += (high_res.x + high_bias.x); - high_data.y += (high_res.y + high_bias.y); - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - input_cast[offset] = vals_vec; - } -#endif -} - -template -void launch_bias_residual(T* input, - const T* residual, - const T* bias, - int batch, - int intermediate_size, - cudaStream_t stream) -{ - int total_count = batch * intermediate_size / 4; - dim3 block_dims(1024); - dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size); - - fused_bias_residual<<>>( - input, residual, bias, total_count, intermediate_size / 4); -} - -template void launch_bias_residual(float*, - const float*, - const float*, - int, - int, - cudaStream_t); -template void launch_bias_residual<__half>(__half*, - const __half*, - const __half*, - int, - int, - cudaStream_t); +#include "custom_cuda_layers.h" + +inline __device__ float gelu(const float x) +{ + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); +} + +__global__ void fused_bias_gelu(float* input, + const float* bias, + int total_count, + int intermediate_size) +{ + float4* input_cast = reinterpret_cast(input); + const float4* bias_cast = reinterpret_cast(bias); + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float4 data = input_cast[offset]; + float4 bias_data = bias_cast[offset % intermediate_size]; + + data.x += bias_data.x; + data.y += bias_data.y; + data.z += bias_data.z; + data.w += bias_data.w; + + data.x = gelu(data.x); + data.y = gelu(data.y); + data.z = gelu(data.z); + data.w = gelu(data.w); + + input_cast[offset] = data; + } +} + +__global__ void fused_bias_gelu(__half* input, + const __half* bias, + int total_count, + int intermediate_size) +{ +#if __CUDA_ARCH__ >= 700 + + float2* input_cast = reinterpret_cast(input); + const float2* bias_cast = reinterpret_cast(bias); + + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float2 vals_vec = input_cast[offset]; + float2 bias_vec = bias_cast[offset % intermediate_size]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += low_bias.x; + low_data.y += low_bias.y; + high_data.x += high_bias.x; + high_data.y += high_bias.y; + + low_data.x = gelu(low_data.x); + low_data.y = gelu(low_data.y); + high_data.x = gelu(high_data.x); + high_data.y = gelu(high_data.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + input_cast[offset] = vals_vec; + } +#endif +} + +template +void launch_bias_gelu(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int total_count = batch_size * (intermediate_size / 4); + int threads = 1024; // intermediate_size / iterations / 4; + dim3 block_dims(threads); + dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size); + + fused_bias_gelu<<>>( + input, bias, total_count, intermediate_size / 4); +} + +template void launch_bias_gelu(float*, const float*, int, int, cudaStream_t); +template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t); + +__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size) +{ + float4* input_cast = reinterpret_cast(input); + const float4* bias_cast = reinterpret_cast(bias); + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float4 data = input_cast[offset]; + float4 bias_data = bias_cast[offset % hidden_size]; + + data.x += bias_data.x; + data.y += bias_data.y; + data.z += bias_data.z; + data.w += bias_data.w; + + input_cast[offset] = data; + } +} + +__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size) +{ +#if __CUDA_ARCH__ >= 700 + + float2* input_cast = reinterpret_cast(input); + const float2* bias_cast = reinterpret_cast(bias); + + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float2 vals_vec = input_cast[offset]; + float2 bias_vec = bias_cast[offset % hidden_size]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += low_bias.x; + low_data.y += low_bias.y; + high_data.x += high_bias.x; + high_data.y += high_bias.y; + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + input_cast[offset] = vals_vec; + } +#endif +} + +template +void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream) +{ + int total_count = batch_size * (hidden_size / 4); + int threads = 1024; // hidden_size / iterations / 4; + dim3 block_dims(threads); + dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size); + + fused_bias_add<<>>(input, bias, total_count, hidden_size / 4); +} + +template void launch_bias_add(float*, const float*, int, int, cudaStream_t); +template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t); + +__global__ void fused_bias_residual(float* input, + const float* residual, + const float* bias, + int total_count, + int intermediate_size) +{ + float4* input_cast = reinterpret_cast(input); + const float4* residual_cast = reinterpret_cast(residual); + const float4* bias_cast = reinterpret_cast(bias); + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float4 data = input_cast[offset]; + float4 res_vec = residual_cast[offset]; + float4 bias_data = bias_cast[offset % intermediate_size]; + + data.x += (res_vec.x + bias_data.x); + data.y += (res_vec.y + bias_data.y); + data.z += (res_vec.z + bias_data.z); + data.w += (res_vec.w + bias_data.w); + + input_cast[offset] = data; + } +} + +__global__ void fused_bias_residual(__half* input, + const __half* residual, + const __half* bias, + int total_count, + int intermediate_size) +{ +#if __CUDA_ARCH__ >= 700 + + float2* input_cast = reinterpret_cast(input); + const float2* residual_cast = reinterpret_cast(residual); + + const float2* bias_cast = reinterpret_cast(bias); + + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float2 vals_vec = input_cast[offset]; + float2 res_vec = residual_cast[offset]; + + float2 bias_vec = bias_cast[offset % intermediate_size]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* res_half = reinterpret_cast<__half2*>(&res_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + float2 low_res = __half22float2(res_half[0]); + float2 high_res = __half22float2(res_half[1]); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += (low_res.x + low_bias.x); + low_data.y += (low_res.y + low_bias.y); + high_data.x += (high_res.x + high_bias.x); + high_data.y += (high_res.y + high_bias.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + input_cast[offset] = vals_vec; + } +#endif +} + +template +void launch_bias_residual(T* input, + const T* residual, + const T* bias, + int batch, + int intermediate_size, + cudaStream_t stream) +{ + int total_count = batch * intermediate_size / 4; + dim3 block_dims(1024); + dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size); + + fused_bias_residual<<>>( + input, residual, bias, total_count, intermediate_size / 4); +} + +template void launch_bias_residual(float*, + const float*, + const float*, + int, + int, + cudaStream_t); +template void launch_bias_residual<__half>(__half*, + const __half*, + const __half*, + int, + int, + cudaStream_t); diff --git a/csrc/transformer/inference/csrc/normalize.cu b/csrc/transformer/inference/csrc/normalize.cu index ecd73154f..23cceff37 100755 --- a/csrc/transformer/inference/csrc/normalize.cu +++ b/csrc/transformer/inference/csrc/normalize.cu @@ -1,426 +1,426 @@ -#include -#include "custom_cuda_layers.h" - -#include -#include -#include -#include - -#define NORM_REG (MAX_REGISTERS) - -namespace cg = cooperative_groups; - -__global__ void fused_bias_residual_layer_norm(float* output, - const float* vals, - const float* gamma, - const float* beta, - float epsilon, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - int warp_num = iteration_stride >> 5; - - float inp_reg[NORM_REG]; - - int k = 0; - float sum = 0; - int input_id = id; - while (input_id < row_stride) { - inp_reg[k] = vals[input_id + row * row_stride]; - sum += inp_reg[k++]; - input_id += iteration_stride; - } - - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - - __shared__ float shr[MAX_WARP_NUM]; - - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - - float mean = sum / (row_stride); - sum = 0.f; - for (int f = 0; f < k; f++) { - inp_reg[f] -= mean; - sum += inp_reg[f] * inp_reg[f]; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= (row_stride); - sum += epsilon; - sum = __frsqrt_rn(sum); - for (int f = 0; f < k; f++) { - int out_id = f * iteration_stride + id; - inp_reg[f] = inp_reg[f] * sum; - inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id]; - output[out_id + row * row_stride] = inp_reg[f]; - } -} - -__global__ void fused_bias_residual_layer_norm(__half* output, - const __half* vals, - const __half* gamma, - const __half* beta, - float epsilon, - int row_stride) -{ -#if __CUDA_ARCH__ >= 700 - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - int warp_num = iteration_stride >> 5; - - __half2 inp_reg[NORM_REG]; - - const __half2* vals_cast = reinterpret_cast(vals); - __half2* out_cast = reinterpret_cast<__half2*>(output); - - int k = 0; - int input_id = id; - while (input_id < row_stride) { - inp_reg[k++] = vals_cast[input_id + row * row_stride]; - input_id += iteration_stride; - } - float sum = 0; - for (int f = k - 1; f >= 0; f--) { - float2 inp_f = __half22float2(inp_reg[f]); - sum += inp_f.x + inp_f.y; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - __shared__ float shr[MAX_WARP_NUM]; - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - float mean = sum / (row_stride << 1); - sum = 0.f; - for (int f = 0; f < k; f++) { - float2 inp_f = __half22float2(inp_reg[f]); - inp_f.x -= mean; - inp_f.y -= mean; - inp_reg[f] = __float22half2_rn(inp_f); - sum += inp_f.x * inp_f.x; - sum += inp_f.y * inp_f.y; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= (row_stride << 1); - sum += epsilon; - sum = __frsqrt_rn(sum); - __half2 variance_h = __float2half2_rn(sum); - const __half2* gamma_cast = reinterpret_cast(gamma); - const __half2* beta_cast = reinterpret_cast(beta); - for (int f = 0; f < k; f++) { - int out_id = f * iteration_stride + id; - inp_reg[f] = inp_reg[f] * variance_h; - inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id]; - out_cast[out_id + row * row_stride] = inp_reg[f]; - } -#endif -} - -template -void launch_layer_norm(T* out, - T* vals, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream); - -template <> -void launch_layer_norm(float* out, - float* vals, - const float* gamma, - const float* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream) -{ - constexpr int threads = 1024; - - dim3 grid_dim(batch_size); - - dim3 block_dim(threads); - - fused_bias_residual_layer_norm<<>>( - out, vals, gamma, beta, epsilon, hidden_dim); -} - -template <> -void launch_layer_norm<__half>(__half* out, - __half* vals, - const __half* gamma, - const __half* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream) -{ - constexpr int threads = 1024; - - dim3 grid_dim(batch_size); - dim3 block_dim(threads); - - fused_bias_residual_layer_norm<<>>( - out, vals, gamma, beta, epsilon, hidden_dim / 2); -} - -__global__ void fused_residual_layer_norm(float* norm, - float* res_add, - float* vals, - float* residual, - const float* bias, - const float* gamma, - const float* beta, - float epsilon, - int row_stride, - bool preLN) -{ - int iteration_stride = blockDim.x; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - int warp_num = iteration_stride >> 5; - - float inp_reg[NORM_REG]; - - int k = 0; - int input_id = id; - - float sum = 0; - while (input_id < row_stride) { - inp_reg[k] = vals[input_id + row * row_stride]; - float res_f = (residual[input_id + row * row_stride]); - float bias_f = (bias[input_id]); - inp_reg[k] += res_f + bias_f; - if (preLN) res_add[input_id + row * row_stride] = inp_reg[k]; - sum += inp_reg[k++]; - input_id += iteration_stride; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - - __shared__ float shr[MAX_WARP_NUM]; - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - float mean = sum / (row_stride); - sum = 0.f; - for (int f = 0; f < k; f++) { - inp_reg[f] -= mean; - sum += inp_reg[f] * inp_reg[f]; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= (row_stride); - sum += epsilon; - sum = __frsqrt_rn(sum); - - for (int f = 0; f < k; f++) { - int out_id = f * iteration_stride + id; - inp_reg[f] = inp_reg[f] * sum; - inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id]; - norm[out_id + row * row_stride] = inp_reg[f]; - } -} - -__global__ void fused_residual_layer_norm(__half* norm, - __half* res_add, - __half* vals, - __half* residual, - const __half* bias, - const __half* gamma, - const __half* beta, - float epsilon, - int row_stride, - bool preLN) -{ -#if __CUDA_ARCH__ >= 700 - int iteration_stride = blockDim.x; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - int warp_num = iteration_stride >> 5; - - __half2 inp_reg[NORM_REG]; - - __half2* vals_cast = reinterpret_cast<__half2*>(vals); - __half2* norm_cast = reinterpret_cast<__half2*>(norm); - __half2* res_add_cast = reinterpret_cast<__half2*>(res_add); - __half2* residual_cast = reinterpret_cast<__half2*>(residual); - const __half2* bias_cast = reinterpret_cast(bias); - - int k = 0; - int input_id = id; - - float sum = 0; - while (input_id < row_stride) { - inp_reg[k] = vals_cast[input_id + row * row_stride]; - float2 inp_f = __half22float2(inp_reg[k]); - float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]); - float2 bias_f = __half22float2(bias_cast[input_id]); - inp_f.x += res_f.x + bias_f.x; - inp_f.y += res_f.y + bias_f.y; - inp_reg[k] = __float22half2_rn(inp_f); - - if (preLN) res_add_cast[input_id + row * row_stride] = inp_reg[k]; - sum += inp_f.x + inp_f.y; - input_id += iteration_stride; - k++; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - __shared__ float shr[MAX_WARP_NUM]; - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - float mean = sum / (row_stride << 1); - sum = 0.f; - for (int f = 0; f < k; f++) { - float2 inp_f = __half22float2(inp_reg[f]); - inp_f.x -= mean; - inp_f.y -= mean; - inp_reg[f] = __float22half2_rn(inp_f); - sum += inp_f.x * inp_f.x; - sum += inp_f.y * inp_f.y; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= (row_stride << 1); - sum += epsilon; - sum = __frsqrt_rn(sum); - __half2 variance_h = __float2half2_rn(sum); - const __half2* gamma_cast = reinterpret_cast(gamma); - const __half2* beta_cast = reinterpret_cast(beta); - for (int f = 0; f < k; f++) { - int out_id = f * iteration_stride + id; - inp_reg[f] = inp_reg[f] * variance_h; - inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id]; - norm_cast[out_id + row * row_stride] = inp_reg[f]; - } -#endif -} - -template -void launch_residual_layer_norm(T* norm, - T* res_add, - T* vals, - T* residual, - const T* bias, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - bool preLN, - cudaStream_t stream); - -template <> -void launch_residual_layer_norm(float* norm, - float* res_add, - float* vals, - float* residual, - const float* bias, - const float* gamma, - const float* beta, - float epsilon, - int batch_size, - int hidden_dim, - bool preLN, - cudaStream_t stream) -{ - constexpr int threads = 1024; - - dim3 grid_dim(batch_size); - - dim3 block_dim(threads); - - fused_residual_layer_norm<<>>( - norm, res_add, vals, residual, bias, gamma, beta, epsilon, hidden_dim, preLN); -} - -template <> -void launch_residual_layer_norm<__half>(__half* norm, - __half* res_add, - __half* vals, - __half* residual, - const __half* bias, - const __half* gamma, - const __half* beta, - float epsilon, - int batch_size, - int hidden_dim, - bool preLN, - cudaStream_t stream) -{ - constexpr int threads = 1024; - - dim3 grid_dim(batch_size); - dim3 block_dim(threads); - - fused_residual_layer_norm<<>>( - norm, res_add, vals, residual, bias, gamma, beta, epsilon, hidden_dim / 2, preLN); -} +#include +#include "custom_cuda_layers.h" + +#include +#include +#include +#include + +#define NORM_REG (MAX_REGISTERS) + +namespace cg = cooperative_groups; + +__global__ void fused_bias_residual_layer_norm(float* output, + const float* vals, + const float* gamma, + const float* beta, + float epsilon, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + int warp_num = iteration_stride >> 5; + + float inp_reg[NORM_REG]; + + int k = 0; + float sum = 0; + int input_id = id; + while (input_id < row_stride) { + inp_reg[k] = vals[input_id + row * row_stride]; + sum += inp_reg[k++]; + input_id += iteration_stride; + } + + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + + __shared__ float shr[MAX_WARP_NUM]; + + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + + float mean = sum / (row_stride); + sum = 0.f; + for (int f = 0; f < k; f++) { + inp_reg[f] -= mean; + sum += inp_reg[f] * inp_reg[f]; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= (row_stride); + sum += epsilon; + sum = __frsqrt_rn(sum); + for (int f = 0; f < k; f++) { + int out_id = f * iteration_stride + id; + inp_reg[f] = inp_reg[f] * sum; + inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id]; + output[out_id + row * row_stride] = inp_reg[f]; + } +} + +__global__ void fused_bias_residual_layer_norm(__half* output, + const __half* vals, + const __half* gamma, + const __half* beta, + float epsilon, + int row_stride) +{ +#if __CUDA_ARCH__ >= 700 + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + int warp_num = iteration_stride >> 5; + + __half2 inp_reg[NORM_REG]; + + const __half2* vals_cast = reinterpret_cast(vals); + __half2* out_cast = reinterpret_cast<__half2*>(output); + + int k = 0; + int input_id = id; + while (input_id < row_stride) { + inp_reg[k++] = vals_cast[input_id + row * row_stride]; + input_id += iteration_stride; + } + float sum = 0; + for (int f = k - 1; f >= 0; f--) { + float2 inp_f = __half22float2(inp_reg[f]); + sum += inp_f.x + inp_f.y; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + __shared__ float shr[MAX_WARP_NUM]; + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + float mean = sum / (row_stride << 1); + sum = 0.f; + for (int f = 0; f < k; f++) { + float2 inp_f = __half22float2(inp_reg[f]); + inp_f.x -= mean; + inp_f.y -= mean; + inp_reg[f] = __float22half2_rn(inp_f); + sum += inp_f.x * inp_f.x; + sum += inp_f.y * inp_f.y; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= (row_stride << 1); + sum += epsilon; + sum = __frsqrt_rn(sum); + __half2 variance_h = __float2half2_rn(sum); + const __half2* gamma_cast = reinterpret_cast(gamma); + const __half2* beta_cast = reinterpret_cast(beta); + for (int f = 0; f < k; f++) { + int out_id = f * iteration_stride + id; + inp_reg[f] = inp_reg[f] * variance_h; + inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id]; + out_cast[out_id + row * row_stride] = inp_reg[f]; + } +#endif +} + +template +void launch_layer_norm(T* out, + T* vals, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream); + +template <> +void launch_layer_norm(float* out, + float* vals, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream) +{ + constexpr int threads = 1024; + + dim3 grid_dim(batch_size); + + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + out, vals, gamma, beta, epsilon, hidden_dim); +} + +template <> +void launch_layer_norm<__half>(__half* out, + __half* vals, + const __half* gamma, + const __half* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream) +{ + constexpr int threads = 1024; + + dim3 grid_dim(batch_size); + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + out, vals, gamma, beta, epsilon, hidden_dim / 2); +} + +__global__ void fused_residual_layer_norm(float* norm, + float* res_add, + float* vals, + float* residual, + const float* bias, + const float* gamma, + const float* beta, + float epsilon, + int row_stride, + bool preLN) +{ + int iteration_stride = blockDim.x; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + int warp_num = iteration_stride >> 5; + + float inp_reg[NORM_REG]; + + int k = 0; + int input_id = id; + + float sum = 0; + while (input_id < row_stride) { + inp_reg[k] = vals[input_id + row * row_stride]; + float res_f = (residual[input_id + row * row_stride]); + float bias_f = (bias[input_id]); + inp_reg[k] += res_f + bias_f; + if (preLN) res_add[input_id + row * row_stride] = inp_reg[k]; + sum += inp_reg[k++]; + input_id += iteration_stride; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + + __shared__ float shr[MAX_WARP_NUM]; + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + float mean = sum / (row_stride); + sum = 0.f; + for (int f = 0; f < k; f++) { + inp_reg[f] -= mean; + sum += inp_reg[f] * inp_reg[f]; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= (row_stride); + sum += epsilon; + sum = __frsqrt_rn(sum); + + for (int f = 0; f < k; f++) { + int out_id = f * iteration_stride + id; + inp_reg[f] = inp_reg[f] * sum; + inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id]; + norm[out_id + row * row_stride] = inp_reg[f]; + } +} + +__global__ void fused_residual_layer_norm(__half* norm, + __half* res_add, + __half* vals, + __half* residual, + const __half* bias, + const __half* gamma, + const __half* beta, + float epsilon, + int row_stride, + bool preLN) +{ +#if __CUDA_ARCH__ >= 700 + int iteration_stride = blockDim.x; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + int warp_num = iteration_stride >> 5; + + __half2 inp_reg[NORM_REG]; + + __half2* vals_cast = reinterpret_cast<__half2*>(vals); + __half2* norm_cast = reinterpret_cast<__half2*>(norm); + __half2* res_add_cast = reinterpret_cast<__half2*>(res_add); + __half2* residual_cast = reinterpret_cast<__half2*>(residual); + const __half2* bias_cast = reinterpret_cast(bias); + + int k = 0; + int input_id = id; + + float sum = 0; + while (input_id < row_stride) { + inp_reg[k] = vals_cast[input_id + row * row_stride]; + float2 inp_f = __half22float2(inp_reg[k]); + float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]); + float2 bias_f = __half22float2(bias_cast[input_id]); + inp_f.x += res_f.x + bias_f.x; + inp_f.y += res_f.y + bias_f.y; + inp_reg[k] = __float22half2_rn(inp_f); + + if (preLN) res_add_cast[input_id + row * row_stride] = inp_reg[k]; + sum += inp_f.x + inp_f.y; + input_id += iteration_stride; + k++; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + __shared__ float shr[MAX_WARP_NUM]; + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + float mean = sum / (row_stride << 1); + sum = 0.f; + for (int f = 0; f < k; f++) { + float2 inp_f = __half22float2(inp_reg[f]); + inp_f.x -= mean; + inp_f.y -= mean; + inp_reg[f] = __float22half2_rn(inp_f); + sum += inp_f.x * inp_f.x; + sum += inp_f.y * inp_f.y; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= (row_stride << 1); + sum += epsilon; + sum = __frsqrt_rn(sum); + __half2 variance_h = __float2half2_rn(sum); + const __half2* gamma_cast = reinterpret_cast(gamma); + const __half2* beta_cast = reinterpret_cast(beta); + for (int f = 0; f < k; f++) { + int out_id = f * iteration_stride + id; + inp_reg[f] = inp_reg[f] * variance_h; + inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id]; + norm_cast[out_id + row * row_stride] = inp_reg[f]; + } +#endif +} + +template +void launch_residual_layer_norm(T* norm, + T* res_add, + T* vals, + T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + bool preLN, + cudaStream_t stream); + +template <> +void launch_residual_layer_norm(float* norm, + float* res_add, + float* vals, + float* residual, + const float* bias, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + bool preLN, + cudaStream_t stream) +{ + constexpr int threads = 1024; + + dim3 grid_dim(batch_size); + + dim3 block_dim(threads); + + fused_residual_layer_norm<<>>( + norm, res_add, vals, residual, bias, gamma, beta, epsilon, hidden_dim, preLN); +} + +template <> +void launch_residual_layer_norm<__half>(__half* norm, + __half* res_add, + __half* vals, + __half* residual, + const __half* bias, + const __half* gamma, + const __half* beta, + float epsilon, + int batch_size, + int hidden_dim, + bool preLN, + cudaStream_t stream) +{ + constexpr int threads = 1024; + + dim3 grid_dim(batch_size); + dim3 block_dim(threads); + + fused_residual_layer_norm<<>>( + norm, res_add, vals, residual, bias, gamma, beta, epsilon, hidden_dim / 2, preLN); +} diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 1ebadaeb5..b587b6233 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1,630 +1,630 @@ - - -#include -#include -#include -#include "context.h" -#include "cublas_wrappers.h" -#include "custom_cuda_layers.h" - -std::array gemm_algos = std::array({99, 99, 99}); - -template -at::Tensor ds_softmax(at::Tensor& attn_scores, - at::Tensor& attn_mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size) -{ - auto attn_scores_c = attn_scores.contiguous(); - int bsz = attn_scores_c.size(0); - int seq_len = attn_scores_c.size(2); - int soft_len = attn_scores_c.size(3); - int heads = attn_scores_c.size(1); - launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(), - (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), - triangular, - recompute, - local_attention, - window_size, - bsz, - heads, - seq_len, - soft_len, - 1.0, - at::cuda::getCurrentCUDAStream()); - - return attn_scores_c; -} - -template -void attention_unfused(at::Tensor& prev_key_cont, - at::Tensor& query_cont, - at::Tensor& attn_mask, - at::Tensor& prev_value_cont, - at::Tensor& output, - int& bsz, - int& seq_len, - int& soft_len, - int& heads, - float& norm_factor, - bool triangular, - bool recompute, - bool local_attention, - int window_size) -{ - auto options = at::TensorOptions() - .dtype(query_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - float alpha = norm_factor; - float gemm_beta = 0.0; - auto attn_score = at::zeros({bsz, heads, seq_len, soft_len}, options); - int k = prev_value_cont.size(2) / heads; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), - soft_len, - seq_len, - k, - &alpha, - &gemm_beta, - (T*)prev_key_cont.data_ptr(), - (T*)query_cont.data_ptr(), - (T*)attn_score.data_ptr(), - CUBLAS_OP_N, - CUBLAS_OP_N, - soft_len * k, - seq_len * k, - seq_len * soft_len, - bsz * heads, - CUBLAS_GEMM_DEFAULT_TENSOR_OP); - attn_score = - ds_softmax(attn_score, attn_mask, triangular, recompute, local_attention, window_size); - alpha = 1.0; - cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), - k, - seq_len, - soft_len, - &alpha, - &gemm_beta, - (T*)prev_value_cont.data_ptr(), - (T*)attn_score.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_OP_N, - CUBLAS_OP_N, - soft_len * k, - seq_len * soft_len, - seq_len * k, - bsz * heads, - CUBLAS_GEMM_DEFAULT_TENSOR_OP); -} - -template -std::vector ds_softmax_context(at::Tensor& query, - at::Tensor& prev_key, - at::Tensor& new_key, - at::Tensor& attn_mask, - at::Tensor& prev_value, - at::Tensor& new_value, - int heads, - float norm_factor, - bool merging, - bool triangular, - bool local_attention, - int window_size, - bool no_masking) -{ - auto query_cont = query.contiguous(); - auto prev_key_cont = prev_key.contiguous(); - auto prev_value_cont = prev_value.contiguous(); - - int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0); - - // Attn_Score [ batch Head Sequence-length Softmax-length] - - int bsz = query_cont.size(0); - int seq_len = query_cont.size(1); - int soft_len = prev_value.size(1); - - auto options = at::TensorOptions() - .dtype(query_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = - at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options); - attention_unfused(prev_key_cont, - query_cont, - attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()), - prev_value_cont, - output, - bsz, - seq_len, - soft_len, - heads, - norm_factor, - (triangular && (new_size == 0)), - (new_size == 0), - local_attention, - window_size); - - return {output, prev_key, prev_value}; -} - -template -at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) -{ - auto input_cont = input.contiguous(); - - int bsz = input_cont.size(0) * input_cont.size(1); - int intermediate_size = input_cont.size(2); - - launch_bias_gelu((T*)input_cont.data_ptr(), - (T*)bias.data_ptr(), - intermediate_size, - bsz, - Context::Instance().GetCurrentStream()); - return input_cont; -} - -template -at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias) -{ - auto input_cont = input.contiguous(); - auto residual_cont = residual.contiguous(); - - int bsz = input_cont.size(0) * input_cont.size(1); - - launch_bias_residual((T*)input_cont.data_ptr(), - (T*)residual_cont.data_ptr(), - (T*)bias.data_ptr(), - bsz, - input_cont.size(2), - Context::Instance().GetCurrentStream()); - return input_cont; -} - -template -at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& betta, float epsilon) -{ - int bsz = input_cont.size(0) * input_cont.size(1); - auto inp_norm = at::empty_like(input_cont); - launch_layer_norm((T*)inp_norm.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)gamma.data_ptr(), - (T*)betta.data_ptr(), - epsilon, - bsz, - input_cont.size(2), - Context::Instance().GetCurrentStream()); - return inp_norm; -} - -template -void qkv_unfused_cublas(at::Tensor& output, - at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool add_bias) -{ - auto inp_norm = ds_layernorm(input, gamma, beta, epsilon); - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - int bsz = input.size(0) * input.size(1); - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)inp_norm.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_GEMM_DEFAULT_TENSOR_OP); - if (add_bias) - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); -} - -template -at::Tensor ds_qkv_gemm(at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool add_bias) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - int bsz = input_cont.size(0) * input_cont.size(1); - qkv_unfused_cublas(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias); - - return output; -} - -template -void quantized_gemm(at::Tensor& output, - at::Tensor& input, - at::Tensor& weight, - at::Tensor& qscale, - int groups, - int merge_count) -{ - int bsz = input.size(0) * input.size(1); - auto options = at::TensorOptions() - .dtype(input.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - auto weight16 = at::empty({weight.size(0), weight.size(1)}, options); - - launch_dequantize((T*)weight16.data_ptr(), - (int8_t*)weight.data_ptr(), - (float*)qscale.data_ptr(), - weight.size(1), - weight.size(0), - groups, - merge_count, - Context::Instance().GetCurrentStream()); - - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight16.data_ptr(), - (T*)input.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_GEMM_DEFAULT_TENSOR_OP); -} - -template -at::Tensor ds_qkv_gemm_int8(at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - at::Tensor& q_scale, - int groups, - bool add_bias) -{ - int bsz = input.size(0) * input.size(1); - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - auto inp_norm = ds_layernorm(input_cont, gamma, beta, epsilon); - - quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); - if (add_bias) - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); - - return output; -} - -template -at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - int bsz = input_cont.size(0) * input_cont.size(1); - - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input_cont.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_GEMM_DEFAULT_TENSOR_OP); - - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); - - return output; -} - -template -at::Tensor ds_linear_layer_int8(at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& q_scale, - int groups) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - int bsz = input_cont.size(0) * input_cont.size(1); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - quantized_gemm(output, input_cont, weight, q_scale, groups, 0); - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); - return output; -} - -template -at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - int bsz = input_cont.size(0) * input_cont.size(1); - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input_cont.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_GEMM_DEFAULT_TENSOR_OP); - - return output; -} - -template -at::Tensor ds_vector_matmul_int8(at::Tensor& input, - at::Tensor& weight, - at::Tensor& q_scale, - int groups, - int merge_count) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - quantized_gemm(output, input_cont, weight, q_scale, groups, merge_count); - return output; -} - -template -void mlp_unfused_cublas(at::Tensor& output, - at::Tensor& residual_add, - at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool preLayerNorm) -{ - int bsz = input.size(0) * input.size(1); - auto inp_norm = preLayerNorm ? at::empty_like(input) : residual_add; - - launch_residual_layer_norm((T*)inp_norm.data_ptr(), - (T*)residual_add.data_ptr(), - (T*)input.data_ptr(), - (T*)residual.data_ptr(), - (T*)input_bias.data_ptr(), - (T*)gamma.data_ptr(), - (T*)beta.data_ptr(), - epsilon, - bsz, - input.size(2), - preLayerNorm, - Context::Instance().GetCurrentStream()); - - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)inp_norm.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_GEMM_DEFAULT_TENSOR_OP); - - launch_bias_gelu((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); -} -template -std::vector ds_mlp_gemm(at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool preLayerNorm) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - auto residual_add = at::empty_like(input_cont); - int bsz = input_cont.size(0) * input_cont.size(1); - - mlp_unfused_cublas(output, - residual_add, - input, - residual, - input_bias, - weight, - bias, - gamma, - beta, - epsilon, - preLayerNorm); - - return {output, residual_add}; -} - -template -std::vector ds_mlp_gemm_int8(at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - at::Tensor& q_scale, - int groups, - bool preLayerNorm) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - int bsz = input_cont.size(0) * input_cont.size(1); - auto inp_norm = at::empty_like(input_cont); - - auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm); - // computing the blocking across K dimension - launch_residual_layer_norm((T*)inp_norm.data_ptr(), - (T*)residual_add.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)residual.data_ptr(), - (T*)input_bias.data_ptr(), - (T*)gamma.data_ptr(), - (T*)beta.data_ptr(), - epsilon, - bsz, - input_cont.size(2), - preLayerNorm, - Context::Instance().GetCurrentStream()); - - quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); - launch_bias_gelu((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); - - return {output, residual_add}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("softmax_fp32", &ds_softmax, "DeepSpeed SoftMax with fp32 (CUDA)"); - m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp32 (CUDA)"); - m.def( - "softmax_context_fp32", &ds_softmax_context, "DeepSpeed attention with fp32 (CUDA)"); - m.def("softmax_context_fp16", - &ds_softmax_context<__half>, - "DeepSpeed attention with fp32 (CUDA)"); - m.def("bias_gelu_fp32", &ds_bias_gelu, "DeepSpeed Gelu with fp32 (CUDA)"); - m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)"); - m.def("bias_residual_fp32", - &ds_bias_residual, - "DeepSpeed residual-bias add with fp32 (CUDA)"); - m.def("bias_residual_fp16", - &ds_bias_residual<__half>, - "DeepSpeed residual-bias add with fp32 (CUDA)"); - m.def("layer_norm_fp32", &ds_layernorm, "DeepSpeed layer-norm with fp32 (CUDA)"); - m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)"); - m.def("qkv_gemm_fp32", &ds_qkv_gemm, "DeepSpeed qkv gemm with fp32 (CUDA)"); - m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)"); - m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)"); - m.def("mlp_gemm_fp32", &ds_mlp_gemm, "DeepSpeed mlp with fp32 (CUDA)"); - m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)"); - m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)"); - m.def("vector_matmul_fp32", &ds_vector_matmul, "DeepSpeed vector-MM with fp32 (CUDA)"); - m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)"); - m.def("vector_matmul_int8", - &ds_vector_matmul_int8<__half>, - "DeepSpeed vector-MM with int8 (CUDA)"); - m.def("linear_layer_fp32", &ds_linear_layer, "DeepSpeed linear_layer with fp32 (CUDA)"); - m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)"); - m.def("linear_layer_int8", - &ds_linear_layer_int8<__half>, - "DeepSpeed linear_layer with int8 (CUDA)"); -} + + +#include +#include +#include +#include "context.h" +#include "cublas_wrappers.h" +#include "custom_cuda_layers.h" + +std::array gemm_algos = std::array({99, 99, 99}); + +template +at::Tensor ds_softmax(at::Tensor& attn_scores, + at::Tensor& attn_mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size) +{ + auto attn_scores_c = attn_scores.contiguous(); + int bsz = attn_scores_c.size(0); + int seq_len = attn_scores_c.size(2); + int soft_len = attn_scores_c.size(3); + int heads = attn_scores_c.size(1); + launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(), + (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), + triangular, + recompute, + local_attention, + window_size, + bsz, + heads, + seq_len, + soft_len, + 1.0, + at::cuda::getCurrentCUDAStream()); + + return attn_scores_c; +} + +template +void attention_unfused(at::Tensor& prev_key_cont, + at::Tensor& query_cont, + at::Tensor& attn_mask, + at::Tensor& prev_value_cont, + at::Tensor& output, + int& bsz, + int& seq_len, + int& soft_len, + int& heads, + float& norm_factor, + bool triangular, + bool recompute, + bool local_attention, + int window_size) +{ + auto options = at::TensorOptions() + .dtype(query_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + float alpha = norm_factor; + float gemm_beta = 0.0; + auto attn_score = at::zeros({bsz, heads, seq_len, soft_len}, options); + int k = prev_value_cont.size(2) / heads; + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), + soft_len, + seq_len, + k, + &alpha, + &gemm_beta, + (T*)prev_key_cont.data_ptr(), + (T*)query_cont.data_ptr(), + (T*)attn_score.data_ptr(), + CUBLAS_OP_N, + CUBLAS_OP_N, + soft_len * k, + seq_len * k, + seq_len * soft_len, + bsz * heads, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + attn_score = + ds_softmax(attn_score, attn_mask, triangular, recompute, local_attention, window_size); + alpha = 1.0; + cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), + k, + seq_len, + soft_len, + &alpha, + &gemm_beta, + (T*)prev_value_cont.data_ptr(), + (T*)attn_score.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_OP_N, + CUBLAS_OP_N, + soft_len * k, + seq_len * soft_len, + seq_len * k, + bsz * heads, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +template +std::vector ds_softmax_context(at::Tensor& query, + at::Tensor& prev_key, + at::Tensor& new_key, + at::Tensor& attn_mask, + at::Tensor& prev_value, + at::Tensor& new_value, + int heads, + float norm_factor, + bool merging, + bool triangular, + bool local_attention, + int window_size, + bool no_masking) +{ + auto query_cont = query.contiguous(); + auto prev_key_cont = prev_key.contiguous(); + auto prev_value_cont = prev_value.contiguous(); + + int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0); + + // Attn_Score [ batch Head Sequence-length Softmax-length] + + int bsz = query_cont.size(0); + int seq_len = query_cont.size(1); + int soft_len = prev_value.size(1); + + auto options = at::TensorOptions() + .dtype(query_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = + at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options); + attention_unfused(prev_key_cont, + query_cont, + attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()), + prev_value_cont, + output, + bsz, + seq_len, + soft_len, + heads, + norm_factor, + (triangular && (new_size == 0)), + (new_size == 0), + local_attention, + window_size); + + return {output, prev_key, prev_value}; +} + +template +at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + int intermediate_size = input_cont.size(2); + + launch_bias_gelu((T*)input_cont.data_ptr(), + (T*)bias.data_ptr(), + intermediate_size, + bsz, + Context::Instance().GetCurrentStream()); + return input_cont; +} + +template +at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + auto residual_cont = residual.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + + launch_bias_residual((T*)input_cont.data_ptr(), + (T*)residual_cont.data_ptr(), + (T*)bias.data_ptr(), + bsz, + input_cont.size(2), + Context::Instance().GetCurrentStream()); + return input_cont; +} + +template +at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& betta, float epsilon) +{ + int bsz = input_cont.size(0) * input_cont.size(1); + auto inp_norm = at::empty_like(input_cont); + launch_layer_norm((T*)inp_norm.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)gamma.data_ptr(), + (T*)betta.data_ptr(), + epsilon, + bsz, + input_cont.size(2), + Context::Instance().GetCurrentStream()); + return inp_norm; +} + +template +void qkv_unfused_cublas(at::Tensor& output, + at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool add_bias) +{ + auto inp_norm = ds_layernorm(input, gamma, beta, epsilon); + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + int bsz = input.size(0) * input.size(1); + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)inp_norm.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + if (add_bias) + launch_bias_add((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); +} + +template +at::Tensor ds_qkv_gemm(at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool add_bias) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + int bsz = input_cont.size(0) * input_cont.size(1); + qkv_unfused_cublas(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias); + + return output; +} + +template +void quantized_gemm(at::Tensor& output, + at::Tensor& input, + at::Tensor& weight, + at::Tensor& qscale, + int groups, + int merge_count) +{ + int bsz = input.size(0) * input.size(1); + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto weight16 = at::empty({weight.size(0), weight.size(1)}, options); + + launch_dequantize((T*)weight16.data_ptr(), + (int8_t*)weight.data_ptr(), + (float*)qscale.data_ptr(), + weight.size(1), + weight.size(0), + groups, + merge_count, + Context::Instance().GetCurrentStream()); + + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight16.data_ptr(), + (T*)input.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +template +at::Tensor ds_qkv_gemm_int8(at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + at::Tensor& q_scale, + int groups, + bool add_bias) +{ + int bsz = input.size(0) * input.size(1); + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + + auto inp_norm = ds_layernorm(input_cont, gamma, beta, epsilon); + + quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); + if (add_bias) + launch_bias_add((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + + return output; +} + +template +at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + int bsz = input_cont.size(0) * input_cont.size(1); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input_cont.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + + launch_bias_add((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + + return output; +} + +template +at::Tensor ds_linear_layer_int8(at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& q_scale, + int groups) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + int bsz = input_cont.size(0) * input_cont.size(1); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + + quantized_gemm(output, input_cont, weight, q_scale, groups, 0); + launch_bias_add((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + return output; +} + +template +at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + int bsz = input_cont.size(0) * input_cont.size(1); + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input_cont.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + + return output; +} + +template +at::Tensor ds_vector_matmul_int8(at::Tensor& input, + at::Tensor& weight, + at::Tensor& q_scale, + int groups, + int merge_count) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + + quantized_gemm(output, input_cont, weight, q_scale, groups, merge_count); + return output; +} + +template +void mlp_unfused_cublas(at::Tensor& output, + at::Tensor& residual_add, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool preLayerNorm) +{ + int bsz = input.size(0) * input.size(1); + auto inp_norm = preLayerNorm ? at::empty_like(input) : residual_add; + + launch_residual_layer_norm((T*)inp_norm.data_ptr(), + (T*)residual_add.data_ptr(), + (T*)input.data_ptr(), + (T*)residual.data_ptr(), + (T*)input_bias.data_ptr(), + (T*)gamma.data_ptr(), + (T*)beta.data_ptr(), + epsilon, + bsz, + input.size(2), + preLayerNorm, + Context::Instance().GetCurrentStream()); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)inp_norm.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + + launch_bias_gelu((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); +} +template +std::vector ds_mlp_gemm(at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool preLayerNorm) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + auto residual_add = at::empty_like(input_cont); + int bsz = input_cont.size(0) * input_cont.size(1); + + mlp_unfused_cublas(output, + residual_add, + input, + residual, + input_bias, + weight, + bias, + gamma, + beta, + epsilon, + preLayerNorm); + + return {output, residual_add}; +} + +template +std::vector ds_mlp_gemm_int8(at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + at::Tensor& q_scale, + int groups, + bool preLayerNorm) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + + int bsz = input_cont.size(0) * input_cont.size(1); + auto inp_norm = at::empty_like(input_cont); + + auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm); + // computing the blocking across K dimension + launch_residual_layer_norm((T*)inp_norm.data_ptr(), + (T*)residual_add.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)residual.data_ptr(), + (T*)input_bias.data_ptr(), + (T*)gamma.data_ptr(), + (T*)beta.data_ptr(), + epsilon, + bsz, + input_cont.size(2), + preLayerNorm, + Context::Instance().GetCurrentStream()); + + quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); + launch_bias_gelu((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + + return {output, residual_add}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("softmax_fp32", &ds_softmax, "DeepSpeed SoftMax with fp32 (CUDA)"); + m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp32 (CUDA)"); + m.def( + "softmax_context_fp32", &ds_softmax_context, "DeepSpeed attention with fp32 (CUDA)"); + m.def("softmax_context_fp16", + &ds_softmax_context<__half>, + "DeepSpeed attention with fp32 (CUDA)"); + m.def("bias_gelu_fp32", &ds_bias_gelu, "DeepSpeed Gelu with fp32 (CUDA)"); + m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)"); + m.def("bias_residual_fp32", + &ds_bias_residual, + "DeepSpeed residual-bias add with fp32 (CUDA)"); + m.def("bias_residual_fp16", + &ds_bias_residual<__half>, + "DeepSpeed residual-bias add with fp32 (CUDA)"); + m.def("layer_norm_fp32", &ds_layernorm, "DeepSpeed layer-norm with fp32 (CUDA)"); + m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)"); + m.def("qkv_gemm_fp32", &ds_qkv_gemm, "DeepSpeed qkv gemm with fp32 (CUDA)"); + m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)"); + m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)"); + m.def("mlp_gemm_fp32", &ds_mlp_gemm, "DeepSpeed mlp with fp32 (CUDA)"); + m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)"); + m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)"); + m.def("vector_matmul_fp32", &ds_vector_matmul, "DeepSpeed vector-MM with fp32 (CUDA)"); + m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)"); + m.def("vector_matmul_int8", + &ds_vector_matmul_int8<__half>, + "DeepSpeed vector-MM with int8 (CUDA)"); + m.def("linear_layer_fp32", &ds_linear_layer, "DeepSpeed linear_layer with fp32 (CUDA)"); + m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)"); + m.def("linear_layer_int8", + &ds_linear_layer_int8<__half>, + "DeepSpeed linear_layer with int8 (CUDA)"); +} diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index 950ae6aea..774e7ce6c 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -1,432 +1,432 @@ -#include -#include "custom_cuda_layers.h" - -#include -#include -#include -#include - -#define ATTN_THREADS 1024 -#define MAX_REG_SIZE 8 - -#define minus_infinity -10000.0 - -void CheckCudaErrorAux(const char* file, unsigned line) -{ - cudaError_t err = cudaGetLastError(); - if (err == cudaSuccess) return; - std::cerr << cudaGetErrorString(err) << "(" << err << ") at " << file << ":" << line - << std::endl; - throw std::runtime_error("CUDA ERROR!!!\n"); -} - -#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__) - -namespace cg = cooperative_groups; - -__global__ void attn_softmax_v2(__half* vals, - __half* mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int total_count, - int heads, - int sequence_length, - int num_seq, - float scale, - int iterations, - int reduceWidth) -{ -#if __CUDA_ARCH__ >= 700 - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - float2 low_data[MAX_REG_SIZE]; - float2 high_data[MAX_REG_SIZE]; - - __half2 h_scale = __float2half2_rn(scale); - - int wid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - int warp_num = blockDim.x >> 5; - - int reduce_blocks = reduceWidth >> 5; - int seq_lane = threadIdx.x % reduceWidth; - - __shared__ float partialSum[MAX_WARP_NUM]; - - int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); - - if (iter_offset < total_count) { - vals += (iter_offset * sequence_length); - - int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); - int seq_id = iter_offset % num_seq; - int seq_id4 = seq_id >> 2; - - int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); - int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) - ? (real_seq_id >> 2) - (window_size >> 2) - : 0; - int window_stride = - (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; - - float max_val = minus_infinity; - - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane << 2); - if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && - data_id < sequence_length) { - if ((sequence_length - data_id) >= 4) { - low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) - : minus_infinity; - low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && - (data_id + 1) > window_stride) - ? __half2float(vals[data_id + 1]) - : minus_infinity; - high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) && - (data_id + 2) > window_stride) - ? __half2float(vals[data_id + 2]) - : minus_infinity; - high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) && - (data_id + 3) > window_stride) - ? __half2float(vals[data_id + 3]) - : minus_infinity; - if (mask && recompute) { - low_data[i].x += __half2float(mask[data_id + mask_offset]); - low_data[i].y += __half2float(mask[data_id + mask_offset + 1]); - high_data[i].x += __half2float(mask[data_id + mask_offset + 2]); - high_data[i].y += __half2float(mask[data_id + mask_offset + 3]); - } - } else { - low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) - : minus_infinity; - low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) && - (data_id + 1) > window_stride) && - (data_id + 1) < sequence_length) - ? __half2float(vals[data_id + 1]) - : minus_infinity; - high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) && - (data_id + 2) > window_stride) && - (data_id + 2) < sequence_length) - ? __half2float(vals[data_id + 2]) - : minus_infinity; - high_data[i].y = minus_infinity; - if (mask && recompute) { - low_data[i].x += __half2float(mask[data_id + mask_offset]); - if ((data_id + 1) < sequence_length) - low_data[i].y += __half2float(mask[data_id + mask_offset + 1]); - if ((data_id + 2) < sequence_length) - high_data[i].x += __half2float(mask[data_id + mask_offset + 2]); - } - } - // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id); - max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); - max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); - max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); - max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); - } else { - low_data[i].x = minus_infinity; - low_data[i].y = minus_infinity; - high_data[i].x = minus_infinity; - high_data[i].y = minus_infinity; - } - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = max_val; - b.sync(); - - if (lane < warp_num) max_val = partialSum[lane]; - - b.sync(); - - for (int i = 1; i < reduce_blocks; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); - } - float sum = 0; - for (int i = 0; i < iterations; i++) { - low_data[i].x = __expf(low_data[i].x - max_val); - low_data[i].y = __expf(low_data[i].y - max_val); - high_data[i].x = __expf(high_data[i].x - max_val); - high_data[i].y = __expf(high_data[i].y - max_val); - - sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); - - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = sum; - b.sync(); - - if (lane < warp_num) sum = partialSum[lane]; - - b.sync(); - - for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } - - sum = g.shfl(sum, threadIdx.x / WARP_SIZE); - } - sum += 1e-6; - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane << 2); - - if (data_id < sequence_length) { - if ((sequence_length - data_id) >= 4) { - vals[data_id] = low_data[i].x / sum; - vals[data_id + 1] = low_data[i].y / sum; - vals[data_id + 2] = high_data[i].x / sum; - vals[data_id + 3] = high_data[i].y / sum; - } else { - vals[data_id] = low_data[i].x / sum; - if ((data_id + 1) < sequence_length) vals[data_id + 1] = low_data[i].y / sum; - if ((data_id + 2) < sequence_length) vals[data_id + 2] = high_data[i].x / sum; - } - } - } - } -#endif -} - -__global__ void attn_softmax_v2(float* vals, - float* attn_mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int total_count, - int heads, - int sequence_length, - int num_seq, - float scale, - int iterations, - int reduceWidth) -{ - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - float4 data[MAX_REG_SIZE]; - - int wid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - int warp_num = blockDim.x >> 5; - - int reduce_blocks = reduceWidth >> 5; - int seq_lane = threadIdx.x % reduceWidth; - - __shared__ float partialSum[MAX_WARP_NUM]; - - int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); - if (iter_offset < total_count) { - vals += (iter_offset * sequence_length); - - int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); - int seq_id = iter_offset % num_seq; - int seq_id4 = seq_id >> 2; - - int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); - int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) - ? (real_seq_id >> 2) - (window_size >> 2) - : 0; - int window_stride = - (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; - - float max_val = minus_infinity; - - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane << 2); - if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && - data_id < sequence_length) { - if ((sequence_length - data_id) >= 4) { - data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity); - data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && - (data_id + 1) > window_stride) - ? vals[data_id + 1] - : minus_infinity; - data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) && - (data_id + 2) > window_stride) - ? vals[data_id + 2] - : minus_infinity; - data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) && - (data_id + 3) > window_stride) - ? vals[data_id + 3] - : minus_infinity; - if (attn_mask && recompute) { - data[i].x += attn_mask[data_id + mask_offset]; - data[i].y += attn_mask[data_id + mask_offset + 1]; - data[i].z += attn_mask[data_id + mask_offset + 2]; - data[i].w += attn_mask[data_id + mask_offset + 3]; - } - } else { - data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity; - data[i].y = (((!triangular || (data_id + 1) <= seq_id)) && - (data_id + 1) > window_stride && (data_id + 1) < sequence_length) - ? (vals[data_id + 1]) - : minus_infinity; - data[i].z = (((!triangular || (data_id + 2) <= seq_id)) && - (data_id + 2) > window_stride && (data_id + 2) < sequence_length) - ? (vals[data_id + 2]) - : minus_infinity; - data[i].w = minus_infinity; - if (attn_mask && recompute) { - data[i].x += attn_mask[data_id + mask_offset]; - if ((data_id + 1) < sequence_length) - data[i].y += attn_mask[data_id + mask_offset + 1]; - if ((data_id + 2) < sequence_length) - data[i].z += attn_mask[data_id + mask_offset + 2]; - } - } - max_val = (data[i].x > max_val ? data[i].x : max_val); - max_val = (data[i].y > max_val ? data[i].y : max_val); - max_val = (data[i].z > max_val ? data[i].z : max_val); - max_val = (data[i].w > max_val ? data[i].w : max_val); - } else { - data[i].x = minus_infinity; - data[i].y = minus_infinity; - data[i].z = minus_infinity; - data[i].w = minus_infinity; - } - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = max_val; - b.sync(); - - if (lane < warp_num) max_val = partialSum[lane]; - - b.sync(); - - for (int i = 1; i < reduce_blocks; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); - } - - float sum = 0; - for (int i = 0; i < iterations; i++) { - data[i].x = __expf(data[i].x - max_val); - data[i].y = __expf(data[i].y - max_val); - data[i].z = __expf(data[i].z - max_val); - data[i].w = __expf(data[i].w - max_val); - - sum += (data[i].x + data[i].y + data[i].z + data[i].w); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); - - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = sum; - b.sync(); - - if (lane < warp_num) sum = partialSum[lane]; - - b.sync(); - - for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } - - sum = g.shfl(sum, threadIdx.x / WARP_SIZE); - } - sum += 1e-6; - - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane << 2); - - if (data_id < sequence_length) { - if ((sequence_length - data_id) >= 4) { - vals[data_id] = data[i].x / sum; - vals[data_id + 1] = data[i].y / sum; - vals[data_id + 2] = data[i].z / sum; - vals[data_id + 3] = data[i].w / sum; - } else { - vals[data_id] = data[i].x / sum; - if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum; - if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum; - } - } - } - } -} - -template -void launch_attn_softmax_v2(T* vals, - T* mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int batch_size, - int heads, - int num_seq, - int sequence_length, - float scale, - cudaStream_t stream) -{ - int total_count = batch_size * heads * num_seq; - dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1); - dim3 block_dim(ATTN_THREADS); - - const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE; - const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1; - - if (sequence_length <= 32768) - attn_softmax_v2<<>>( - vals, - mask, - triangular, - recompute, - local_attention, - window_size, - total_count, - (triangular ? (heads * batch_size) : heads), - sequence_length, - num_seq, - scale, - iterations, - reduce_width); - else - throw std::runtime_error("Unsupport Seq_Length!"); -} - -template void launch_attn_softmax_v2(float* vals, - float* mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int batch_size, - int heads, - int num_seq, - int sequence_length, - float scale, - cudaStream_t stream); -template void launch_attn_softmax_v2(__half* vals, - __half* mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int batch_size, - int heads, - int num_seq, - int sequence_length, - float scale, - cudaStream_t stream); +#include +#include "custom_cuda_layers.h" + +#include +#include +#include +#include + +#define ATTN_THREADS 1024 +#define MAX_REG_SIZE 8 + +#define minus_infinity -10000.0 + +void CheckCudaErrorAux(const char* file, unsigned line) +{ + cudaError_t err = cudaGetLastError(); + if (err == cudaSuccess) return; + std::cerr << cudaGetErrorString(err) << "(" << err << ") at " << file << ":" << line + << std::endl; + throw std::runtime_error("CUDA ERROR!!!\n"); +} + +#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__) + +namespace cg = cooperative_groups; + +__global__ void attn_softmax_v2(__half* vals, + __half* mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int total_count, + int heads, + int sequence_length, + int num_seq, + float scale, + int iterations, + int reduceWidth) +{ +#if __CUDA_ARCH__ >= 700 + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + float2 low_data[MAX_REG_SIZE]; + float2 high_data[MAX_REG_SIZE]; + + __half2 h_scale = __float2half2_rn(scale); + + int wid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + + int reduce_blocks = reduceWidth >> 5; + int seq_lane = threadIdx.x % reduceWidth; + + __shared__ float partialSum[MAX_WARP_NUM]; + + int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); + + if (iter_offset < total_count) { + vals += (iter_offset * sequence_length); + + int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); + int seq_id = iter_offset % num_seq; + int seq_id4 = seq_id >> 2; + + int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); + int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) + ? (real_seq_id >> 2) - (window_size >> 2) + : 0; + int window_stride = + (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); + if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && + data_id < sequence_length) { + if ((sequence_length - data_id) >= 4) { + low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) + : minus_infinity; + low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && + (data_id + 1) > window_stride) + ? __half2float(vals[data_id + 1]) + : minus_infinity; + high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) && + (data_id + 2) > window_stride) + ? __half2float(vals[data_id + 2]) + : minus_infinity; + high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) && + (data_id + 3) > window_stride) + ? __half2float(vals[data_id + 3]) + : minus_infinity; + if (mask && recompute) { + low_data[i].x += __half2float(mask[data_id + mask_offset]); + low_data[i].y += __half2float(mask[data_id + mask_offset + 1]); + high_data[i].x += __half2float(mask[data_id + mask_offset + 2]); + high_data[i].y += __half2float(mask[data_id + mask_offset + 3]); + } + } else { + low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) + : minus_infinity; + low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) && + (data_id + 1) > window_stride) && + (data_id + 1) < sequence_length) + ? __half2float(vals[data_id + 1]) + : minus_infinity; + high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) && + (data_id + 2) > window_stride) && + (data_id + 2) < sequence_length) + ? __half2float(vals[data_id + 2]) + : minus_infinity; + high_data[i].y = minus_infinity; + if (mask && recompute) { + low_data[i].x += __half2float(mask[data_id + mask_offset]); + if ((data_id + 1) < sequence_length) + low_data[i].y += __half2float(mask[data_id + mask_offset + 1]); + if ((data_id + 2) < sequence_length) + high_data[i].x += __half2float(mask[data_id + mask_offset + 2]); + } + } + // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id); + max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); + max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); + max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); + max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); + } else { + low_data[i].x = minus_infinity; + low_data[i].y = minus_infinity; + high_data[i].x = minus_infinity; + high_data[i].y = minus_infinity; + } + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); + } + float sum = 0; + for (int i = 0; i < iterations; i++) { + low_data[i].x = __expf(low_data[i].x - max_val); + low_data[i].y = __expf(low_data[i].y - max_val); + high_data[i].x = __expf(high_data[i].x - max_val); + high_data[i].y = __expf(high_data[i].y - max_val); + + sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / WARP_SIZE); + } + sum += 1e-6; + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); + + if (data_id < sequence_length) { + if ((sequence_length - data_id) >= 4) { + vals[data_id] = low_data[i].x / sum; + vals[data_id + 1] = low_data[i].y / sum; + vals[data_id + 2] = high_data[i].x / sum; + vals[data_id + 3] = high_data[i].y / sum; + } else { + vals[data_id] = low_data[i].x / sum; + if ((data_id + 1) < sequence_length) vals[data_id + 1] = low_data[i].y / sum; + if ((data_id + 2) < sequence_length) vals[data_id + 2] = high_data[i].x / sum; + } + } + } + } +#endif +} + +__global__ void attn_softmax_v2(float* vals, + float* attn_mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int total_count, + int heads, + int sequence_length, + int num_seq, + float scale, + int iterations, + int reduceWidth) +{ + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + float4 data[MAX_REG_SIZE]; + + int wid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + + int reduce_blocks = reduceWidth >> 5; + int seq_lane = threadIdx.x % reduceWidth; + + __shared__ float partialSum[MAX_WARP_NUM]; + + int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); + if (iter_offset < total_count) { + vals += (iter_offset * sequence_length); + + int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); + int seq_id = iter_offset % num_seq; + int seq_id4 = seq_id >> 2; + + int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); + int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) + ? (real_seq_id >> 2) - (window_size >> 2) + : 0; + int window_stride = + (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); + if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && + data_id < sequence_length) { + if ((sequence_length - data_id) >= 4) { + data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity); + data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && + (data_id + 1) > window_stride) + ? vals[data_id + 1] + : minus_infinity; + data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) && + (data_id + 2) > window_stride) + ? vals[data_id + 2] + : minus_infinity; + data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) && + (data_id + 3) > window_stride) + ? vals[data_id + 3] + : minus_infinity; + if (attn_mask && recompute) { + data[i].x += attn_mask[data_id + mask_offset]; + data[i].y += attn_mask[data_id + mask_offset + 1]; + data[i].z += attn_mask[data_id + mask_offset + 2]; + data[i].w += attn_mask[data_id + mask_offset + 3]; + } + } else { + data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity; + data[i].y = (((!triangular || (data_id + 1) <= seq_id)) && + (data_id + 1) > window_stride && (data_id + 1) < sequence_length) + ? (vals[data_id + 1]) + : minus_infinity; + data[i].z = (((!triangular || (data_id + 2) <= seq_id)) && + (data_id + 2) > window_stride && (data_id + 2) < sequence_length) + ? (vals[data_id + 2]) + : minus_infinity; + data[i].w = minus_infinity; + if (attn_mask && recompute) { + data[i].x += attn_mask[data_id + mask_offset]; + if ((data_id + 1) < sequence_length) + data[i].y += attn_mask[data_id + mask_offset + 1]; + if ((data_id + 2) < sequence_length) + data[i].z += attn_mask[data_id + mask_offset + 2]; + } + } + max_val = (data[i].x > max_val ? data[i].x : max_val); + max_val = (data[i].y > max_val ? data[i].y : max_val); + max_val = (data[i].z > max_val ? data[i].z : max_val); + max_val = (data[i].w > max_val ? data[i].w : max_val); + } else { + data[i].x = minus_infinity; + data[i].y = minus_infinity; + data[i].z = minus_infinity; + data[i].w = minus_infinity; + } + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + data[i].x = __expf(data[i].x - max_val); + data[i].y = __expf(data[i].y - max_val); + data[i].z = __expf(data[i].z - max_val); + data[i].w = __expf(data[i].w - max_val); + + sum += (data[i].x + data[i].y + data[i].z + data[i].w); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / WARP_SIZE); + } + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); + + if (data_id < sequence_length) { + if ((sequence_length - data_id) >= 4) { + vals[data_id] = data[i].x / sum; + vals[data_id + 1] = data[i].y / sum; + vals[data_id + 2] = data[i].z / sum; + vals[data_id + 3] = data[i].w / sum; + } else { + vals[data_id] = data[i].x / sum; + if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum; + if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum; + } + } + } + } +} + +template +void launch_attn_softmax_v2(T* vals, + T* mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + float scale, + cudaStream_t stream) +{ + int total_count = batch_size * heads * num_seq; + dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1); + dim3 block_dim(ATTN_THREADS); + + const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE; + const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1; + + if (sequence_length <= 32768) + attn_softmax_v2<<>>( + vals, + mask, + triangular, + recompute, + local_attention, + window_size, + total_count, + (triangular ? (heads * batch_size) : heads), + sequence_length, + num_seq, + scale, + iterations, + reduce_width); + else + throw std::runtime_error("Unsupport Seq_Length!"); +} + +template void launch_attn_softmax_v2(float* vals, + float* mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + float scale, + cudaStream_t stream); +template void launch_attn_softmax_v2(__half* vals, + __half* mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + float scale, + cudaStream_t stream); diff --git a/csrc/transformer/inference/includes/context.h b/csrc/transformer/inference/includes/context.h index 65e464f57..4385bd7d5 100755 --- a/csrc/transformer/inference/includes/context.h +++ b/csrc/transformer/inference/includes/context.h @@ -1,112 +1,112 @@ -#pragma once - -#include -#include -#include -#include -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" - -#define WARP_SIZE 32 - -#define CUDA_CHECK(callstr) \ - { \ - cudaError_t error_code = callstr; \ - if (error_code != cudaSuccess) { \ - std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ - assert(0); \ - } \ - } - -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) - -#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \ - for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y) - -#define DS_CUDA_NUM_THREADS 512 -#define DS_MAXIMUM_NUM_BLOCKS 262144 - -inline int DS_GET_BLOCKS(const int N) -{ - return std::max( - std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS), - // Use at least 1 block, since CUDA does not allow empty block - 1); -} - -class Context { -public: - Context() : _workspace(nullptr), _seed(42), _curr_offset(0) - { - curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); - curandSetPseudoRandomGeneratorSeed(_gen, 123); - if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) { - auto message = std::string("Fail to create cublas handle."); - std::cerr << message << std::endl; - throw std::runtime_error(message); - } - cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); - } - - virtual ~Context() - { - cublasDestroy(_cublasHandle); - cudaFree(_workspace); - } - - static Context& Instance() - { - static Context _ctx; - return _ctx; - } - - void GenWorkSpace(size_t size) - { - if (!_workspace) { - assert(_workspace == nullptr); - cudaMalloc(&_workspace, size); - } else if (_workSpaceSize < size) { - cudaFree(_workspace); - cudaMalloc(&_workspace, size); - } - - _workSpaceSize = size; - } - - void* GetWorkSpace() { return _workspace; } - - curandGenerator_t& GetRandGenerator() { return _gen; } - - cudaStream_t GetCurrentStream() - { - // get current pytorch stream. - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - return stream; - } - - cublasHandle_t GetCublasHandle() { return _cublasHandle; } - - std::pair IncrementOffset(uint64_t offset_inc) - { - uint64_t offset = _curr_offset; - _curr_offset += offset_inc; - return std::pair(_seed, offset); - } - - void SetSeed(uint64_t new_seed) { _seed = new_seed; } - - const std::vector>& GetGemmAlgos() const { return _gemm_algos; } - -private: - curandGenerator_t _gen; - cublasHandle_t _cublasHandle; - void* _workspace; - uint64_t _seed; - uint64_t _curr_offset; - size_t _workSpaceSize; - std::vector> _gemm_algos; -}; +#pragma once + +#include +#include +#include +#include +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" + +#define WARP_SIZE 32 + +#define CUDA_CHECK(callstr) \ + { \ + cudaError_t error_code = callstr; \ + if (error_code != cudaSuccess) { \ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ + assert(0); \ + } \ + } + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + +#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \ + for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y) + +#define DS_CUDA_NUM_THREADS 512 +#define DS_MAXIMUM_NUM_BLOCKS 262144 + +inline int DS_GET_BLOCKS(const int N) +{ + return std::max( + std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS), + // Use at least 1 block, since CUDA does not allow empty block + 1); +} + +class Context { +public: + Context() : _workspace(nullptr), _seed(42), _curr_offset(0) + { + curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(_gen, 123); + if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) { + auto message = std::string("Fail to create cublas handle."); + std::cerr << message << std::endl; + throw std::runtime_error(message); + } + cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); + } + + virtual ~Context() + { + cublasDestroy(_cublasHandle); + cudaFree(_workspace); + } + + static Context& Instance() + { + static Context _ctx; + return _ctx; + } + + void GenWorkSpace(size_t size) + { + if (!_workspace) { + assert(_workspace == nullptr); + cudaMalloc(&_workspace, size); + } else if (_workSpaceSize < size) { + cudaFree(_workspace); + cudaMalloc(&_workspace, size); + } + + _workSpaceSize = size; + } + + void* GetWorkSpace() { return _workspace; } + + curandGenerator_t& GetRandGenerator() { return _gen; } + + cudaStream_t GetCurrentStream() + { + // get current pytorch stream. + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + return stream; + } + + cublasHandle_t GetCublasHandle() { return _cublasHandle; } + + std::pair IncrementOffset(uint64_t offset_inc) + { + uint64_t offset = _curr_offset; + _curr_offset += offset_inc; + return std::pair(_seed, offset); + } + + void SetSeed(uint64_t new_seed) { _seed = new_seed; } + + const std::vector>& GetGemmAlgos() const { return _gemm_algos; } + +private: + curandGenerator_t _gen; + cublasHandle_t _cublasHandle; + void* _workspace; + uint64_t _seed; + uint64_t _curr_offset; + size_t _workSpaceSize; + std::vector> _gemm_algos; +}; diff --git a/csrc/transformer/inference/includes/cublas_wrappers.h b/csrc/transformer/inference/includes/cublas_wrappers.h index 4be9e09a3..0b37c51ca 100755 --- a/csrc/transformer/inference/includes/cublas_wrappers.h +++ b/csrc/transformer/inference/includes/cublas_wrappers.h @@ -1,208 +1,208 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include "cublas_wrappers.h" - -int cublas_gemm_ex(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - const float* beta, - const float* A, - const float* B, - float* C, - cublasGemmAlgo_t algo) -{ - cublasStatus_t status = cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - (const void*)alpha, - (const void*)A, - CUDA_R_32F, - (transa == CUBLAS_OP_N) ? m : k, - (const void*)B, - CUDA_R_32F, - (transb == CUBLAS_OP_N) ? k : n, - (const void*)beta, - C, - CUDA_R_32F, - m, - CUDA_R_32F, - algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, - n, - k, - (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_gemm_ex(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - const float* beta, - const __half* A, - const __half* B, - __half* C, - cublasGemmAlgo_t algo) -{ - cublasStatus_t status = cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - (const void*)alpha, - (const void*)A, - CUDA_R_16F, - (transa == CUBLAS_OP_N) ? m : k, - (const void*)B, - CUDA_R_16F, - (transb == CUBLAS_OP_N) ? k : n, - (const void*)beta, - (void*)C, - CUDA_R_16F, - m, - CUDA_R_32F, - algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, - n, - k, - (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, - int m, - int n, - int k, - const float* alpha, - const float* beta, - const float* A, - const float* B, - float* C, - cublasOperation_t op_A, - cublasOperation_t op_B, - int stride_A, - int stride_B, - int stride_C, - int batch, - cublasGemmAlgo_t algo) -{ - cublasStatus_t status = cublasGemmStridedBatchedEx(handle, - op_A, - op_B, - m, - n, - k, - alpha, - A, - CUDA_R_32F, - (op_A == CUBLAS_OP_N) ? m : k, - stride_A, - B, - CUDA_R_32F, - (op_B == CUBLAS_OP_N) ? k : n, - stride_B, - beta, - C, - CUDA_R_32F, - m, - stride_C, - batch, - CUDA_R_32F, - algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n", - batch, - m, - n, - k, - (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, - int m, - int n, - int k, - const float* alpha, - const float* beta, - const __half* A, - const __half* B, - __half* C, - cublasOperation_t op_A, - cublasOperation_t op_B, - int stride_A, - int stride_B, - int stride_C, - int batch, - cublasGemmAlgo_t algo) -{ - cublasStatus_t status = cublasGemmStridedBatchedEx(handle, - op_A, - op_B, - m, - n, - k, - alpha, - A, - CUDA_R_16F, - (op_A == CUBLAS_OP_N) ? m : k, - stride_A, - B, - CUDA_R_16F, - (op_B == CUBLAS_OP_N) ? k : n, - stride_B, - beta, - C, - CUDA_R_16F, - m, - stride_C, - batch, - CUDA_R_32F, - algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, - n, - k, - (int)status); - return EXIT_FAILURE; - } - - return 0; -} +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "cublas_wrappers.h" + +int cublas_gemm_ex(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + cublasGemmAlgo_t algo) +{ + cublasStatus_t status = cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, + CUDA_R_32F, + (transa == CUBLAS_OP_N) ? m : k, + (const void*)B, + CUDA_R_32F, + (transb == CUBLAS_OP_N) ? k : n, + (const void*)beta, + C, + CUDA_R_32F, + m, + CUDA_R_32F, + algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_gemm_ex(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const __half* A, + const __half* B, + __half* C, + cublasGemmAlgo_t algo) +{ + cublasStatus_t status = cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, + CUDA_R_16F, + (transa == CUBLAS_OP_N) ? m : k, + (const void*)B, + CUDA_R_16F, + (transb == CUBLAS_OP_N) ? k : n, + (const void*)beta, + (void*)C, + CUDA_R_16F, + m, + CUDA_R_32F, + algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + cublasOperation_t op_A, + cublasOperation_t op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + cublasGemmAlgo_t algo) +{ + cublasStatus_t status = cublasGemmStridedBatchedEx(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, + CUDA_R_32F, + (op_A == CUBLAS_OP_N) ? m : k, + stride_A, + B, + CUDA_R_32F, + (op_B == CUBLAS_OP_N) ? k : n, + stride_B, + beta, + C, + CUDA_R_32F, + m, + stride_C, + batch, + CUDA_R_32F, + algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n", + batch, + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const __half* A, + const __half* B, + __half* C, + cublasOperation_t op_A, + cublasOperation_t op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + cublasGemmAlgo_t algo) +{ + cublasStatus_t status = cublasGemmStridedBatchedEx(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, + CUDA_R_16F, + (op_A == CUBLAS_OP_N) ? m : k, + stride_A, + B, + CUDA_R_16F, + (op_B == CUBLAS_OP_N) ? k : n, + stride_B, + beta, + C, + CUDA_R_16F, + m, + stride_C, + batch, + CUDA_R_32F, + algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + + return 0; +} diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index b544517fa..dfb43a07a 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -1,79 +1,79 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#define MAX_WARP_NUM 32 -#define WARP_SIZE 32 -#define SMs 80 - -#define MAX_REGISTERS 256 -template -void launch_attn_softmax_v2(T* vals, - T* mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int batch_size, - int heads, - int num_seq, - int sequence_length, - float scale, - cudaStream_t stream); - -// Fused bias add with gelu activation -template -void launch_bias_gelu(T* input, - const T* bias, - int intermediate_size, - int batch_size, - cudaStream_t stream); -template -void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream); - -template -void launch_bias_residual(T* input, - const T* residual, - const T* bias, - int size, - int intermediate_size, - cudaStream_t stream); - -template -void launch_layer_norm(T* out, - T* vals, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream); - -template -void launch_residual_layer_norm(T* norm, - T* res_add, - T* vals, - T* residual, - const T* bias, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - bool preLN, - cudaStream_t stream); -template -void launch_dequantize(T* output, - const int8_t* input, - const float* qscale, - unsigned output_size, - unsigned hidden_dim, - unsigned groups, - unsigned merge_count, - cudaStream_t stream); +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#define MAX_WARP_NUM 32 +#define WARP_SIZE 32 +#define SMs 80 + +#define MAX_REGISTERS 256 +template +void launch_attn_softmax_v2(T* vals, + T* mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + float scale, + cudaStream_t stream); + +// Fused bias add with gelu activation +template +void launch_bias_gelu(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream); +template +void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream); + +template +void launch_bias_residual(T* input, + const T* residual, + const T* bias, + int size, + int intermediate_size, + cudaStream_t stream); + +template +void launch_layer_norm(T* out, + T* vals, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream); + +template +void launch_residual_layer_norm(T* norm, + T* res_add, + T* vals, + T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + bool preLN, + cudaStream_t stream); +template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + unsigned merge_count, + cudaStream_t stream); diff --git a/csrc/transformer/normalize_kernels.cu b/csrc/transformer/normalize_kernels.cu index ba8314c58..b7afa3589 100644 --- a/csrc/transformer/normalize_kernels.cu +++ b/csrc/transformer/normalize_kernels.cu @@ -1,2103 +1,2103 @@ -#include "custom_cuda_layers.h" - -namespace cg = cooperative_groups; - -/* -Fused bias add, residual (elementwise) add, and normalization layer. - -For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for -__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic). - -For specific launch constraints, see the launch functions. -*/ - -#define NORM_REG (MAX_REGISTERS / 4) - -__global__ void fused_bias_residual_layer_norm(float* vals, - const float* residual, - const float* gamma, - const float* beta, - float epsilon, - bool preLayerNorm, - bool training, - float* vars, - float* means, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id / WARP_SIZE; - - float vals_arr[NORM_REG]; - __shared__ float shr[MAX_WARP_NUM]; - - residual += (row * row_stride); - vals += (row * row_stride); - - float sum = 0.f; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - vals_arr[i] = residual[i * iteration_stride + id]; - sum += vals_arr[i]; - } - if (high_index < row_stride) { - vals_arr[iterations] = residual[high_index]; - sum += vals_arr[iterations]; - iterations++; - } - - for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) shr[gid] = sum; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; - -#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700 - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } - - sum = g.shfl(sum, 0); - float mean = sum / row_stride; - if (training) - if (threadIdx.x == 0) means[row] = mean; - float variance = 0.f; - for (int i = 0; i < iterations; i++) { - vals_arr[i] -= mean; - variance += vals_arr[i] * vals_arr[i]; - } - - for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } - - if (g.thread_rank() == 0) shr[gid] = variance; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } - variance = g.shfl(variance, 0); - variance /= row_stride; - variance += epsilon; - if (training) - if (threadIdx.x == 0) vars[row] = variance; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr[i] = vals_arr[i] * rsqrtf(variance); - vals_arr[i] = - vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; - vals[i * iteration_stride + id] = vals_arr[i]; - } - if ((high_index) < row_stride) { - vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance); - vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index]; - vals[high_index] = vals_arr[iterations]; - } -} - -__global__ void fused_bias_residual_layer_norm(__half* vals, - const __half* residual, - const __half* gamma, - const __half* beta, - float epsilon, - bool preLayerNorm, - bool training, - __half* vars, - __half* means, - int row_stride) -{ -#if __CUDA_ARCH__ >= 700 - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - - float2 vals_f[NORM_REG]; - __shared__ float shr[MAX_WARP_NUM]; - - __half2* vals_cast = reinterpret_cast<__half2*>(vals); - const __half2* residual_cast = reinterpret_cast(residual); - - residual_cast += (row * row_stride); - vals_cast += (row * row_stride); - - float sum = 0.f; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]); - sum += vals_f[i].x; - sum += vals_f[i].y; - } - if ((high_index) < row_stride) { - vals_f[iterations] = __half22float2(residual_cast[high_index]); - sum += vals_f[iterations].x; - sum += vals_f[iterations].y; - iterations++; - } - - for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) shr[gid] = sum; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } - sum = g.shfl(sum, 0); - float mean = sum / (row_stride * 2); - - float variance = 0.f; - for (int i = 0; i < iterations; i++) { - vals_f[i].x -= mean; - vals_f[i].y -= mean; - variance += vals_f[i].x * vals_f[i].x; - variance += vals_f[i].y * vals_f[i].y; - } - - for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } - - if (g.thread_rank() == 0) shr[gid] = variance; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } - variance = g.shfl(variance, 0); - variance /= (row_stride * 2); - variance += epsilon; - - __half2 variance_h = __float2half2_rn(variance); - const __half2* gamma_cast = reinterpret_cast(gamma); - const __half2* beta_cast = reinterpret_cast(beta); - - if (training && threadIdx.x == 0) { - vars[row] = __float2half(variance); - means[row] = __float2half(mean); - } - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - __half2 vals_arr = __float22half2_rn(vals_f[i]); - vals_arr = vals_arr * h2rsqrt(variance_h); - vals_arr = - vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id]; - vals_cast[i * iteration_stride + id] = vals_arr; - } - if ((high_index) < row_stride) { - __half2 vals_arr = __float22half2_rn(vals_f[iterations]); - vals_arr = vals_arr * h2rsqrt(variance_h); - vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; - vals_cast[high_index] = vals_arr; - } -#endif -} - -template -void launch_bias_residual_layer_norm(T* vals, - const T* residual, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - T* vars, - T* means); - -template <> -void launch_bias_residual_layer_norm(float* vals, - const float* residual, - const float* gamma, - const float* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - float* vars, - float* means) -{ - int threads = THREADS; - - dim3 grid_dim(batch_size); - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim(threads); - - fused_bias_residual_layer_norm<<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim); -} - -template <> -void launch_bias_residual_layer_norm<__half>(__half* vals, - const __half* residual, - const __half* gamma, - const __half* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - __half* vars, - __half* means) -{ - int threads = 128; - - dim3 grid_dim(batch_size); - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim(threads); - - fused_bias_residual_layer_norm<<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2); -} - -__global__ void fused_bias_residual_layer_norm(float* vals, - const float* residual, - const float* gamma, - const float* beta, - float epsilon, - bool preLayerNorm, - bool training, - float* vars, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id / 32; - - float vals_arr[NORM_REG]; - __shared__ float shr[MAX_WARP_NUM]; - - residual += (row * row_stride); - vals += (row * row_stride); - - float sum = 0.f; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - vals_arr[i] = residual[i * iteration_stride + id]; - sum += vals_arr[i]; - } - if ((high_index) < row_stride) { - vals_arr[iterations] = residual[high_index]; - sum += vals_arr[iterations]; - iterations++; - } - - for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) shr[gid] = sum; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; - -#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700 - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } - - sum = g.shfl(sum, 0); - float mean = sum / row_stride; - float variance = 0.f; - for (int i = 0; i < iterations; i++) { - vals_arr[i] -= mean; - variance += vals_arr[i] * vals_arr[i]; - } - - for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } - - if (g.thread_rank() == 0) shr[gid] = variance; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } - variance = g.shfl(variance, 0); - variance /= row_stride; - variance += epsilon; - if (training) - if (threadIdx.x == 0) vars[row] = variance; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr[i] = vals_arr[i] * rsqrtf(variance); - vals_arr[i] = - vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; - vals[i * iteration_stride + id] = vals_arr[i]; - } - if ((high_index) < row_stride) { - vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance); - vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index]; - vals[high_index] = vals_arr[iterations]; - } -} - -__global__ void fused_bias_residual_layer_norm(__half* vals, - const __half* residual, - const __half* gamma, - const __half* beta, - float epsilon, - bool preLayerNorm, - bool training, - __half* vars, - int row_stride) -{ -#if __CUDA_ARCH__ >= 700 - - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - - float2 vals_f[NORM_REG]; - __shared__ float shr[MAX_WARP_NUM]; - - __half2* vals_cast = reinterpret_cast<__half2*>(vals); - const __half2* residual_cast = reinterpret_cast(residual); - - residual_cast += (row * row_stride); - vals_cast += (row * row_stride); - - float sum = 0.f; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]); - sum += vals_f[i].x; - sum += vals_f[i].y; - } - if ((high_index) < row_stride) { - vals_f[iterations] = __half22float2(residual_cast[high_index]); - sum += vals_f[iterations].x; - sum += vals_f[iterations].y; - iterations++; - } - - for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) shr[gid] = sum; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } - sum = g.shfl(sum, 0); - float mean = sum / (row_stride * 2); - - float variance = 0.f; - for (int i = 0; i < iterations; i++) { - vals_f[i].x -= mean; - vals_f[i].y -= mean; - variance += vals_f[i].x * vals_f[i].x; - variance += vals_f[i].y * vals_f[i].y; - } - - for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } - - if (g.thread_rank() == 0) shr[gid] = variance; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } - variance = g.shfl(variance, 0); - variance /= (row_stride * 2); - variance += epsilon; - - __half2 variance_h = __float2half2_rn(variance); - const __half2* gamma_cast = reinterpret_cast(gamma); - const __half2* beta_cast = reinterpret_cast(beta); - - if (training && threadIdx.x == 0) vars[row] = __float2half(variance); - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - __half2 vals_arr = __float22half2_rn(vals_f[i]); - vals_arr = vals_arr * h2rsqrt(variance_h); - vals_arr = - vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id]; - vals_cast[i * iteration_stride + id] = vals_arr; - } - if ((high_index) < row_stride) { - __half2 vals_arr = __float22half2_rn(vals_f[iterations]); - vals_arr = vals_arr * h2rsqrt(variance_h); - vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; - vals_cast[high_index] = vals_arr; - } -#endif -} - -template -void launch_bias_residual_layer_norm(T* vals, - const T* residual, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - T* vars); - -/* -To tune this launch the following restrictions must be met: - -For float: -row_stride == hidden_size -threads * iterations == row_stride -threads is in [32, 64, 128, 256, 512, 1024] - -For half: -row_stride == hidden_size / 2 -threads * iterations == row_stride -threads is in [32, 64, 128, 256, 512, 1024] - -*/ - -template <> -void launch_bias_residual_layer_norm(float* vals, - const float* residual, - const float* gamma, - const float* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - float* vars) -{ - int threads = THREADS; - - dim3 grid_dim(batch_size); - - // There are some limitations to call below functions, now just enumerate the situations. - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim(threads); - - fused_bias_residual_layer_norm<<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim); -} - -template <> -void launch_bias_residual_layer_norm<__half>(__half* vals, - const __half* residual, - const __half* gamma, - const __half* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - __half* vars) -{ - int threads = 128; - - dim3 grid_dim(batch_size); - - // There are some limitations to call below functions, now just enumerate the situations. - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim(threads); - fused_bias_residual_layer_norm<<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2); -} - -/* Normalize Gamma & Betta gradients - * Compute gradients using either X_hat or - * normalize input (invertible). - * Combine transpose with gradients computation. - */ - -template -__global__ void LayerNormBackward1(const T* __restrict__ out_grad, - const T* __restrict__ vals_hat, - const T* __restrict__ gamma, - const T* __restrict__ betta, - T* __restrict__ gamma_grad, - T* __restrict__ betta_grad, - int rows, - int width, - bool invertible) -{ - __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - float betta_reg = (invertible ? (float)betta[idx] : 0.0f); - float gamma_reg = (float)gamma[idx]; - - // Loop across matrix height - float betta_tmp = 0; - float gamma_tmp = 0; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - float grad = (float)out_grad[offset]; - float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg - : (float)vals_hat[offset]); - betta_tmp += grad; - gamma_tmp += (val * grad); - - offset += y_stride; - } - - betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; - gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; - - __syncthreads(); - - // Sum the shared buffer. - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - if (threadIdx.x == 0) { - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -/* Normalize Gamma & Betta gradients - * Compute gradients using the input to - * the normalize. - * Combine transpose with gradients computation. - */ - -template -__global__ void LayerNormBackward1(const T* __restrict__ out_grad, - const T* __restrict__ X_data, - const T* __restrict__ vars, - const T* __restrict__ means, - T* __restrict__ gamma_grad, - T* __restrict__ betta_grad, - int rows, - int width) -{ - __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - // Loop across matrix height - - float betta_tmp = 0; - float gamma_tmp = 0; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - float grad = (float)out_grad[offset]; - float val = (float)X_data[offset]; - val = (val - (float)means[r]) * rsqrtf((float)vars[r]); - betta_tmp += grad; - gamma_tmp += (val * grad); - - offset += y_stride; - } - - betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; - gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; - - __syncthreads(); - - // Sum the shared buffer. - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - if (threadIdx.x == 0) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} -/* - -/* Backward Normalize (Input-Gradient) - * Using the means and variances from the input - * This type of backward is invertible! - * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization. - */ - -__global__ void LayerNormBackward2(const float* out_grad, - const float* vals_hat, - const float* gamma, - const float* betta, - const float* vars, - float* inp_grad, - bool invertible, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - out_grad += (row * row_stride); - vals_hat += (row * row_stride); - inp_grad += (row * row_stride); - - float vals_arr[NORM_REG]; - float vals_hat_arr[NORM_REG]; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - float gamma_reg = gamma[i * iteration_stride + id]; - vals_arr[i] = out_grad[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; - vals_hat_arr[i] = - (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) / - gamma_reg - : vals_hat[i * iteration_stride + id]); - } - if ((high_index) < row_stride) { - float gamma_reg = gamma[high_index]; - vals_arr[iterations] = out_grad[high_index]; - vals_arr[iterations] *= gamma_reg; - vals_hat_arr[iterations] = - (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg - : vals_hat[high_index]); - iterations++; - } - - float var_reg = vars[row]; - - float sum = 0; - for (int i = 0; i < iterations; i++) { - sum += vals_hat_arr[i] * vals_arr[i] * - sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad - vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var) - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= row_stride; - - for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } - - sum = 0; - for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= row_stride; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); - if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum); -} - -__global__ void LayerNormBackward2(const __half* out_grad, - const __half* vals_hat, - const __half* gamma, - const __half* betta, - const __half* vars, - __half* inp_grad, - bool invertible, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - __half2 vals_arr[NORM_REG]; - float2 vals_arr_f[NORM_REG]; - __half2 vals_hat_arr[NORM_REG]; - - __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); - const __half2* out_grad_h = reinterpret_cast(out_grad); - const __half2* vals_hat_h = reinterpret_cast(vals_hat); - - inp_grad_h += (row * row_stride); - out_grad_h += (row * row_stride); - vals_hat_h += (row * row_stride); - - const __half2* gamma_h = reinterpret_cast(gamma); - const __half2* betta_h = (invertible ? reinterpret_cast(betta) : nullptr); - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - __half2 gamma_reg = gamma_h[i * iteration_stride + id]; - vals_arr[i] = out_grad_h[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; - vals_hat_arr[i] = - (invertible - ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) / - gamma_reg - : vals_hat_h[i * iteration_stride + id]); - } - if ((high_index) < row_stride) { - __half2 gamma_reg = gamma_h[high_index]; - vals_arr[iterations] = out_grad_h[high_index]; - vals_arr[iterations] *= gamma_reg; - vals_hat_arr[iterations] = - (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg - : vals_hat_h[high_index]); - iterations++; - } - __half var_h = vars[row]; - __half2 var_reg = __halves2half2(var_h, var_h); - - float sum = 0.f; - for (int i = 0; i < iterations; i++) { - __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg)); - float2 result_f = __half22float2(result_h); - sum += result_f.x; - sum += result_f.y; - vals_arr[i] *= h2rsqrt(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - __half2 sum_h = __float2half2_rn(sum); - - for (int i = 0; i < iterations; i++) { - __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg)); - vals_arr_f[i] = __half22float2(vals_arr[i]); - float2 temp_f = __half22float2(temp); - vals_arr_f[i].x += temp_f.x; - vals_arr_f[i].y += temp_f.y; - } - sum = 0.f; - - for (int i = 0; i < iterations; i++) { - sum += (vals_arr_f[i].x); - sum += (vals_arr_f[i].y); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr_f[i].x -= sum; - vals_arr_f[i].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[i]); - - inp_grad_h[i * iteration_stride + id] = temp; - } - if ((high_index) < row_stride) { - vals_arr_f[iterations].x -= sum; - vals_arr_f[iterations].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[iterations]); - - inp_grad_h[high_index] = temp; - } -} - -template <> -void launch_layerNorm_backward(const float* out_grad, - const float* vals_hat, - const float* vars, - const float* gamma, - float* gamma_grad, - float* betta_grad, - float* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2], - bool invertible, - const float* betta) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<<>>( - out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); - - dim3 grid_dim2(batch); - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads); - - LayerNormBackward2<<>>( - out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim); -} - -template <> -void launch_layerNorm_backward<__half>(const __half* out_grad, - const __half* vals_hat, - const __half* vars, - const __half* gamma, - __half* gamma_grad, - __half* betta_grad, - __half* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2], - bool invertible, - const __half* betta) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<__half><<>>( - out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); - - dim3 grid_dim2(batch); - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads / 2); - - LayerNormBackward2<<>>( - out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2); -} - -/* Backward Normalize (Input-Gradient) - * Using the means and variances from the input - * This type of backward is not invertible! - * We do the backward using the input (X) - */ - -__global__ void LayerNormBackward2(const float* out_grad, - const float* X_vals, - const float* gamma, - const float* vars, - const float* means, - float* inp_grad, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - out_grad += (row * row_stride); - X_vals += (row * row_stride); - inp_grad += (row * row_stride); - - float vals_arr[NORM_REG]; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - float gamma_reg = gamma[i * iteration_stride + id]; - vals_arr[i] = out_grad[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; - } - if ((high_index) < row_stride) { - float gamma_reg = gamma[high_index]; - vals_arr[iterations] = out_grad[high_index]; - vals_arr[iterations] *= gamma_reg; - iterations++; - } - - float var_reg = vars[row]; - float mean_reg = means[row]; - - float sum = 0; - float xu[NORM_REG]; - for (int i = 0; i < iterations; i++) { - xu[i] = (X_vals[i * iteration_stride + id] - mean_reg); - sum += vals_arr[i] * xu[i]; - vals_arr[i] *= rsqrtf(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= row_stride; - - for (int i = 0; i < iterations; i++) { - vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg)); - } - - sum = 0; - for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= row_stride; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); - if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum); -} - -__global__ void LayerNormBackward2(const __half* out_grad, - const __half* X_vals, - const __half* gamma, - const __half* vars, - const __half* means, - __half* inp_grad, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; - - __shared__ float partialSum[MAX_WARP_NUM]; - - __half2 vals_arr[NORM_REG]; - float2 vals_arr_f[NORM_REG]; - - __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); - const __half2* out_grad_h = reinterpret_cast(out_grad); - const __half2* vals_hat_h = reinterpret_cast(X_vals); - - inp_grad_h += (row * row_stride); - out_grad_h += (row * row_stride); - vals_hat_h += (row * row_stride); - - const __half2* gamma_h = reinterpret_cast(gamma); - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - __half2 gamma_reg = gamma_h[i * iteration_stride + id]; - vals_arr[i] = out_grad_h[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; // out_grad * gamma - } - if ((high_index) < row_stride) { - __half2 gamma_reg = gamma_h[high_index]; - vals_arr[iterations] = out_grad_h[high_index]; - vals_arr[iterations] *= gamma_reg; // out_grad * gamma - iterations++; - } - __half mean_h = means[row]; - __half var_h = vars[row]; - __half2 var_reg = __halves2half2(var_h, var_h); - __half2 mean_reg = __halves2half2(mean_h, mean_h); - __half2 xu[NORM_REG]; - - float sum = 0.f; - for (int i = 0; i < iterations; i++) { - xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg); - __half2 result_h = (xu[i] * vals_arr[i]); - float2 result_f = __half22float2(result_h); - sum += result_f.x; - sum += result_f.y; - vals_arr[i] *= h2rsqrt(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - __half2 sum_h = __float2half2_rn(sum); - - for (int i = 0; i < iterations; i++) { - __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg)); - vals_arr_f[i] = __half22float2(vals_arr[i]); - float2 xu_grad_f = __half22float2(xu_grad); - vals_arr_f[i].x += xu_grad_f.x; - vals_arr_f[i].y += xu_grad_f.y; - } - - sum = 0.f; - for (int i = 0; i < iterations; i++) { - sum += (vals_arr_f[i].x); - sum += (vals_arr_f[i].y); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr_f[i].x -= sum; - vals_arr_f[i].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[i]); - inp_grad_h[i * iteration_stride + id] = temp; - } - if ((high_index) < row_stride) { - vals_arr_f[iterations].x -= sum; - vals_arr_f[iterations].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[iterations]); - inp_grad_h[high_index] = temp; - } -} - -template <> -void launch_layerNorm_backward(const float* out_grad, - const float* X_data, - const float* vars, - const float* means, - const float* gamma, - float* gamma_grad, - float* betta_grad, - float* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2]) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<<>>( - out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); - - dim3 grid_dim2(batch); - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads); - LayerNormBackward2<<>>( - out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim); -} - -template <> -void launch_layerNorm_backward<__half>(const __half* out_grad, - const __half* X_data, - const __half* vars, - const __half* means, - const __half* gamma, - __half* gamma_grad, - __half* betta_grad, - __half* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2]) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<__half><<>>( - out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); - - dim3 grid_dim2(batch); - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads / 2); - LayerNormBackward2<<>>( - out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2); -} - -template -__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, - const T* __restrict__ out_grad2, - const T* __restrict__ vals_hat, - const T* __restrict__ gamma, - const T* __restrict__ betta, - T* __restrict__ gamma_grad, - T* __restrict__ betta_grad, - int rows, - int width, - bool invertible) -{ - __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - float betta_reg = (invertible ? (float)betta[idx] : 0.0f); - float gamma_reg = (float)gamma[idx]; - - // Loop across matrix height - float betta_tmp = 0; - float gamma_tmp = 0; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - float grad = (float)out_grad1[offset] + (float)out_grad2[offset]; - float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg - : (float)vals_hat[offset]); - betta_tmp += grad; - gamma_tmp += (val * grad); - - offset += y_stride; - } - - betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; - gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; - - __syncthreads(); - - // Sum the shared buffer. - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - if (threadIdx.x == 0) { - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -template -__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, - const T* __restrict__ out_grad2, - const T* __restrict__ X_data, - const T* __restrict__ vars, - const T* __restrict__ means, - T* __restrict__ gamma_grad, - T* __restrict__ betta_grad, - int rows, - int width) -{ - __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - // Loop across matrix height - - float betta_tmp = 0; - float gamma_tmp = 0; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - float grad = (float)out_grad1[offset] + (float)out_grad2[offset]; - float val = (float)X_data[offset]; - val = (val - (float)means[r]) * rsqrtf((float)vars[r]); - betta_tmp += grad; - gamma_tmp += (val * grad); - - offset += y_stride; - } - - betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; - gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; - - __syncthreads(); - - // Sum the shared buffer. - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - if (threadIdx.x == 0) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -__global__ void LayerNormBackward2_fused_add(const float* out_grad1, - const float* out_grad2, - const float* vals_hat, - const float* gamma, - const float* betta, - const float* vars, - float* inp_grad, - bool invertible, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - out_grad1 += (row * row_stride); - out_grad2 += (row * row_stride); - vals_hat += (row * row_stride); - inp_grad += (row * row_stride); - - float vals_arr[NORM_REG]; - float vals_hat_arr[NORM_REG]; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - float gamma_reg = gamma[i * iteration_stride + id]; - vals_arr[i] = out_grad1[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; - vals_hat_arr[i] = - (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) / - gamma_reg - : vals_hat[i * iteration_stride + id]); - } - if ((high_index) < row_stride) { - float gamma_reg = gamma[high_index]; - vals_arr[iterations] = out_grad1[high_index]; - vals_arr[iterations] *= gamma_reg; - vals_hat_arr[iterations] = - (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg - : vals_hat[high_index]); - iterations++; - } - - float var_reg = vars[row]; - - float sum = 0; - for (int i = 0; i < iterations; i++) { - sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg); - vals_arr[i] *= rsqrtf(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= row_stride; - - for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } - - sum = 0; - for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= row_stride; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) - inp_grad[i * iteration_stride + id] = - (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id]; - if ((high_index) < row_stride) - inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index]; -} - -__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, - const __half* out_grad2, - const __half* vals_hat, - const __half* gamma, - const __half* betta, - const __half* vars, - __half* inp_grad, - bool invertible, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - __half2 vals_arr[NORM_REG]; - float2 vals_arr_f[NORM_REG]; - __half2 vals_hat_arr[NORM_REG]; - - // float2 result[iterations]; - - __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); - const __half2* out_grad_h1 = reinterpret_cast(out_grad1); - const __half2* out_grad_h2 = reinterpret_cast(out_grad2); - const __half2* vals_hat_h = reinterpret_cast(vals_hat); - - inp_grad_h += (row * row_stride); - out_grad_h1 += (row * row_stride); - out_grad_h2 += (row * row_stride); - vals_hat_h += (row * row_stride); - - const __half2* gamma_h = reinterpret_cast(gamma); - const __half2* betta_h = (invertible ? reinterpret_cast(betta) : nullptr); - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - __half2 gamma_reg = gamma_h[i * iteration_stride + id]; - vals_arr[i] = out_grad_h1[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; // out_grad * gamma - vals_hat_arr[i] = - (invertible - ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) / - gamma_reg - : vals_hat_h[i * iteration_stride + id]); - } - if ((high_index) < row_stride) { - __half2 gamma_reg = gamma_h[high_index]; - vals_arr[iterations] = out_grad_h1[high_index]; - vals_arr[iterations] *= gamma_reg; // out_grad * gamma - vals_hat_arr[iterations] = - (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg - : vals_hat_h[high_index]); - iterations++; - } - __half var_h = vars[row]; - __half2 var_reg = __halves2half2(var_h, var_h); - - float sum = 0.f; - for (int i = 0; i < iterations; i++) { - __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg)); - float2 result_f = __half22float2(result_h); - sum += result_f.x; - sum += result_f.y; - vals_arr[i] *= h2rsqrt(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - __half2 sum_h = __float2half2_rn(sum); - - for (int i = 0; i < iterations; i++) { - __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg)); - vals_arr_f[i] = __half22float2(vals_arr[i]); - float2 temp_f = __half22float2(temp); - vals_arr_f[i].x += temp_f.x; - vals_arr_f[i].y += temp_f.y; - } - sum = 0.f; - for (int i = 0; i < iterations; i++) { - sum += (vals_arr_f[i].x); - sum += (vals_arr_f[i].y); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr_f[i].x -= sum; - vals_arr_f[i].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[i]); - - inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id]; - } - if ((high_index) < row_stride) { - vals_arr_f[iterations].x -= sum; - vals_arr_f[iterations].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[iterations]); - - inp_grad_h[high_index] = temp + out_grad_h2[high_index]; - } -} - -template <> -void launch_layerNorm_backward_fused_add(const float* out_grad1, - const float* out_grad2, - const float* vals_hat, - const float* vars, - const float* gamma, - float* gamma_grad, - float* betta_grad, - float* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2], - bool invertible, - const float* betta) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - LayerNormBackward1<<>>( - out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); - - dim3 grid_dim2(batch); - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads); - LayerNormBackward2_fused_add<<>>( - out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim); -} - -template <> -void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, - const __half* out_grad2, - const __half* vals_hat, - const __half* vars, - const __half* gamma, - __half* gamma_grad, - __half* betta_grad, - __half* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2], - bool invertible, - const __half* betta) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<__half><<>>( - out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); - - dim3 grid_dim2(batch); - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads / 2); - LayerNormBackward2_fused_add<<>>( - out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2); -} - -/* Backward Normalize (Input-Gradient) - * Using the means and variances from the input - * This type of backward is not invertible! - * We do the backward using the input (X) - */ - -__global__ void LayerNormBackward2_fused_add(const float* out_grad1, - const float* out_grad2, - const float* X_vals, - const float* gamma, - const float* vars, - const float* means, - float* inp_grad, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - float vals_arr[NORM_REG]; - float vals_hat_arr[NORM_REG]; - - out_grad1 += (row * row_stride); - out_grad2 += (row * row_stride); - X_vals += (row * row_stride); - inp_grad += (row * row_stride); - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - float gamma_reg = gamma[i * iteration_stride + id]; - vals_arr[i] = out_grad1[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; - vals_hat_arr[i] = X_vals[i * iteration_stride + id]; - } - if ((high_index) < row_stride) { - float gamma_reg = gamma[high_index]; - vals_arr[iterations] = out_grad1[high_index]; - vals_arr[iterations] *= gamma_reg; - vals_hat_arr[iterations] = X_vals[high_index]; - iterations++; - } - - float var_reg = vars[row]; - float mean_reg = means[row]; - - float sum = 0; - float xu[NORM_REG]; - for (int i = 0; i < iterations; i++) { - xu[i] = (vals_hat_arr[i] - mean_reg); - sum += vals_arr[i] * xu[i]; - vals_arr[i] *= rsqrtf(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= row_stride; - - for (int i = 0; i < iterations; i++) { - vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg)); - } - - sum = 0; - for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= row_stride; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) - inp_grad[i * iteration_stride + id] = - (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id]; - if ((high_index) < row_stride) - inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index]; -} - -__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, - const __half* out_grad2, - const __half* X_vals, - const __half* gamma, - const __half* vars, - const __half* means, - __half* inp_grad, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; - - __shared__ float partialSum[MAX_WARP_NUM]; - - __half2 vals_arr[NORM_REG]; - float2 vals_arr_f[NORM_REG]; - __half2 vals_hat_arr[NORM_REG]; - - __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); - const __half2* out_grad_h1 = reinterpret_cast(out_grad1); - const __half2* out_grad_h2 = reinterpret_cast(out_grad2); - const __half2* vals_hat_h = reinterpret_cast(X_vals); - - out_grad_h1 += (row * row_stride); - out_grad_h2 += (row * row_stride); - inp_grad_h += (row * row_stride); - vals_hat_h += (row * row_stride); - - const __half2* gamma_h = reinterpret_cast(gamma); - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - __half2 gamma_reg = gamma_h[i * iteration_stride + id]; - vals_arr[i] = out_grad_h1[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; // out_grad * gamma - vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id]; - } - if ((high_index) < row_stride) { - __half2 gamma_reg = gamma_h[high_index]; - vals_arr[iterations] = out_grad_h1[high_index]; - vals_arr[iterations] *= gamma_reg; // out_grad * gamma - vals_hat_arr[iterations] = vals_hat_h[high_index]; - iterations++; - } - - __half mean_h = means[row]; - __half var_h = vars[row]; - __half2 var_reg = __halves2half2(var_h, var_h); - __half2 mean_reg = __halves2half2(mean_h, mean_h); - __half2 xu[NORM_REG]; - - float sum = 0.f; - for (int i = 0; i < iterations; i++) { - xu[i] = (vals_hat_arr[i] - mean_reg); - __half2 result_h = (xu[i] * vals_arr[i]); - float2 result_f = __half22float2(result_h); - sum += result_f.x; - sum += result_f.y; - vals_arr[i] *= h2rsqrt(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - __half2 sum_h = __float2half2_rn(sum); - - for (int i = 0; i < iterations; i++) { - __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg)); - vals_arr_f[i] = __half22float2(vals_arr[i]); - float2 xu_grad_f = __half22float2(xu_grad); - vals_arr_f[i].x += xu_grad_f.x; - vals_arr_f[i].y += xu_grad_f.y; - } - - sum = 0.f; - for (int i = 0; i < iterations; i++) { - sum += (vals_arr_f[i].x); - sum += (vals_arr_f[i].y); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr_f[i].x -= sum; - vals_arr_f[i].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[i]); - inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id]; - } - if ((high_index) < row_stride) { - vals_arr_f[iterations].x -= sum; - vals_arr_f[iterations].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[iterations]); - inp_grad_h[high_index] = temp + out_grad_h2[high_index]; - } -} - -template <> -void launch_layerNorm_backward_fused_add(const float* out_grad1, - const float* out_grad2, - const float* X_data, - const float* vars, - const float* means, - const float* gamma, - float* gamma_grad, - float* betta_grad, - float* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2]) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<<>>( - out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); - - dim3 grid_dim2(batch); - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads); - LayerNormBackward2_fused_add<<>>( - out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim); -} - -template <> -void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, - const __half* out_grad2, - const __half* X_data, - const __half* vars, - const __half* means, - const __half* gamma, - __half* gamma_grad, - __half* betta_grad, - __half* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2]) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<__half><<>>( - out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); - - dim3 grid_dim2(batch); - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads / 2); - LayerNormBackward2_fused_add<<>>( - out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2); -} +#include "custom_cuda_layers.h" + +namespace cg = cooperative_groups; + +/* +Fused bias add, residual (elementwise) add, and normalization layer. + +For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for +__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic). + +For specific launch constraints, see the launch functions. +*/ + +#define NORM_REG (MAX_REGISTERS / 4) + +__global__ void fused_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + bool preLayerNorm, + bool training, + float* vars, + float* means, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id / WARP_SIZE; + + float vals_arr[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + residual += (row * row_stride); + vals += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_arr[i] = residual[i * iteration_stride + id]; + sum += vals_arr[i]; + } + if (high_index < row_stride) { + vals_arr[iterations] = residual[high_index]; + sum += vals_arr[iterations]; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; + +#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700 + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } + + sum = g.shfl(sum, 0); + float mean = sum / row_stride; + if (training) + if (threadIdx.x == 0) means[row] = mean; + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_arr[i] -= mean; + variance += vals_arr[i] * vals_arr[i]; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } + variance = g.shfl(variance, 0); + variance /= row_stride; + variance += epsilon; + if (training) + if (threadIdx.x == 0) vars[row] = variance; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr[i] = vals_arr[i] * rsqrtf(variance); + vals_arr[i] = + vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; + vals[i * iteration_stride + id] = vals_arr[i]; + } + if ((high_index) < row_stride) { + vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance); + vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index]; + vals[high_index] = vals_arr[iterations]; + } +} + +__global__ void fused_bias_residual_layer_norm(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + bool preLayerNorm, + bool training, + __half* vars, + __half* means, + int row_stride) +{ +#if __CUDA_ARCH__ >= 700 + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + + float2 vals_f[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + __half2* vals_cast = reinterpret_cast<__half2*>(vals); + const __half2* residual_cast = reinterpret_cast(residual); + + residual_cast += (row * row_stride); + vals_cast += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]); + sum += vals_f[i].x; + sum += vals_f[i].y; + } + if ((high_index) < row_stride) { + vals_f[iterations] = __half22float2(residual_cast[high_index]); + sum += vals_f[iterations].x; + sum += vals_f[iterations].y; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } + sum = g.shfl(sum, 0); + float mean = sum / (row_stride * 2); + + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_f[i].x -= mean; + vals_f[i].y -= mean; + variance += vals_f[i].x * vals_f[i].x; + variance += vals_f[i].y * vals_f[i].y; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } + variance = g.shfl(variance, 0); + variance /= (row_stride * 2); + variance += epsilon; + + __half2 variance_h = __float2half2_rn(variance); + const __half2* gamma_cast = reinterpret_cast(gamma); + const __half2* beta_cast = reinterpret_cast(beta); + + if (training && threadIdx.x == 0) { + vars[row] = __float2half(variance); + means[row] = __float2half(mean); + } + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + __half2 vals_arr = __float22half2_rn(vals_f[i]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = + vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id]; + vals_cast[i * iteration_stride + id] = vals_arr; + } + if ((high_index) < row_stride) { + __half2 vals_arr = __float22half2_rn(vals_f[iterations]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; + vals_cast[high_index] = vals_arr; + } +#endif +} + +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + T* vars, + T* means); + +template <> +void launch_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + float* vars, + float* means) +{ + int threads = THREADS; + + dim3 grid_dim(batch_size); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim); +} + +template <> +void launch_bias_residual_layer_norm<__half>(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + __half* vars, + __half* means) +{ + int threads = 128; + + dim3 grid_dim(batch_size); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2); +} + +__global__ void fused_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + bool preLayerNorm, + bool training, + float* vars, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id / 32; + + float vals_arr[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + residual += (row * row_stride); + vals += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_arr[i] = residual[i * iteration_stride + id]; + sum += vals_arr[i]; + } + if ((high_index) < row_stride) { + vals_arr[iterations] = residual[high_index]; + sum += vals_arr[iterations]; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; + +#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700 + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } + + sum = g.shfl(sum, 0); + float mean = sum / row_stride; + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_arr[i] -= mean; + variance += vals_arr[i] * vals_arr[i]; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } + variance = g.shfl(variance, 0); + variance /= row_stride; + variance += epsilon; + if (training) + if (threadIdx.x == 0) vars[row] = variance; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr[i] = vals_arr[i] * rsqrtf(variance); + vals_arr[i] = + vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; + vals[i * iteration_stride + id] = vals_arr[i]; + } + if ((high_index) < row_stride) { + vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance); + vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index]; + vals[high_index] = vals_arr[iterations]; + } +} + +__global__ void fused_bias_residual_layer_norm(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + bool preLayerNorm, + bool training, + __half* vars, + int row_stride) +{ +#if __CUDA_ARCH__ >= 700 + + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + + float2 vals_f[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + __half2* vals_cast = reinterpret_cast<__half2*>(vals); + const __half2* residual_cast = reinterpret_cast(residual); + + residual_cast += (row * row_stride); + vals_cast += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]); + sum += vals_f[i].x; + sum += vals_f[i].y; + } + if ((high_index) < row_stride) { + vals_f[iterations] = __half22float2(residual_cast[high_index]); + sum += vals_f[iterations].x; + sum += vals_f[iterations].y; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } + sum = g.shfl(sum, 0); + float mean = sum / (row_stride * 2); + + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_f[i].x -= mean; + vals_f[i].y -= mean; + variance += vals_f[i].x * vals_f[i].x; + variance += vals_f[i].y * vals_f[i].y; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } + variance = g.shfl(variance, 0); + variance /= (row_stride * 2); + variance += epsilon; + + __half2 variance_h = __float2half2_rn(variance); + const __half2* gamma_cast = reinterpret_cast(gamma); + const __half2* beta_cast = reinterpret_cast(beta); + + if (training && threadIdx.x == 0) vars[row] = __float2half(variance); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + __half2 vals_arr = __float22half2_rn(vals_f[i]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = + vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id]; + vals_cast[i * iteration_stride + id] = vals_arr; + } + if ((high_index) < row_stride) { + __half2 vals_arr = __float22half2_rn(vals_f[iterations]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; + vals_cast[high_index] = vals_arr; + } +#endif +} + +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + T* vars); + +/* +To tune this launch the following restrictions must be met: + +For float: +row_stride == hidden_size +threads * iterations == row_stride +threads is in [32, 64, 128, 256, 512, 1024] + +For half: +row_stride == hidden_size / 2 +threads * iterations == row_stride +threads is in [32, 64, 128, 256, 512, 1024] + +*/ + +template <> +void launch_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + float* vars) +{ + int threads = THREADS; + + dim3 grid_dim(batch_size); + + // There are some limitations to call below functions, now just enumerate the situations. + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim); +} + +template <> +void launch_bias_residual_layer_norm<__half>(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + __half* vars) +{ + int threads = 128; + + dim3 grid_dim(batch_size); + + // There are some limitations to call below functions, now just enumerate the situations. + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2); +} + +/* Normalize Gamma & Betta gradients + * Compute gradients using either X_hat or + * normalize input (invertible). + * Combine transpose with gradients computation. + */ + +template +__global__ void LayerNormBackward1(const T* __restrict__ out_grad, + const T* __restrict__ vals_hat, + const T* __restrict__ gamma, + const T* __restrict__ betta, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width, + bool invertible) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + float betta_reg = (invertible ? (float)betta[idx] : 0.0f); + float gamma_reg = (float)gamma[idx]; + + // Loop across matrix height + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad[offset]; + float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg + : (float)vals_hat[offset]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +/* Normalize Gamma & Betta gradients + * Compute gradients using the input to + * the normalize. + * Combine transpose with gradients computation. + */ + +template +__global__ void LayerNormBackward1(const T* __restrict__ out_grad, + const T* __restrict__ X_data, + const T* __restrict__ vars, + const T* __restrict__ means, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + // Loop across matrix height + + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad[offset]; + float val = (float)X_data[offset]; + val = (val - (float)means[r]) * rsqrtf((float)vars[r]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} +/* + +/* Backward Normalize (Input-Gradient) + * Using the means and variances from the input + * This type of backward is invertible! + * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization. + */ + +__global__ void LayerNormBackward2(const float* out_grad, + const float* vals_hat, + const float* gamma, + const float* betta, + const float* vars, + float* inp_grad, + bool invertible, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + out_grad += (row * row_stride); + vals_hat += (row * row_stride); + inp_grad += (row * row_stride); + + float vals_arr[NORM_REG]; + float vals_hat_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = + (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) / + gamma_reg + : vals_hat[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = + (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg + : vals_hat[high_index]); + iterations++; + } + + float var_reg = vars[row]; + + float sum = 0; + for (int i = 0; i < iterations; i++) { + sum += vals_hat_arr[i] * vals_arr[i] * + sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad + vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var) + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); + if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum); +} + +__global__ void LayerNormBackward2(const __half* out_grad, + const __half* vals_hat, + const __half* gamma, + const __half* betta, + const __half* vars, + __half* inp_grad, + bool invertible, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + __half2 vals_hat_arr[NORM_REG]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h = reinterpret_cast(out_grad); + const __half2* vals_hat_h = reinterpret_cast(vals_hat); + + inp_grad_h += (row * row_stride); + out_grad_h += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + const __half2* betta_h = (invertible ? reinterpret_cast(betta) : nullptr); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = + (invertible + ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) / + gamma_reg + : vals_hat_h[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = + (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg + : vals_hat_h[high_index]); + iterations++; + } + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg)); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 temp_f = __half22float2(temp); + vals_arr_f[i].x += temp_f.x; + vals_arr_f[i].y += temp_f.y; + } + sum = 0.f; + + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + + inp_grad_h[i * iteration_stride + id] = temp; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + + inp_grad_h[high_index] = temp; + } +} + +template <> +void launch_layerNorm_backward(const float* out_grad, + const float* vals_hat, + const float* vars, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const float* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<<>>( + out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + + LayerNormBackward2<<>>( + out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim); +} + +template <> +void launch_layerNorm_backward<__half>(const __half* out_grad, + const __half* vals_hat, + const __half* vars, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const __half* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<__half><<>>( + out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + + LayerNormBackward2<<>>( + out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2); +} + +/* Backward Normalize (Input-Gradient) + * Using the means and variances from the input + * This type of backward is not invertible! + * We do the backward using the input (X) + */ + +__global__ void LayerNormBackward2(const float* out_grad, + const float* X_vals, + const float* gamma, + const float* vars, + const float* means, + float* inp_grad, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + out_grad += (row * row_stride); + X_vals += (row * row_stride); + inp_grad += (row * row_stride); + + float vals_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad[high_index]; + vals_arr[iterations] *= gamma_reg; + iterations++; + } + + float var_reg = vars[row]; + float mean_reg = means[row]; + + float sum = 0; + float xu[NORM_REG]; + for (int i = 0; i < iterations; i++) { + xu[i] = (X_vals[i * iteration_stride + id] - mean_reg); + sum += vals_arr[i] * xu[i]; + vals_arr[i] *= rsqrtf(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { + vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg)); + } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); + if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum); +} + +__global__ void LayerNormBackward2(const __half* out_grad, + const __half* X_vals, + const __half* gamma, + const __half* vars, + const __half* means, + __half* inp_grad, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; + + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h = reinterpret_cast(out_grad); + const __half2* vals_hat_h = reinterpret_cast(X_vals); + + inp_grad_h += (row * row_stride); + out_grad_h += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; // out_grad * gamma + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h[high_index]; + vals_arr[iterations] *= gamma_reg; // out_grad * gamma + iterations++; + } + __half mean_h = means[row]; + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + __half2 mean_reg = __halves2half2(mean_h, mean_h); + __half2 xu[NORM_REG]; + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg); + __half2 result_h = (xu[i] * vals_arr[i]); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 xu_grad_f = __half22float2(xu_grad); + vals_arr_f[i].x += xu_grad_f.x; + vals_arr_f[i].y += xu_grad_f.y; + } + + sum = 0.f; + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + inp_grad_h[i * iteration_stride + id] = temp; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + inp_grad_h[high_index] = temp; + } +} + +template <> +void launch_layerNorm_backward(const float* out_grad, + const float* X_data, + const float* vars, + const float* means, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<<>>( + out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + LayerNormBackward2<<>>( + out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim); +} + +template <> +void launch_layerNorm_backward<__half>(const __half* out_grad, + const __half* X_data, + const __half* vars, + const __half* means, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<__half><<>>( + out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + LayerNormBackward2<<>>( + out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2); +} + +template +__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, + const T* __restrict__ out_grad2, + const T* __restrict__ vals_hat, + const T* __restrict__ gamma, + const T* __restrict__ betta, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width, + bool invertible) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + float betta_reg = (invertible ? (float)betta[idx] : 0.0f); + float gamma_reg = (float)gamma[idx]; + + // Loop across matrix height + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad1[offset] + (float)out_grad2[offset]; + float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg + : (float)vals_hat[offset]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +template +__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, + const T* __restrict__ out_grad2, + const T* __restrict__ X_data, + const T* __restrict__ vars, + const T* __restrict__ means, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + // Loop across matrix height + + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad1[offset] + (float)out_grad2[offset]; + float val = (float)X_data[offset]; + val = (val - (float)means[r]) * rsqrtf((float)vars[r]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +__global__ void LayerNormBackward2_fused_add(const float* out_grad1, + const float* out_grad2, + const float* vals_hat, + const float* gamma, + const float* betta, + const float* vars, + float* inp_grad, + bool invertible, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + out_grad1 += (row * row_stride); + out_grad2 += (row * row_stride); + vals_hat += (row * row_stride); + inp_grad += (row * row_stride); + + float vals_arr[NORM_REG]; + float vals_hat_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = + (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) / + gamma_reg + : vals_hat[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad1[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = + (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg + : vals_hat[high_index]); + iterations++; + } + + float var_reg = vars[row]; + + float sum = 0; + for (int i = 0; i < iterations; i++) { + sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg); + vals_arr[i] *= rsqrtf(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) + inp_grad[i * iteration_stride + id] = + (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id]; + if ((high_index) < row_stride) + inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index]; +} + +__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, + const __half* out_grad2, + const __half* vals_hat, + const __half* gamma, + const __half* betta, + const __half* vars, + __half* inp_grad, + bool invertible, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + __half2 vals_hat_arr[NORM_REG]; + + // float2 result[iterations]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h1 = reinterpret_cast(out_grad1); + const __half2* out_grad_h2 = reinterpret_cast(out_grad2); + const __half2* vals_hat_h = reinterpret_cast(vals_hat); + + inp_grad_h += (row * row_stride); + out_grad_h1 += (row * row_stride); + out_grad_h2 += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + const __half2* betta_h = (invertible ? reinterpret_cast(betta) : nullptr); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; // out_grad * gamma + vals_hat_arr[i] = + (invertible + ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) / + gamma_reg + : vals_hat_h[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h1[high_index]; + vals_arr[iterations] *= gamma_reg; // out_grad * gamma + vals_hat_arr[iterations] = + (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg + : vals_hat_h[high_index]); + iterations++; + } + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg)); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 temp_f = __half22float2(temp); + vals_arr_f[i].x += temp_f.x; + vals_arr_f[i].y += temp_f.y; + } + sum = 0.f; + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + + inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + + inp_grad_h[high_index] = temp + out_grad_h2[high_index]; + } +} + +template <> +void launch_layerNorm_backward_fused_add(const float* out_grad1, + const float* out_grad2, + const float* vals_hat, + const float* vars, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const float* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + LayerNormBackward1<<>>( + out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim); +} + +template <> +void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, + const __half* out_grad2, + const __half* vals_hat, + const __half* vars, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const __half* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<__half><<>>( + out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2); +} + +/* Backward Normalize (Input-Gradient) + * Using the means and variances from the input + * This type of backward is not invertible! + * We do the backward using the input (X) + */ + +__global__ void LayerNormBackward2_fused_add(const float* out_grad1, + const float* out_grad2, + const float* X_vals, + const float* gamma, + const float* vars, + const float* means, + float* inp_grad, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + float vals_arr[NORM_REG]; + float vals_hat_arr[NORM_REG]; + + out_grad1 += (row * row_stride); + out_grad2 += (row * row_stride); + X_vals += (row * row_stride); + inp_grad += (row * row_stride); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = X_vals[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad1[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = X_vals[high_index]; + iterations++; + } + + float var_reg = vars[row]; + float mean_reg = means[row]; + + float sum = 0; + float xu[NORM_REG]; + for (int i = 0; i < iterations; i++) { + xu[i] = (vals_hat_arr[i] - mean_reg); + sum += vals_arr[i] * xu[i]; + vals_arr[i] *= rsqrtf(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { + vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg)); + } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) + inp_grad[i * iteration_stride + id] = + (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id]; + if ((high_index) < row_stride) + inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index]; +} + +__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, + const __half* out_grad2, + const __half* X_vals, + const __half* gamma, + const __half* vars, + const __half* means, + __half* inp_grad, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; + + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + __half2 vals_hat_arr[NORM_REG]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h1 = reinterpret_cast(out_grad1); + const __half2* out_grad_h2 = reinterpret_cast(out_grad2); + const __half2* vals_hat_h = reinterpret_cast(X_vals); + + out_grad_h1 += (row * row_stride); + out_grad_h2 += (row * row_stride); + inp_grad_h += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; // out_grad * gamma + vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h1[high_index]; + vals_arr[iterations] *= gamma_reg; // out_grad * gamma + vals_hat_arr[iterations] = vals_hat_h[high_index]; + iterations++; + } + + __half mean_h = means[row]; + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + __half2 mean_reg = __halves2half2(mean_h, mean_h); + __half2 xu[NORM_REG]; + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + xu[i] = (vals_hat_arr[i] - mean_reg); + __half2 result_h = (xu[i] * vals_arr[i]); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 xu_grad_f = __half22float2(xu_grad); + vals_arr_f[i].x += xu_grad_f.x; + vals_arr_f[i].y += xu_grad_f.y; + } + + sum = 0.f; + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + inp_grad_h[high_index] = temp + out_grad_h2[high_index]; + } +} + +template <> +void launch_layerNorm_backward_fused_add(const float* out_grad1, + const float* out_grad2, + const float* X_data, + const float* vars, + const float* means, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<<>>( + out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim); +} + +template <> +void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, + const __half* out_grad2, + const __half* X_data, + const __half* vars, + const __half* means, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<__half><<>>( + out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2); +} diff --git a/csrc/transformer/softmax_kernels.cu b/csrc/transformer/softmax_kernels.cu index 4c0d987e6..d0f3dca5c 100644 --- a/csrc/transformer/softmax_kernels.cu +++ b/csrc/transformer/softmax_kernels.cu @@ -1,595 +1,595 @@ -#include -#include "custom_cuda_layers.h" -#include "general_kernels.h" - -namespace cg = cooperative_groups; - -dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads) -{ - int seq_length4 = sequence_length / 4; - int block_compute_size = - (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1); - // Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited: - // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications - // The batch size is typically relatively small, while the sequence length could potentially be - // arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit. - unsigned x = heads * sequence_length / block_compute_size; - unsigned y = batch_size; - return {x, y}; -} - -// Fused attention + softmax -template -__global__ void attn_softmax(float* vals, - const float* attn_mask, - int heads, - int seq_length, - int iterations) -{ - __shared__ float partialSum[MAX_WARP_NUM]; - - int warp_num = blockDim.x >> 5; - - int iteration_stride = blockDim.x; - int block_width = blockStride * seq_length; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int batch = blockIdx.y; - int row = blockIdx.x; - int max_threads_in_sequence = std::max(seq_length, tbSeq); - int seq_lane = threadIdx.x % max_threads_in_sequence; - - int data_offset = batch * (gridDim.x * block_width) + row * block_width + - (threadIdx.x / max_threads_in_sequence) * seq_length; - int mask_offset = batch * seq_length; - - int wid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - - float4* val_cast = reinterpret_cast(vals); - const float4* attn_mask_cast = reinterpret_cast(attn_mask); - - float4 data[MAX_THREAD_ITERATIONS]; - - float max_val = minus_infinity; - - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + seq_lane; - if (data_id < seq_length) { - float4 mask = attn_mask_cast[mask_offset + data_id]; - data[i] = val_cast[data_offset + data_id]; - - data[i].x += mask.x; - data[i].y += mask.y; - data[i].z += mask.z; - data[i].w += mask.w; - - max_val = (data[i].x > max_val ? data[i].x : max_val); - max_val = (data[i].y > max_val ? data[i].y : max_val); - max_val = (data[i].z > max_val ? data[i].z : max_val); - max_val = (data[i].w > max_val ? data[i].w : max_val); - } else { - data[i].x = minus_infinity; - data[i].y = minus_infinity; - data[i].z = minus_infinity; - data[i].w = minus_infinity; - } - } - - for (int i = 1; i < tbSize; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - if (seq_length > tbSize) { - if (lane == 0) partialSum[wid] = max_val; - b.sync(); - - if (lane < warp_num) max_val = partialSum[lane]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - int iters = warp_num; - if (seq_length < iteration_stride) - iters = warp_num / (iteration_stride / max_threads_in_sequence); - - for (int i = 1; i < iters; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - max_val = g.shfl(max_val, threadIdx.x / tbSize); - } - - float sum = 0; - for (int i = 0; i < iterations; i++) { - data[i].x = __expf(data[i].x - max_val); - data[i].y = __expf(data[i].y - max_val); - data[i].z = __expf(data[i].z - max_val); - data[i].w = __expf(data[i].w - max_val); - - sum += (data[i].x + data[i].y + data[i].z + data[i].w); - } - - for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); } - - if (seq_length > tbSize) { - if (lane == 0) partialSum[wid] = sum; - b.sync(); - - if (lane < warp_num) sum = partialSum[lane]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - int iters = warp_num; - if (seq_length < iteration_stride) - iters = warp_num / (iteration_stride / max_threads_in_sequence); - - for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } - - sum = g.shfl(sum, threadIdx.x / tbSize); - } - - sum += 1e-6; - - for (int i = 0; i < iterations; i++) { - data[i].x /= sum; - data[i].y /= sum; - data[i].z /= sum; - data[i].w /= sum; - - int data_id = i * iteration_stride + seq_lane; - if (data_id < seq_length) val_cast[data_offset + data_id] = data[i]; - } -} - -template -__global__ void attn_softmax(__half* vals, - const __half* attn_mask, - int heads, - int seq_length, - int iterations) -{ -#if __CUDA_ARCH__ >= 700 - __shared__ float partialSum[MAX_WARP_NUM]; - - int warp_num = blockDim.x >> 5; - - int iteration_stride = blockDim.x; - int block_width = blockStride * seq_length; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int batch = blockIdx.y; - int row = blockIdx.x; - int max_threads_in_sequence = std::max(seq_length, tbSeq); - int seq_lane = threadIdx.x % max_threads_in_sequence; - - int data_offset = batch * (gridDim.x * block_width) + row * block_width + - (threadIdx.x / max_threads_in_sequence) * seq_length; - int mask_offset = batch * seq_length; - - int wid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - - float2* val_cast = reinterpret_cast(vals); - const float2* attn_mask_cast = reinterpret_cast(attn_mask); - - val_cast += data_offset; - attn_mask_cast += mask_offset; - - float2 low_data[MAX_THREAD_ITERATIONS]; - float2 high_data[MAX_THREAD_ITERATIONS]; - - float max_val = minus_infinity; - - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + seq_lane; - if (data_id < seq_length) { - float2 data = val_cast[data_id]; - float2 mask = attn_mask_cast[data_id]; - - __half2* data_arr = reinterpret_cast<__half2*>(&data); - __half2* mask_arr = reinterpret_cast<__half2*>(&mask); - - low_data[i] = __half22float2(data_arr[0]); - high_data[i] = __half22float2(data_arr[1]); - float2 low_mask = __half22float2(mask_arr[0]); - float2 high_mask = __half22float2(mask_arr[1]); - - low_data[i].x += low_mask.x; - low_data[i].y += low_mask.y; - high_data[i].x += high_mask.x; - high_data[i].y += high_mask.y; - - max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); - max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); - max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); - max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); - } - } - - for (int i = 1; i < tbSize; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - if (seq_length > tbSize) { - if (lane == 0) partialSum[wid] = max_val; - b.sync(); - - if (lane < warp_num) max_val = partialSum[lane]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - int iters = warp_num; - if (seq_length < iteration_stride) - iters = warp_num / (iteration_stride / max_threads_in_sequence); - - for (int i = 1; i < iters; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - max_val = g.shfl(max_val, threadIdx.x / tbSize); - } - - float sum = 0; - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + seq_lane; - if (data_id < seq_length) { - low_data[i].x = __expf(low_data[i].x - max_val); - low_data[i].y = __expf(low_data[i].y - max_val); - high_data[i].x = __expf(high_data[i].x - max_val); - high_data[i].y = __expf(high_data[i].y - max_val); - - sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y); - } - } - - for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); } - - if (seq_length > tbSize) { - if (lane == 0) partialSum[wid] = sum; - b.sync(); - - if (lane < warp_num) sum = partialSum[lane]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - int iters = warp_num; - if (seq_length < iteration_stride) - iters = warp_num / (iteration_stride / max_threads_in_sequence); - - for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } - - sum = g.shfl(sum, threadIdx.x / tbSize); - } - - sum += 1e-6; - - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + seq_lane; - if (data_id < seq_length) { - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - - low_data[i].x /= sum; - low_data[i].y /= sum; - high_data[i].x /= sum; - high_data[i].y /= sum; - - result_h[0] = __float22half2_rn(low_data[i]); - result_h[1] = __float22half2_rn(high_data[i]); - - val_cast[data_id] = result_f; - } - } - -#endif -} - -template -void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t); - -template <> -void launch_attn_softmax(float* vals, - const float* attn_mask, - int batch_size, - int heads, - int sequence_length, - cudaStream_t stream) -{ - const int threads = 128; - int seq_length4 = sequence_length / 4; - - dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); - - int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; - - dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / - subblock_max_workload * threads) - : threads); - int iterations = - (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads - : MAX_THREAD_ITERATIONS); - - if (sequence_length <= 8) - attn_softmax<2, (threads / 2), 2> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 16) - attn_softmax<4, (threads / 4), 4> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 32) - attn_softmax<8, (threads / 8), 8> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 64) - attn_softmax<16, (threads / 16), 16> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 128) - attn_softmax<32, (threads / 32), 32> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 256) - attn_softmax<32, (threads / 64), 64> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else { - const int threads = 256; - dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); - - int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; - - dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / - subblock_max_workload * threads) - : threads); - iterations = - (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads - : MAX_THREAD_ITERATIONS); - if (sequence_length <= 512) - attn_softmax<32, (threads / 128), 128><<>>( - vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4)) - attn_softmax<32, 1, 128><<>>( - vals, attn_mask, heads, seq_length4, iterations); - else - throw std::runtime_error( - "Unsupport Seq_Length! Check the restriction of the max_threads and " - "max_thread_iterations!"); - } -} - -template <> -void launch_attn_softmax<__half>(__half* vals, - const __half* attn_mask, - int batch_size, - int heads, - int sequence_length, - cudaStream_t stream) -{ - const int threads = 128; - int seq_length4 = sequence_length / 4; - - dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); - - int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; - - dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / - subblock_max_workload * threads) - : threads); - - int iterations = - (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads - : MAX_THREAD_ITERATIONS); - - if (sequence_length <= 8) - attn_softmax<2, (threads / 2), 2> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 16) - attn_softmax<4, (threads / 4), 4> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 32) - attn_softmax<8, (threads / 8), 8> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 64) - attn_softmax<16, (threads / 16), 16> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 128) - attn_softmax<32, (threads / 32), 32> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 256) - attn_softmax<32, (threads / 64), 64> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else { - const int threads = 256; - dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); - - int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; - - dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / - subblock_max_workload * threads) - : threads); - iterations = - (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads - : MAX_THREAD_ITERATIONS); - if (sequence_length <= 512) - attn_softmax<32, (threads / 128), 128><<>>( - vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4)) - attn_softmax<32, 1, 128><<>>( - vals, attn_mask, heads, seq_length4, iterations); - else - throw std::runtime_error( - "Unsupport Seq_Length! Check the restriction of the max_threads and " - "max_thread_iterations!"); - } -} - -template -__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length) -{ - __shared__ float partialSum[MAX_WARP_NUM]; - - int warp_num = blockDim.x >> 5; // warp-count = num_threads / WARP_SIZE (32) - - int iteration_stride = blockDim.x; - int block_width = blockStride * seq_length; - - int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride) - ? (seq_length + iteration_stride - 1) / iteration_stride - : MAX_THREAD_ITERATIONS); - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - - int wid = id >> 5; - int lane = id & 0x1f; - - T val_reg[MAX_THREAD_ITERATIONS]; - T soft_reg[MAX_THREAD_ITERATIONS]; - float grad_reg = 0.0f; - -#pragma unroll - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + id; - if (data_id < block_width) { - val_reg[i] = out_grad[row * block_width + data_id]; - soft_reg[i] = soft_inp[row * block_width + data_id]; - - grad_reg += ((float)val_reg[i] * - (float)soft_reg[i]); // if done in half, the multiplication, we may lose - // 2% of accuracy in computation!! - } - } - for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i); - - if (seq_length > tbSize) { - if (lane == 0) partialSum[wid] = grad_reg; - b.sync(); - - if (lane < warp_num) grad_reg = partialSum[lane]; - - int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); - - for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i); - - grad_reg = g.shfl(grad_reg, id / tbSize); - } - - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + id; - if (data_id < block_width) { - float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg); - out_grad[row * block_width + data_id] = (T)temp; - } - } -} - -template -__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/, - const T* output, - int softmax_length) -{ - int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; - int offset = batch_idx * softmax_length + threadIdx.x; - - grad += offset; - output += offset; - - T grad_reg[ITERATIONS]; - T output_reg[ITERATIONS]; - float sum = 0.0; - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) { - grad_reg[i] = grad[i * WARP_SIZE]; - output_reg[i] = output[i * WARP_SIZE]; - sum += (float)grad_reg[i] * (float)output_reg[i]; - } - } - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) - grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum); - } -} - -template -void launch_attn_softmax_backward_v2(T* out_grad, - const T* soft_inp, - int batch_size, - int heads, - int seq_length, - cudaStream_t stream) -{ - const int warps_per_block = 4; - dim3 grid_dim(batch_size * heads * seq_length / warps_per_block); - dim3 block_dim(WARP_SIZE, warps_per_block); - - if (seq_length <= 32) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 64) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 128) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 256) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 384) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 512) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 768) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 1024) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 2048) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else - throw std::runtime_error( - std::string("Special sequence length found in softmax backward, seq_length: ") + - std::to_string(seq_length)); -} - -template void launch_attn_softmax_backward_v2<__half>(__half* out_grad, - const __half* soft_inp, - int batch_size, - int heads, - int seq_length, - cudaStream_t stream); -template void launch_attn_softmax_backward_v2(float* out_grad, - const float* soft_inp, - int batch_size, - int heads, - int seq_length, - cudaStream_t stream); +#include +#include "custom_cuda_layers.h" +#include "general_kernels.h" + +namespace cg = cooperative_groups; + +dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads) +{ + int seq_length4 = sequence_length / 4; + int block_compute_size = + (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1); + // Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited: + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications + // The batch size is typically relatively small, while the sequence length could potentially be + // arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit. + unsigned x = heads * sequence_length / block_compute_size; + unsigned y = batch_size; + return {x, y}; +} + +// Fused attention + softmax +template +__global__ void attn_softmax(float* vals, + const float* attn_mask, + int heads, + int seq_length, + int iterations) +{ + __shared__ float partialSum[MAX_WARP_NUM]; + + int warp_num = blockDim.x >> 5; + + int iteration_stride = blockDim.x; + int block_width = blockStride * seq_length; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int batch = blockIdx.y; + int row = blockIdx.x; + int max_threads_in_sequence = std::max(seq_length, tbSeq); + int seq_lane = threadIdx.x % max_threads_in_sequence; + + int data_offset = batch * (gridDim.x * block_width) + row * block_width + + (threadIdx.x / max_threads_in_sequence) * seq_length; + int mask_offset = batch * seq_length; + + int wid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + + float4* val_cast = reinterpret_cast(vals); + const float4* attn_mask_cast = reinterpret_cast(attn_mask); + + float4 data[MAX_THREAD_ITERATIONS]; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float4 mask = attn_mask_cast[mask_offset + data_id]; + data[i] = val_cast[data_offset + data_id]; + + data[i].x += mask.x; + data[i].y += mask.y; + data[i].z += mask.z; + data[i].w += mask.w; + + max_val = (data[i].x > max_val ? data[i].x : max_val); + max_val = (data[i].y > max_val ? data[i].y : max_val); + max_val = (data[i].z > max_val ? data[i].z : max_val); + max_val = (data[i].w > max_val ? data[i].w : max_val); + } else { + data[i].x = minus_infinity; + data[i].y = minus_infinity; + data[i].z = minus_infinity; + data[i].w = minus_infinity; + } + } + + for (int i = 1; i < tbSize; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / tbSize); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + data[i].x = __expf(data[i].x - max_val); + data[i].y = __expf(data[i].y - max_val); + data[i].z = __expf(data[i].z - max_val); + data[i].w = __expf(data[i].w - max_val); + + sum += (data[i].x + data[i].y + data[i].z + data[i].w); + } + + for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / tbSize); + } + + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + data[i].x /= sum; + data[i].y /= sum; + data[i].z /= sum; + data[i].w /= sum; + + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) val_cast[data_offset + data_id] = data[i]; + } +} + +template +__global__ void attn_softmax(__half* vals, + const __half* attn_mask, + int heads, + int seq_length, + int iterations) +{ +#if __CUDA_ARCH__ >= 700 + __shared__ float partialSum[MAX_WARP_NUM]; + + int warp_num = blockDim.x >> 5; + + int iteration_stride = blockDim.x; + int block_width = blockStride * seq_length; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int batch = blockIdx.y; + int row = blockIdx.x; + int max_threads_in_sequence = std::max(seq_length, tbSeq); + int seq_lane = threadIdx.x % max_threads_in_sequence; + + int data_offset = batch * (gridDim.x * block_width) + row * block_width + + (threadIdx.x / max_threads_in_sequence) * seq_length; + int mask_offset = batch * seq_length; + + int wid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + + float2* val_cast = reinterpret_cast(vals); + const float2* attn_mask_cast = reinterpret_cast(attn_mask); + + val_cast += data_offset; + attn_mask_cast += mask_offset; + + float2 low_data[MAX_THREAD_ITERATIONS]; + float2 high_data[MAX_THREAD_ITERATIONS]; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float2 data = val_cast[data_id]; + float2 mask = attn_mask_cast[data_id]; + + __half2* data_arr = reinterpret_cast<__half2*>(&data); + __half2* mask_arr = reinterpret_cast<__half2*>(&mask); + + low_data[i] = __half22float2(data_arr[0]); + high_data[i] = __half22float2(data_arr[1]); + float2 low_mask = __half22float2(mask_arr[0]); + float2 high_mask = __half22float2(mask_arr[1]); + + low_data[i].x += low_mask.x; + low_data[i].y += low_mask.y; + high_data[i].x += high_mask.x; + high_data[i].y += high_mask.y; + + max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); + max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); + max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); + max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); + } + } + + for (int i = 1; i < tbSize; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / tbSize); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + low_data[i].x = __expf(low_data[i].x - max_val); + low_data[i].y = __expf(low_data[i].y - max_val); + high_data[i].x = __expf(high_data[i].x - max_val); + high_data[i].y = __expf(high_data[i].y - max_val); + + sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y); + } + } + + for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / tbSize); + } + + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + low_data[i].x /= sum; + low_data[i].y /= sum; + high_data[i].x /= sum; + high_data[i].y /= sum; + + result_h[0] = __float22half2_rn(low_data[i]); + result_h[1] = __float22half2_rn(high_data[i]); + + val_cast[data_id] = result_f; + } + } + +#endif +} + +template +void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t); + +template <> +void launch_attn_softmax(float* vals, + const float* attn_mask, + int batch_size, + int heads, + int sequence_length, + cudaStream_t stream) +{ + const int threads = 128; + int seq_length4 = sequence_length / 4; + + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + int iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + + if (sequence_length <= 8) + attn_softmax<2, (threads / 2), 2> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 16) + attn_softmax<4, (threads / 4), 4> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 32) + attn_softmax<8, (threads / 8), 8> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 64) + attn_softmax<16, (threads / 16), 16> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 128) + attn_softmax<32, (threads / 32), 32> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 256) + attn_softmax<32, (threads / 64), 64> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else { + const int threads = 256; + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + if (sequence_length <= 512) + attn_softmax<32, (threads / 128), 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4)) + attn_softmax<32, 1, 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else + throw std::runtime_error( + "Unsupport Seq_Length! Check the restriction of the max_threads and " + "max_thread_iterations!"); + } +} + +template <> +void launch_attn_softmax<__half>(__half* vals, + const __half* attn_mask, + int batch_size, + int heads, + int sequence_length, + cudaStream_t stream) +{ + const int threads = 128; + int seq_length4 = sequence_length / 4; + + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + + int iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + + if (sequence_length <= 8) + attn_softmax<2, (threads / 2), 2> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 16) + attn_softmax<4, (threads / 4), 4> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 32) + attn_softmax<8, (threads / 8), 8> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 64) + attn_softmax<16, (threads / 16), 16> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 128) + attn_softmax<32, (threads / 32), 32> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 256) + attn_softmax<32, (threads / 64), 64> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else { + const int threads = 256; + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + if (sequence_length <= 512) + attn_softmax<32, (threads / 128), 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4)) + attn_softmax<32, 1, 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else + throw std::runtime_error( + "Unsupport Seq_Length! Check the restriction of the max_threads and " + "max_thread_iterations!"); + } +} + +template +__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length) +{ + __shared__ float partialSum[MAX_WARP_NUM]; + + int warp_num = blockDim.x >> 5; // warp-count = num_threads / WARP_SIZE (32) + + int iteration_stride = blockDim.x; + int block_width = blockStride * seq_length; + + int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride) + ? (seq_length + iteration_stride - 1) / iteration_stride + : MAX_THREAD_ITERATIONS); + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + + int wid = id >> 5; + int lane = id & 0x1f; + + T val_reg[MAX_THREAD_ITERATIONS]; + T soft_reg[MAX_THREAD_ITERATIONS]; + float grad_reg = 0.0f; + +#pragma unroll + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + id; + if (data_id < block_width) { + val_reg[i] = out_grad[row * block_width + data_id]; + soft_reg[i] = soft_inp[row * block_width + data_id]; + + grad_reg += ((float)val_reg[i] * + (float)soft_reg[i]); // if done in half, the multiplication, we may lose + // 2% of accuracy in computation!! + } + } + for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i); + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = grad_reg; + b.sync(); + + if (lane < warp_num) grad_reg = partialSum[lane]; + + int iters = warp_num; + if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + + for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i); + + grad_reg = g.shfl(grad_reg, id / tbSize); + } + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + id; + if (data_id < block_width) { + float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg); + out_grad[row * block_width + data_id] = (T)temp; + } + } +} + +template +__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/, + const T* output, + int softmax_length) +{ + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + grad += offset; + output += offset; + + T grad_reg[ITERATIONS]; + T output_reg[ITERATIONS]; + float sum = 0.0; + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) { + grad_reg[i] = grad[i * WARP_SIZE]; + output_reg[i] = output[i * WARP_SIZE]; + sum += (float)grad_reg[i] * (float)output_reg[i]; + } + } + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) + grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum); + } +} + +template +void launch_attn_softmax_backward_v2(T* out_grad, + const T* soft_inp, + int batch_size, + int heads, + int seq_length, + cudaStream_t stream) +{ + const int warps_per_block = 4; + dim3 grid_dim(batch_size * heads * seq_length / warps_per_block); + dim3 block_dim(WARP_SIZE, warps_per_block); + + if (seq_length <= 32) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 64) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 128) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 256) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 384) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 512) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 768) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 1024) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 2048) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else + throw std::runtime_error( + std::string("Special sequence length found in softmax backward, seq_length: ") + + std::to_string(seq_length)); +} + +template void launch_attn_softmax_backward_v2<__half>(__half* out_grad, + const __half* soft_inp, + int batch_size, + int heads, + int seq_length, + cudaStream_t stream); +template void launch_attn_softmax_backward_v2(float* out_grad, + const float* soft_inp, + int batch_size, + int heads, + int seq_length, + cudaStream_t stream); diff --git a/csrc/transformer/transform_kernels.cu b/csrc/transformer/transform_kernels.cu index 7d8a27eee..b7924c6f2 100755 --- a/csrc/transformer/transform_kernels.cu +++ b/csrc/transformer/transform_kernels.cu @@ -1,575 +1,575 @@ -#include "custom_cuda_layers.h" - -#define rows_trans 16 -#define cols_trans 16 - -template -__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width) -{ - __shared__ T data_block[rows_trans * (cols_trans + 1)]; - - int r = threadIdx.x / cols_trans; - int c = threadIdx.x % cols_trans; - - int m = row_width / cols_trans; - - int i = blockIdx.x / m * rows_trans + r; - int j = blockIdx.x % m * cols_trans + c; - - int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS); - - for (int k = 0; k < rows_trans; k += row_stride) - data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j]; - - __syncthreads(); - - i = blockIdx.x % m * rows_trans + r; - j = blockIdx.x / m * cols_trans + c; - - for (int k = 0; k < rows_trans; k += row_stride) - out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k]; -} - -template <> -void Transpose<__half>(const __half* inp_mat, - __half* out_mat, - int rows, - int cols, - cudaStream_t stream) -{ - int threads = THREADS; - - Transpose_Kernel<__half><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>( - inp_mat, out_mat, cols, rows); -} - -template <> -void Transpose(const float* inp_mat, float* out_mat, int rows, int cols, cudaStream_t stream) -{ - int threads = THREADS; - - Transpose_Kernel<<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>( - inp_mat, out_mat, cols, rows); -} - -template -__global__ void transform_0213(T* output, - const T* vals, - int hidden_dim, - int seq_length, - int heads, - int head_ext); - -template <> -__global__ void transform_0213(float* output, - const float* vals, - int hidden_dim, - int seq_length, - int heads, - int head_ext) -{ - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = d2_stride * seq_length; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y / head_ext; // Sequence ID (0-127) - int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = threadIdx.x; // Values (groups of 4) - - const float4* vals_vec = reinterpret_cast(vals); - float4* output_vec = reinterpret_cast(output); - - float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; - output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs; -} - -template <> -__global__ void transform_0213<__half>(__half* output, - const __half* vals, - int hidden_dim, - int seq_length, - int heads, - int head_ext) -{ -#if __CUDA_ARCH__ >= 700 - - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = d2_stride * seq_length; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y / head_ext; // Sequence ID (0-127) - int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = threadIdx.x; // Values (groups of 4) - - float4 vals_arr[1]; - - const float4* vals_vec = reinterpret_cast(vals); - float4* output_vec = reinterpret_cast(output); - - vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; - output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0]; -#endif -} - -template <> -void launch_transform_0213(float* output, - const float* vals, - int batch_size, - int seq_length, - int hidden_dim, - int heads, - cudaStream_t stream) -{ - hidden_dim >>= 2; - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - dim3 block_dim(hidden_dim / heads, (heads / head_ext)); - dim3 grid_dim(batch_size, (seq_length * head_ext)); - - transform_0213 - <<>>(output, vals, hidden_dim, seq_length, heads, head_ext); -} - -template <> -void launch_transform_0213<__half>(__half* output, - const __half* vals, - int batch_size, - int seq_length, - int hidden_dim, - int heads, - cudaStream_t stream) -{ - hidden_dim >>= 3; - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - dim3 block_dim(hidden_dim / heads, (heads / head_ext)); - dim3 grid_dim(batch_size, (seq_length * head_ext)); - transform_0213<__half> - <<>>(output, vals, hidden_dim, seq_length, heads, head_ext); -} - -// Bias add -template -__global__ void bias_add_transform_0213(T* output, - const T* vals, - const T* bias, - int hidden_dim, - int seq_length, - int heads, - int head_ext); - -template <> -__global__ void bias_add_transform_0213(float* output, - const float* vals, - const float* bias, - int hidden_dim, - int seq_length, - int heads, - int head_ext) -{ - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = d2_stride * seq_length; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y; // Sequence ID (0-127) - int cnt = blockIdx.z / head_ext; // Hidden count - int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = threadIdx.x; // Values (groups of 4) - - const float4* vals_vec = reinterpret_cast(vals); - const float4* bias_vec = reinterpret_cast(bias); - float4* output_vec = reinterpret_cast(output); - - float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride + - d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3]; - float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; - - float4 outputs; - outputs.x = inputs.x + biases.x; - outputs.y = inputs.y + biases.y; - outputs.z = inputs.z + biases.z; - outputs.w = inputs.w + biases.w; - - output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride + - d2 * d2_out_stride + d3] = outputs; -} - -#define ATTN_H 3 -#define MAX_SEQ_LINE 10 - -template <> -__global__ void bias_add_transform_0213<__half>(__half* output, - const __half* vals, - const __half* bias, - int hidden_dim, - int seq_length, - int heads, - int head_ext) -{ -#if __CUDA_ARCH__ >= 700 - - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d2_out_stride = d2_stride * seq_length; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y; // Sequence ID (0-127) - int cnt = blockIdx.z / head_ext; // Hidden count - int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = threadIdx.x; // Values (groups of 4) - - float4 vals_arr; - float4 bias_arr; - float4 output_arr; - __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_arr); - __half2* output_half = reinterpret_cast<__half2*>(&output_arr); - - const float4* vals_vec = reinterpret_cast(vals); - const float4* bias_vec = reinterpret_cast(bias); - float4* output_vec = reinterpret_cast(output); - - vals_vec += (d0 * d0_stride * (gridDim.z / head_ext)); - vals_vec += (d1 * d1_stride * (gridDim.z / head_ext)); - vals_vec += (cnt * d1_stride); - vals_vec += (d2 * d2_stride); - - bias_vec += (cnt * d1_stride); - bias_vec += (d2 * d2_stride); - - output_vec += (cnt * d0_stride * gridDim.x); - output_vec += (d1 * d2_stride); - output_vec += (d0 * d0_stride); - output_vec += (d2 * d2_out_stride); - - bias_arr = bias_vec[d3]; - vals_arr = vals_vec[d3]; - -#if defined(__ACC_HALF__) - output_half[0] = vals_half[0] + bias_half[0]; - output_half[1] = vals_half[1] + bias_half[1]; - output_half[2] = vals_half[2] + bias_half[2]; - output_half[3] = vals_half[3] + bias_half[3]; -#else - float2 bias_arr_f[4]; - float2 vals_arr_f[4]; -#pragma unroll - for (int l = 0; l < 4; l++) { - bias_arr_f[l] = __half22float2(bias_half[l]); - vals_arr_f[l] = __half22float2(vals_half[l]); - vals_arr_f[l].x += bias_arr_f[l].x; - vals_arr_f[l].y += bias_arr_f[l].y; - output_half[l] = __float22half2_rn(vals_arr_f[l]); - } -#endif - output_vec[d3] = output_arr; - -#endif -} - -__global__ void bias_add_transform_0213_v2(__half* output, - const __half* vals, - const __half* bias, - int hidden_dim, - int seq_length, - int heads) -{ -#if __CUDA_ARCH__ >= 700 - __shared__ float4 in_data[3072]; - - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8 - int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8 - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = d2_stride * seq_length; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y; // Sequence ID (0-127) - int cnt = threadIdx.z; // blockIdx.z; // Hidden count - int d2 = threadIdx.y; // Head (0-11) - int d3 = threadIdx.x; // Values (groups of 4) - - float4 vals_arr[1]; - float4 bias_arr[1]; - float4 output_arr[1]; - __half2* vals_half = reinterpret_cast<__half2*>(vals_arr); - __half2* bias_half = reinterpret_cast<__half2*>(bias_arr); - __half2* output_half = reinterpret_cast<__half2*>(output_arr); - - const float4* vals_vec = reinterpret_cast(vals); - const float4* bias_vec = reinterpret_cast(bias); - float4* output_vec = reinterpret_cast(output); - - int iter_index = cnt * d1_stride + d2 * d2_stride + d3; - int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); - bias_arr[0] = bias_vec[iter_index]; - -#pragma unroll - for (int iter = 0; iter < 2; iter++) { - int iter_id = iter * iteration_stride + iter_index; - vals_arr[0] = vals_vec[input_offset + iter_id]; - - output_half[0] = vals_half[0] + bias_half[0]; - output_half[1] = vals_half[1] + bias_half[1]; - output_half[2] = vals_half[2] + bias_half[2]; - output_half[3] = vals_half[3] + bias_half[3]; - - in_data[iter_id] = output_arr[0]; - } - __syncthreads(); - - iteration_stride = blockDim.z * (blockDim.y >> 1); - int matrix_stride = (d0_out_stride * gridDim.x); - int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1); - - int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride; - -#pragma unroll - for (int iter = 0; iter < 2; iter++) { - int iter_row = (iter * iteration_stride) + head_count; - int iter_offset = - (iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride; - output_vec[out_index + iter_offset] = - in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)]; - } -#endif -} - -// [B S C*H] - > C * [B A S N] -template <> -void launch_bias_add_transform_0213(float* output, - const float* vals, - const float* bias, - int batch_size, - int seq_length, - int hidden_dim, - int heads, - cudaStream_t stream, - int trans_count) -{ - hidden_dim >>= 2; - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - - dim3 block_dim(hidden_dim / heads, (heads / head_ext)); - dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); - - bias_add_transform_0213<<>>( - output, vals, bias, hidden_dim, seq_length, heads, head_ext); -} - -template <> -void launch_bias_add_transform_0213<__half>(__half* output, - const __half* vals, - const __half* bias, - int batch_size, - int seq_length, - int hidden_dim, - int heads, - cudaStream_t stream, - int trans_count) -{ - hidden_dim >>= 3; - if (hidden_dim > 128 || hidden_dim < 16) { - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - dim3 block_dim(hidden_dim / heads, (heads / head_ext)); - dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); - bias_add_transform_0213<__half><<>>( - output, vals, bias, hidden_dim, seq_length, heads, head_ext); - } else { - dim3 block_dim(hidden_dim / heads, heads, trans_count); - dim3 grid_dim(batch_size, seq_length / 2); - bias_add_transform_0213_v2<<>>( - output, vals, bias, hidden_dim, seq_length, heads); - } -} - -template -__global__ void transform4d_0213(T* out, - const T* in, - int heads, - int seq_length, - int hidden_dim, - int head_ext); - -template <> -__global__ void transform4d_0213(float* out, - const float* in, - int heads, - int seq_length, - int hidden_dim, - int head_ext) -{ - int d0_stride = hidden_dim * seq_length; - int d1_stride = d0_stride / heads; - int d2_stride = hidden_dim / heads; - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = hidden_dim; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head - int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length; - int cnt = blockIdx.z; - int d3 = threadIdx.x; // Values (groups of 8) - - if (d2 < seq_length) { - const float4* in_vec = reinterpret_cast(in); - float4* out_vec = reinterpret_cast(out); - - float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride + - d2 * d2_stride + d3]; - out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride + - d2 * d2_out_stride * gridDim.z + d3] = vals_vec; - } -} - -template <> -__global__ void transform4d_0213<__half>(__half* out, - const __half* in, - int heads, - int seq_length, - int hidden_dim, - int head_ext) -{ -#if __CUDA_ARCH__ >= 700 - - int d0_stride = hidden_dim * (seq_length / head_ext); - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0 = blockIdx.x; // Batch - int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head - int d2 = blockIdx.z / head_ext; // Sequence - int cnt = blockIdx.y; // Hidden count - int d3 = threadIdx.x; // Values (groups of 8) - - const float4* in_vec = reinterpret_cast(in); - float4* out_vec = reinterpret_cast(out); - - in_vec += (cnt * d0_stride * gridDim.x); - in_vec += (d0 * d0_stride); - in_vec += (d2 * d2_stride); - in_vec += (d1 * d2_stride * seq_length); - - out_vec += (cnt * d1_stride); - out_vec += (d1 * d2_stride); - out_vec += (d0 * d0_stride * gridDim.y); - out_vec += (d2 * d1_stride * gridDim.y); - - out_vec[d3] = in_vec[d3]; - -#endif -} - -__global__ void transform4d_0213_v2(__half* out, - const __half* in, - int heads, - int seq_length, - int hidden_dim) -{ -#if __CUDA_ARCH__ >= 700 - __shared__ float4 in_data[3072]; - - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0 = blockIdx.x; // Batch - int d1 = threadIdx.y; // Head - int d2 = blockIdx.y; // Sequence - int cnt = threadIdx.z; // Hidden count - int d3 = threadIdx.x; // Values (groups of 8) - - const float4* in_vec = reinterpret_cast(in); - float4* out_vec = reinterpret_cast(out); - - int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride; - int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1); - int iteration_stride = blockDim.z * (blockDim.y >> 1); - int matrix_stride = (d0_stride * gridDim.x); - -#pragma unroll - for (int iter = 0; iter < 2; iter++) { - int iter_row = iter * iteration_stride + head_count; - int iter_offset = (iter_row % blockDim.y) * d2_stride; - - in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] = - in_vec[input_offset + iter_offset * seq_length + - (iter_row / blockDim.y) * matrix_stride]; - } - __syncthreads(); - - iteration_stride = d1_stride * blockDim.z; - int iter_index = cnt * d1_stride + d1 * d2_stride + d3; - int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1); - -#pragma unroll - for (int iter = 0; iter < 2; iter++) { - int iter_id = iter * iteration_stride + iter_index; - out_vec[output_offset + iter_id] = in_data[iter_id]; - } -#endif -} - -// 3 * [B A S N] - > [B S C*H] -template <> -void launch_transform4d_0213(float* out, - const float* in, - int batch_size, - int heads, - int seq_length, - int hidden_dim, - cudaStream_t stream, - int trans_count) -{ - hidden_dim >>= 2; - dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count); - dim3 block_dims(hidden_dim / heads, 8); - transform4d_0213 - <<>>(out, in, heads, seq_length, hidden_dim, 1); -} - -template <> -void launch_transform4d_0213<__half>(__half* out, - const __half* in, - int batch_size, - int heads, - int seq_length, - int hidden_dim, - cudaStream_t stream, - int trans_count) -{ - hidden_dim >>= 3; - if (hidden_dim > 128 || hidden_dim < 16) { - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext)); - dim3 block_dims(hidden_dim / heads, (heads / head_ext)); - transform4d_0213<__half><<>>( - out, in, heads, seq_length, hidden_dim, head_ext); - } else { - dim3 grid_dims(batch_size, seq_length / 2); - dim3 block_dims(hidden_dim / heads, heads, trans_count); - transform4d_0213_v2<<>>( - out, in, heads, seq_length, hidden_dim); - } -} +#include "custom_cuda_layers.h" + +#define rows_trans 16 +#define cols_trans 16 + +template +__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width) +{ + __shared__ T data_block[rows_trans * (cols_trans + 1)]; + + int r = threadIdx.x / cols_trans; + int c = threadIdx.x % cols_trans; + + int m = row_width / cols_trans; + + int i = blockIdx.x / m * rows_trans + r; + int j = blockIdx.x % m * cols_trans + c; + + int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS); + + for (int k = 0; k < rows_trans; k += row_stride) + data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j]; + + __syncthreads(); + + i = blockIdx.x % m * rows_trans + r; + j = blockIdx.x / m * cols_trans + c; + + for (int k = 0; k < rows_trans; k += row_stride) + out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k]; +} + +template <> +void Transpose<__half>(const __half* inp_mat, + __half* out_mat, + int rows, + int cols, + cudaStream_t stream) +{ + int threads = THREADS; + + Transpose_Kernel<__half><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>( + inp_mat, out_mat, cols, rows); +} + +template <> +void Transpose(const float* inp_mat, float* out_mat, int rows, int cols, cudaStream_t stream) +{ + int threads = THREADS; + + Transpose_Kernel<<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>( + inp_mat, out_mat, cols, rows); +} + +template +__global__ void transform_0213(T* output, + const T* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext); + +template <> +__global__ void transform_0213(float* output, + const float* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y / head_ext; // Sequence ID (0-127) + int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = reinterpret_cast(output); + + float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; + output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs; +} + +template <> +__global__ void transform_0213<__half>(__half* output, + const __half* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ +#if __CUDA_ARCH__ >= 700 + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y / head_ext; // Sequence ID (0-127) + int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr[1]; + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = reinterpret_cast(output); + + vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; + output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0]; +#endif +} + +template <> +void launch_transform_0213(float* output, + const float* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, (seq_length * head_ext)); + + transform_0213 + <<>>(output, vals, hidden_dim, seq_length, heads, head_ext); +} + +template <> +void launch_transform_0213<__half>(__half* output, + const __half* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream) +{ + hidden_dim >>= 3; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, (seq_length * head_ext)); + transform_0213<__half> + <<>>(output, vals, hidden_dim, seq_length, heads, head_ext); +} + +// Bias add +template +__global__ void bias_add_transform_0213(T* output, + const T* vals, + const T* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext); + +template <> +__global__ void bias_add_transform_0213(float* output, + const float* vals, + const float* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride + + d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3]; + float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; + + float4 outputs; + outputs.x = inputs.x + biases.x; + outputs.y = inputs.y + biases.y; + outputs.z = inputs.z + biases.z; + outputs.w = inputs.w + biases.w; + + output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride + + d2 * d2_out_stride + d3] = outputs; +} + +#define ATTN_H 3 +#define MAX_SEQ_LINE 10 + +template <> +__global__ void bias_add_transform_0213<__half>(__half* output, + const __half* vals, + const __half* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ +#if __CUDA_ARCH__ >= 700 + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr; + float4 bias_arr; + float4 output_arr; + __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_arr); + __half2* output_half = reinterpret_cast<__half2*>(&output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + vals_vec += (d0 * d0_stride * (gridDim.z / head_ext)); + vals_vec += (d1 * d1_stride * (gridDim.z / head_ext)); + vals_vec += (cnt * d1_stride); + vals_vec += (d2 * d2_stride); + + bias_vec += (cnt * d1_stride); + bias_vec += (d2 * d2_stride); + + output_vec += (cnt * d0_stride * gridDim.x); + output_vec += (d1 * d2_stride); + output_vec += (d0 * d0_stride); + output_vec += (d2 * d2_out_stride); + + bias_arr = bias_vec[d3]; + vals_arr = vals_vec[d3]; + +#if defined(__ACC_HALF__) + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; +#else + float2 bias_arr_f[4]; + float2 vals_arr_f[4]; +#pragma unroll + for (int l = 0; l < 4; l++) { + bias_arr_f[l] = __half22float2(bias_half[l]); + vals_arr_f[l] = __half22float2(vals_half[l]); + vals_arr_f[l].x += bias_arr_f[l].x; + vals_arr_f[l].y += bias_arr_f[l].y; + output_half[l] = __float22half2_rn(vals_arr_f[l]); + } +#endif + output_vec[d3] = output_arr; + +#endif +} + +__global__ void bias_add_transform_0213_v2(__half* output, + const __half* vals, + const __half* bias, + int hidden_dim, + int seq_length, + int heads) +{ +#if __CUDA_ARCH__ >= 700 + __shared__ float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8 + int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8 + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = threadIdx.z; // blockIdx.z; // Hidden count + int d2 = threadIdx.y; // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr[1]; + float4 bias_arr[1]; + float4 output_arr[1]; + __half2* vals_half = reinterpret_cast<__half2*>(vals_arr); + __half2* bias_half = reinterpret_cast<__half2*>(bias_arr); + __half2* output_half = reinterpret_cast<__half2*>(output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + int iter_index = cnt * d1_stride + d2 * d2_stride + d3; + int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); + bias_arr[0] = bias_vec[iter_index]; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + vals_arr[0] = vals_vec[input_offset + iter_id]; + + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; + + in_data[iter_id] = output_arr[0]; + } + __syncthreads(); + + iteration_stride = blockDim.z * (blockDim.y >> 1); + int matrix_stride = (d0_out_stride * gridDim.x); + int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1); + + int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = (iter * iteration_stride) + head_count; + int iter_offset = + (iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride; + output_vec[out_index + iter_offset] = + in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)]; + } +#endif +} + +// [B S C*H] - > C * [B A S N] +template <> +void launch_bias_add_transform_0213(float* output, + const float* vals, + const float* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); + + bias_add_transform_0213<<>>( + output, vals, bias, hidden_dim, seq_length, heads, head_ext); +} + +template <> +void launch_bias_add_transform_0213<__half>(__half* output, + const __half* vals, + const __half* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 3; + if (hidden_dim > 128 || hidden_dim < 16) { + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); + bias_add_transform_0213<__half><<>>( + output, vals, bias, hidden_dim, seq_length, heads, head_ext); + } else { + dim3 block_dim(hidden_dim / heads, heads, trans_count); + dim3 grid_dim(batch_size, seq_length / 2); + bias_add_transform_0213_v2<<>>( + output, vals, bias, hidden_dim, seq_length, heads); + } +} + +template +__global__ void transform4d_0213(T* out, + const T* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext); + +template <> +__global__ void transform4d_0213(float* out, + const float* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = d0_stride / heads; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = hidden_dim; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head + int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length; + int cnt = blockIdx.z; + int d3 = threadIdx.x; // Values (groups of 8) + + if (d2 < seq_length) { + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride + + d2 * d2_stride + d3]; + out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride + + d2 * d2_out_stride * gridDim.z + d3] = vals_vec; + } +} + +template <> +__global__ void transform4d_0213<__half>(__half* out, + const __half* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) +{ +#if __CUDA_ARCH__ >= 700 + + int d0_stride = hidden_dim * (seq_length / head_ext); + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head + int d2 = blockIdx.z / head_ext; // Sequence + int cnt = blockIdx.y; // Hidden count + int d3 = threadIdx.x; // Values (groups of 8) + + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + in_vec += (cnt * d0_stride * gridDim.x); + in_vec += (d0 * d0_stride); + in_vec += (d2 * d2_stride); + in_vec += (d1 * d2_stride * seq_length); + + out_vec += (cnt * d1_stride); + out_vec += (d1 * d2_stride); + out_vec += (d0 * d0_stride * gridDim.y); + out_vec += (d2 * d1_stride * gridDim.y); + + out_vec[d3] = in_vec[d3]; + +#endif +} + +__global__ void transform4d_0213_v2(__half* out, + const __half* in, + int heads, + int seq_length, + int hidden_dim) +{ +#if __CUDA_ARCH__ >= 700 + __shared__ float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = threadIdx.y; // Head + int d2 = blockIdx.y; // Sequence + int cnt = threadIdx.z; // Hidden count + int d3 = threadIdx.x; // Values (groups of 8) + + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride; + int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1); + int iteration_stride = blockDim.z * (blockDim.y >> 1); + int matrix_stride = (d0_stride * gridDim.x); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = iter * iteration_stride + head_count; + int iter_offset = (iter_row % blockDim.y) * d2_stride; + + in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] = + in_vec[input_offset + iter_offset * seq_length + + (iter_row / blockDim.y) * matrix_stride]; + } + __syncthreads(); + + iteration_stride = d1_stride * blockDim.z; + int iter_index = cnt * d1_stride + d1 * d2_stride + d3; + int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + out_vec[output_offset + iter_id] = in_data[iter_id]; + } +#endif +} + +// 3 * [B A S N] - > [B S C*H] +template <> +void launch_transform4d_0213(float* out, + const float* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 2; + dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count); + dim3 block_dims(hidden_dim / heads, 8); + transform4d_0213 + <<>>(out, in, heads, seq_length, hidden_dim, 1); +} + +template <> +void launch_transform4d_0213<__half>(__half* out, + const __half* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 3; + if (hidden_dim > 128 || hidden_dim < 16) { + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext)); + dim3 block_dims(hidden_dim / heads, (heads / head_ext)); + transform4d_0213<__half><<>>( + out, in, heads, seq_length, hidden_dim, head_ext); + } else { + dim3 grid_dims(batch_size, seq_length / 2); + dim3 block_dims(hidden_dim / heads, heads, trans_count); + transform4d_0213_v2<<>>( + out, in, heads, seq_length, hidden_dim); + } +} diff --git a/deepspeed/launcher/constants.py b/deepspeed/launcher/constants.py index fd56facc4..15b262342 100644 --- a/deepspeed/launcher/constants.py +++ b/deepspeed/launcher/constants.py @@ -1,9 +1,9 @@ -# Copyright 2020 The Microsoft DeepSpeed Team - -PDSH_LAUNCHER = 'pdsh' -PDSH_MAX_FAN_OUT = 1024 - -OPENMPI_LAUNCHER = 'openmpi' - -MVAPICH_LAUNCHER = 'mvapich' -MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile' +# Copyright 2020 The Microsoft DeepSpeed Team + +PDSH_LAUNCHER = 'pdsh' +PDSH_MAX_FAN_OUT = 1024 + +OPENMPI_LAUNCHER = 'openmpi' + +MVAPICH_LAUNCHER = 'mvapich' +MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile' diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 4c4ac3b49..a8fc6fcc1 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -1,227 +1,227 @@ -import os -import sys -import shutil -import subprocess -import warnings -from abc import ABC, abstractmethod - -from ..utils import logger -from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE - - -class MultiNodeRunner(ABC): - def __init__(self, args, world_info_base64): - self.args = args - self.validate_args() - self.user_arguments = self.parse_user_args() - self.user_script = args.user_script - self.world_info_base64 = world_info_base64 - self.exports = {} - - @abstractmethod - def backend_exists(self): - """Return whether the corresponding backend exists""" - - @abstractmethod - def get_cmd(self, environment, active_resources): - """Return the command to execute on node""" - - def add_export(self, key, var): - self.exports[key.strip()] = var.strip() - - def parse_user_args(self): - return self.args.user_args - - @property - def name(self): - """Return the name of the backend""" - return self.__class__.__name__ - - def validate_args(self): - """Validate self.args""" - - -class PDSHRunner(MultiNodeRunner): - def __init__(self, args, world_info_base64): - super().__init__(args, world_info_base64) - - def backend_exists(self): - return shutil.which('pdsh') - - @property - def name(self): - return "pdsh" - - def parse_user_args(self): - return list( - map(lambda x: x if x.startswith("-") else f"'{x}'", - self.args.user_args)) - - def get_cmd(self, environment, active_resources): - environment['PDSH_RCMD_TYPE'] = 'ssh' - - active_workers = ",".join(active_resources.keys()) - logger.info("Running on the following workers: %s" % active_workers) - - # PDSH flags for max node fan out and specific hosts to launch on - # See https://linux.die.net/man/1/pdsh for flag details - pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers] - - exports = "" - for key, val in self.exports.items(): - exports += f"export {key}={val}; " - - # https://linux.die.net/man/1/pdsh - # %n will be replaced by pdsh command - deepspeed_launch = [ - exports, - f"cd {os.path.abspath('.')};", - sys.executable, - "-u", - "-m", - "deepspeed.launcher.launch", - f'--world_info={self.world_info_base64}', - "--node_rank=%n", - f"--master_addr={self.args.master_addr}", - f"--master_port={self.args.master_port}" - ] - - return pdsh_cmd_args + deepspeed_launch + [self.user_script - ] + self.user_arguments - - -class OpenMPIRunner(MultiNodeRunner): - def __init__(self, args, world_info_base64, resource_pool): - super().__init__(args, world_info_base64) - self.resource_pool = resource_pool - self.add_export('UCX_TLS', 'tcp') - - def backend_exists(self): - #TODO: if IB is available we should suggestion mvapich - return shutil.which('ompi_info') - - @property - def name(self): - return "openmpi" - - def validate_args(self): - super().validate_args() - #TODO: Allow for include/exclude at node-level but not gpu-level - if self.args.include != "" or self.args.exclude != "": - raise ValueError( - f"{self.name} backend does not support worker include/exclusion") - if self.args.num_nodes != -1 or self.args.num_gpus != -1: - raise ValueError( - f"{self.name} backend does not support limiting num nodes/gpus") - - def get_cmd(self, environment, active_resources): - total_process_count = sum(self.resource_pool.values()) - - mpirun_cmd = [ - 'mpirun', - '-n', - f'{total_process_count}', - '-hostfile', - f'{self.args.hostfile}', - '--mca', - 'btl', - '^openib', - '--mca', - 'btl_tcp_if_include', - 'eth0', - ] - - export_cmd = [] - for k, v in self.exports.items(): - export_cmd += ['-x', f'{k}={v}'] - - python_exec = [sys.executable, "-u"] - - return mpirun_cmd + export_cmd + python_exec + [self.user_script - ] + self.user_arguments - - -class MVAPICHRunner(MultiNodeRunner): - def __init__(self, args, world_info_base64, resource_pool): - super().__init__(args, world_info_base64) - self.resource_pool = resource_pool - - # Disable the CMA kernel module, not available on Ubuntu systems - self.add_export('MV2_SMP_USE_CMA', '0') - - # If we fail this will output more verbose logging - self.add_export('MV2_DEBUG_SHOW_BACKTRACE', '1') - - # Enabled cuda-aware communication - self.add_export('MV2_USE_CUDA', '1') - - # Support deep learning frameworks: http://hidl.cse.ohio-state.edu/userguide/horovod/ - self.add_export('MV2_SUPPORT_DL', '1') - - # Support MPI_THREAD_MULTIPLE - self.add_export('MV2_ENABLE_AFFINITY', '0') - - # Performance tuning flags for allgather - self.add_export('MV2_INTER_ALLGATHER_TUNING', '5') - self.add_export('MV2_CUDA_USE_NAIVE', '0') - - def backend_exists(self): - #TODO: if IB is available we should suggestion mvapich - mpiname_exists = shutil.which('mpiname') - exists = False - if not mpiname_exists: - warnings.warn("mpiname does not exist, mvapich is not installed properly") - else: - results = subprocess.check_output('mpiname', shell=True) - mpiname_results = results.decode('utf-8').strip() - if "MVAPICH2-GDR" in mpiname_results: - exists = True - else: - warnings.warn( - f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}" - ) - return exists - - @property - def name(self): - return "mvapich" - - def validate_args(self): - super().validate_args() - #TODO: Allow for include/exclude at node-level but not gpu-level - if self.args.include != "" or self.args.exclude != "": - raise ValueError( - f"{self.name} backend does not support worker include/exclusion") - if self.args.num_nodes != -1 or self.args.num_gpus != -1: - raise ValueError( - f"{self.name} backend does not support limiting num nodes/gpus") - - def get_cmd(self, environment, active_resources): - devices_per_node = self.resource_pool.values() - total_process_count = sum(devices_per_node) - process_per_node = list(devices_per_node)[0] - if not all([n == process_per_node for n in devices_per_node]): - raise ValueError("mvapich requires same number of devices per node") - - with open(MVAPICH_TMP_HOSTFILE, 'w') as fd: - for host in self.resource_pool.keys(): - fd.write(f'{host}\n') - - mpirun_cmd = [ - 'mpirun', - '-np', - f'{total_process_count}', - '-ppn', - f'{process_per_node}', - '--hostfile', - f'{MVAPICH_TMP_HOSTFILE}', - ] - - export_cmd = [] - for k, v in self.exports.items(): - export_cmd += ['-env', f'{k}={v}'] - - python_exec = [sys.executable, "-u"] - - return mpirun_cmd + export_cmd + python_exec + [self.user_script - ] + self.user_arguments +import os +import sys +import shutil +import subprocess +import warnings +from abc import ABC, abstractmethod + +from ..utils import logger +from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE + + +class MultiNodeRunner(ABC): + def __init__(self, args, world_info_base64): + self.args = args + self.validate_args() + self.user_arguments = self.parse_user_args() + self.user_script = args.user_script + self.world_info_base64 = world_info_base64 + self.exports = {} + + @abstractmethod + def backend_exists(self): + """Return whether the corresponding backend exists""" + + @abstractmethod + def get_cmd(self, environment, active_resources): + """Return the command to execute on node""" + + def add_export(self, key, var): + self.exports[key.strip()] = var.strip() + + def parse_user_args(self): + return self.args.user_args + + @property + def name(self): + """Return the name of the backend""" + return self.__class__.__name__ + + def validate_args(self): + """Validate self.args""" + + +class PDSHRunner(MultiNodeRunner): + def __init__(self, args, world_info_base64): + super().__init__(args, world_info_base64) + + def backend_exists(self): + return shutil.which('pdsh') + + @property + def name(self): + return "pdsh" + + def parse_user_args(self): + return list( + map(lambda x: x if x.startswith("-") else f"'{x}'", + self.args.user_args)) + + def get_cmd(self, environment, active_resources): + environment['PDSH_RCMD_TYPE'] = 'ssh' + + active_workers = ",".join(active_resources.keys()) + logger.info("Running on the following workers: %s" % active_workers) + + # PDSH flags for max node fan out and specific hosts to launch on + # See https://linux.die.net/man/1/pdsh for flag details + pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers] + + exports = "" + for key, val in self.exports.items(): + exports += f"export {key}={val}; " + + # https://linux.die.net/man/1/pdsh + # %n will be replaced by pdsh command + deepspeed_launch = [ + exports, + f"cd {os.path.abspath('.')};", + sys.executable, + "-u", + "-m", + "deepspeed.launcher.launch", + f'--world_info={self.world_info_base64}', + "--node_rank=%n", + f"--master_addr={self.args.master_addr}", + f"--master_port={self.args.master_port}" + ] + + return pdsh_cmd_args + deepspeed_launch + [self.user_script + ] + self.user_arguments + + +class OpenMPIRunner(MultiNodeRunner): + def __init__(self, args, world_info_base64, resource_pool): + super().__init__(args, world_info_base64) + self.resource_pool = resource_pool + self.add_export('UCX_TLS', 'tcp') + + def backend_exists(self): + #TODO: if IB is available we should suggestion mvapich + return shutil.which('ompi_info') + + @property + def name(self): + return "openmpi" + + def validate_args(self): + super().validate_args() + #TODO: Allow for include/exclude at node-level but not gpu-level + if self.args.include != "" or self.args.exclude != "": + raise ValueError( + f"{self.name} backend does not support worker include/exclusion") + if self.args.num_nodes != -1 or self.args.num_gpus != -1: + raise ValueError( + f"{self.name} backend does not support limiting num nodes/gpus") + + def get_cmd(self, environment, active_resources): + total_process_count = sum(self.resource_pool.values()) + + mpirun_cmd = [ + 'mpirun', + '-n', + f'{total_process_count}', + '-hostfile', + f'{self.args.hostfile}', + '--mca', + 'btl', + '^openib', + '--mca', + 'btl_tcp_if_include', + 'eth0', + ] + + export_cmd = [] + for k, v in self.exports.items(): + export_cmd += ['-x', f'{k}={v}'] + + python_exec = [sys.executable, "-u"] + + return mpirun_cmd + export_cmd + python_exec + [self.user_script + ] + self.user_arguments + + +class MVAPICHRunner(MultiNodeRunner): + def __init__(self, args, world_info_base64, resource_pool): + super().__init__(args, world_info_base64) + self.resource_pool = resource_pool + + # Disable the CMA kernel module, not available on Ubuntu systems + self.add_export('MV2_SMP_USE_CMA', '0') + + # If we fail this will output more verbose logging + self.add_export('MV2_DEBUG_SHOW_BACKTRACE', '1') + + # Enabled cuda-aware communication + self.add_export('MV2_USE_CUDA', '1') + + # Support deep learning frameworks: http://hidl.cse.ohio-state.edu/userguide/horovod/ + self.add_export('MV2_SUPPORT_DL', '1') + + # Support MPI_THREAD_MULTIPLE + self.add_export('MV2_ENABLE_AFFINITY', '0') + + # Performance tuning flags for allgather + self.add_export('MV2_INTER_ALLGATHER_TUNING', '5') + self.add_export('MV2_CUDA_USE_NAIVE', '0') + + def backend_exists(self): + #TODO: if IB is available we should suggestion mvapich + mpiname_exists = shutil.which('mpiname') + exists = False + if not mpiname_exists: + warnings.warn("mpiname does not exist, mvapich is not installed properly") + else: + results = subprocess.check_output('mpiname', shell=True) + mpiname_results = results.decode('utf-8').strip() + if "MVAPICH2-GDR" in mpiname_results: + exists = True + else: + warnings.warn( + f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}" + ) + return exists + + @property + def name(self): + return "mvapich" + + def validate_args(self): + super().validate_args() + #TODO: Allow for include/exclude at node-level but not gpu-level + if self.args.include != "" or self.args.exclude != "": + raise ValueError( + f"{self.name} backend does not support worker include/exclusion") + if self.args.num_nodes != -1 or self.args.num_gpus != -1: + raise ValueError( + f"{self.name} backend does not support limiting num nodes/gpus") + + def get_cmd(self, environment, active_resources): + devices_per_node = self.resource_pool.values() + total_process_count = sum(devices_per_node) + process_per_node = list(devices_per_node)[0] + if not all([n == process_per_node for n in devices_per_node]): + raise ValueError("mvapich requires same number of devices per node") + + with open(MVAPICH_TMP_HOSTFILE, 'w') as fd: + for host in self.resource_pool.keys(): + fd.write(f'{host}\n') + + mpirun_cmd = [ + 'mpirun', + '-np', + f'{total_process_count}', + '-ppn', + f'{process_per_node}', + '--hostfile', + f'{MVAPICH_TMP_HOSTFILE}', + ] + + export_cmd = [] + for k, v in self.exports.items(): + export_cmd += ['-env', f'{k}={v}'] + + python_exec = [sys.executable, "-u"] + + return mpirun_cmd + export_cmd + python_exec + [self.user_script + ] + self.user_arguments diff --git a/deepspeed/module_inject/module_quantize.py b/deepspeed/module_inject/module_quantize.py index 26c542284..fde6990eb 100755 --- a/deepspeed/module_inject/module_quantize.py +++ b/deepspeed/module_inject/module_quantize.py @@ -1,80 +1,80 @@ -import copy -import torch -import deepspeed - - -def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=False): - """ Quantize bert-style transformer layers with DeepSpeed's transformer layer - Arguments: - orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, - e.g., transformers.modeling_bert.BertLayer. - model (torch.nn.Module): user's nn.module representing their model - - megatron (bool): megatron model-parallel implementation (this is supported for inference only) - preln (bool): does the original layer implementation do pre or post layer norm? - - Note: For Bert kind of models, we inject based on the DeepSpeed-Example models, if not setting huggingface flag. - - Returns: - Updated nn.module with quantized transformer layers - """ - def quantize_weight(weight): - return weight.to(torch.int8) - - def megatron_layer_quantize(layer): - layer.attention.query_key_value.weight.data = quantize_weight( - layer.attention.query_key_value.weight.data) - layer.attention.dense.weight.data = quantize_weight( - layer.attention.dense.weight.data) - layer.mlp.dense_h_to_4h.weight.data = quantize_weight( - layer.mlp.dense_h_to_4h.weight.data) - layer.mlp.dense_4h_to_h.weight.data = quantize_weight( - layer.mlp.dense_4h_to_h.weight.data) - - def bert_layer_quantize(layer): - layer.attention.self.query.weight.data = quantize_weight( - layer.attention.self.query.weight.data) - layer.attention.self.key.weight.data = quantize_weight( - layer.attention.self.key.weight.data) - layer.attention.self.value.weight.data = quantize_weight( - layer.attention.self.value.weight.data) - layer.attention.output.dense.weight.data = quantize_weight( - layer.attention.output.dense.weight.data) - if preln: - layer.intermediate.dense_act.weight.data = quantize_weight( - layer.intermediate.dense_act.weight.data) - else: - layer.intermediate.dense.weight.data = quantize_weight( - layer.intermediate.dense.weight.data) - layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data) - - def quantize_fn(child): - if megatron: - # Quantize megatron GPT2 / GPT3 trained model - megatron_layer_quantize(child) - else: - # Quantize either DeepSpeed or HuggingFace trained model - bert_layer_quantize(child) - - return child - - return quantize_module(model=model, - orig_class=orig_layer_impl, - quantize_fn=quantize_fn) - - -def quantize_module(model, orig_class, quantize_fn): - policy = {orig_class: quantize_fn} - return _quantize_module(model, policy) - - -def _quantize_module(model, policies): - for name, child in model.named_children(): - if child.__class__ in policies: - orig = repr(child) - setattr(model, name, policies[child.__class__](child)) - new = getattr(model, name) - else: - _quantize_module(child, policies) - - return model +import copy +import torch +import deepspeed + + +def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=False): + """ Quantize bert-style transformer layers with DeepSpeed's transformer layer + Arguments: + orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, + e.g., transformers.modeling_bert.BertLayer. + model (torch.nn.Module): user's nn.module representing their model + + megatron (bool): megatron model-parallel implementation (this is supported for inference only) + preln (bool): does the original layer implementation do pre or post layer norm? + + Note: For Bert kind of models, we inject based on the DeepSpeed-Example models, if not setting huggingface flag. + + Returns: + Updated nn.module with quantized transformer layers + """ + def quantize_weight(weight): + return weight.to(torch.int8) + + def megatron_layer_quantize(layer): + layer.attention.query_key_value.weight.data = quantize_weight( + layer.attention.query_key_value.weight.data) + layer.attention.dense.weight.data = quantize_weight( + layer.attention.dense.weight.data) + layer.mlp.dense_h_to_4h.weight.data = quantize_weight( + layer.mlp.dense_h_to_4h.weight.data) + layer.mlp.dense_4h_to_h.weight.data = quantize_weight( + layer.mlp.dense_4h_to_h.weight.data) + + def bert_layer_quantize(layer): + layer.attention.self.query.weight.data = quantize_weight( + layer.attention.self.query.weight.data) + layer.attention.self.key.weight.data = quantize_weight( + layer.attention.self.key.weight.data) + layer.attention.self.value.weight.data = quantize_weight( + layer.attention.self.value.weight.data) + layer.attention.output.dense.weight.data = quantize_weight( + layer.attention.output.dense.weight.data) + if preln: + layer.intermediate.dense_act.weight.data = quantize_weight( + layer.intermediate.dense_act.weight.data) + else: + layer.intermediate.dense.weight.data = quantize_weight( + layer.intermediate.dense.weight.data) + layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data) + + def quantize_fn(child): + if megatron: + # Quantize megatron GPT2 / GPT3 trained model + megatron_layer_quantize(child) + else: + # Quantize either DeepSpeed or HuggingFace trained model + bert_layer_quantize(child) + + return child + + return quantize_module(model=model, + orig_class=orig_layer_impl, + quantize_fn=quantize_fn) + + +def quantize_module(model, orig_class, quantize_fn): + policy = {orig_class: quantize_fn} + return _quantize_module(model, policy) + + +def _quantize_module(model, policies): + for name, child in model.named_children(): + if child.__class__ in policies: + orig = repr(child) + setattr(model, name, policies[child.__class__](child)) + new = getattr(model, name) + else: + _quantize_module(child, policies) + + return model diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index 3758ffd9b..cda2a685d 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -1,239 +1,239 @@ -from abc import ABC - -import torch - - -class DSPolicy(ABC): - def __init__(self, inference=True, linear_layer=True, scale_attention=True): - self.inference = inference - self.linear_layer = linear_layer - self.scale_attention = scale_attention - - def attention(self): - """ - Returns attention qkv and dense parameters - weight: (3*hidden, hidden) and (hidden, hidden) - bias: (3*hidden) and (hidden) - """ - raise NotImplementedError - - def get_hidden_heads(self): - """ - return hidden_size and number of heads - """ - raise NotImplementedError - - def mlp(self): - """ - Returns mlp intermediate and output - weight: (intermediate, hidden) and (hidden, intermediate) - bias: (intermediate) and (hidden) - """ - raise NotImplementedError - - def layerNorm(self): - """ - Returns LayerNorms used in transformer layer - Post-Attention and pre/post layer norm - gamma and beta with shape: (hidden) - """ - raise NotImplementedError - - -class HFBertLayerPolicy(DSPolicy): - _orig_layer_class = None - - def __init__(self, client_module, inference=False, preln=False): - super().__init__(inference) - self.client_module = client_module - self.preln = preln - if HFBertLayerPolicy._orig_layer_class is None: - try: - import transformers - HFBertLayerPolicy._orig_layer_class = transformers.models.bert.modeling_bert.BertLayer - except: - HFBertLayerPolicy._orig_layer_class = None - - def get_hidden_heads(self): - return self.client_module.attention.self.query.weight.data.shape[1], \ - self.client_module.attention.self.num_attention_heads - - def attention(self): - qw = self.client_module.attention.self.query.weight.data - qb = self.client_module.attention.self.query.bias.data - kw = self.client_module.attention.self.key.weight.data - kb = self.client_module.attention.self.key.bias.data - vw = self.client_module.attention.self.value.weight.data - vb = self.client_module.attention.self.value.bias.data - - qkvw = torch.cat((qw, kw, vw), dim=0) - qkvb = torch.cat((qb, kb, vb), dim=0) - - return self.linear_layer, \ - qkvw, \ - qkvb, \ - self.client_module.attention.output.dense.weight.data, \ - self.client_module.attention.output.dense.bias.data, \ - self.scale_attention - - def mlp(self): - if self.preln: - intermediate_ff = self.client_module.intermediate.dense_act - else: - intermediate_ff = self.client_module.intermediate.dense - - return self.linear_layer, intermediate_ff.weight.data, intermediate_ff.bias.data, \ - self.client_module.output.dense.weight.data, \ - self.client_module.output.dense.bias.data - - def layerNorm(self): - if self.preln: - attention_layernorm = self.client_module.PostAttentionLayerNorm - transformer_layernorm = self.client_module.PreAttentionLayerNorm - else: - attention_layernorm = self.client_module.attention.output.LayerNorm - transformer_layernorm = self.client_module.output.LayerNorm - return attention_layernorm.weight.data, \ - attention_layernorm.bias.data, \ - transformer_layernorm.weight.data, \ - transformer_layernorm.bias.data - - -class HFGPTNEOLayerPolicy(DSPolicy): - _orig_layer_class = None - - def __init__(self, client_module, inference=True): - super().__init__(inference, scale_attention=False) - self.client_module = client_module - try: - import transformers - HFGPTNEOLayerPolicy._orig_layer_class = transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock - except: - HFGPTNEOLayerPolicy._orig_layer_class = None - - def get_hidden_heads(self): - return self.client_module.attn.attention.q_proj.weight.data.shape[1], \ - self.client_module.attn.attention.num_heads - - def attention(self): - qw = self.client_module.attn.attention.q_proj.weight.data - kw = self.client_module.attn.attention.k_proj.weight.data - vw = self.client_module.attn.attention.v_proj.weight.data - - qkvw = torch.cat((qw, kw, vw), dim=0) - - return self.linear_layer, \ - qkvw, \ - None, \ - self.client_module.attn.attention.out_proj.weight.data, \ - self.client_module.attn.attention.out_proj.bias.data, \ - self.scale_attention - - def mlp(self): - return self.linear_layer, \ - self.client_module.mlp.c_fc.weight.data, \ - self.client_module.mlp.c_fc.bias.data, \ - self.client_module.mlp.c_proj.weight.data, \ - self.client_module.mlp.c_proj.bias.data - - def layerNorm(self): - return self.client_module.ln_2.weight.data, \ - self.client_module.ln_2.bias.data, \ - self.client_module.ln_1.weight.data, \ - self.client_module.ln_1.bias.data - - -class MegatronLayerPolicy(DSPolicy): - _orig_layer_class = None - - def __init__(self, client_module, version=0, inference=True): - super().__init__(inference) - self.client_module = client_module - # we use megatron version to differentiate between the old and new - # megatron-lm source code - self.version = version - if MegatronLayerPolicy._orig_layer_class is None: - try: - import megatron - from megatron.model.transformer import ParallelTransformerLayer - MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer - except ImportError: - MegatronLayerPolicy._orig_layer_class = None - - def get_hidden_heads(self): - return self.client_module.attention.query_key_value.weight.data.shape[1], \ - self.client_module.attention.num_attention_heads - - def attention(self): - if self.inference: - if self.version == 0: - attention = self.client_module.attention - else: - attention = self.client_module.self_attention - - return self.linear_layer, \ - attention.query_key_value.weight.data, \ - attention.query_key_value.bias.data, \ - attention.dense.weight.data, \ - attention.dense.bias.data, \ - self.scale_attention - - def mlp(self): - return self.linear_layer, \ - self.client_module.mlp.dense_h_to_4h.weight.data, \ - self.client_module.mlp.dense_h_to_4h.bias.data, \ - self.client_module.mlp.dense_4h_to_h.weight.data, \ - self.client_module.mlp.dense_4h_to_h.bias.data - - def layerNorm(self): - return self.client_module.post_attention_layernorm.weight.data, \ - self.client_module.post_attention_layernorm.bias.data, \ - self.client_module.input_layernorm.weight.data, \ - self.client_module.input_layernorm.bias.data - - -class HFGPT2LayerPolicy(DSPolicy): - _orig_layer_class = None - - def __init__(self, client_module, inference=True): - # HuggingFace GPT2 uses convolutional layer instead of linear layer - super().__init__(inference, linear_layer=False) - self.client_module = client_module - try: - import transformers - HFGPT2LayerPolicy._orig_layer_class = transformers.models.gpt2.modeling_gpt2.GPT2Block - except ImportError: - HFGPT2LayerPolicy._orig_layer_class = None - - def get_hidden_heads(self): - return self.client_module.attn.embed_dim, \ - self.client_module.attn.num_heads - - def attention(self): - return self.linear_layer, \ - self.client_module.attn.c_attn.weight.data, \ - self.client_module.attn.c_attn.bias.data, \ - self.client_module.attn.c_proj.weight.data, \ - self.client_module.attn.c_proj.bias.data, \ - self.scale_attention - - def mlp(self): - return self.linear_layer, \ - self.client_module.mlp.c_fc.weight.data, \ - self.client_module.mlp.c_fc.bias.data, \ - self.client_module.mlp.c_proj.weight.data, \ - self.client_module.mlp.c_proj.bias.data - - def layerNorm(self): - return self.client_module.ln_2.weight.data, \ - self.client_module.ln_2.bias.data, \ - self.client_module.ln_1.weight.data, \ - self.client_module.ln_1.bias.data - - -replace_policies = [ - HFBertLayerPolicy, - HFGPTNEOLayerPolicy, - MegatronLayerPolicy, - HFGPT2LayerPolicy, -] +from abc import ABC + +import torch + + +class DSPolicy(ABC): + def __init__(self, inference=True, linear_layer=True, scale_attention=True): + self.inference = inference + self.linear_layer = linear_layer + self.scale_attention = scale_attention + + def attention(self): + """ + Returns attention qkv and dense parameters + weight: (3*hidden, hidden) and (hidden, hidden) + bias: (3*hidden) and (hidden) + """ + raise NotImplementedError + + def get_hidden_heads(self): + """ + return hidden_size and number of heads + """ + raise NotImplementedError + + def mlp(self): + """ + Returns mlp intermediate and output + weight: (intermediate, hidden) and (hidden, intermediate) + bias: (intermediate) and (hidden) + """ + raise NotImplementedError + + def layerNorm(self): + """ + Returns LayerNorms used in transformer layer + Post-Attention and pre/post layer norm + gamma and beta with shape: (hidden) + """ + raise NotImplementedError + + +class HFBertLayerPolicy(DSPolicy): + _orig_layer_class = None + + def __init__(self, client_module, inference=False, preln=False): + super().__init__(inference) + self.client_module = client_module + self.preln = preln + if HFBertLayerPolicy._orig_layer_class is None: + try: + import transformers + HFBertLayerPolicy._orig_layer_class = transformers.models.bert.modeling_bert.BertLayer + except: + HFBertLayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.attention.self.query.weight.data.shape[1], \ + self.client_module.attention.self.num_attention_heads + + def attention(self): + qw = self.client_module.attention.self.query.weight.data + qb = self.client_module.attention.self.query.bias.data + kw = self.client_module.attention.self.key.weight.data + kb = self.client_module.attention.self.key.bias.data + vw = self.client_module.attention.self.value.weight.data + vb = self.client_module.attention.self.value.bias.data + + qkvw = torch.cat((qw, kw, vw), dim=0) + qkvb = torch.cat((qb, kb, vb), dim=0) + + return self.linear_layer, \ + qkvw, \ + qkvb, \ + self.client_module.attention.output.dense.weight.data, \ + self.client_module.attention.output.dense.bias.data, \ + self.scale_attention + + def mlp(self): + if self.preln: + intermediate_ff = self.client_module.intermediate.dense_act + else: + intermediate_ff = self.client_module.intermediate.dense + + return self.linear_layer, intermediate_ff.weight.data, intermediate_ff.bias.data, \ + self.client_module.output.dense.weight.data, \ + self.client_module.output.dense.bias.data + + def layerNorm(self): + if self.preln: + attention_layernorm = self.client_module.PostAttentionLayerNorm + transformer_layernorm = self.client_module.PreAttentionLayerNorm + else: + attention_layernorm = self.client_module.attention.output.LayerNorm + transformer_layernorm = self.client_module.output.LayerNorm + return attention_layernorm.weight.data, \ + attention_layernorm.bias.data, \ + transformer_layernorm.weight.data, \ + transformer_layernorm.bias.data + + +class HFGPTNEOLayerPolicy(DSPolicy): + _orig_layer_class = None + + def __init__(self, client_module, inference=True): + super().__init__(inference, scale_attention=False) + self.client_module = client_module + try: + import transformers + HFGPTNEOLayerPolicy._orig_layer_class = transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock + except: + HFGPTNEOLayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.attn.attention.q_proj.weight.data.shape[1], \ + self.client_module.attn.attention.num_heads + + def attention(self): + qw = self.client_module.attn.attention.q_proj.weight.data + kw = self.client_module.attn.attention.k_proj.weight.data + vw = self.client_module.attn.attention.v_proj.weight.data + + qkvw = torch.cat((qw, kw, vw), dim=0) + + return self.linear_layer, \ + qkvw, \ + None, \ + self.client_module.attn.attention.out_proj.weight.data, \ + self.client_module.attn.attention.out_proj.bias.data, \ + self.scale_attention + + def mlp(self): + return self.linear_layer, \ + self.client_module.mlp.c_fc.weight.data, \ + self.client_module.mlp.c_fc.bias.data, \ + self.client_module.mlp.c_proj.weight.data, \ + self.client_module.mlp.c_proj.bias.data + + def layerNorm(self): + return self.client_module.ln_2.weight.data, \ + self.client_module.ln_2.bias.data, \ + self.client_module.ln_1.weight.data, \ + self.client_module.ln_1.bias.data + + +class MegatronLayerPolicy(DSPolicy): + _orig_layer_class = None + + def __init__(self, client_module, version=0, inference=True): + super().__init__(inference) + self.client_module = client_module + # we use megatron version to differentiate between the old and new + # megatron-lm source code + self.version = version + if MegatronLayerPolicy._orig_layer_class is None: + try: + import megatron + from megatron.model.transformer import ParallelTransformerLayer + MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer + except ImportError: + MegatronLayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.attention.query_key_value.weight.data.shape[1], \ + self.client_module.attention.num_attention_heads + + def attention(self): + if self.inference: + if self.version == 0: + attention = self.client_module.attention + else: + attention = self.client_module.self_attention + + return self.linear_layer, \ + attention.query_key_value.weight.data, \ + attention.query_key_value.bias.data, \ + attention.dense.weight.data, \ + attention.dense.bias.data, \ + self.scale_attention + + def mlp(self): + return self.linear_layer, \ + self.client_module.mlp.dense_h_to_4h.weight.data, \ + self.client_module.mlp.dense_h_to_4h.bias.data, \ + self.client_module.mlp.dense_4h_to_h.weight.data, \ + self.client_module.mlp.dense_4h_to_h.bias.data + + def layerNorm(self): + return self.client_module.post_attention_layernorm.weight.data, \ + self.client_module.post_attention_layernorm.bias.data, \ + self.client_module.input_layernorm.weight.data, \ + self.client_module.input_layernorm.bias.data + + +class HFGPT2LayerPolicy(DSPolicy): + _orig_layer_class = None + + def __init__(self, client_module, inference=True): + # HuggingFace GPT2 uses convolutional layer instead of linear layer + super().__init__(inference, linear_layer=False) + self.client_module = client_module + try: + import transformers + HFGPT2LayerPolicy._orig_layer_class = transformers.models.gpt2.modeling_gpt2.GPT2Block + except ImportError: + HFGPT2LayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.attn.embed_dim, \ + self.client_module.attn.num_heads + + def attention(self): + return self.linear_layer, \ + self.client_module.attn.c_attn.weight.data, \ + self.client_module.attn.c_attn.bias.data, \ + self.client_module.attn.c_proj.weight.data, \ + self.client_module.attn.c_proj.bias.data, \ + self.scale_attention + + def mlp(self): + return self.linear_layer, \ + self.client_module.mlp.c_fc.weight.data, \ + self.client_module.mlp.c_fc.bias.data, \ + self.client_module.mlp.c_proj.weight.data, \ + self.client_module.mlp.c_proj.bias.data + + def layerNorm(self): + return self.client_module.ln_2.weight.data, \ + self.client_module.ln_2.bias.data, \ + self.client_module.ln_1.weight.data, \ + self.client_module.ln_1.bias.data + + +replace_policies = [ + HFBertLayerPolicy, + HFGPTNEOLayerPolicy, + MegatronLayerPolicy, + HFGPT2LayerPolicy, +] diff --git a/deepspeed/ops/adagrad/cpu_adagrad.py b/deepspeed/ops/adagrad/cpu_adagrad.py index 4c86a2158..44b052fa3 100755 --- a/deepspeed/ops/adagrad/cpu_adagrad.py +++ b/deepspeed/ops/adagrad/cpu_adagrad.py @@ -1,135 +1,135 @@ -''' -Copyright 2020 The Microsoft DeepSpeed Team -''' - -import math -import torch -import time -from pathlib import Path -from ..op_builder import CPUAdagradBuilder -from deepspeed.utils.logging import should_log_le - - -class DeepSpeedCPUAdagrad(torch.optim.Optimizer): - optimizer_id = 0 - - def __init__(self, - model_params, - lr=1e-2, - eps=1e-10, - weight_decay=0, - amsgrad=False, - fp32_optimizer_states=True): - - default_args = dict(lr=lr, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) - super(DeepSpeedCPUAdagrad, self).__init__(model_params, default_args) - - self.opt_id = DeepSpeedCPUAdagrad.optimizer_id - DeepSpeedCPUAdagrad.optimizer_id = DeepSpeedCPUAdagrad.optimizer_id + 1 - self.fp32_optimizer_states = fp32_optimizer_states - self.ds_opt_adagrad = CPUAdagradBuilder().load() - - self.ds_opt_adagrad.create_adagrad(self.opt_id, - lr, - eps, - weight_decay, - should_log_le("info")) - - def __del__(self): - # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize - # is used multiple times in the same process (notebook or pytest worker) - self.ds_opt_adagrad.destroy_adagrad(self.opt_id) - - def __setstate__(self, state): - super(DeepSpeedCPUAdagrad, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('amsgrad', False) - - @torch.no_grad() - def step(self, closure=None, fp16_param_groups=None): - """Update the model parameters. - - .. note:: - This method will be called internally by ZeRO-Offload. DeepSpeed - users should still use ``engine.step()`` as shown in the - `Getting Started - `_ guide. - - Args: - closure (callable, optional): closure to compute the loss. - Defaults to ``None``. - fp16_param_groups: FP16 GPU parameters to update. Performing the - copy here reduces communication time. Defaults to ``None``. - - Returns: - loss: if ``closure`` is provided. Otherwise ``None``. - """ - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group_id, group in enumerate(self.param_groups): - for param_id, p in enumerate(group['params']): - - if p.grad is None: - continue - - state = self.state[p] - # State initialization - if len(state) == 0: - #print(f'group {group_id} param {param_id} = {p.numel()}') - state['step'] = 0 - - #use full precision by default unless self.fp32_optimizer_states is off - state_dtype = torch.float if self.fp32_optimizer_states else p.dtype - - #memory_format=torch.preserve_format) - # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p.data, - dtype=state_dtype, - device='cpu') - #memory_format=torch.preserve_format) - - state['step'] += 1 - - if p.grad.is_sparse == True: - sparse_param = p.sparse_mask(p.grad) - sparse_exp_avg_sq = state['exp_avg_sq'].sparse_mask(p.grad) - self.ds_opt_adagrad.adagrad_update(self.opt_id, - state['step'], - group['lr'], - group['eps'], - group['weight_decay'], - sparse_param.values(), - p.grad.values(), - sparse_exp_avg_sq.values()) - p[sparse_param.indices()] = sparse_param.values() - state['exp_avg_sq'][ - sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values() - if fp16_param_groups is not None: - fp16_param_groups[group_id][param_id][ - sparse_param.indices()] = sparse_param.values() - else: - if fp16_param_groups is not None: - self.ds_opt_adagrad.adagrad_update_copy( - self.opt_id, - state['step'], - group['lr'], - group['eps'], - group['weight_decay'], - p.data, - p.grad.data, - state['exp_avg_sq'], - fp16_param_groups[group_id][param_id].data) - else: - self.ds_opt_adagrad.adagrad_update(self.opt_id, - state['step'], - group['lr'], - group['eps'], - group['weight_decay'], - p.data, - p.grad.data, - state['exp_avg_sq']) - return loss +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' + +import math +import torch +import time +from pathlib import Path +from ..op_builder import CPUAdagradBuilder +from deepspeed.utils.logging import should_log_le + + +class DeepSpeedCPUAdagrad(torch.optim.Optimizer): + optimizer_id = 0 + + def __init__(self, + model_params, + lr=1e-2, + eps=1e-10, + weight_decay=0, + amsgrad=False, + fp32_optimizer_states=True): + + default_args = dict(lr=lr, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) + super(DeepSpeedCPUAdagrad, self).__init__(model_params, default_args) + + self.opt_id = DeepSpeedCPUAdagrad.optimizer_id + DeepSpeedCPUAdagrad.optimizer_id = DeepSpeedCPUAdagrad.optimizer_id + 1 + self.fp32_optimizer_states = fp32_optimizer_states + self.ds_opt_adagrad = CPUAdagradBuilder().load() + + self.ds_opt_adagrad.create_adagrad(self.opt_id, + lr, + eps, + weight_decay, + should_log_le("info")) + + def __del__(self): + # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize + # is used multiple times in the same process (notebook or pytest worker) + self.ds_opt_adagrad.destroy_adagrad(self.opt_id) + + def __setstate__(self, state): + super(DeepSpeedCPUAdagrad, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None, fp16_param_groups=None): + """Update the model parameters. + + .. note:: + This method will be called internally by ZeRO-Offload. DeepSpeed + users should still use ``engine.step()`` as shown in the + `Getting Started + `_ guide. + + Args: + closure (callable, optional): closure to compute the loss. + Defaults to ``None``. + fp16_param_groups: FP16 GPU parameters to update. Performing the + copy here reduces communication time. Defaults to ``None``. + + Returns: + loss: if ``closure`` is provided. Otherwise ``None``. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + + if p.grad is None: + continue + + state = self.state[p] + # State initialization + if len(state) == 0: + #print(f'group {group_id} param {param_id} = {p.numel()}') + state['step'] = 0 + + #use full precision by default unless self.fp32_optimizer_states is off + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + + #memory_format=torch.preserve_format) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p.data, + dtype=state_dtype, + device='cpu') + #memory_format=torch.preserve_format) + + state['step'] += 1 + + if p.grad.is_sparse == True: + sparse_param = p.sparse_mask(p.grad) + sparse_exp_avg_sq = state['exp_avg_sq'].sparse_mask(p.grad) + self.ds_opt_adagrad.adagrad_update(self.opt_id, + state['step'], + group['lr'], + group['eps'], + group['weight_decay'], + sparse_param.values(), + p.grad.values(), + sparse_exp_avg_sq.values()) + p[sparse_param.indices()] = sparse_param.values() + state['exp_avg_sq'][ + sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values() + if fp16_param_groups is not None: + fp16_param_groups[group_id][param_id][ + sparse_param.indices()] = sparse_param.values() + else: + if fp16_param_groups is not None: + self.ds_opt_adagrad.adagrad_update_copy( + self.opt_id, + state['step'], + group['lr'], + group['eps'], + group['weight_decay'], + p.data, + p.grad.data, + state['exp_avg_sq'], + fp16_param_groups[group_id][param_id].data) + else: + self.ds_opt_adagrad.adagrad_update(self.opt_id, + state['step'], + group['lr'], + group['eps'], + group['weight_decay'], + p.data, + p.grad.data, + state['exp_avg_sq']) + return loss diff --git a/deepspeed/ops/adam/__init__.py b/deepspeed/ops/adam/__init__.py index 6e620b36b..6ab6cbd37 100755 --- a/deepspeed/ops/adam/__init__.py +++ b/deepspeed/ops/adam/__init__.py @@ -1,2 +1,2 @@ -from .cpu_adam import DeepSpeedCPUAdam -from .fused_adam import FusedAdam +from .cpu_adam import DeepSpeedCPUAdam +from .fused_adam import FusedAdam diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 5d6b59714..9304cdeac 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -1,186 +1,186 @@ -''' -Copyright 2020 The Microsoft DeepSpeed Team -''' - -import math -import torch -import time -from pathlib import Path -from ..op_builder import CPUAdamBuilder -from deepspeed.utils.logging import should_log_le - - -class DeepSpeedCPUAdam(torch.optim.Optimizer): - optimizer_id = 0 - - def __init__(self, - model_params, - lr=1e-3, - bias_correction=True, - betas=(0.9, - 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - adamw_mode=True, - fp32_optimizer_states=True): - """Fast vectorized implementation of two variations of Adam optimizer on CPU: - - * Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980); - * AdamW: Fixing Weight Decay Regularization in Adam (https://arxiv.org/abs/1711.05101) - - DeepSpeed CPU Adam(W) provides between 5x to 7x speedup over torch.optim.adam(W). - In order to apply this optimizer, the model requires to have its master parameter (in FP32) - reside on the CPU memory. - - To train on a heterogeneous system, such as coordinating CPU and GPU, DeepSpeed offers - the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory, - with minimal impact on training throughput. DeepSpeedCPUAdam plays an important role to minimize - the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial - (https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology. - - For calling step function, there are two options available: (1) update optimizer's states and (2) update - optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second - option can bring 30% higher throughput than the doing the copy separately using option one. - - - .. note:: - We recommend using our `config - `_ - to allow :meth:`deepspeed.initialize` to build this optimizer - for you. - - - Arguments: - model_params (iterable): iterable of parameters to optimize or dicts defining - parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square. (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) NOT SUPPORTED in DeepSpeed CPUAdam! - adamw_mode: select between Adam and AdamW implementations (default: AdamW) - full_precision_optimizer_states: creates momementum and variance in full precision regardless of - the precision of the parameters (default: True) - """ - - default_args = dict(lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - bias_correction=bias_correction, - amsgrad=amsgrad) - super(DeepSpeedCPUAdam, self).__init__(model_params, default_args) - - self.opt_id = DeepSpeedCPUAdam.optimizer_id - DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1 - self.adam_w_mode = adamw_mode - self.fp32_optimizer_states = fp32_optimizer_states - self.ds_opt_adam = CPUAdamBuilder().load() - - self.ds_opt_adam.create_adam(self.opt_id, - lr, - betas[0], - betas[1], - eps, - weight_decay, - adamw_mode, - should_log_le("info")) - - def __del__(self): - # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize - # is used multiple times in the same process (notebook or pytest worker) - self.ds_opt_adam.destroy_adam(self.opt_id) - - def __setstate__(self, state): - super(DeepSpeedCPUAdam, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('amsgrad', False) - - @torch.no_grad() - def step(self, closure=None, fp16_param_groups=None): - """Update the model parameters. - - .. note:: - This method will be called internally by ZeRO-Offload. DeepSpeed - users should still use ``engine.step()`` as shown in the - `Getting Started - `_ guide. - - Args: - closure (callable, optional): closure to compute the loss. - Defaults to ``None``. - fp16_param_groups: FP16 GPU parameters to update. Performing the - copy here reduces communication time. Defaults to ``None``. - - Returns: - loss: if ``closure`` is provided. Otherwise ``None``. - """ - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group_id, group in enumerate(self.param_groups): - for param_id, p in enumerate(group['params']): - - if p.grad is None: - continue - - state = self.state[p] - # State initialization - if len(state) == 0: - #print(f'group {group_id} param {param_id} = {p.numel()}') - state['step'] = 0 - - #use full precision by default unless self.fp32_optimizer_states is off - state_dtype = torch.float if self.fp32_optimizer_states else p.dtype - - # gradient momentums - state['exp_avg'] = torch.zeros_like(p.data, - dtype=state_dtype, - device='cpu') - #memory_format=torch.preserve_format) - # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p.data, - dtype=state_dtype, - device='cpu') - #memory_format=torch.preserve_format) - - state['step'] += 1 - beta1, beta2 = group['betas'] - - if fp16_param_groups is not None: - self.ds_opt_adam.adam_update_copy( - self.opt_id, - state['step'], - group['lr'], - beta1, - beta2, - group['eps'], - group['weight_decay'], - group['bias_correction'], - p.data, - p.grad.data, - state['exp_avg'], - state['exp_avg_sq'], - fp16_param_groups[group_id][param_id].data) - else: - self.ds_opt_adam.adam_update(self.opt_id, - state['step'], - group['lr'], - beta1, - beta2, - group['eps'], - group['weight_decay'], - group['bias_correction'], - p.data, - p.grad.data, - state['exp_avg'], - state['exp_avg_sq']) - return loss +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' + +import math +import torch +import time +from pathlib import Path +from ..op_builder import CPUAdamBuilder +from deepspeed.utils.logging import should_log_le + + +class DeepSpeedCPUAdam(torch.optim.Optimizer): + optimizer_id = 0 + + def __init__(self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, + 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adamw_mode=True, + fp32_optimizer_states=True): + """Fast vectorized implementation of two variations of Adam optimizer on CPU: + + * Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980); + * AdamW: Fixing Weight Decay Regularization in Adam (https://arxiv.org/abs/1711.05101) + + DeepSpeed CPU Adam(W) provides between 5x to 7x speedup over torch.optim.adam(W). + In order to apply this optimizer, the model requires to have its master parameter (in FP32) + reside on the CPU memory. + + To train on a heterogeneous system, such as coordinating CPU and GPU, DeepSpeed offers + the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory, + with minimal impact on training throughput. DeepSpeedCPUAdam plays an important role to minimize + the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial + (https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology. + + For calling step function, there are two options available: (1) update optimizer's states and (2) update + optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second + option can bring 30% higher throughput than the doing the copy separately using option one. + + + .. note:: + We recommend using our `config + `_ + to allow :meth:`deepspeed.initialize` to build this optimizer + for you. + + + Arguments: + model_params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in DeepSpeed CPUAdam! + adamw_mode: select between Adam and AdamW implementations (default: AdamW) + full_precision_optimizer_states: creates momementum and variance in full precision regardless of + the precision of the parameters (default: True) + """ + + default_args = dict(lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + bias_correction=bias_correction, + amsgrad=amsgrad) + super(DeepSpeedCPUAdam, self).__init__(model_params, default_args) + + self.opt_id = DeepSpeedCPUAdam.optimizer_id + DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1 + self.adam_w_mode = adamw_mode + self.fp32_optimizer_states = fp32_optimizer_states + self.ds_opt_adam = CPUAdamBuilder().load() + + self.ds_opt_adam.create_adam(self.opt_id, + lr, + betas[0], + betas[1], + eps, + weight_decay, + adamw_mode, + should_log_le("info")) + + def __del__(self): + # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize + # is used multiple times in the same process (notebook or pytest worker) + self.ds_opt_adam.destroy_adam(self.opt_id) + + def __setstate__(self, state): + super(DeepSpeedCPUAdam, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None, fp16_param_groups=None): + """Update the model parameters. + + .. note:: + This method will be called internally by ZeRO-Offload. DeepSpeed + users should still use ``engine.step()`` as shown in the + `Getting Started + `_ guide. + + Args: + closure (callable, optional): closure to compute the loss. + Defaults to ``None``. + fp16_param_groups: FP16 GPU parameters to update. Performing the + copy here reduces communication time. Defaults to ``None``. + + Returns: + loss: if ``closure`` is provided. Otherwise ``None``. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + + if p.grad is None: + continue + + state = self.state[p] + # State initialization + if len(state) == 0: + #print(f'group {group_id} param {param_id} = {p.numel()}') + state['step'] = 0 + + #use full precision by default unless self.fp32_optimizer_states is off + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + + # gradient momentums + state['exp_avg'] = torch.zeros_like(p.data, + dtype=state_dtype, + device='cpu') + #memory_format=torch.preserve_format) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p.data, + dtype=state_dtype, + device='cpu') + #memory_format=torch.preserve_format) + + state['step'] += 1 + beta1, beta2 = group['betas'] + + if fp16_param_groups is not None: + self.ds_opt_adam.adam_update_copy( + self.opt_id, + state['step'], + group['lr'], + beta1, + beta2, + group['eps'], + group['weight_decay'], + group['bias_correction'], + p.data, + p.grad.data, + state['exp_avg'], + state['exp_avg_sq'], + fp16_param_groups[group_id][param_id].data) + else: + self.ds_opt_adam.adam_update(self.opt_id, + state['step'], + group['lr'], + beta1, + beta2, + group['eps'], + group['weight_decay'], + group['bias_correction'], + p.data, + p.grad.data, + state['exp_avg'], + state['exp_avg_sq']) + return loss diff --git a/deepspeed/ops/aio/__init__.py b/deepspeed/ops/aio/__init__.py index 50e6c9a3c..d25f81573 100755 --- a/deepspeed/ops/aio/__init__.py +++ b/deepspeed/ops/aio/__init__.py @@ -1,6 +1,6 @@ -''' -Copyright 2020 The Microsoft DeepSpeed Team. -Licensed under the MIT license. -''' - -from ..op_builder import AsyncIOBuilder +''' +Copyright 2020 The Microsoft DeepSpeed Team. +Licensed under the MIT license. +''' + +from ..op_builder import AsyncIOBuilder diff --git a/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py b/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py index 7a0bd4f4c..6c134d71f 100755 --- a/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py +++ b/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py @@ -1,78 +1,78 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -""" - -from torch import nn -from deepspeed.ops.sparse_attention import SparseSelfAttention, FixedSparsityConfig - - -class BertSparseSelfAttention(nn.Module): - """Implements Sparse Self Attention layer of Bert model based on https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py#L373 - - For more information please see, TODO DeepSpeed Sparse Transformer. - - For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial. - """ - def __init__( - self, - config, - # SparsityConfig parameters needs to be set accordingly - sparsity_config=FixedSparsityConfig(num_heads=4)): - """Initialize the bert sparse self attention layer. - - Note) you can use any of the provided sparsity configs or simply add yours! - - Arguments: - config: required: Bert model config - sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on FixedSparsityConfig class. - """ - - super(BertSparseSelfAttention, self).__init__() - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, - config.num_attention_heads)) - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.sparse_self_attention = SparseSelfAttention(sparsity_config) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, - self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states, attention_mask): - """Applies forward phase of bert sparse self attention - - Arguments: - hidden_states: required: hidden_states tensor of the bert model - attn_mask: required: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported - - Return: - context_layer: a dense tensor containing attention context - """ - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - context_layer = self.sparse_self_attention(query_layer, - key_layer, - value_layer, - key_padding_mask=attention_mask) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) - context_layer = context_layer.view(*new_context_layer_shape) - return context_layer +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" + +from torch import nn +from deepspeed.ops.sparse_attention import SparseSelfAttention, FixedSparsityConfig + + +class BertSparseSelfAttention(nn.Module): + """Implements Sparse Self Attention layer of Bert model based on https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py#L373 + + For more information please see, TODO DeepSpeed Sparse Transformer. + + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial. + """ + def __init__( + self, + config, + # SparsityConfig parameters needs to be set accordingly + sparsity_config=FixedSparsityConfig(num_heads=4)): + """Initialize the bert sparse self attention layer. + + Note) you can use any of the provided sparsity configs or simply add yours! + + Arguments: + config: required: Bert model config + sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on FixedSparsityConfig class. + """ + + super(BertSparseSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, + config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.sparse_self_attention = SparseSelfAttention(sparsity_config) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + """Applies forward phase of bert sparse self attention + + Arguments: + hidden_states: required: hidden_states tensor of the bert model + attn_mask: required: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported + + Return: + context_layer: a dense tensor containing attention context + """ + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + context_layer = self.sparse_self_attention(query_layer, + key_layer, + value_layer, + key_padding_mask=attention_mask) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 8bf3c5d58..9a5bee886 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -1,912 +1,912 @@ -''' -Copyright (c) Microsoft Corporation -Licensed under the MIT license. - -Use to partition the activations stored for backward propagation -Therefore reduces the memory consumption -Also implements CPU checkpointing and contiguous memory checkpointing -Reduces memory consumption and memory fragmentation - -Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py -b886b7bb972afe72bac0f5de4f42a4a7bae8ebef -''' - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch -import copy -import torch -import contextlib -import torch.distributed as dist - -import mmap -from torch import _C -from torch.cuda import _lazy_call, device as device_ctx_manager - -from deepspeed.runtime.config import DeepSpeedConfig -from deepspeed.utils import logger -from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank -from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers - -# DeepSpeed Checkpointing Enabled or Disabled -deepspeed_checkpointing_enabled = False - -# MP parameters -mpu = None -mp_rank = None -mp_size = None -mp_group = None - -# Model Parameters -num_layers = None - -# Checkpointing buffers -contiguous_data_buffers = [] -data_offsets = [] - -contiguous_size_buffers = [] -size_offsets = [] - -timers = None - -# optimization flags -PARTITION_ACTIVATIONS = False -CPU_CHECKPOINT = False -CONTIGUOUS_CHECKPOINTING = False -SYNCHRONIZE = False -PROFILE_TIME = False - -# Default name for the model parallel rng tracker. -_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' -transport_stream = None -cuda_device = None - - -def detach_variable(inputs, device=None): - if isinstance(inputs, tuple): - out = [] - for inp in inputs: - if not isinstance(inp, torch.Tensor): - out.append(inp) - continue - - requires_grad = inp.requires_grad - - if device is not None: - x = inp.to(device=device) - else: - x = inp - - x = x.detach() - x.requires_grad = requires_grad - out.append(x) - return tuple(out) - else: - raise RuntimeError( - "Only tuple of tensors is supported. Got Unsupported input type: ", - type(inputs).__name__) - - -def _set_cuda_rng_state(new_state, device=-1): - """Sets the random number generator state of the current GPU. - - Arguments: - new_state (torch.ByteTensor): The desired state - This function is adapted from PyTorch repo (torch.cuda.set_rng_state) - with a single change: the input state is not cloned. Cloning caused - major performance issues for +4 GPU cases. - """ - if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): - # older PyTorch - def cb(): - with device_ctx_manager(device): - _C._cuda_setRNGState(new_state) - else: - # newer PyTorch - if device == -1: - device = torch.device('cuda') - elif isinstance(device, str): - device = torch.device(device) - elif isinstance(device, int): - device = torch.device('cuda', device) - - def cb(): - idx = device.index - if idx is None: - idx = torch.cuda.current_device() - default_generator = torch.cuda.default_generators[idx] - default_generator.set_state(new_state) - - _lazy_call(cb) - - -class CudaRNGStatesTracker: - """Tracker for the cuda RNG states. - - Using the `add` method, a cuda rng state is initialized based on - the input `seed` and is assigned to `name`. Later, by forking the - rng state, we can perform operations and return to our starting - cuda state. - """ - def __init__(self): - # Map from a string name to the cuda rng state. - self.states_ = {} - # Seeds are just for book keeping and ensure no seed is set twice. - self.seeds_ = set() - - def reset(self): - """Set to the initial state (no tracker).""" - self.states_ = {} - self.seeds_ = set() - - def get_states(self): - """Get rng states. Copy the dictionary so we have direct - pointers to the states, not just a pointer to the dictionary.""" - return copy.copy(self.states_) - - def set_states(self, states): - """Set the rng states. For efficiency purposes, we do not check - the size of seed for compatibility.""" - self.states_ = states - - def add(self, name, seed): - """Track the rng state.""" - # Check seed is not already used. - if seed in self.seeds_: - raise Exception('seed {} already exists'.format(seed)) - self.seeds_.add(seed) - # Check that state is not already defined. - if name in self.states_: - raise Exception('cuda rng state {} already exists'.format(name)) - # Get the current rng state. - orig_rng_state = torch.cuda.get_rng_state() - # Set the new state and store it. - torch.cuda.manual_seed(seed) - self.states_[name] = torch.cuda.get_rng_state() - # Reset rng state to what it was. - _set_cuda_rng_state(orig_rng_state) - - @contextlib.contextmanager - def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): - """Fork the cuda rng state, perform operations, and exit with - the original state.""" - # Check if we have added the state - if name not in self.states_: - raise Exception('cuda rng state {} is not added'.format(name)) - # Store current rng state. - orig_cuda_rng_state = torch.cuda.get_rng_state() - # Set rng state to the desired one - _set_cuda_rng_state(self.states_[name]) - # Do the stuff we wanted to do. - try: - yield - finally: - # Update the current rng state for later use. - self.states_[name] = torch.cuda.get_rng_state() - # And set the state to the original state we started with. - _set_cuda_rng_state(orig_cuda_rng_state) - - -# RNG tracker object. -_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - - -def get_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _CUDA_RNG_STATE_TRACKER - - -def model_parallel_cuda_manual_seed(seed): - """Initialize model parallel cuda seed. - - This function should be called after the model parallel is - initialized. Also, no torch.cuda.manual_seed should be called - after this function. Basically, this is replacement for that - function. - Two set of RNG states are tracked: - default state: This is for data parallelism and is the same among a - set of model parallel GPUs but different across - different model parallel groups. This is used for - example for dropout in the non-model-parallel regions. - model-parallel state: This state is different among a set of model - parallel GPUs, but the same across data parallel - groups. This is used for example for dropout in - model parallel regions. - """ - global mpu - - tp_rank = bwc_tensor_model_parallel_rank(mpu) - - # 2718 is just for fun and any POSITIVE value will work. - offset = seed + 2718 - model_parallel_seed = offset + tp_rank - # Data parallel gets the original seed. - data_parallel_seed = seed - - if torch.distributed.get_rank() == 0: - logger.info( - '> initializing model parallel cuda seeds on global rank {}, ' - 'model parallel rank {}, and data parallel rank {} with ' - 'model parallel seed: {} and data parallel seed: {}'.format( - torch.distributed.get_rank(), - tp_rank, - mpu.get_data_parallel_rank(), - model_parallel_seed, - data_parallel_seed), - ) - _CUDA_RNG_STATE_TRACKER.reset() - # Set the default state. - torch.cuda.manual_seed(data_parallel_seed) - # and model parallel state. - _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) - - -def get_partition_start(item): - global mp_rank, mp_size, mp_group - size = item.numel() - partition_size = size / mp_size - start = partition_size * mp_rank - return int(start) - - -def get_partition_size(item): - global mp_rank, mp_size, mp_group - size = item.numel() - assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size" - partition_size = size / mp_size - return int(partition_size) - - -def gather_partitioned_activations(tensors, device=None): - global mp_rank, mp_size, mp_group - assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}' - inputs = [] - num_args = int(len(tensors) / 2) - for i in range(num_args): - - item = tensors[2 * i] - size = tensors[2 * i + 1] - - if not is_activation_to_checkpoint(item): - inputs.append(item) - continue - - partition_size = item.numel() - tensor_size = partition_size * mp_size - if device is not None: - flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device) - else: - flat_tensor = torch.zeros([tensor_size], - dtype=item.dtype, - device=item.device) - partitions = [] - for i in range(mp_size): - part_i = flat_tensor.narrow(0, partition_size * i, partition_size) - if i == mp_rank: - part_i.copy_(item) - partitions.append(part_i) - if mp_group is not None: - dist.all_gather(partitions, partitions[mp_rank], group=mp_group) - input_tensor = flat_tensor.view(list(size.numpy())) - item.data = input_tensor.data - - inputs.append(item) - - return tuple(inputs) - - -def extract_tensors(all_objects): - """ - Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation. - The order of tensors and non-tensors is preserved in their respective output groups. - - Parameters: - all_objects (list/tuple): Objects containing tensors and non-tensors to be split. - - Returns: - tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor. - - """ - tensor_objects = [v for v in all_objects if torch.is_tensor(v)] - non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)] - tensor_flags = [torch.is_tensor(v) for v in all_objects] - if type(all_objects) is tuple: - return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags) - return tensor_objects, non_tensor_objects, tensor_flags - - -def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags): - """ - Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple). - - Parameters: - tensor_objects (list/tuple): Tensors to merge. - non_tensor_objects (list/tuple): Non-tensors to merge. - tensor_flags (list/tuple): Indicates whether each position in output is a tensor. - - Returns: - tuple: Merge of tensors and non-tensors - """ - merged_objects = [] - tensor_idx = 0 - non_tensor_idx = 0 - - real_tensor_flags = None - - # remove the flags that are assigned to the size of the flattened tensors - if PARTITION_ACTIVATIONS: - real_tensor_flags = [] - previous_flag = False - for flag in tensor_flags: - if previous_flag: - previous_flag = False - continue - previous_flag = flag - real_tensor_flags.append(flag) - else: - real_tensor_flags = tensor_flags - - for is_tensor in real_tensor_flags: - if is_tensor: - merged_objects.append(tensor_objects[tensor_idx]) - tensor_idx += 1 - else: - merged_objects.append(non_tensor_objects[non_tensor_idx]) - non_tensor_idx += 1 - - return tuple(merged_objects) - - -def is_activation_to_checkpoint(item): - """ - Is an activation to be checkpointed - """ - global mp_size - return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size - - -def partition_activations(args, cpu_checkpoint, contiguous_checkpoint): - global contiguous_data_buffers, data_offsets - - inputs = [] - num_non_fp_tensors = 0 - - for arg_index, item in enumerate(args): - if not is_activation_to_checkpoint(item): - inputs.append(item) - num_non_fp_tensors += 1 - continue - - i = arg_index - num_non_fp_tensors - partition_size = get_partition_size(item) - partition = item.detach().contiguous().view(-1).narrow( - 0, - get_partition_start(item), - partition_size).clone() - - buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device - - if contiguous_checkpoint: - if i >= len(contiguous_data_buffers): - tensor_list = [ - torch.tensor(()).new_empty([partition_size], - dtype=partition.dtype, - device=buffer_device) - for _ in range(num_layers) - ] - contiguous_data_buffers.append(tensor_list) - data_offsets.append(0) - elif contiguous_data_buffers[i] is None: - tensor_list = [ - torch.tensor(()).new_empty([partition_size], - dtype=partition.dtype, - device=buffer_device) - for _ in range(num_layers) - ] - contiguous_data_buffers[i] = tensor_list - data_offsets[i] = 0 - - # Because the 'new_empty' returns uninitialized pages, - # the pages need to be populated during the cudaMemcpy time - # which increases the data copy time. To avoid this, we - # pre-populate these pages by simply writing 0 ahead of - # the actual cudaMemcpy operation time. Due to the - # previously launched GPU kernels, there is a small - # window of time here for CPUs to populate pages asynchronously. - contiguous_data_buffers[i][data_offsets[i]].data[range( - 0, - contiguous_data_buffers[i][data_offsets[i]].data.shape[0], - int(mmap.PAGESIZE / - contiguous_data_buffers[i][data_offsets[i]].data.element_size()) - )] = 0 - - contiguous_partition = contiguous_data_buffers[i][ - data_offsets[i]].data.copy_(partition.data) - data_offsets[i] = data_offsets[i] + 1 - inputs.append(contiguous_partition) - else: - partition = partition.cpu() if CPU_CHECKPOINT else partition - inputs.append(partition) - - return inputs - - -def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint): - global contiguous_size_buffers, size_offsets - - new_args = [] - num_non_fp_tensors = 0 - - for arg_index, (arg, inp) in enumerate(zip(args, inputs)): - size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None - if not is_activation_to_checkpoint(arg): - new_args.append(arg) - new_args.append(size) - num_non_fp_tensors += 1 - continue - - arg.data = inp.data - new_args.append(arg) - i = arg_index - num_non_fp_tensors - - if contiguous_checkpoint: - numel = size.numel() - if i >= len(contiguous_size_buffers): - tmp = torch.tensor(()) - contiguous_size_buffers.append( - tmp.new_empty([numel * num_layers], - dtype=size.dtype, - device=size.device)) - size_offsets.append(0) - elif contiguous_size_buffers[i] is None: - tmp = torch.tensor(()) - contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], - dtype=size.dtype, - device=size.device) - size_offsets[i] = 0 - - contiguous_size = contiguous_size_buffers[i].narrow( - 0, - size_offsets[i], - numel).data.copy_(size.data) - contiguous_size = contiguous_size.view_as(size) - size_offsets[i] = size_offsets[i] + numel - new_args.append(contiguous_size) - else: - new_args.append(size) - - return new_args - - -def get_cpu_activations_for_backward(args, inputs): - new_args = [] - for i, (arg, inp) in enumerate(zip(args, inputs)): - if not is_activation_to_checkpoint(arg): - new_args.append(arg) - continue - - arg.data = inp.data - new_args.append(arg) - - return new_args - - -class CheckpointFunction(torch.autograd.Function): - """This function is adapted from torch.utils.checkpoint with - two main changes: - 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` - 2) the states in the model parallel tracker are also properly - tracked/set/reset. - 3) Performance activation partitioning, contiguous memory optimization - 4) CPU Checkpointing - 5) Profile forward and backward functions - """ - @staticmethod - def forward(ctx, run_function, all_outputs, *args): - global mpu, timers, SYNCHRONIZE, PROFILE_TIME - - def save_args_for_backward(*all_args): - tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args) - ctx.deepspeed_saved_tensors = tensor_args - ctx.non_tensor_args = non_tensor_args - ctx.tensor_flags = tensor_flags - - if SYNCHRONIZE: - torch.cuda.synchronize() - - if timers is None and PROFILE_TIME: - timers = Timers() - - if PROFILE_TIME: - timers('forward').start() - - ctx.run_function = run_function - global num_layers - global mp_rank, mp_size, mp_group - global contiguous_data_buffers, contiguous_size_buffers - global data_offsets, size_offsets - if mp_rank is None: - if mpu is not None: - if hasattr(mpu, 'get_tensor_model_parallel_rank'): - mp_rank = mpu.get_tensor_model_parallel_rank() - mp_size = mpu.get_tensor_model_parallel_world_size() - mp_group = mpu.get_tensor_model_parallel_group() - else: - mp_rank = mpu.get_model_parallel_rank() - mp_size = mpu.get_model_parallel_world_size() - mp_group = mpu.get_model_parallel_group() - else: - mp_rank = 0 - mp_size = 1 - mp_group = None - - global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset - - if cuda_device is None: - see_memory_usage("First Forward Beginning", force=False) - if dist.get_rank() == 0: - logger.info(f"Activation Checkpointing Information") - logger.info( - f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}" - ) - logger.info( - f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" - ) - logger.info(f"----Synchronization {SYNCHRONIZE}") - logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}") - - cuda_device = torch.cuda.current_device() - transport_stream = torch.cuda.Stream(device=cuda_device) - - if PARTITION_ACTIVATIONS: - inputs = partition_activations(args, - CPU_CHECKPOINT, - CONTIGUOUS_CHECKPOINTING) - elif CPU_CHECKPOINT: - inputs = copy_to_device(args, - device=torch.device('cpu'), - criterion_func=is_activation_to_checkpoint) - - # just in case something funky is happening such as reuse of inputs - inputs_cuda = copy_to_device(args, - device=cuda_device, - criterion_func=is_activation_to_checkpoint) - - # Copy the rng states. - ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - see_memory_usage("Before running forward on the layer", force=False) - # ctx.save_for_backward(*args) - with torch.no_grad(): - outputs = run_function(*inputs_cuda) - - see_memory_usage("After running forward on the layer", force=False) - del inputs_cuda - - if PARTITION_ACTIVATIONS: - new_args = get_partitioned_activations_for_backward( - args, - inputs, - CONTIGUOUS_CHECKPOINTING) - assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}' - save_args_for_backward(*new_args) - elif CPU_CHECKPOINT: - new_args = get_cpu_activations_for_backward(args, inputs) - save_args_for_backward(*new_args) - else: - save_args_for_backward(*args) - - if PROFILE_TIME: - timers('forward').stop() - timers.log(['forward']) - if SYNCHRONIZE: - torch.cuda.synchronize() - - # Tensors returned from forward() may not be differentiable. - if torch.is_tensor(outputs): - non_grad_outputs = [outputs] if not outputs.is_floating_point() else [] - else: - non_grad_outputs = [ - o for o in outputs if torch.is_tensor(o) and not o.is_floating_point() - ] - ctx.mark_non_differentiable(*non_grad_outputs) - - if torch.is_tensor(outputs): - all_outputs += [outputs] - return outputs - else: - all_outputs += outputs - outputs, _, _ = extract_tensors(all_objects=outputs) - return tuple(outputs) - - @staticmethod - def backward(ctx, *grads): - global timers - see_memory_usage("In backward", force=False) - # removing pointers to the contiguous buffer memory - # so that they can be garbage collected once the checkpoints - # have been used - if SYNCHRONIZE: - torch.cuda.synchronize() - if PROFILE_TIME: - timers('backward').start() - - if CONTIGUOUS_CHECKPOINTING: - global data_offsets, size_offsets - global contiguous_data_buffers, contiguous_size_buffers - - for buffers in contiguous_data_buffers: - buffers = [] - - # frees up all the pointers to the checkpoints except for the ones - # stored by save for backward - contiguous_data_buffers = [] - contiguous_size_buffers = [] - data_offsets = [] - size_offsets = [] - - see_memory_usage("In backward checkpointing code", force=False) - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError("Checkpointing is not compatible with .grad(), " - "please use .backward() if possible") - - global cuda_device, transport_stream, PARTITION_ACTIVATIONS - - if PARTITION_ACTIVATIONS: - # with torch.cuda.stream(transport_stream): - inputs = gather_partitioned_activations( - ctx.deepspeed_saved_tensors, - device=cuda_device if CPU_CHECKPOINT else None) - detached_inputs = detach_variable(inputs) - elif CPU_CHECKPOINT: - inputs = move_to_device(ctx.deepspeed_saved_tensors, - cuda_device, - is_activation_to_checkpoint) - detached_inputs = detach_variable(inputs) - else: - inputs = ctx.deepspeed_saved_tensors - detached_inputs = detach_variable(inputs) - - # Add non tensor input args - detached_inputs = merge_tensors(tensor_objects=detached_inputs, - non_tensor_objects=ctx.non_tensor_args, - tensor_flags=ctx.tensor_flags) - - # Store the current states. - bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = torch.cuda.get_rng_state() - bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - # Set the states to what it used to be before the forward pass. - torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) - - # if PARTITION_ACTIVATIONS: - # current_stream=torch.cuda.current_stream() - # current_stream.wait_stream(transport_stream) - - see_memory_usage("In backward checkpointing code before forward", force=False) - - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) - - see_memory_usage("In backward checkpointing code after forward", force=False) - # Set the states back to what it was at the start of this function. - torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) - - if isinstance(outputs, torch.Tensor): - outputs = (outputs, ) - - # Filter out non tensor outputs - outputs, _, _ = extract_tensors(all_objects=outputs) - - # Construct arguments to autograd.backward(). - # This is usually just outputs and grads, but forward() can return tensors that - # are not differentiable. - output_tensors = [] - grad_tensors = [] - for out, grad in zip(outputs, grads): - if out.requires_grad: - output_tensors.append(out) - grad_tensors.append(grad) - - see_memory_usage("In backward checkpointing code before backward", force=False) - - torch.autograd.backward(output_tensors, grad_tensors) - - see_memory_usage("After backward checkpointing code after backward", force=False) - - if PROFILE_TIME: - timers('backward').stop() - timers.log(['backward']) - if SYNCHRONIZE: - torch.cuda.synchronize() - ret_list = [None, None] # first None for ctx - for inp in detached_inputs: - if torch.is_tensor(inp): - ret_list.append(inp.grad) - else: - ret_list.append(None) - - return tuple(ret_list) - - -def checkpoint(function, *args): - """Checkpoint a model or part of the model. - This has been directly copied from torch.utils.checkpoint. """ - - all_outputs = [] - CheckpointFunction.apply(function, all_outputs, *args) - if len(all_outputs) == 1: - return all_outputs[0] - else: - return tuple(all_outputs) - - -def partition_activations_in_checkpoint(partition_activation): - global PARTITION_ACTIVATIONS - PARTITION_ACTIVATIONS = partition_activation - if dist.get_rank() == 0: - logger.info( - f"**************Partition Activations {PARTITION_ACTIVATIONS}************") - - -def set_num_layers(nlayers): - global num_layers - num_layers = nlayers - - -def reset(): - """Resets memory buffers related to contiguous memory optimizations. - Should be called during eval when multiple forward propagations are - computed without any backward propagation that usually clears these - buffers. - Arguments: - None - - Return: - None - """ - if CONTIGUOUS_CHECKPOINTING: - global data_offsets, size_offsets - global contiguous_data_buffers, contiguous_size_buffers - - for buffers in contiguous_data_buffers: - buffers = [] - - # frees up all the pointers to the checkpoints except for the ones - # stored by save for backward - contiguous_data_buffers = [] - contiguous_size_buffers = [] - data_offsets = [] - size_offsets = [] - - -def _configure_using_config_file(config, mpu=None): - global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ - CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME - - config = DeepSpeedConfig(config, mpu=mpu).activation_checkpointing_config - if dist.get_rank() == 0: - logger.info(config.repr()) - PARTITION_ACTIVATIONS = config.partition_activations - CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization - num_layers = config.number_checkpoints - CPU_CHECKPOINT = config.cpu_checkpointing - SYNCHRONIZE = config.synchronize_checkpoint_boundary - PROFILE_TIME = config.profile - - -def _configure_defaults(): - - global mpu, num_layers, deepspeed_checkpointing_enabled - - global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ - CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME - - PARTITION_ACTIVATIONS = False - CONTIGUOUS_CHECKPOINTING = False - num_layers = False - CPU_CHECKPOINT = False - SYNCHRONIZE = False - PROFILE_TIME = False - deepspeed_checkpointing_enabled = True - - -def configure( - mpu_, - deepspeed_config=None, - partition_activations=None, - contiguous_checkpointing=None, - num_checkpoints=None, - checkpoint_in_cpu=None, - synchronize=None, - profile=None, -): - """Configure DeepSpeed Activation Checkpointing. - - Arguments: - mpu_: Optional: An object that implements the following methods - get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size - - deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to - configure DeepSpeed Activation Checkpointing - - partition_activations: Optional: Partitions activation checkpoint across model parallel - GPUs when enabled. By default False. Will overwrite deepspeed_config if provided - - contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory - buffer. Works only with homogeneous checkpoints when partition_activations is enabled. - Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if - provided - - num_checkpoints: Optional: Number of activation checkpoints stored during the forward - propagation of the model. Used to calculate the buffer size for contiguous_checkpointing - Will overwrite deepspeed_config if provided - - checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with - partition_activation. Default is false. Will overwrite deepspeed_config if provided - - synchronize: Optional: Performs torch.cuda.synchronize() at the beginning and end of - each call to deepspeed.checkpointing.checkpoint for both forward and backward pass. - By default false. Will overwrite deepspeed_config if provided - - profile: Optional: Logs the forward and backward time for each - deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config - if provided - - Returns: - None - """ - global mpu, num_layers, deepspeed_checkpointing_enabled - - global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ - CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME - - _configure_defaults() - - if mpu_ is not None: - mpu = mpu_ - - if deepspeed_config is not None: - _configure_using_config_file(deepspeed_config, mpu=mpu) - - if partition_activations is not None: - PARTITION_ACTIVATIONS = partition_activations - - if contiguous_checkpointing is not None: - CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing - - if num_checkpoints is not None: - num_layers = num_checkpoints - - if checkpoint_in_cpu is not None: - CPU_CHECKPOINT = checkpoint_in_cpu - - if synchronize is not None: - SYNCHRONIZE = synchronize - - if profile is not None: - PROFILE_TIME = profile - - if CONTIGUOUS_CHECKPOINTING: - assert PARTITION_ACTIVATIONS, "Contiguous Checkpointing is only available with partitioned activations. Set partitioned activations to true in deepspeed config" - if CONTIGUOUS_CHECKPOINTING: - assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing" - - -def is_configured(): - """True if deepspeed activation checkpointing has been configured - by calling deepspeed.checkpointing.configure, else returns false - - Arguments: - None - - Return: - True of configured, else False - """ - return deepspeed_checkpointing_enabled +''' +Copyright (c) Microsoft Corporation +Licensed under the MIT license. + +Use to partition the activations stored for backward propagation +Therefore reduces the memory consumption +Also implements CPU checkpointing and contiguous memory checkpointing +Reduces memory consumption and memory fragmentation + +Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py +b886b7bb972afe72bac0f5de4f42a4a7bae8ebef +''' + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch +import copy +import torch +import contextlib +import torch.distributed as dist + +import mmap +from torch import _C +from torch.cuda import _lazy_call, device as device_ctx_manager + +from deepspeed.runtime.config import DeepSpeedConfig +from deepspeed.utils import logger +from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank +from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers + +# DeepSpeed Checkpointing Enabled or Disabled +deepspeed_checkpointing_enabled = False + +# MP parameters +mpu = None +mp_rank = None +mp_size = None +mp_group = None + +# Model Parameters +num_layers = None + +# Checkpointing buffers +contiguous_data_buffers = [] +data_offsets = [] + +contiguous_size_buffers = [] +size_offsets = [] + +timers = None + +# optimization flags +PARTITION_ACTIVATIONS = False +CPU_CHECKPOINT = False +CONTIGUOUS_CHECKPOINTING = False +SYNCHRONIZE = False +PROFILE_TIME = False + +# Default name for the model parallel rng tracker. +_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' +transport_stream = None +cuda_device = None + + +def detach_variable(inputs, device=None): + if isinstance(inputs, tuple): + out = [] + for inp in inputs: + if not isinstance(inp, torch.Tensor): + out.append(inp) + continue + + requires_grad = inp.requires_grad + + if device is not None: + x = inp.to(device=device) + else: + x = inp + + x = x.detach() + x.requires_grad = requires_grad + out.append(x) + return tuple(out) + else: + raise RuntimeError( + "Only tuple of tensors is supported. Got Unsupported input type: ", + type(inputs).__name__) + + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + + Arguments: + new_state (torch.ByteTensor): The desired state + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + """ + if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + _C._cuda_setRNGState(new_state) + else: + # newer PyTorch + if device == -1: + device = torch.device('cuda') + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device('cuda', device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + return copy.copy(self.states_) + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception('seed {} already exists'.format(seed)) + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception('cuda rng state {} already exists'.format(name)) + # Get the current rng state. + orig_rng_state = torch.cuda.get_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = torch.cuda.get_rng_state() + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + raise Exception('cuda rng state {} is not added'.format(name)) + # Store current rng state. + orig_cuda_rng_state = torch.cuda.get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = torch.cuda.get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) + + +# RNG tracker object. +_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + + +def get_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _CUDA_RNG_STATE_TRACKER + + +def model_parallel_cuda_manual_seed(seed): + """Initialize model parallel cuda seed. + + This function should be called after the model parallel is + initialized. Also, no torch.cuda.manual_seed should be called + after this function. Basically, this is replacement for that + function. + Two set of RNG states are tracked: + default state: This is for data parallelism and is the same among a + set of model parallel GPUs but different across + different model parallel groups. This is used for + example for dropout in the non-model-parallel regions. + model-parallel state: This state is different among a set of model + parallel GPUs, but the same across data parallel + groups. This is used for example for dropout in + model parallel regions. + """ + global mpu + + tp_rank = bwc_tensor_model_parallel_rank(mpu) + + # 2718 is just for fun and any POSITIVE value will work. + offset = seed + 2718 + model_parallel_seed = offset + tp_rank + # Data parallel gets the original seed. + data_parallel_seed = seed + + if torch.distributed.get_rank() == 0: + logger.info( + '> initializing model parallel cuda seeds on global rank {}, ' + 'model parallel rank {}, and data parallel rank {} with ' + 'model parallel seed: {} and data parallel seed: {}'.format( + torch.distributed.get_rank(), + tp_rank, + mpu.get_data_parallel_rank(), + model_parallel_seed, + data_parallel_seed), + ) + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) + + +def get_partition_start(item): + global mp_rank, mp_size, mp_group + size = item.numel() + partition_size = size / mp_size + start = partition_size * mp_rank + return int(start) + + +def get_partition_size(item): + global mp_rank, mp_size, mp_group + size = item.numel() + assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size" + partition_size = size / mp_size + return int(partition_size) + + +def gather_partitioned_activations(tensors, device=None): + global mp_rank, mp_size, mp_group + assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}' + inputs = [] + num_args = int(len(tensors) / 2) + for i in range(num_args): + + item = tensors[2 * i] + size = tensors[2 * i + 1] + + if not is_activation_to_checkpoint(item): + inputs.append(item) + continue + + partition_size = item.numel() + tensor_size = partition_size * mp_size + if device is not None: + flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device) + else: + flat_tensor = torch.zeros([tensor_size], + dtype=item.dtype, + device=item.device) + partitions = [] + for i in range(mp_size): + part_i = flat_tensor.narrow(0, partition_size * i, partition_size) + if i == mp_rank: + part_i.copy_(item) + partitions.append(part_i) + if mp_group is not None: + dist.all_gather(partitions, partitions[mp_rank], group=mp_group) + input_tensor = flat_tensor.view(list(size.numpy())) + item.data = input_tensor.data + + inputs.append(item) + + return tuple(inputs) + + +def extract_tensors(all_objects): + """ + Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation. + The order of tensors and non-tensors is preserved in their respective output groups. + + Parameters: + all_objects (list/tuple): Objects containing tensors and non-tensors to be split. + + Returns: + tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor. + + """ + tensor_objects = [v for v in all_objects if torch.is_tensor(v)] + non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)] + tensor_flags = [torch.is_tensor(v) for v in all_objects] + if type(all_objects) is tuple: + return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags) + return tensor_objects, non_tensor_objects, tensor_flags + + +def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags): + """ + Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple). + + Parameters: + tensor_objects (list/tuple): Tensors to merge. + non_tensor_objects (list/tuple): Non-tensors to merge. + tensor_flags (list/tuple): Indicates whether each position in output is a tensor. + + Returns: + tuple: Merge of tensors and non-tensors + """ + merged_objects = [] + tensor_idx = 0 + non_tensor_idx = 0 + + real_tensor_flags = None + + # remove the flags that are assigned to the size of the flattened tensors + if PARTITION_ACTIVATIONS: + real_tensor_flags = [] + previous_flag = False + for flag in tensor_flags: + if previous_flag: + previous_flag = False + continue + previous_flag = flag + real_tensor_flags.append(flag) + else: + real_tensor_flags = tensor_flags + + for is_tensor in real_tensor_flags: + if is_tensor: + merged_objects.append(tensor_objects[tensor_idx]) + tensor_idx += 1 + else: + merged_objects.append(non_tensor_objects[non_tensor_idx]) + non_tensor_idx += 1 + + return tuple(merged_objects) + + +def is_activation_to_checkpoint(item): + """ + Is an activation to be checkpointed + """ + global mp_size + return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size + + +def partition_activations(args, cpu_checkpoint, contiguous_checkpoint): + global contiguous_data_buffers, data_offsets + + inputs = [] + num_non_fp_tensors = 0 + + for arg_index, item in enumerate(args): + if not is_activation_to_checkpoint(item): + inputs.append(item) + num_non_fp_tensors += 1 + continue + + i = arg_index - num_non_fp_tensors + partition_size = get_partition_size(item) + partition = item.detach().contiguous().view(-1).narrow( + 0, + get_partition_start(item), + partition_size).clone() + + buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device + + if contiguous_checkpoint: + if i >= len(contiguous_data_buffers): + tensor_list = [ + torch.tensor(()).new_empty([partition_size], + dtype=partition.dtype, + device=buffer_device) + for _ in range(num_layers) + ] + contiguous_data_buffers.append(tensor_list) + data_offsets.append(0) + elif contiguous_data_buffers[i] is None: + tensor_list = [ + torch.tensor(()).new_empty([partition_size], + dtype=partition.dtype, + device=buffer_device) + for _ in range(num_layers) + ] + contiguous_data_buffers[i] = tensor_list + data_offsets[i] = 0 + + # Because the 'new_empty' returns uninitialized pages, + # the pages need to be populated during the cudaMemcpy time + # which increases the data copy time. To avoid this, we + # pre-populate these pages by simply writing 0 ahead of + # the actual cudaMemcpy operation time. Due to the + # previously launched GPU kernels, there is a small + # window of time here for CPUs to populate pages asynchronously. + contiguous_data_buffers[i][data_offsets[i]].data[range( + 0, + contiguous_data_buffers[i][data_offsets[i]].data.shape[0], + int(mmap.PAGESIZE / + contiguous_data_buffers[i][data_offsets[i]].data.element_size()) + )] = 0 + + contiguous_partition = contiguous_data_buffers[i][ + data_offsets[i]].data.copy_(partition.data) + data_offsets[i] = data_offsets[i] + 1 + inputs.append(contiguous_partition) + else: + partition = partition.cpu() if CPU_CHECKPOINT else partition + inputs.append(partition) + + return inputs + + +def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint): + global contiguous_size_buffers, size_offsets + + new_args = [] + num_non_fp_tensors = 0 + + for arg_index, (arg, inp) in enumerate(zip(args, inputs)): + size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None + if not is_activation_to_checkpoint(arg): + new_args.append(arg) + new_args.append(size) + num_non_fp_tensors += 1 + continue + + arg.data = inp.data + new_args.append(arg) + i = arg_index - num_non_fp_tensors + + if contiguous_checkpoint: + numel = size.numel() + if i >= len(contiguous_size_buffers): + tmp = torch.tensor(()) + contiguous_size_buffers.append( + tmp.new_empty([numel * num_layers], + dtype=size.dtype, + device=size.device)) + size_offsets.append(0) + elif contiguous_size_buffers[i] is None: + tmp = torch.tensor(()) + contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], + dtype=size.dtype, + device=size.device) + size_offsets[i] = 0 + + contiguous_size = contiguous_size_buffers[i].narrow( + 0, + size_offsets[i], + numel).data.copy_(size.data) + contiguous_size = contiguous_size.view_as(size) + size_offsets[i] = size_offsets[i] + numel + new_args.append(contiguous_size) + else: + new_args.append(size) + + return new_args + + +def get_cpu_activations_for_backward(args, inputs): + new_args = [] + for i, (arg, inp) in enumerate(zip(args, inputs)): + if not is_activation_to_checkpoint(arg): + new_args.append(arg) + continue + + arg.data = inp.data + new_args.append(arg) + + return new_args + + +class CheckpointFunction(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + 3) Performance activation partitioning, contiguous memory optimization + 4) CPU Checkpointing + 5) Profile forward and backward functions + """ + @staticmethod + def forward(ctx, run_function, all_outputs, *args): + global mpu, timers, SYNCHRONIZE, PROFILE_TIME + + def save_args_for_backward(*all_args): + tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args) + ctx.deepspeed_saved_tensors = tensor_args + ctx.non_tensor_args = non_tensor_args + ctx.tensor_flags = tensor_flags + + if SYNCHRONIZE: + torch.cuda.synchronize() + + if timers is None and PROFILE_TIME: + timers = Timers() + + if PROFILE_TIME: + timers('forward').start() + + ctx.run_function = run_function + global num_layers + global mp_rank, mp_size, mp_group + global contiguous_data_buffers, contiguous_size_buffers + global data_offsets, size_offsets + if mp_rank is None: + if mpu is not None: + if hasattr(mpu, 'get_tensor_model_parallel_rank'): + mp_rank = mpu.get_tensor_model_parallel_rank() + mp_size = mpu.get_tensor_model_parallel_world_size() + mp_group = mpu.get_tensor_model_parallel_group() + else: + mp_rank = mpu.get_model_parallel_rank() + mp_size = mpu.get_model_parallel_world_size() + mp_group = mpu.get_model_parallel_group() + else: + mp_rank = 0 + mp_size = 1 + mp_group = None + + global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset + + if cuda_device is None: + see_memory_usage("First Forward Beginning", force=False) + if dist.get_rank() == 0: + logger.info(f"Activation Checkpointing Information") + logger.info( + f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}" + ) + logger.info( + f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" + ) + logger.info(f"----Synchronization {SYNCHRONIZE}") + logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}") + + cuda_device = torch.cuda.current_device() + transport_stream = torch.cuda.Stream(device=cuda_device) + + if PARTITION_ACTIVATIONS: + inputs = partition_activations(args, + CPU_CHECKPOINT, + CONTIGUOUS_CHECKPOINTING) + elif CPU_CHECKPOINT: + inputs = copy_to_device(args, + device=torch.device('cpu'), + criterion_func=is_activation_to_checkpoint) + + # just in case something funky is happening such as reuse of inputs + inputs_cuda = copy_to_device(args, + device=cuda_device, + criterion_func=is_activation_to_checkpoint) + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + see_memory_usage("Before running forward on the layer", force=False) + # ctx.save_for_backward(*args) + with torch.no_grad(): + outputs = run_function(*inputs_cuda) + + see_memory_usage("After running forward on the layer", force=False) + del inputs_cuda + + if PARTITION_ACTIVATIONS: + new_args = get_partitioned_activations_for_backward( + args, + inputs, + CONTIGUOUS_CHECKPOINTING) + assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}' + save_args_for_backward(*new_args) + elif CPU_CHECKPOINT: + new_args = get_cpu_activations_for_backward(args, inputs) + save_args_for_backward(*new_args) + else: + save_args_for_backward(*args) + + if PROFILE_TIME: + timers('forward').stop() + timers.log(['forward']) + if SYNCHRONIZE: + torch.cuda.synchronize() + + # Tensors returned from forward() may not be differentiable. + if torch.is_tensor(outputs): + non_grad_outputs = [outputs] if not outputs.is_floating_point() else [] + else: + non_grad_outputs = [ + o for o in outputs if torch.is_tensor(o) and not o.is_floating_point() + ] + ctx.mark_non_differentiable(*non_grad_outputs) + + if torch.is_tensor(outputs): + all_outputs += [outputs] + return outputs + else: + all_outputs += outputs + outputs, _, _ = extract_tensors(all_objects=outputs) + return tuple(outputs) + + @staticmethod + def backward(ctx, *grads): + global timers + see_memory_usage("In backward", force=False) + # removing pointers to the contiguous buffer memory + # so that they can be garbage collected once the checkpoints + # have been used + if SYNCHRONIZE: + torch.cuda.synchronize() + if PROFILE_TIME: + timers('backward').start() + + if CONTIGUOUS_CHECKPOINTING: + global data_offsets, size_offsets + global contiguous_data_buffers, contiguous_size_buffers + + for buffers in contiguous_data_buffers: + buffers = [] + + # frees up all the pointers to the checkpoints except for the ones + # stored by save for backward + contiguous_data_buffers = [] + contiguous_size_buffers = [] + data_offsets = [] + size_offsets = [] + + see_memory_usage("In backward checkpointing code", force=False) + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad(), " + "please use .backward() if possible") + + global cuda_device, transport_stream, PARTITION_ACTIVATIONS + + if PARTITION_ACTIVATIONS: + # with torch.cuda.stream(transport_stream): + inputs = gather_partitioned_activations( + ctx.deepspeed_saved_tensors, + device=cuda_device if CPU_CHECKPOINT else None) + detached_inputs = detach_variable(inputs) + elif CPU_CHECKPOINT: + inputs = move_to_device(ctx.deepspeed_saved_tensors, + cuda_device, + is_activation_to_checkpoint) + detached_inputs = detach_variable(inputs) + else: + inputs = ctx.deepspeed_saved_tensors + detached_inputs = detach_variable(inputs) + + # Add non tensor input args + detached_inputs = merge_tensors(tensor_objects=detached_inputs, + non_tensor_objects=ctx.non_tensor_args, + tensor_flags=ctx.tensor_flags) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + # if PARTITION_ACTIVATIONS: + # current_stream=torch.cuda.current_stream() + # current_stream.wait_stream(transport_stream) + + see_memory_usage("In backward checkpointing code before forward", force=False) + + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + see_memory_usage("In backward checkpointing code after forward", force=False) + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs, ) + + # Filter out non tensor outputs + outputs, _, _ = extract_tensors(all_objects=outputs) + + # Construct arguments to autograd.backward(). + # This is usually just outputs and grads, but forward() can return tensors that + # are not differentiable. + output_tensors = [] + grad_tensors = [] + for out, grad in zip(outputs, grads): + if out.requires_grad: + output_tensors.append(out) + grad_tensors.append(grad) + + see_memory_usage("In backward checkpointing code before backward", force=False) + + torch.autograd.backward(output_tensors, grad_tensors) + + see_memory_usage("After backward checkpointing code after backward", force=False) + + if PROFILE_TIME: + timers('backward').stop() + timers.log(['backward']) + if SYNCHRONIZE: + torch.cuda.synchronize() + ret_list = [None, None] # first None for ctx + for inp in detached_inputs: + if torch.is_tensor(inp): + ret_list.append(inp.grad) + else: + ret_list.append(None) + + return tuple(ret_list) + + +def checkpoint(function, *args): + """Checkpoint a model or part of the model. + This has been directly copied from torch.utils.checkpoint. """ + + all_outputs = [] + CheckpointFunction.apply(function, all_outputs, *args) + if len(all_outputs) == 1: + return all_outputs[0] + else: + return tuple(all_outputs) + + +def partition_activations_in_checkpoint(partition_activation): + global PARTITION_ACTIVATIONS + PARTITION_ACTIVATIONS = partition_activation + if dist.get_rank() == 0: + logger.info( + f"**************Partition Activations {PARTITION_ACTIVATIONS}************") + + +def set_num_layers(nlayers): + global num_layers + num_layers = nlayers + + +def reset(): + """Resets memory buffers related to contiguous memory optimizations. + Should be called during eval when multiple forward propagations are + computed without any backward propagation that usually clears these + buffers. + Arguments: + None + + Return: + None + """ + if CONTIGUOUS_CHECKPOINTING: + global data_offsets, size_offsets + global contiguous_data_buffers, contiguous_size_buffers + + for buffers in contiguous_data_buffers: + buffers = [] + + # frees up all the pointers to the checkpoints except for the ones + # stored by save for backward + contiguous_data_buffers = [] + contiguous_size_buffers = [] + data_offsets = [] + size_offsets = [] + + +def _configure_using_config_file(config, mpu=None): + global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ + CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME + + config = DeepSpeedConfig(config, mpu=mpu).activation_checkpointing_config + if dist.get_rank() == 0: + logger.info(config.repr()) + PARTITION_ACTIVATIONS = config.partition_activations + CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization + num_layers = config.number_checkpoints + CPU_CHECKPOINT = config.cpu_checkpointing + SYNCHRONIZE = config.synchronize_checkpoint_boundary + PROFILE_TIME = config.profile + + +def _configure_defaults(): + + global mpu, num_layers, deepspeed_checkpointing_enabled + + global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ + CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME + + PARTITION_ACTIVATIONS = False + CONTIGUOUS_CHECKPOINTING = False + num_layers = False + CPU_CHECKPOINT = False + SYNCHRONIZE = False + PROFILE_TIME = False + deepspeed_checkpointing_enabled = True + + +def configure( + mpu_, + deepspeed_config=None, + partition_activations=None, + contiguous_checkpointing=None, + num_checkpoints=None, + checkpoint_in_cpu=None, + synchronize=None, + profile=None, +): + """Configure DeepSpeed Activation Checkpointing. + + Arguments: + mpu_: Optional: An object that implements the following methods + get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size + + deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to + configure DeepSpeed Activation Checkpointing + + partition_activations: Optional: Partitions activation checkpoint across model parallel + GPUs when enabled. By default False. Will overwrite deepspeed_config if provided + + contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory + buffer. Works only with homogeneous checkpoints when partition_activations is enabled. + Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if + provided + + num_checkpoints: Optional: Number of activation checkpoints stored during the forward + propagation of the model. Used to calculate the buffer size for contiguous_checkpointing + Will overwrite deepspeed_config if provided + + checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with + partition_activation. Default is false. Will overwrite deepspeed_config if provided + + synchronize: Optional: Performs torch.cuda.synchronize() at the beginning and end of + each call to deepspeed.checkpointing.checkpoint for both forward and backward pass. + By default false. Will overwrite deepspeed_config if provided + + profile: Optional: Logs the forward and backward time for each + deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config + if provided + + Returns: + None + """ + global mpu, num_layers, deepspeed_checkpointing_enabled + + global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ + CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME + + _configure_defaults() + + if mpu_ is not None: + mpu = mpu_ + + if deepspeed_config is not None: + _configure_using_config_file(deepspeed_config, mpu=mpu) + + if partition_activations is not None: + PARTITION_ACTIVATIONS = partition_activations + + if contiguous_checkpointing is not None: + CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing + + if num_checkpoints is not None: + num_layers = num_checkpoints + + if checkpoint_in_cpu is not None: + CPU_CHECKPOINT = checkpoint_in_cpu + + if synchronize is not None: + SYNCHRONIZE = synchronize + + if profile is not None: + PROFILE_TIME = profile + + if CONTIGUOUS_CHECKPOINTING: + assert PARTITION_ACTIVATIONS, "Contiguous Checkpointing is only available with partitioned activations. Set partitioned activations to true in deepspeed config" + if CONTIGUOUS_CHECKPOINTING: + assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing" + + +def is_configured(): + """True if deepspeed activation checkpointing has been configured + by calling deepspeed.checkpointing.configure, else returns false + + Arguments: + None + + Return: + True of configured, else False + """ + return deepspeed_checkpointing_enabled diff --git a/deepspeed/runtime/activation_checkpointing/config.py b/deepspeed/runtime/activation_checkpointing/config.py index 19e904980..0ab59ac64 100755 --- a/deepspeed/runtime/activation_checkpointing/config.py +++ b/deepspeed/runtime/activation_checkpointing/config.py @@ -1,103 +1,103 @@ -""" -Copyright (c) Microsoft Corporation -Licensed under the MIT license. -""" - -from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject - -######################################### -# DeepSpeed Activation Checkpointing -######################################### -# Activation Checkpointing Allows to save memory by only keeping a select few -#activations for the backpropagation. -ACTIVATION_CHKPT_FORMAT = ''' -Activation Checkpointing should be configured as: -"session_params": { - "activation_checkpointing": { - "partitioned_activations": [true|false], - "number_checkpoints": 100, - "contiguous_memory_optimization": [true|false], - "cpu_checkpointing": [true|false] - "profile": [true|false], - "synchronize_checkpoint_boundary": [true|false], - } -} -''' - -ACT_CHKPT_PARTITION_ACTIVATIONS = 'partition_activations' -ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT = False - -ACT_CHKPT_NUMBER_CHECKPOINTS = 'number_checkpoints' -ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT = None - -ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION = 'contiguous_memory_optimization' -ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT = False - -ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY = 'synchronize_checkpoint_boundary' -ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT = False - -ACT_CHKPT_PROFILE = 'profile' -ACT_CHKPT_PROFILE_DEFAULT = False - -ACT_CHKPT_CPU_CHECKPOINTING = 'cpu_checkpointing' -ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT = False - -ACT_CHKPT = 'activation_checkpointing' - -ACT_CHKPT_DEFAULT = { - ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT, - ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT, - ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION: - ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT, - ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY: - ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT, - ACT_CHKPT_PROFILE: ACT_CHKPT_PROFILE_DEFAULT, - ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT -} - - -class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject): - def __init__(self, param_dict): - super(DeepSpeedActivationCheckpointingConfig, self).__init__() - - self.partition_activations = None - self.contiguous_memory_optimization = None - self.cpu_checkpointing = None - self.number_checkpoints = None - self.synchronize_checkpoint_boundary = None - self.profile = None - - if ACT_CHKPT in param_dict.keys(): - act_chkpt_config_dict = param_dict[ACT_CHKPT] - else: - act_chkpt_config_dict = ACT_CHKPT_DEFAULT - - self._initialize(act_chkpt_config_dict) - - def _initialize(self, act_chkpt_config_dict): - self.partition_activations = get_scalar_param( - act_chkpt_config_dict, - ACT_CHKPT_PARTITION_ACTIVATIONS, - ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT) - - self.contiguous_memory_optimization = get_scalar_param( - act_chkpt_config_dict, - ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION, - ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT) - - self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict, - ACT_CHKPT_CPU_CHECKPOINTING, - ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT) - - self.number_checkpoints = get_scalar_param(act_chkpt_config_dict, - ACT_CHKPT_NUMBER_CHECKPOINTS, - ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT) - - self.profile = get_scalar_param(act_chkpt_config_dict, - ACT_CHKPT_PROFILE, - ACT_CHKPT_PROFILE_DEFAULT) - - self.synchronize_checkpoint_boundary = get_scalar_param( - act_chkpt_config_dict, - ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY, - ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT) +""" +Copyright (c) Microsoft Corporation +Licensed under the MIT license. +""" + +from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject + +######################################### +# DeepSpeed Activation Checkpointing +######################################### +# Activation Checkpointing Allows to save memory by only keeping a select few +#activations for the backpropagation. +ACTIVATION_CHKPT_FORMAT = ''' +Activation Checkpointing should be configured as: +"session_params": { + "activation_checkpointing": { + "partitioned_activations": [true|false], + "number_checkpoints": 100, + "contiguous_memory_optimization": [true|false], + "cpu_checkpointing": [true|false] + "profile": [true|false], + "synchronize_checkpoint_boundary": [true|false], + } +} +''' + +ACT_CHKPT_PARTITION_ACTIVATIONS = 'partition_activations' +ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT = False + +ACT_CHKPT_NUMBER_CHECKPOINTS = 'number_checkpoints' +ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT = None + +ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION = 'contiguous_memory_optimization' +ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT = False + +ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY = 'synchronize_checkpoint_boundary' +ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT = False + +ACT_CHKPT_PROFILE = 'profile' +ACT_CHKPT_PROFILE_DEFAULT = False + +ACT_CHKPT_CPU_CHECKPOINTING = 'cpu_checkpointing' +ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT = False + +ACT_CHKPT = 'activation_checkpointing' + +ACT_CHKPT_DEFAULT = { + ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT, + ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT, + ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION: + ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT, + ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY: + ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT, + ACT_CHKPT_PROFILE: ACT_CHKPT_PROFILE_DEFAULT, + ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT +} + + +class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject): + def __init__(self, param_dict): + super(DeepSpeedActivationCheckpointingConfig, self).__init__() + + self.partition_activations = None + self.contiguous_memory_optimization = None + self.cpu_checkpointing = None + self.number_checkpoints = None + self.synchronize_checkpoint_boundary = None + self.profile = None + + if ACT_CHKPT in param_dict.keys(): + act_chkpt_config_dict = param_dict[ACT_CHKPT] + else: + act_chkpt_config_dict = ACT_CHKPT_DEFAULT + + self._initialize(act_chkpt_config_dict) + + def _initialize(self, act_chkpt_config_dict): + self.partition_activations = get_scalar_param( + act_chkpt_config_dict, + ACT_CHKPT_PARTITION_ACTIVATIONS, + ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT) + + self.contiguous_memory_optimization = get_scalar_param( + act_chkpt_config_dict, + ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION, + ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT) + + self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict, + ACT_CHKPT_CPU_CHECKPOINTING, + ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT) + + self.number_checkpoints = get_scalar_param(act_chkpt_config_dict, + ACT_CHKPT_NUMBER_CHECKPOINTS, + ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT) + + self.profile = get_scalar_param(act_chkpt_config_dict, + ACT_CHKPT_PROFILE, + ACT_CHKPT_PROFILE_DEFAULT) + + self.synchronize_checkpoint_boundary = get_scalar_param( + act_chkpt_config_dict, + ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY, + ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT) diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index 1d499cdcb..199c773f4 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -1,80 +1,80 @@ -""" -Copyright (c) Microsoft Corporation -Licensed under the MIT license. -""" -""" -Collection of DeepSpeed configuration utilities -""" -import json -import collections - - -# adapted from https://stackoverflow.com/a/50701137/9201239 -class ScientificNotationEncoder(json.JSONEncoder): - """ - This class overrides ``json.dumps`` default formatter. - - This version keeps everything as normal except formats numbers bigger than 1e3 using scientific notation. - - Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it - - """ - def iterencode(self, o, _one_shot=False, level=0): - indent = self.indent if self.indent is not None else 4 - prefix_close = " " * level * indent - level += 1 - prefix = " " * level * indent - if isinstance(o, bool): - return "true" if o else "false" - elif isinstance(o, float) or isinstance(o, int): - if o > 1e3: - return f"{o:e}" - else: - return f"{o}" - elif isinstance(o, collections.Mapping): - x = [ - f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k, - v in o.items() - ] - return "{" + ', '.join(x) + f"\n{prefix_close}" + "}" - elif isinstance(o, collections.Sequence) and not isinstance(o, str): - return f"[{ f', '.join(map(self.iterencode, o)) }]" - return "\n, ".join(super().iterencode(o, _one_shot)) - - -class DeepSpeedConfigObject(object): - """ - For json serialization - """ - def repr(self): - return self.__dict__ - - def __repr__(self): - return json.dumps( - self.__dict__, - sort_keys=True, - indent=4, - cls=ScientificNotationEncoder, - ) - - -def get_scalar_param(param_dict, param_name, param_default_value): - return param_dict.get(param_name, param_default_value) - - -def get_list_param(param_dict, param_name, param_default_value): - return param_dict.get(param_name, param_default_value) - - -def get_dict_param(param_dict, param_name, param_default_value): - return param_dict.get(param_name, param_default_value) - - -def dict_raise_error_on_duplicate_keys(ordered_pairs): - """Reject duplicate keys.""" - d = dict((k, v) for k, v in ordered_pairs) - if len(d) != len(ordered_pairs): - counter = collections.Counter([pair[0] for pair in ordered_pairs]) - keys = [key for key, value in counter.items() if value > 1] - raise ValueError("Duplicate keys in DeepSpeed config: {}".format(keys)) - return d +""" +Copyright (c) Microsoft Corporation +Licensed under the MIT license. +""" +""" +Collection of DeepSpeed configuration utilities +""" +import json +import collections + + +# adapted from https://stackoverflow.com/a/50701137/9201239 +class ScientificNotationEncoder(json.JSONEncoder): + """ + This class overrides ``json.dumps`` default formatter. + + This version keeps everything as normal except formats numbers bigger than 1e3 using scientific notation. + + Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it + + """ + def iterencode(self, o, _one_shot=False, level=0): + indent = self.indent if self.indent is not None else 4 + prefix_close = " " * level * indent + level += 1 + prefix = " " * level * indent + if isinstance(o, bool): + return "true" if o else "false" + elif isinstance(o, float) or isinstance(o, int): + if o > 1e3: + return f"{o:e}" + else: + return f"{o}" + elif isinstance(o, collections.Mapping): + x = [ + f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k, + v in o.items() + ] + return "{" + ', '.join(x) + f"\n{prefix_close}" + "}" + elif isinstance(o, collections.Sequence) and not isinstance(o, str): + return f"[{ f', '.join(map(self.iterencode, o)) }]" + return "\n, ".join(super().iterencode(o, _one_shot)) + + +class DeepSpeedConfigObject(object): + """ + For json serialization + """ + def repr(self): + return self.__dict__ + + def __repr__(self): + return json.dumps( + self.__dict__, + sort_keys=True, + indent=4, + cls=ScientificNotationEncoder, + ) + + +def get_scalar_param(param_dict, param_name, param_default_value): + return param_dict.get(param_name, param_default_value) + + +def get_list_param(param_dict, param_name, param_default_value): + return param_dict.get(param_name, param_default_value) + + +def get_dict_param(param_dict, param_name, param_default_value): + return param_dict.get(param_name, param_default_value) + + +def dict_raise_error_on_duplicate_keys(ordered_pairs): + """Reject duplicate keys.""" + d = dict((k, v) for k, v in ordered_pairs) + if len(d) != len(ordered_pairs): + counter = collections.Counter([pair[0] for pair in ordered_pairs]) + keys = [key for key, value in counter.items() if value > 1] + raise ValueError("Duplicate keys in DeepSpeed config: {}".format(keys)) + return d diff --git a/deepspeed/runtime/eigenvalue.py b/deepspeed/runtime/eigenvalue.py index b1974d975..490899bda 100755 --- a/deepspeed/runtime/eigenvalue.py +++ b/deepspeed/runtime/eigenvalue.py @@ -1,152 +1,152 @@ -import torch -from deepspeed.utils import log_dist -import numpy as np -import logging - - -class Eigenvalue(object): - def __init__(self, - verbose=False, - max_iter=100, - tol=1e-2, - stability=0, - gas_boundary_resolution=1, - layer_name='', - layer_num=0): - super().__init__() - - self.verbose = verbose - self.max_iter = max_iter - self.tol = tol - self.stability = stability - self.gas_boundary_resolution = gas_boundary_resolution - self.layer_name = layer_name - self.layer_num = layer_num - - assert len(self.layer_name) > 0 and layer_num > 0 - - log_dist( - f'enabled eigenvalue with verbose={verbose}, max_iter={max_iter}, tol={tol}, stability={stability}, gas_boundary_resolution={gas_boundary_resolution}, layer_name={layer_name}, layer_num={layer_num}', - ranks=[0]) - - # Replace all nan/pos-inf/neg-inf to zero - # TODO: Pytorch new version may add this function, replace this one by then. - def nan_to_num(self, x): - device = x.device - x = x.cpu().numpy() - x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0) - return torch.from_numpy(x).to(device) - - def normalize(self, v): - norm_squared = self.inner_product(v, v) - norm = norm_squared**0.5 + self.stability - normalized_vectors = [vector / norm for vector in v] - normalized_vectors = [self.nan_to_num(vector) for vector in normalized_vectors] - return normalized_vectors - - def inner_product(self, xs, ys): - return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)]) - - def get_layers(self, module): - scope_names = self.layer_name.split('.') - assert len(scope_names) > 0 - - m = module - for name in scope_names: - assert hasattr(m, name), "layer_name configuration is invalid." - m = getattr(m, name) - - return m - - def compute_eigenvalue(self, module, device=None, scale=1.0): - block_eigenvalue = [] - param_keys = [] - layers = self.get_layers(module) - - for block in range(self.layer_num): - model_block = layers[block] - - # We found this randn() has obvious accuracy impact in some cases, save/recover random state here. - rng_state = torch.random.get_rng_state() - if device is None: - v = [ - torch.randn(p.size()) for p in model_block.parameters() - if p.grad is not None and p.grad.grad_fn is not None - ] - else: - v = [ - torch.randn(p.size(), - device=device) for p in model_block.parameters() - if p.grad is not None and p.grad.grad_fn is not None - ] - torch.random.set_rng_state(rng_state) - - grads = [ - param.grad for param in model_block.parameters() - if param.grad is not None and param.grad.grad_fn is not None - ] - params = [ - param for param in model_block.parameters() - if param.grad is not None and param.grad.grad_fn is not None - ] - - layer_keys = [id(p) for p in model_block.parameters()] - param_keys.append(layer_keys) - - v = self.normalize(v) - - # Disable eigenvalue if the model doesn't support second order gradients computation, - # e.g. when enabling DS transformer kernel. - if len(grads) == 0 or len(params) == 0: - log_dist(f'The model does NOT support eigenvalue computation.', - ranks=[0], - level=logging.WARNING) - return [] - - i = 0 - eigenvalue_current, eigenvalue_previous = 1., 0. - - while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs( - (eigenvalue_current - eigenvalue_previous) / - eigenvalue_current) >= self.tol): # test convergence criteria - eigenvalue_previous = eigenvalue_current - - Hv = torch.autograd.grad(grads, - params, - grad_outputs=v, - only_inputs=True, - retain_graph=True) - #Hv = [hv.float() for hv in Hv] - Hv = [self.nan_to_num(hv).float() for hv in Hv] - - eigenvalue_current = self.inner_product(Hv, v).item() - - v = self.normalize(Hv) - v = [x / scale for x in v] - i += 1 - - eigenvalue_current *= scale - block_eigenvalue.append(eigenvalue_current) - - if self.verbose: - log_dist( - f'block: {block}, power iteration: {i}, eigenvalue: {eigenvalue_current}', - ranks=[0]) - - block_eigenvalue = self.post_process(block_eigenvalue) - - if self.verbose: - log_dist(f'post processed block_eigenvalue: {block_eigenvalue}', ranks=[0]) - - # {param_id: (eigenvalue, layer_id)} - ev_dict = {} - for i, (layer_keys, value) in enumerate(zip(param_keys, block_eigenvalue)): - ev_dict.update(dict.fromkeys(layer_keys, (value, i))) - - return ev_dict - - # 1. Map all eigenvalues to [0, 1.0]. - # 2. Some layers can't generate valid eigenvalues on fp16 precision, use 1.0 instead. - def post_process(self, value_list): - max_value = abs(max(value_list, key=abs)) - return [abs(v) / max_value if v != 0.0 else 1.0 for v in value_list] +import torch +from deepspeed.utils import log_dist +import numpy as np +import logging + + +class Eigenvalue(object): + def __init__(self, + verbose=False, + max_iter=100, + tol=1e-2, + stability=0, + gas_boundary_resolution=1, + layer_name='', + layer_num=0): + super().__init__() + + self.verbose = verbose + self.max_iter = max_iter + self.tol = tol + self.stability = stability + self.gas_boundary_resolution = gas_boundary_resolution + self.layer_name = layer_name + self.layer_num = layer_num + + assert len(self.layer_name) > 0 and layer_num > 0 + + log_dist( + f'enabled eigenvalue with verbose={verbose}, max_iter={max_iter}, tol={tol}, stability={stability}, gas_boundary_resolution={gas_boundary_resolution}, layer_name={layer_name}, layer_num={layer_num}', + ranks=[0]) + + # Replace all nan/pos-inf/neg-inf to zero + # TODO: Pytorch new version may add this function, replace this one by then. + def nan_to_num(self, x): + device = x.device + x = x.cpu().numpy() + x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0) + return torch.from_numpy(x).to(device) + + def normalize(self, v): + norm_squared = self.inner_product(v, v) + norm = norm_squared**0.5 + self.stability + normalized_vectors = [vector / norm for vector in v] + normalized_vectors = [self.nan_to_num(vector) for vector in normalized_vectors] + return normalized_vectors + + def inner_product(self, xs, ys): + return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)]) + + def get_layers(self, module): + scope_names = self.layer_name.split('.') + assert len(scope_names) > 0 + + m = module + for name in scope_names: + assert hasattr(m, name), "layer_name configuration is invalid." + m = getattr(m, name) + + return m + + def compute_eigenvalue(self, module, device=None, scale=1.0): + block_eigenvalue = [] + param_keys = [] + layers = self.get_layers(module) + + for block in range(self.layer_num): + model_block = layers[block] + + # We found this randn() has obvious accuracy impact in some cases, save/recover random state here. + rng_state = torch.random.get_rng_state() + if device is None: + v = [ + torch.randn(p.size()) for p in model_block.parameters() + if p.grad is not None and p.grad.grad_fn is not None + ] + else: + v = [ + torch.randn(p.size(), + device=device) for p in model_block.parameters() + if p.grad is not None and p.grad.grad_fn is not None + ] + torch.random.set_rng_state(rng_state) + + grads = [ + param.grad for param in model_block.parameters() + if param.grad is not None and param.grad.grad_fn is not None + ] + params = [ + param for param in model_block.parameters() + if param.grad is not None and param.grad.grad_fn is not None + ] + + layer_keys = [id(p) for p in model_block.parameters()] + param_keys.append(layer_keys) + + v = self.normalize(v) + + # Disable eigenvalue if the model doesn't support second order gradients computation, + # e.g. when enabling DS transformer kernel. + if len(grads) == 0 or len(params) == 0: + log_dist(f'The model does NOT support eigenvalue computation.', + ranks=[0], + level=logging.WARNING) + return [] + + i = 0 + eigenvalue_current, eigenvalue_previous = 1., 0. + + while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs( + (eigenvalue_current - eigenvalue_previous) / + eigenvalue_current) >= self.tol): # test convergence criteria + eigenvalue_previous = eigenvalue_current + + Hv = torch.autograd.grad(grads, + params, + grad_outputs=v, + only_inputs=True, + retain_graph=True) + #Hv = [hv.float() for hv in Hv] + Hv = [self.nan_to_num(hv).float() for hv in Hv] + + eigenvalue_current = self.inner_product(Hv, v).item() + + v = self.normalize(Hv) + v = [x / scale for x in v] + i += 1 + + eigenvalue_current *= scale + block_eigenvalue.append(eigenvalue_current) + + if self.verbose: + log_dist( + f'block: {block}, power iteration: {i}, eigenvalue: {eigenvalue_current}', + ranks=[0]) + + block_eigenvalue = self.post_process(block_eigenvalue) + + if self.verbose: + log_dist(f'post processed block_eigenvalue: {block_eigenvalue}', ranks=[0]) + + # {param_id: (eigenvalue, layer_id)} + ev_dict = {} + for i, (layer_keys, value) in enumerate(zip(param_keys, block_eigenvalue)): + ev_dict.update(dict.fromkeys(layer_keys, (value, i))) + + return ev_dict + + # 1. Map all eigenvalues to [0, 1.0]. + # 2. Some layers can't generate valid eigenvalues on fp16 precision, use 1.0 instead. + def post_process(self, value_list): + max_value = abs(max(value_list, key=abs)) + return [abs(v) / max_value if v != 0.0 else 1.0 for v in value_list] diff --git a/deepspeed/runtime/progressive_layer_drop.py b/deepspeed/runtime/progressive_layer_drop.py index 770978a94..41c08cfd9 100755 --- a/deepspeed/runtime/progressive_layer_drop.py +++ b/deepspeed/runtime/progressive_layer_drop.py @@ -1,33 +1,33 @@ -import numpy as np -from deepspeed.utils import log_dist - - -class ProgressiveLayerDrop(object): - r""" Progressive Layer Dropping (PLD) for model training. - This implements the PLD technique for compressed model training - from this paper: https://arxiv.org/pdf/2010.13369.pdf - Args: - theta (float): a hyper-parameter that controls the trade-off between training time and robustness. - The lower the theta value, the faster the training speed. Default value: 0.5. - gamma (float): a hyper-parameter that controls how fast the drop ratio increases. Default value: 0.001. - """ - def __init__(self, theta=0.5, gamma=0.001): - super().__init__() - - self.theta = theta - self.gamma = gamma - self.current_theta = 1.0 - log_dist(f'Enabled progressive layer dropping (theta = {self.theta})', ranks=[0]) - - def get_state(self): - kwargs = {'progressive_layer_drop': True, 'pld_theta': self.get_theta()} - return kwargs - - def get_theta(self): - return self.current_theta - - def update_state(self, global_step): - def _prob(x, gamma, p): - return (1. - p) * np.exp(-gamma * x) + p - - self.current_theta = _prob(global_step, self.gamma, self.theta) +import numpy as np +from deepspeed.utils import log_dist + + +class ProgressiveLayerDrop(object): + r""" Progressive Layer Dropping (PLD) for model training. + This implements the PLD technique for compressed model training + from this paper: https://arxiv.org/pdf/2010.13369.pdf + Args: + theta (float): a hyper-parameter that controls the trade-off between training time and robustness. + The lower the theta value, the faster the training speed. Default value: 0.5. + gamma (float): a hyper-parameter that controls how fast the drop ratio increases. Default value: 0.001. + """ + def __init__(self, theta=0.5, gamma=0.001): + super().__init__() + + self.theta = theta + self.gamma = gamma + self.current_theta = 1.0 + log_dist(f'Enabled progressive layer dropping (theta = {self.theta})', ranks=[0]) + + def get_state(self): + kwargs = {'progressive_layer_drop': True, 'pld_theta': self.get_theta()} + return kwargs + + def get_theta(self): + return self.current_theta + + def update_state(self, global_step): + def _prob(x, gamma, p): + return (1. - p) * np.exp(-gamma * x) + p + + self.current_theta = _prob(global_step, self.gamma, self.theta) diff --git a/deepspeed/runtime/quantize.py b/deepspeed/runtime/quantize.py index a23d189aa..05fc50201 100755 --- a/deepspeed/runtime/quantize.py +++ b/deepspeed/runtime/quantize.py @@ -1,224 +1,224 @@ -import torch -import math -from deepspeed.utils import log_dist -from deepspeed.utils import logger -from deepspeed.ops.quantizer import ds_quantizer - -# number of 2-dimensional parameters in a layer -# this is set for transformer-based models -TWO_D_PARAMS = 6 - - -class Quantizer(object): - def __init__(self, - q_target_bits=8, - q_start_bits=16, - q_period=100, - q_offset=100, - q_groups=1, - q_mixed_fp16=False, - q_change_ratio=0.01, - q_type=0, - q_rounding=0, - q_verbose=False, - q_eigenvalue=False, - use_quantizer_kernel=False, - layer_num=0): - - self.q_target_bits = q_target_bits - - self.q_start_bits = [q_start_bits] * (layer_num if layer_num != 0 else 1) - self.q_period = [q_period] * (layer_num if layer_num != 0 else 1) - self.q_offset = q_offset - self.q_groups = q_groups - self.q_mixed_fp16 = q_mixed_fp16 - self.q_change_ratio = q_change_ratio - self.q_type = q_type - self.qsteps = 0 - self.q_init_period = q_period - self.quantize_real_ratio = 1.000 - self.q_verbose = q_verbose - self.q_eigenvalue = q_eigenvalue - self.use_quantizer_kernel = use_quantizer_kernel - self.q_rounding = q_rounding - self.layer_num = layer_num - - def any_precision_switch(self): - if self.layer_num == 0: - return True - result = False - for index in range(self.layer_num): - if self.q_start_bits[index] != self.q_target_bits: - next_step = self.qsteps + ( - TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1)) - if next_step >= self.q_period[index]: - result = True - return result - - def quantize(self, - parameter_group, - overflow, - eigenvalue_enabled, - block_eigenvalue={}): - - if overflow and not eigenvalue_enabled: - return - - self.step() - - self.update_fp16_ratio() - - for i in range(len(parameter_group)): - for p in parameter_group[i]: - if len(p.size()) > 1: - param_id = id(p) - eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None, 0) - if eigenvalue is not None: - factor = 1 + math.floor(eigenvalue * 4) - p.data = self.compute_quantization(p.data, layer_id, factor) - else: - p.data = self.compute_quantization(p.data, layer_id) - - def step(self): - self.qsteps += (TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1)) - - def sr_quantize(self, input_flat, input_g, scale): - # Random number generator (Uniform) - p = torch.cuda.FloatTensor(input_flat.size(), - device=input_flat.device).uniform_() - p = torch.split(p, p.size(0) // self.q_groups) - add_s = torch.zeros_like(input_flat) - add_s = torch.split(add_s, add_s.size(0) // self.q_groups) - - scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g] - # Quantize with INT rounding - input_flat = [(g * s).int().float() / s for (g, s) in zip(input_g, scale)] - # Compute the error - error = [((g - q).abs() / s) for (g, s, q) in zip(input_g, scale, input_flat)] - # Stochastic Rounding - add_s = [ - a_s.masked_fill_(pg < err_g, - 1 / s) for (a_s, - pg, - err_g, - s) in zip(add_s, - p, - error, - scale) - ] - add_s = [ - a_s * (g > 0).float() - a_s * (g < 0).float() for a_s, - g in zip(add_s, - input_flat) - ] - input_flat = [((q + a_s) * s).clamp(-(q_range >> 1), - (q_range >> 1) - 1) / s for q, - a_s, - s in zip(input_flat, - add_s, - scale)] - return input_flat - - def mixed_fp16_quantize(self, input, input_q, index): - if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - 1): - input_q = input * self.quantize_real_ratio + ( - 1 - self.quantize_real_ratio) * input_q - return input_q - return input_q - - def compute_quantization(self, input, index=0, factor=1): - # fixing the quantization bits based on the training steps - # when reducing 1 bit at each period, we increase the period - # to go slowly toward the target quantization bits - # the period and starting bit can be configured - if self.q_offset > 0: - if self.qsteps >= self.q_offset: - self.q_offset = 0 - self.qsteps = 0 - else: - return input - - if self.q_start_bits[index] != self.q_target_bits: - if self.qsteps >= self.q_period[index]: - self.quantize_real_ratio = 1.0 - if self.q_eigenvalue: - self.q_period[index] <<= 1 - self.q_period[index] *= factor - self.q_start_bits[index] -= 1 - else: - for i in range(len(self.q_start_bits)): - self.q_start_bits[i] -= 1 - self.q_period[i] <<= 1 - if self.q_verbose: - logger.info( - f'Quantization settings: current bit-precision = {self.q_start_bits[index]}, step = {self.qsteps}, quantization period = {self.q_period[index]}, index = {index}' - ) - assert (self.q_start_bits[index] >= self.q_target_bits), \ - 'Quantization bit is lower than target precision bits!' - - # quantize the weights base on the selected bits and the value-range - if not self.use_quantizer_kernel: - q_range = 2**self.q_start_bits[index] - input_flat = input.view(-1) - input_g = torch.split(input_flat, input_flat.size(0) // self.q_groups) - if self.q_type == 0: #symmetric - if self.use_quantizer_kernel: - input_q = ds_quantizer(input.clone(), - self.q_groups, - self.q_start_bits[index]) - else: - scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g] - if self.q_rounding == 0: # Nearest value rounding - input_flat = [(g * s).round().clamp(-(q_range >> 1), - (q_range >> 1) - 1) / s for g, - s in zip(input_g, - scale)] - else: # Stochastic Rounding - if self.use_quantizer_kernel: - input_q = ds_quantizer(input.clone(), - self.q_groups, - self.q_start_bits[index], - sr=True) - else: - input_flat = self.sr_quantize(input_flat, input_g) - else: #asymmetric - if self.q_rounding == 0: - if self.use_quantizer_kernel: - input_q = ds_quantizer(input.clone(), - self.q_groups, - self.q_start_bits[index], - asym=True) - else: - scale = [(g.max() - g.min()) / q_range for g in input_g] - input_flat = [ - ((g - g.min()) / s).round().clamp(0, - (q_range - 1)) * s + g.min() - for g, - s in zip(input_g, - scale) - ] - else: - input_q = ds_quantizer(input.clone(), - self.q_groups, - self.q_start_bits[index], - asym=True) - - if self.use_quantizer_kernel or (self.q_type and self.q_rounding): - return self.mixed_fp16_quantize(input, input_q, index) - else: - if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - - 1): - input_flat = [(self.quantize_real_ratio * g) + - ((1 - self.quantize_real_ratio) * g_q) for g, - g_q in zip(input_g, - input_flat)] - input_q = torch.cat(input_flat) - input_q = input_q.reshape(input.size()) - return input_q - - def update_fp16_ratio(self): - if self.q_mixed_fp16: - if self.quantize_real_ratio > 0: - self.quantize_real_ratio -= self.q_change_ratio - else: - self.quantize_real_ratio = 0.000 +import torch +import math +from deepspeed.utils import log_dist +from deepspeed.utils import logger +from deepspeed.ops.quantizer import ds_quantizer + +# number of 2-dimensional parameters in a layer +# this is set for transformer-based models +TWO_D_PARAMS = 6 + + +class Quantizer(object): + def __init__(self, + q_target_bits=8, + q_start_bits=16, + q_period=100, + q_offset=100, + q_groups=1, + q_mixed_fp16=False, + q_change_ratio=0.01, + q_type=0, + q_rounding=0, + q_verbose=False, + q_eigenvalue=False, + use_quantizer_kernel=False, + layer_num=0): + + self.q_target_bits = q_target_bits + + self.q_start_bits = [q_start_bits] * (layer_num if layer_num != 0 else 1) + self.q_period = [q_period] * (layer_num if layer_num != 0 else 1) + self.q_offset = q_offset + self.q_groups = q_groups + self.q_mixed_fp16 = q_mixed_fp16 + self.q_change_ratio = q_change_ratio + self.q_type = q_type + self.qsteps = 0 + self.q_init_period = q_period + self.quantize_real_ratio = 1.000 + self.q_verbose = q_verbose + self.q_eigenvalue = q_eigenvalue + self.use_quantizer_kernel = use_quantizer_kernel + self.q_rounding = q_rounding + self.layer_num = layer_num + + def any_precision_switch(self): + if self.layer_num == 0: + return True + result = False + for index in range(self.layer_num): + if self.q_start_bits[index] != self.q_target_bits: + next_step = self.qsteps + ( + TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1)) + if next_step >= self.q_period[index]: + result = True + return result + + def quantize(self, + parameter_group, + overflow, + eigenvalue_enabled, + block_eigenvalue={}): + + if overflow and not eigenvalue_enabled: + return + + self.step() + + self.update_fp16_ratio() + + for i in range(len(parameter_group)): + for p in parameter_group[i]: + if len(p.size()) > 1: + param_id = id(p) + eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None, 0) + if eigenvalue is not None: + factor = 1 + math.floor(eigenvalue * 4) + p.data = self.compute_quantization(p.data, layer_id, factor) + else: + p.data = self.compute_quantization(p.data, layer_id) + + def step(self): + self.qsteps += (TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1)) + + def sr_quantize(self, input_flat, input_g, scale): + # Random number generator (Uniform) + p = torch.cuda.FloatTensor(input_flat.size(), + device=input_flat.device).uniform_() + p = torch.split(p, p.size(0) // self.q_groups) + add_s = torch.zeros_like(input_flat) + add_s = torch.split(add_s, add_s.size(0) // self.q_groups) + + scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g] + # Quantize with INT rounding + input_flat = [(g * s).int().float() / s for (g, s) in zip(input_g, scale)] + # Compute the error + error = [((g - q).abs() / s) for (g, s, q) in zip(input_g, scale, input_flat)] + # Stochastic Rounding + add_s = [ + a_s.masked_fill_(pg < err_g, + 1 / s) for (a_s, + pg, + err_g, + s) in zip(add_s, + p, + error, + scale) + ] + add_s = [ + a_s * (g > 0).float() - a_s * (g < 0).float() for a_s, + g in zip(add_s, + input_flat) + ] + input_flat = [((q + a_s) * s).clamp(-(q_range >> 1), + (q_range >> 1) - 1) / s for q, + a_s, + s in zip(input_flat, + add_s, + scale)] + return input_flat + + def mixed_fp16_quantize(self, input, input_q, index): + if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - 1): + input_q = input * self.quantize_real_ratio + ( + 1 - self.quantize_real_ratio) * input_q + return input_q + return input_q + + def compute_quantization(self, input, index=0, factor=1): + # fixing the quantization bits based on the training steps + # when reducing 1 bit at each period, we increase the period + # to go slowly toward the target quantization bits + # the period and starting bit can be configured + if self.q_offset > 0: + if self.qsteps >= self.q_offset: + self.q_offset = 0 + self.qsteps = 0 + else: + return input + + if self.q_start_bits[index] != self.q_target_bits: + if self.qsteps >= self.q_period[index]: + self.quantize_real_ratio = 1.0 + if self.q_eigenvalue: + self.q_period[index] <<= 1 + self.q_period[index] *= factor + self.q_start_bits[index] -= 1 + else: + for i in range(len(self.q_start_bits)): + self.q_start_bits[i] -= 1 + self.q_period[i] <<= 1 + if self.q_verbose: + logger.info( + f'Quantization settings: current bit-precision = {self.q_start_bits[index]}, step = {self.qsteps}, quantization period = {self.q_period[index]}, index = {index}' + ) + assert (self.q_start_bits[index] >= self.q_target_bits), \ + 'Quantization bit is lower than target precision bits!' + + # quantize the weights base on the selected bits and the value-range + if not self.use_quantizer_kernel: + q_range = 2**self.q_start_bits[index] + input_flat = input.view(-1) + input_g = torch.split(input_flat, input_flat.size(0) // self.q_groups) + if self.q_type == 0: #symmetric + if self.use_quantizer_kernel: + input_q = ds_quantizer(input.clone(), + self.q_groups, + self.q_start_bits[index]) + else: + scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g] + if self.q_rounding == 0: # Nearest value rounding + input_flat = [(g * s).round().clamp(-(q_range >> 1), + (q_range >> 1) - 1) / s for g, + s in zip(input_g, + scale)] + else: # Stochastic Rounding + if self.use_quantizer_kernel: + input_q = ds_quantizer(input.clone(), + self.q_groups, + self.q_start_bits[index], + sr=True) + else: + input_flat = self.sr_quantize(input_flat, input_g) + else: #asymmetric + if self.q_rounding == 0: + if self.use_quantizer_kernel: + input_q = ds_quantizer(input.clone(), + self.q_groups, + self.q_start_bits[index], + asym=True) + else: + scale = [(g.max() - g.min()) / q_range for g in input_g] + input_flat = [ + ((g - g.min()) / s).round().clamp(0, + (q_range - 1)) * s + g.min() + for g, + s in zip(input_g, + scale) + ] + else: + input_q = ds_quantizer(input.clone(), + self.q_groups, + self.q_start_bits[index], + asym=True) + + if self.use_quantizer_kernel or (self.q_type and self.q_rounding): + return self.mixed_fp16_quantize(input, input_q, index) + else: + if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - + 1): + input_flat = [(self.quantize_real_ratio * g) + + ((1 - self.quantize_real_ratio) * g_q) for g, + g_q in zip(input_g, + input_flat)] + input_q = torch.cat(input_flat) + input_q = input_q.reshape(input.size()) + return input_q + + def update_fp16_ratio(self): + if self.q_mixed_fp16: + if self.quantize_real_ratio > 0: + self.quantize_real_ratio -= self.q_change_ratio + else: + self.quantize_real_ratio = 0.000 diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 2698f798a..ae2b4e14f 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1,3478 +1,3478 @@ -""" -"Copyright 2020 The Microsoft DeepSpeed Team. -Licensed under the MIT license. -""" - -import sys -import os -from collections import defaultdict, OrderedDict -import itertools -import torch -from torch.distributed.distributed_c10d import _get_global_rank -import torch.distributed as dist -import math -from torch._six import inf -from torch.autograd import Variable - -from deepspeed.utils.logging import logger -from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim -from deepspeed.runtime.zero.partition_parameters import * -from deepspeed.runtime.zero.partition_parameters import _init_external_params -from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS -from deepspeed.ops.adam import DeepSpeedCPUAdam -from deepspeed.ops.op_builder import UtilsBuilder -from deepspeed.runtime.zero.offload_constants import * -from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus -from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper -from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper - -# Toggle this to true to enable correctness test -# with gradient partitioning and without -pg_correctness_test = False - -FWD_MODULE_STACK = list() -from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id, debug_param2name_id_numel, debug_param2name_id_shape_device, debug_module2name_class, printflock, log_rank_file - - -def print_rank_0(message, debug=False, force=False): - rank = torch.distributed.get_rank() - if rank == 0 and (debug or force): - print(message) - # other variations - # - print for all ranks w/o interleaving - # printflock(f"[{rank}] {message}") - # - print to log file per rank - # log_rank_file(rank, message) - - -def input(msg): - return - - -def split_half_float_double(tensors): - dtypes = [ - "torch.cuda.HalfTensor", - "torch.cuda.FloatTensor", - "torch.cuda.DoubleTensor" - ] - buckets = [] - for i, dtype in enumerate(dtypes): - bucket = [t for t in tensors if t.type() == dtype] - if bucket: - buckets.append(bucket) - return buckets - - -def isclose(a, b, rtol=1e-09, atol=0.0): - return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) - - -def lcm(x, y): - from fractions import gcd # or can import gcd from `math` in Python 3 - return x * y // gcd(x, y) - - -def move_to_cpu(tensor_list): - for tensor in tensor_list: - tensor.data = tensor.data.cpu() - - -def get_all_parameters(sub_module, recurse=False): - return itertools.chain(sub_module.named_parameters(recurse=recurse), - sub_module.ds_external_parameters()) - - -#apply torch.autograd.Function that calls a backward_function to tensors in output -def _apply_to_tensors_only(module, functional, backward_function, outputs): - if type(outputs) is tuple: - touched_outputs = [] - for output in outputs: - touched_output = _apply_to_tensors_only(module, - functional, - backward_function, - output) - touched_outputs.append(touched_output) - return tuple(touched_outputs) - elif type(outputs) is torch.Tensor: - return functional.apply(module, backward_function, outputs) - else: - return outputs - - -#for each tensor in outputs run the forward_function and register backward_function as hook -def _apply_forward_and_backward_to_tensors_only(module, - forward_function, - backward_function, - outputs): - if type(outputs) is tuple: - touched_outputs = [] - for output in outputs: - touched_output = _apply_forward_and_backward_to_tensors_only( - module, - forward_function, - backward_function, - output) - touched_outputs.append(touched_output) - return tuple(touched_outputs) - elif type(outputs) is torch.Tensor: - forward_function(outputs) - if outputs.requires_grad: - outputs.register_hook(backward_function) - return outputs - else: - return outputs - - -class ZeROOrderedDict(OrderedDict): - def __init__(self, parent_module, *args, **kwargs): - """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. - - Args: - parent_module (``collections.OrderedDict``): the collection to replace - """ - - super().__init__(*args, **kwargs) - self._parent_module = parent_module - self._in_forward = False - - def __getitem__(self, key): - param = super().__getitem__(key) - - # Params can be registered as None (e.g., bias) - if param is None: - return param - - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - if self._parent_module._parameters._in_forward: - print_rank_0(f'Registering external parameter from getter {key}', - force=False) - register_external_parameter(FWD_MODULE_STACK[-1], param) - param.all_gather() - - return param - - -def _inject_parameters(module, cls): - for module in module.modules(): - if cls == ZeROOrderedDict: - new_param = cls(parent_module=module) - else: - new_param = cls() - - for key, param in module._parameters.items(): - new_param[key] = param - module._parameters = new_param - - -# TODO Needs to be implemented -class PrefetchCoordinator(object): - def __init__(self): - # step_id keeps track of the number of sub-modules invoked so far - # the step_id is tracking forward and backward sequence of sub-modules - self.step_id = 0 - - # stores the sequence of sub modules in forward+backward pass - self.sub_module_trace = [] - - # maps sub_module id to submodule objects - self.id_to_sub_module_map = {} - - # stores the total number of parameters in each sub_module - self.id_to_sub_module_size_map = {} - - self.trace_completed = False - - self.most_recent_sub_module_step = {} - - # reuse distances - self.reuse_numel_for_step_id = {} - - def record_trace(self, sub_module): - if not self.trace_completed: - self.sub_module_trace.append(sub_module.id) - self.id_to_sub_module_map[sub_module.id] = sub_module - - def print_trace(self): - print_rank_0( - f"The module trace is : {[self.id_to_sub_module_map[module_id].id for module_id in self.sub_module_trace]}" - ) - - def increment_step(self, sub_module): - self.most_recent_sub_module_step[sub_module.id] = self.step_id - self.step_id += 1 - - def reset_step(self): - self.step_id = 0 - - # returns the next numel parameters that will be used next but are not available or inflight - def get_params_to_prefetch(self, sub_module, numel=2000000): - - # numel_in_sub_module = 0 - # for name, param in sub_module.named_parameters(recurse=False): - # numel_in_sub_module += param.ds_numel - - # #if numel_in_sub_module < (numel // 2): - # return [] - - # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing - if sub_module.id != self.sub_module_trace[self.step_id]: - print_rank_0( - f"Tracing failed. Prefetching is disabled at sub-module: {debug_module2name_id(sub_module)}" - ) - return [] - - params_to_prefetch = [] - total_numel_to_prefetch = 0 - - for i in range(self.step_id, len(self.sub_module_trace)): - module_id = self.sub_module_trace[i] - for _, param in get_all_parameters(self.id_to_sub_module_map[module_id]): - if param.ds_status is ZeroParamStatus.NOT_AVAILABLE and ( - param.ds_id not in [p.ds_id for p in params_to_prefetch]): - params_to_prefetch.append(param) - total_numel_to_prefetch += param.ds_numel - #print_rank_0(f"Total numel to prefetch: {total_numel_to_prefetch}. Param: {param.ds_shape} and numel {param.ds_numel}, numel limit {numel}") - if total_numel_to_prefetch >= numel: # and total_numel_to_prefetch > (numel_in_sub_module // 2): - return params_to_prefetch - - return params_to_prefetch - - # checks if this sub_module will be used again and if so then returns the number of elements - # in the parameters used between this sub_module and the reuse of this sub_module - def get_reuse_distance_in_numel(self, sub_module, sub_module_step_id=None): - #assert is_forward is not None, "is_forward must be set to True for Forward Propagation and False for backward Propagation" - is_there_reuse = False - reuse_distance_in_numel = 1000000000000 - - # set the appropriate trace - trace = self.sub_module_trace - total_steps = len(trace) - if sub_module_step_id is None: - sub_module_step_id = self.most_recent_sub_module_step[sub_module.id] - - # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing - if sub_module.id != trace[sub_module_step_id]: - print_rank_0( - f"Tracing failed. Cannot tell if the sub_module: {sub_module.id} is reused" - ) - return reuse_distance_in_numel - - # return cached value - if sub_module_step_id in self.reuse_numel_for_step_id: - return self.reuse_numel_for_step_id[sub_module_step_id] - - start_step = self.step_id - print_rank_0(f"Step id is {self.step_id} ") - for step_id in range(start_step, total_steps): - print_rank_0(f"Trace id {trace[step_id]} and sub_module id {sub_module.id}") - if sub_module.id == trace[step_id]: - end_step = step_id - - is_there_reuse = True - reuse_distance_in_numel = self._distance_in_numel( - start_step, - end_step, - trace) - break - - self.reuse_numel_for_step_id[sub_module_step_id] = reuse_distance_in_numel - - return reuse_distance_in_numel - - def _distance_in_numel(self, start_step, end_step, trace): - distance_in_numel = 0 - for step_id in range(start_step, end_step): - module_id = trace[step_id] - for _, param in self.id_to_sub_module_map[module_id].named_parameters(recurse=False): - distance_in_numel += param.ds_numel - for _, param in self.id_to_sub_module_map[module_id].ds_external_parameters(): - distance_in_numel += param.ds_numel - return distance_in_numel - - -class PartitionedParameterCoordinator(object): - def __init__(self, - comm_stream=None, - max_reuse_distance_in_numel=500000000, - max_available_parameters_in_numel=700000000): - - self.in_flight_handles = [] - self.params_in_flight = [] - self.comm_stream = comm_stream if comm_stream is not None else torch.cuda.current_stream( - ) - self.prefetch_coordinator = PrefetchCoordinator() - self.hierarchy = 0 - - self.total_available_parameter_numel = 0 - self.max_available_parameters_in_numel = max_available_parameters_in_numel - - # max distance between two use of the module beyond which module is released - self.max_reuse_distance_in_numel = max_reuse_distance_in_numel - - def _increment_available_parameter_numel(self, increment): - self.total_available_parameter_numel += increment - - def _decrement_available_parameter_numel(self, decrement): - self.total_available_parameter_numel -= decrement - - '''-----------------------Tracing and Prefetching ---------------''' - - def record_trace(self, sub_module): - self.prefetch_coordinator.record_trace(sub_module) - - def finish_tracing(self, print_trace=False): - self.prefetch_coordinator.trace_completed = True - - if print_trace: - self.prefetch_coordinator.print_trace() - - #swap in parameter partitions from nvme for those parameters that will be used - # after the ones that are already being prefetched into full parameters - def _prefetch_nvme_param_partitions(self, sub_module, params_in_flight): - numel_in_flight = sum([param.ds_tensor.ds_numel for param in params_in_flight]) - upcoming_param_list = self.prefetch_coordinator.get_params_to_prefetch( - sub_module, - numel=2 * numel_in_flight) - swap_in_params = [] - for param in upcoming_param_list: - if len(swap_in_params) >= param.nvme_swapper.available_swap_in_buffers(): - break - if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: - swap_in_params.append(param) - - if len(swap_in_params) > 0: - swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) - - # Pre fetches the parameters for sub_modules that comes after - # the current sub_module. This call is asynchronous - def prefetch_next_sub_modules(self, sub_module, numel=5000000, nvme=False): - - params_to_prefetch = [] - if not self.prefetch_coordinator.trace_completed: - return params_to_prefetch - - # prefetch if there is no current prefetching in flight - if not self.in_flight_handles and self.total_available_parameter_numel < self.max_available_parameters_in_numel: - params_to_prefetch = self.prefetch_coordinator.get_params_to_prefetch( - sub_module, - numel=numel) - - self._all_gather(params_to_prefetch, async_op=True) - for param in params_to_prefetch: - param.ds_status = ZeroParamStatus.INFLIGHT - - # keeping track of number of elements consumed by available parameters - self._increment_available_parameter_numel(param.ds_numel) - - if nvme: - self._prefetch_nvme_param_partitions(sub_module, params_to_prefetch) - - self._print_prefetch_elements_info(sub_module, params_to_prefetch) - print_rank_0( - f"{'--' * self.hierarchy}--PreFetching parameters {[param.ds_id for param in params_to_prefetch]} and available {self.total_available_parameter_numel}, max limit {self.max_available_parameters_in_numel}", - force=False) - - def _print_prefetch_elements_info(self, sub_module, params_to_prefetch): - sub_module_numel = 0.0 - for name, param in sub_module.named_parameters(recurse=False): - sub_module_numel += param.ds_numel - numel_being_prefetched = 0 - for param in params_to_prefetch: - numel_being_prefetched = param.ds_numel - print_rank_0( - f"{'--' * self.hierarchy}--PreFetching {numel_being_prefetched} numels and number of numel in the next sub module is {sub_module_numel}", - force=False) - - def increment_step(self, sub_module): - self.prefetch_coordinator.increment_step(sub_module) - - def reset_step(self): - self.prefetch_coordinator.reset_step() - - '''----------------------------------------------------------------------''' - - # Fetches the parameters in the sub_module - # This call is blocking - def fetch_sub_module(self, sub_module): - partitioned_params = [] - params_in_flight = False - print_rank_0( - f"{'--' * self.hierarchy}Fetching params in module {debug_module2name_class(sub_module)}" - ) - params_to_fetch = [ - param for _, - param in sub_module.named_parameters(recurse=False) - ] - # print([n for n,p in sub_module.named_parameters(recurse=False)]) - - if hasattr(sub_module, 'ds_external_parameters'): - print_rank_0( - f"{'--' * self.hierarchy}--Fetching external parameters {sub_module.ds_external_parameters()}" - ) - params_to_fetch += [ - param for _, - param in sub_module.ds_external_parameters() - ] - # for _, param in sub_module.named_parameters(recurse=False): - for param in params_to_fetch: - param.ds_active_sub_modules += 1 - print_rank_0( - f"{'--' * self.hierarchy}--Fetching parameters {debug_param2name_id_shape(param)} with active sub modules {param.ds_active_sub_modules}" - ) - - if param.ds_status == ZeroParamStatus.AVAILABLE: - print_rank_0( - f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is already available" - ) - - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - print_rank_0( - f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is being fetched" - ) - partitioned_params.append(param) - - # keeping track of number of elements consumed by available parameters - self._increment_available_parameter_numel(param.ds_numel) - print_rank_0(f"Incrementing with parameter id {param.ds_id}") - - if param.ds_status == ZeroParamStatus.INFLIGHT: - params_in_flight = True - print_rank_0( - f"{'--' * self.hierarchy}--Parameters {debug_param2name_id(param)} is already in flight (prefetched)" - ) - self.hierarchy += 1 - - # parameters are partitioned and need to be allgathered - self._all_gather(partitioned_params, async_op=False) - - # parameters are inflight and communication needs to be completed - if partitioned_params or params_in_flight: - self._synchronize_communication() - - for _, param in sub_module.named_parameters(recurse=False): - param.ds_status = ZeroParamStatus.AVAILABLE - print_rank_0( - f"Param {debug_param2name_id_shape_device(param)} norm={param.norm()}", - force=False) - #print_rank_0(f"After fetching (id, shape, device): {[(param.ds_id, param.shape, param.device) for param in sub_module.named_parameters(recurse=False)]}") - - def release_sub_module(self, sub_module): - self.hierarchy -= 1 - print_rank_0( - f"{'--' * self.hierarchy}Releasing params in module {debug_module2name_class(sub_module)}" - ) - params_to_release = [ - param for _, - param in sub_module.named_parameters(recurse=False) - ] - - if hasattr(sub_module, 'ds_external_parameters'): - #print_rank_0(f"Releasing external parameters {sub_module.ds_external_parameters()}") - params_to_release += [ - param for _, - param in sub_module.ds_external_parameters() - ] - - # for _, param in sub_module.named_parameters(recurse=False): - for param in params_to_release: - param.ds_active_sub_modules -= 1 - if not param.ds_active_sub_modules and not self._keep_for_later( - sub_module) and not param.ds_persist: - print_rank_0( - f"{'--' * self.hierarchy}--Releasing parameter {debug_param2name_id_numel(param)} active sub modules {param.ds_active_sub_modules} and keep for later {self._keep_for_later(sub_module)}", - force=False) - - # Keeping track of number of elements that are consumed by available parameters - self._decrement_available_parameter_numel(param.ds_numel) - see_memory_usage( - f"Before releasing param {debug_param2name_id_numel(param)}", - force=False) - param.partition(hierarchy=self.hierarchy) - see_memory_usage( - f"After releasing param {debug_param2name_id_numel(param)}", - force=False) - - param.ds_status = ZeroParamStatus.NOT_AVAILABLE - else: - - print_rank_0( - f"{'--' * self.hierarchy}--Did not release param {debug_param2name_id_numel(param)} with active sub modules {param.ds_active_sub_modules}, keep for later={self._keep_for_later(sub_module)} and persistence={param.ds_persist}", - force=False) - - def release_and_reset_parameter(self, param): - param.ds_active_sub_modules = 0 - if param.ds_status == ZeroParamStatus.AVAILABLE: - print_rank_0( - f"Releasing unpartitioned param {debug_param2name_id_numel(param)} active sub-modules {param.ds_active_sub_modules} and persistence {param.ds_persist}" - ) - self._decrement_available_parameter_numel(param.ds_numel) - param.partition() - - def _keep_for_later(self, sub_module): - if not self.prefetch_coordinator.trace_completed: - return False - if self.max_reuse_distance_in_numel == 0: - return False - reuse_distance_in_numel = self.prefetch_coordinator.get_reuse_distance_in_numel( - sub_module) - #print_rank_0(f"Reuse distance and numel for sub_module id {sub_module.id} is {reuse_distance_in_numel}") - return reuse_distance_in_numel < self.max_reuse_distance_in_numel - - def _all_gather(self, partitioned_params, async_op=False): - with torch.cuda.stream(self.comm_stream): - handles = partitioned_params[0].all_gather( - param_list=partitioned_params, - async_op=async_op, - hierarchy=self.hierarchy) if partitioned_params else None - - if handles is not None: - self.in_flight_handles.extend(handles) - self.params_in_flight.extend(partitioned_params) - - def _synchronize_communication(self, synchronize_streams=True): - assert len(self.params_in_flight) == len(self.in_flight_handles) - for handle, param in zip(self.in_flight_handles, self.params_in_flight): - if handle is not None: - with torch.cuda.stream(self.comm_stream): - handle.wait() - param.ds_status = ZeroParamStatus.AVAILABLE - self.comm_stream.synchronize() - torch.cuda.synchronize() if synchronize_streams else None - self.in_flight_handles = [] - self.params_in_flight = [] - - -class PreBackwardFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, module, pre_backward_function, outputs): - ctx.module = module - ctx.pre_backward_function = pre_backward_function - if not hasattr(module, "applied_pre_backward_ref_cnt"): - module.applied_pre_backward_ref_cnt = 0 - module.applied_pre_backward_ref_cnt += 1 - #print(f"After Forward: {ctx.module.__class__.__name__}") - outputs = outputs.detach() - return outputs - - @staticmethod - def backward(ctx, *args): - #print(f"Before Backward: {ctx.module.__class__.__name__}") - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -class PostBackwardFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, module, pre_backward_function, output): - ctx.module = module - if output.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.pre_backward_function = pre_backward_function - output = output.detach() - return output - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.pre_backward_function(ctx.module) - #print(f"After Backward: {ctx.module.__class__.__name__}") - return (None, None) + args - - -INITIAL_MICRO_STEP_ID = -1 - - -class FP16_DeepSpeedZeroOptimizer_Stage3(object): - """ - DeepSpeedZeroOptimizer designed to reduce the memory footprint - required for training large deep learning models. - - For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models - https://arxiv.org/abs/1910.02054 - - For usage examples, refer to TODO: DeepSpeed Tutorial - - """ - def __init__(self, - module, - init_optimizer, - timers, - ds_config, - static_loss_scale=1.0, - dynamic_loss_scale=False, - dynamic_loss_args=None, - verbose=True, - contiguous_gradients=True, - reduce_bucket_size=500000000, - prefetch_bucket_size=50000000, - max_reuse_distance=1000000000, - max_live_parameters=1000000000, - param_persistence_threshold=100000, - dp_process_group=None, - reduce_scatter=True, - overlap_comm=False, - offload_optimizer_config=None, - offload_param_config=None, - sub_group_size=1000000000000, - mpu=None, - clip_grad=0.0, - communication_data_type=torch.float16, - postscale_gradients=True, - gradient_predivide_factor=1.0, - gradient_accumulation_steps=1, - elastic_checkpoint=False, - aio_config=None): - - see_memory_usage("Stage 3 initialize beginning", force=False) - - if dist.get_rank() == 0: - logger.info(f"Reduce bucket size {reduce_bucket_size}") - logger.info(f"Allgather bucket size {prefetch_bucket_size}") - # The fused optimizer does all the work. We need this layer for two reason: - # 1. maintain same user API from apex.fp16_utils - # 2. keep common stuff here in case we need to add ne552w fused optimizer later - - # differences from apex.fp16_utils: - # - assume all model params in fp16 - # - assume all params requires grad - # - flat by groups, not keeping state. TODO: remove state explicitly? - # - master gard and unflat master weight never exist. TODO: a way to save out unflat master? - if not torch.cuda.is_available: - raise SystemError("Cannot use fp16 without CUDA.") - self.optimizer = init_optimizer - - # Load pre-built or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() - self.flatten = util_ops.flatten - self.unflatten = util_ops.unflatten - self.dtype = self.optimizer.param_groups[0]['params'][0].dtype - self._global_grad_norm = 0. - - self.optimizer_swapper = None - self.swap_optimizer = False - - self.offload_optimizer = False - self.offload_optimizer_pin_memory = False - self.offload_optimizer_fast_init = False - self.offload_param = False - self.offload_param_pin_memory = False - self.params_in_nvme_and_cpu = False - self.max_params_in_cpu = 0 - - self._configure_offloading(offload_optimizer_config, offload_param_config) - - self._convert_to_zero_parameters(ds_config, module, mpu) - - for m in module.modules(): - _init_external_params(m) - - self.module = module - self.elastic_checkpoint = elastic_checkpoint - self.overlap_comm = overlap_comm - - # Replace ._parameters with a new class to enable auto-registration of - # external parameters - _inject_parameters(module, ZeROOrderedDict) - - if self.overlap_comm: - self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda() - - self.deepspeed_adam_offload = (self.offload_optimizer - and type(init_optimizer) == DeepSpeedCPUAdam) - - self.device = torch.cuda.current_device( - ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE - ############################################################################ - - see_memory_usage("Before Partitioned Parameter Coordinator", force=False) - - fetch_stream = torch.cuda.Stream() if self.overlap_comm else None - self.param_coordinator = PartitionedParameterCoordinator( - comm_stream=fetch_stream, - max_reuse_distance_in_numel=int(max_reuse_distance), - max_available_parameters_in_numel=int(max_live_parameters)) - - see_memory_usage("After Partitioned Parameter Coordinator", force=False) - - #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) - #-------------Stage 3 Setup-------------------# - # parameters smaller than the threshold will be collectively gathered at the - # end of the optimizer step and will be kept till the end of the backward pass - # TODO maybe worth just replicating these parameters and doing all reduce for them - self.persistence_threshold = int(param_persistence_threshold) - - self.persistent_parameters = self.persistent_parameters() - - self.setup_zero_stage3_hooks() - - #resetting ds_tensor just in case parameters have been changed after initialization - #example .half() or .to() - #self.reset_ds_tensor() - #---------------------------------------------# - - self.timers = timers - - self.reduce_scatter = reduce_scatter - - self.dp_process_group = dp_process_group - - self.partition_count = dist.get_world_size(group=self.dp_process_group) - - if mpu is None: - self.model_parallel_group = None - self.model_parallel_rank = 0 - else: - self.model_parallel_group = mpu.get_model_parallel_group() - self.model_parallel_rank = mpu.get_model_parallel_rank() - - self.overflow = False - self.clip_grad = clip_grad - self.communication_data_type = communication_data_type - self.gradient_predivide_factor = gradient_predivide_factor - self.postscale_gradients = postscale_gradients - self.gradient_accumulation_steps = gradient_accumulation_steps - self.micro_step_id = INITIAL_MICRO_STEP_ID - - if self.reduce_scatter: - assert self.communication_data_type in (torch.float16, torch.bfloat16), f"ZeRO-3 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" - assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled" - assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled" - - # Holds the mode parameter - # The param.data may not hold any meaningful data - # when param's status is NOT_AVAILABLE or IN_FLGHT - self.fp16_groups = [] - - # Hold partitioned parameters - self.fp16_partitioned_groups = [] - - # Holds a fused and flattened copy of the parameters - self.fp16_partitioned_groups_flat = [] - self.fp16_partitioned_groups_flat_numel = [] - - #defragmented pinned memory - self.param_groups_fp16_flat_cpu_memory = [] - - #a single 32-bit partition of the parallel partitioned parameters - #that this process will update - self.fp32_partitioned_groups_flat = [] - self.next_swappable_fp32_partitioned_groups = [] - - # number of elements per partition in each group - self.partition_size = [] - - self.all_reduce_print = False - - self.prefetch_elements = int(prefetch_bucket_size) - - # padding on each partition for alignment purposes - self.groups_padding = [] - - self.sub_group_size = sub_group_size - - self.sub_group_to_group_id = {} - see_memory_usage("Before creating fp16 partitions", force=False) - self._create_fp16_partitions_with_defragmentation() - num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) - see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", - force=False) - - # Optimizer tensor swapping - if self.swap_optimizer: - self._configure_tensor_swapping(offload_optimizer_config, aio_config) - - see_memory_usage("Before creating fp32 partitions", force=False) - if not isinstance(self.optimizer, DummyOptim): - self._create_fp32_partitions() - see_memory_usage("After creating fp32 partitions", force=False) - dist.barrier() - - # To support pipelined optimizer swapping - if not isinstance(init_optimizer, DummyOptim): - self._create_next_swappable_fp32_groups() - - see_memory_usage("Before initializing optimizer states", force=False) - if not isinstance(init_optimizer, DummyOptim): - self.initialize_optimizer_states() - see_memory_usage("After initializing optimizer states", force=False) - dist.barrier() - - if dist.get_rank() == 0: - logger.info(f"optimizer state initialized") - - self.reduce_bucket_size = int(reduce_bucket_size) - - self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) - - self.reduction_stream = torch.cuda.Stream( - ) if self.overlap_comm else torch.cuda.current_stream() - self.callback_queued = False - self.copy_grad_stream = torch.cuda.Stream() - - self.param_dict = {} - - # map between param_id and bool to specify if a param is in this partition - self.is_param_in_current_partition = {} - - self.contiguous_gradients = contiguous_gradients - self.extra_large_param_to_reduce = None - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] - self.elements_in_ipg_bucket = 0 - self.params_already_reduced = [] - self.is_gradient_accumulation_boundary = True - self._release_ipg_buffers() - self.previous_reduced_grads = None - - # simplified param id - self.param_id = {} - - count = 0 - for i, params_group in enumerate(self.fp16_groups): - for param in params_group: - unique_id = id(param) - self.param_id[unique_id] = count - self.param_dict[count] = param - self.params_already_reduced.append(False) - count = count + 1 - - #Largest partitioned param - largest_partitioned_param_numel = max([ - max([tensor.numel() for tensor in fp16_partitioned_group]) - for fp16_partitioned_group in self.fp16_partitioned_groups - ]) - print_rank_0( - f'Largest partitioned param numel = {largest_partitioned_param_numel}', - force=False) - - see_memory_usage(f"Before Set Grad positions", force=False) - - self.grad_position = {} - self.set_grad_positions() - see_memory_usage(f"Before CPU Offload initialization", force=False) - - self.grads_in_partition = None - - if self.offload_optimizer: - self.accumulated_grads_in_cpu = {} - self.norm_for_param_grads = {} - self.local_overflow = False - self.temp_grad_buffer_for_gpu_offload = torch.zeros( - largest_partitioned_param_numel, - device=torch.cuda.current_device(), - dtype=self.dtype) - self.temp_grad_gpu_buffer = torch.zeros(largest_partitioned_param_numel, - device=torch.cuda.current_device(), - dtype=self.dtype) - see_memory_usage(f"After CPU Offload initialization", force=False) - - # stores if a partition has been reduced in this step - self.is_partition_reduced = {} - - # stores if a grad in a partition has been computed or not - self.is_grad_computed = {} - - # will store the averaged gradients required by this paritition - self.averaged_gradients = {} - - #creates backward hooks for gradient partitioning - self.create_reduce_and_remove_grad_hooks() - - #exit(0) - - # we may have a way of fusing dynamic scale. Do not support for now - if self.dtype == torch.float or not dynamic_loss_scale: - loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale - - self.dynamic_loss_scale = False - self.loss_scaler = LossScaler(scale=loss_scale_value) - cur_iter = 0 - else: - if dynamic_loss_args is None: - self.loss_scaler = DynamicLossScaler() - else: - self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) - - self.dynamic_loss_scale = True - - self.debug_fp16_grads = [{} for _ in self.fp16_groups] - - if dist.get_rank(group=self.dp_process_group) == 0: - see_memory_usage(f"After initializing ZeRO optimizer", force=False) - - def _configure_offloading(self, offload_optimizer_config, offload_param_config): - ###################### offload optimizer setup ################################## - if offload_optimizer_config is not None: - self.offload_optimizer = True - self.offload_optimizer_pin_memory = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_PIN_MEMORY] - self.swap_optimizer = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE - self.offload_optimizer_fast_init = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_FAST_INIT] - - ###################### offload param setup ################################## - if offload_param_config is not None: - if not isinstance(self.optimizer, DummyOptim): - assert self.offload_optimizer, "parameter offload is only available with optimizer state offload" - self.offload_param = True - self.offload_param_pin_memory = offload_param_config[ - OFFLOAD_PARAM_PIN_MEMORY] - self.params_in_nvme_and_cpu = offload_param_config[ - OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE - self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU] - print_rank_0( - f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}", - force=False) - - def _convert_to_zero_parameters(self, ds_config, module, mpu): - non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] - if non_zero_params: - zero_params = [p for p in module.parameters() if is_zero_param(p)] - if zero_params: - zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) - else: - group = None - if mpu: - group = mpu.get_data_parallel_group() - - if self.params_in_nvme_and_cpu: - remote_device = OFFLOAD_NVME_DEVICE - elif self.offload_param: - remote_device = OFFLOAD_CPU_DEVICE - else: - remote_device = None - - Init(module=module, - data_parallel_group=group, - dtype=self.dtype, - config_dict_or_path=ds_config, - remote_device=remote_device, - pin_memory=self.offload_param_pin_memory, - mpu=mpu) - - def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): - nvme_swap_folder = os.path.join( - offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH], - 'zero_stage_3') - os.makedirs(nvme_swap_folder, exist_ok=True) - if torch.distributed.get_rank() == 0: - logger.info(f'Tensor Swapping: Adding optimizer tensors') - - swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[ - OFFLOAD_OPTIMIZER_PIPELINE] else PartitionedOptimizerSwapper - - self.optimizer_swapper = swapper_type( - swap_config=offload_optimizer_config, - aio_config=aio_config, - base_folder=nvme_swap_folder, - optimizer=self.optimizer, - largest_numel=max(self.fp16_partitioned_groups_flat_numel), - device=self.device, - dtype=torch.float32, - timers=self.timers) - - def _create_fp16_partitions(self): - dist.barrier() - partition_id = dist.get_rank(group=self.dp_process_group) - - # loop to deal with groups - for j, param_group in enumerate(self.optimizer.param_groups): - - sub_groups = self._create_fp16_sub_groups(param_group['params']) - for sub_group in sub_groups: - i = len(self.fp16_groups) - - # push this group to list before modify - self.fp16_groups.append(sub_group) - self.sub_group_to_group_id[i] = j - - #These are the list of the partitioned parameters - self.fp16_partitioned_groups.append( - [param.ds_tensor for param in self.fp16_groups[i]]) - - print_rank_0( - f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" - ) - - # Record padding required to align group to world size (only applies to last rank) - if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: - padding = [p.padding_size() for p in self.fp16_groups[i]] - else: - padding = [0] * len(self.fp16_groups[i]) - self.groups_padding.append(padding) - - #not sure why apex was cloning the weights before flattening - #removing cloning here - see_memory_usage(f"Before Flattening param group {i}", force=False) - - if not self.offload_param: - see_memory_usage(f"Before moving param group {i} to CPU", - force=False) - #move all the parameters to cpu to free up GPU space for creating flat buffer - move_to_cpu(self.fp16_partitioned_groups[i]) - see_memory_usage(f"After moving param group {i} to CPU", force=False) - - #create flat buffer in CPU and move to GPU - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size(group=self.dp_process_group)).cuda( - torch.cuda.current_device())) - see_memory_usage( - f"After flattening and moving param group {i} to GPU", - force=False) - else: - #Without the detach, seems like the flattening becomes part of the - #model graph causing errors downstream - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size( - group=self.dp_process_group)).detach().pin_memory()) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - #set model fp16 weight to slices of flattened buffer - updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i], - self.fp16_partitioned_groups[i]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params): - partitioned_param.data = q.data - - def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): - '''If flat buffer is None then the parameters in the param_list are - not copied to the flat buffer. This is because they excede the number of max_params_in_cpu - Some of these parameters may aready be in CPU in unflattened buffers - or they maybe in GPU, or they maybe in NVME. If they are in NVME, then - they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are - needed during training.''' - if flat_buffer is None: - # this dst buffer is on NVMe, so skip this - return - - start = 0 - for param in param_list: - src = param.ds_tensor - dest = flat_buffer.narrow(0, start, src.ds_numel) - start = start + src.ds_numel - '''if the parameter was initialized in nvme then bring it to the destination buffer directly''' - if src.status == PartitionedParamStatus.NOT_AVAILABLE: - print_rank_0( - f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU" - ) - param.nvme_swapper.swap_into_buffer(param, dest) - src.data = dest.data - src.status = PartitionedParamStatus.AVAILABLE - else: - assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here" - if not avoid_copy: - dest.data.copy_(src.data) - src.data = dest.data - - # Final location must be gpu/cpu in this case - param.ds_tensor.final_location = 'not-nvme' - - def _create_param_groups_fp16_flat_cpu_memory(self): - - aggregate_params_count = 0 - - for j, param_group in enumerate(self.optimizer.param_groups): - params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']]) - - flat_buffer_size = params_in_group - - if self.params_in_nvme_and_cpu and \ - aggregate_params_count + params_in_group > self.max_params_in_cpu: - - flat_buffer_size = max(0, - self.max_params_in_cpu - aggregate_params_count) - - aggregate_params_count += params_in_group - - if flat_buffer_size > 0: - print_rank_0(f"group {j} flat buffer size {flat_buffer_size}", - force=False) - self.param_groups_fp16_flat_cpu_memory.append( - torch.empty(int(flat_buffer_size), - dtype=self.dtype, - pin_memory=True)) - else: - print_rank_0( - f"No flat buffer size. Param group size was {params_in_group}", - force=False) - - self.param_groups_fp16_flat_cpu_memory.append( - torch.empty(1, - dtype=self.dtype)) - - def _create_fp16_partitions_with_defragmentation(self): - dist.barrier() - partition_id = dist.get_rank(group=self.dp_process_group) - create_fp16_flat_reuse_buffer = False - largest_partition_numel = [] - max_partition_numel = 0 - - #create a flat CPU memory allocation for each param group - if self.offload_param: - self._create_param_groups_fp16_flat_cpu_memory() - - # loop to deal with groups - for j, param_group in enumerate(self.optimizer.param_groups): - - sub_groups = self._create_fp16_sub_groups(param_group['params']) - print_rank_0(f'fp16 group {j} has {len(sub_groups)} subgroups', force=False) - - flat_offset = 0 - for sub_group in sub_groups: - i = len(self.fp16_groups) - - # push this group to list before modify - self.fp16_groups.append(sub_group) - self.sub_group_to_group_id[i] = j - - # comment out for zero_to_fp32 debug - # if torch.distributed.get_rank() == 0: - # for param in self.fp16_groups[i]: - # print(f"{debug_param2name_id_shape(param)} {param.ds_shape}") - - #These are the list of the partitioned parameters - self.fp16_partitioned_groups.append( - [param.ds_tensor for param in self.fp16_groups[i]]) - - total_elements = sum( - [t.ds_numel for t in self.fp16_partitioned_groups[i]]) - self.fp16_partitioned_groups_flat_numel.append(total_elements) - - if total_elements > max_partition_numel: - largest_partition_numel = [ - t.ds_numel for t in self.fp16_partitioned_groups[i] - ] - max_partition_numel = total_elements - - print_rank_0( - f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" - ) - - # Record padding required to align group to world size (only applies to last rank) - if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: - padding = [p.padding_size() for p in self.fp16_groups[i]] - else: - padding = [0] * len(self.fp16_groups[i]) - self.groups_padding.append(padding) - - #not sure why apex was cloning the weights before flattening - #removing cloning here - see_memory_usage(f"Before Flattening param subgroup {i}", force=False) - - #all partitioned parameters remain in GPU during training - if not self.offload_param: - see_memory_usage(f"Before moving param subgroup group {i} to CPU", - force=False) - #move all the parameters to cpu to free up GPU space for creating flat buffer - move_to_cpu(self.fp16_partitioned_groups[i]) - see_memory_usage(f"After moving param subgroup {i} to CPU", - force=False) - - #create flat buffer in CPU and move to GPU - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - 1).cuda(torch.cuda.current_device())) - see_memory_usage( - f"After flattening and moving param subgroup {i} to GPU", - force=False) - - #all partitioned parameters are in CPU during training - else: - print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") - #Flat buffer may not be available for parameters that reside in NVME - if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ - j].numel(): - fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[ - j].narrow(0, - flat_offset, - total_elements) - print_rank_0( - f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", - force=False) - #these parameters reside in NVME and - elif self.params_in_nvme_and_cpu: - fp16_partitioned_group_flat = None - print_rank_0( - f"No flat buffer for sub group {i} of {total_elements} elements", - force=False) - else: - assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs" - - self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) - flat_offset += total_elements - - # move param to flat buffer for both param offload on/off - self._move_to_flat_buffer(self.fp16_groups[i], - self.fp16_partitioned_groups_flat[i], - avoid_copy=not self.offload_param) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - #create a pinned memory to be used for swapping out params to NVME after optimizer step - if self.fp16_partitioned_groups_flat[-1] is None: - create_fp16_flat_reuse_buffer = True - - see_memory_usage(f"After Flattening param subgroup {i}", force=False) - - if create_fp16_flat_reuse_buffer: - assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' - self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space( - largest_partition_numel) - - def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): - offset = 0 - elements_in_sub_group = sum( - [t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]]) - assert (flat_buffer.numel() == elements_in_sub_group) - for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): - dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel) - if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: - print_rank_0( - f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}" - ) - param.nvme_swapper.swap_in([param], async_op=False) - dest.data.copy_(partitioned_param.data) - param.nvme_swapper.remove_partition_and_release_buffers([param]) - print_rank_0(f"Swapping in {param.ds_id} done") - else: - dest.data.copy_(partitioned_param.data) - offset += partitioned_param.ds_numel - - def _create_next_swappable_fp32_groups(self): - reverse_order_indices = [ - i for i in range(len(self.fp32_partitioned_groups_flat)) - ] - reverse_order_indices.reverse() - - next_group = None - for i in reverse_order_indices: - self.next_swappable_fp32_partitioned_groups.append(next_group) - if self._swappable_optimizer_subgroup(i): - next_group = self.fp32_partitioned_groups_flat[i] - - self.next_swappable_fp32_partitioned_groups.reverse() - - def _get_sub_group_partitions(self, sub_group_id): - sub_group_partitions = [] - for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): - if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: - swap_path = param.nvme_swapper.get_path(param, True) - sub_group_partitions.append((partitioned_param, - param.ds_tensor.ds_numel, - swap_path)) - else: - sub_group_partitions.append((partitioned_param, - partitioned_param.ds_numel, - None)) - - return sub_group_partitions - - def _create_fp32_partitions(self): - cpu_memory_usage = 0 - cpu_memory_sub_groups = 0 - nvme_memory_usage = 0 - num_swappable_partitions = 0 - num_swap_from_nvme_partitions = 0 - num_swap_from_cpu_partitions = 0 - swap_from_nvme_memory_usage = 0 - swap_from_cpu_memory_usage = 0 - GIGA_BYTES = (1024**3) - - swappable_fp32_tensors = [] - swappable_fp16_src_tensors = [] - nvme_fp16_partitions_info = [] - nvme_fp16_num_elems = [] - nvme_fp32_dest_tensors = [] - fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() - - for i, tensor in enumerate(self.fp16_partitioned_groups_flat): - num_elements = self.fp16_partitioned_groups_flat_numel[i] - - # a partition of the fp32 master weights that will be updated by this process - if self._swappable_optimizer_subgroup(i): - self.fp32_partitioned_groups_flat.append(torch.Tensor()) - nvme_memory_usage += (fp32_element_size * num_elements) - num_swappable_partitions += 1 - - if self.params_in_nvme_and_cpu and tensor is None: - num_swap_from_nvme_partitions += 1 - swap_from_nvme_memory_usage += (fp32_element_size * num_elements) - if self.offload_optimizer_fast_init: - sub_group_partitions = self._get_sub_group_partitions(i) - nvme_fp16_partitions_info.append(sub_group_partitions) - nvme_fp16_num_elems.append(num_elements) - nvme_fp32_dest_tensors.append( - self.fp32_partitioned_groups_flat[i]) - else: - unpinned_fp32_buffer = torch.empty(num_elements, - device=self.device, - dtype=torch.float) - self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) - self.optimizer_swapper.initialize_parameters( - parameters=[self.fp32_partitioned_groups_flat[i]], - src_tensors=[unpinned_fp32_buffer]) - else: - num_swap_from_cpu_partitions += 1 - swap_from_cpu_memory_usage += (fp32_element_size * num_elements) - swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i]) - swappable_fp16_src_tensors.append( - self.fp16_partitioned_groups_flat[i]) - else: - cpu_memory_usage += (fp32_element_size * num_elements) - cpu_memory_sub_groups += 1 - - if self.params_in_nvme_and_cpu and tensor is None: - unpinned_fp32_buffer = torch.empty(num_elements, - device=self.device, - dtype=torch.float) - self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) - self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) - else: - self.fp32_partitioned_groups_flat.append( - self.fp16_partitioned_groups_flat[i].to( - self.device).clone().float().detach()) - - self.fp32_partitioned_groups_flat[ - i].requires_grad = True # keep this in case internal optimizer uses it - - if len(swappable_fp32_tensors) > 0: - self.optimizer_swapper.initialize_parameters( - parameters=swappable_fp32_tensors, - src_tensors=swappable_fp16_src_tensors) - - if len(nvme_fp32_dest_tensors) > 0: - fp16_pinned_buffers = self.fp16_groups[0][ - 0].nvme_swapper.reserve_available_buffers() - assert len(fp16_pinned_buffers) > 0 - self.optimizer_swapper.initialize_from_swapped_fp16_params( - fp16_partitions_info=nvme_fp16_partitions_info, - fp16_num_elems=nvme_fp16_num_elems, - fp16_pinned_buffers=fp16_pinned_buffers, - fp32_parameters=nvme_fp32_dest_tensors) - self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers() - - nvme_gigabytes = nvme_memory_usage / GIGA_BYTES - print_rank_0( - f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB', - force=False) - if self.params_in_nvme_and_cpu: - print_rank_0( - f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB', - force=False) - print_rank_0( - f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB', - force=False) - - cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES - print_rank_0( - f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB', - force=False) - - # Clear for on-the-fly population before the optimizer step - for param_group in self.optimizer.param_groups: - param_group['params'] = [] - - def _create_fp16_sub_groups(self, params_group): - - params_group_numel = sum([param.partitioned_size() for param in params_group]) - sub_group_size = self.sub_group_size - - if sub_group_size is None or sub_group_size >= params_group_numel: - return [params_group] - - sub_groups = [] - sub_group = [] - local_sub_group_size = 0 - for param in params_group: - - sub_group.append(param) - local_sub_group_size += param.partitioned_size() - - if local_sub_group_size >= sub_group_size or id(param) == id( - params_group[-1]): - - sub_groups.append(sub_group) - - sub_group = [] - local_sub_group_size = 0 - - return sub_groups - - # def reset_ds_tensor(self): - # for name, param in self.module.named_parameters(recurse=True): - # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible" - # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now" - # param.ds_tensor.data = param.data - - def setup_zero_stage3_hooks(self): - self.hierarchy = 0 - self._register_hooks_recursively(self.module) - - #reset step at the beginning of forward - def _pre_forward_hook(module, *args): - self.param_coordinator.reset_step() - - #reset step if in inference mode - def _end_of_forward_hook(module, *args): - - if not torch._C.is_grad_enabled(): - self.param_coordinator.reset_step() - - #likely one of them should be enough but just to be safe - self.module.register_forward_hook(_end_of_forward_hook) - self.module.register_forward_pre_hook(_pre_forward_hook) - - # Add top module to stack trace - global FWD_MODULE_STACK - FWD_MODULE_STACK.append(self.module) - - def persistent_parameters(self): - persistent_params = [] - total_persistent_parameters = 0 - params_count = 0 - for _, param in self.module.named_parameters(recurse=True): - if param.ds_numel < self.persistence_threshold: - params_count += 1 - param.ds_persist = True - persistent_params.append(param) - total_persistent_parameters += param.ds_numel - - print_rank_0( - f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params", - force=False) - return persistent_params - - def _register_hooks_recursively(self, module, count=[0]): - my_count = count[0] - module.id = my_count - - #print(f"{module.__class__} : {module.id}") - - for child in module.children(): - count[0] = count[0] + 1 - self._register_hooks_recursively(child, count=count) - - def _pre_forward_module_hook(module, *args): - self.pre_sub_module_forward_function(module) - - def _post_forward_module_hook(module, input, output): - global FWD_MODULE_STACK - FWD_MODULE_STACK.pop() - if output is None: - output = [] - elif not isinstance(output, (list, tuple)): - if torch.is_tensor(output): - output = [output] - else: - #print(f'got UNKNOWN type {type(output)}') - outputs = [] - output = output if isinstance(output, dict) else vars(output) - for name, val in output.items(): - if not name.startswith('__') and torch.is_tensor(val): - outputs.append(val) - output = outputs - #print(f'convert output to {output}') - - for item in filter(lambda item: is_zero_param(item), output): - if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): - item.ds_active_sub_modules += 1 - module_to_register = FWD_MODULE_STACK[-1] - print_rank_0( - f'Registering dangling parameter for module {module_to_register.__class__.__name__}.', - force=False) - register_external_parameter(module_to_register, item) - - # It's possible that the parameter was already external to the completed module. If so, remove it the - # registration as it will be covered by the outer module instead. - if id(item) in module._external_params: - print_rank_0( - f' Unregistering nested dangling parameter from module {module.__class__.__name__}', - force=False) - unregister_external_parameter(module, item) - - item.all_gather() - - self.post_sub_module_forward_function(module) - - def _pre_backward_module_hook(module, inputs, output): - def _run_before_backward_function(sub_module): - # some models (e.g. Albert) may run multiple forwards on the same layer in a loop - # before doing backwards, so each backward will need a pre-fetch - using reference - # counting to support this scenario - #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") - if sub_module.applied_pre_backward_ref_cnt > 0: - self.pre_sub_module_backward_function(sub_module) - sub_module.applied_pre_backward_ref_cnt -= 1 - #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") - - return _apply_to_tensors_only(module, - PreBackwardFunction, - _run_before_backward_function, - output) - - #This is an alternate to doing _post_backward_module_hook - #it uses tensor.register_hook instead of using torch.autograd.Function - def _alternate_post_backward_module_hook(module, inputs): - module.ds_grads_remaining = 0 - - #print(f"Before Forward {module.__class__.__name__}") - - def _run_after_backward_hook(*unused): - module.ds_grads_remaining = module.ds_grads_remaining - 1 - if module.ds_grads_remaining == 0: - #print(f"After backward {module.__class__.__name__}") - self.post_sub_module_backward_function(module) - - def _run_before_forward_function(input): - if input.requires_grad: - module.ds_grads_remaining += 1 - - return _apply_forward_and_backward_to_tensors_only( - module, - _run_before_forward_function, - _run_after_backward_hook, - inputs) - - def _post_backward_module_hook(module, inputs): - module.ds_grads_remaining = 0 - - def _run_after_backward_function(sub_module): - if sub_module.ds_grads_remaining == 0: - self.post_sub_module_backward_function(sub_module) - - return _apply_to_tensors_only(module, - PostBackwardFunction, - _run_after_backward_function, - inputs) - - # Pre forward hook - module.register_forward_pre_hook(_pre_forward_module_hook) - # Post forward hook - module.register_forward_hook(_post_forward_module_hook) - - # Pre backward hook - module.register_forward_hook(_pre_backward_module_hook) - - # post backward hook - module.register_forward_pre_hook(_post_backward_module_hook) - - def pre_sub_module_forward_function(self, sub_module): - see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", - force=False) - - global FWD_MODULE_STACK - FWD_MODULE_STACK.append(sub_module) - - self.param_coordinator.record_trace(sub_module) - - self.param_coordinator.fetch_sub_module(sub_module) - see_memory_usage( - f"Before sub module function {sub_module.__class__.__name__} after fetch", - force=False) - - self.param_coordinator.prefetch_next_sub_modules( - sub_module, - numel=self.prefetch_elements, - nvme=self.params_in_nvme_and_cpu) - see_memory_usage( - f"Before sub module function {sub_module.__class__.__name__} after prefetch", - force=False) - - self.param_coordinator.increment_step(sub_module) - - def post_sub_module_forward_function(self, sub_module): - see_memory_usage( - f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) - - self.param_coordinator.release_sub_module(sub_module) - - see_memory_usage( - f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) - - def pre_sub_module_backward_function(self, sub_module): - self.param_coordinator.record_trace(sub_module) - - self.param_coordinator.fetch_sub_module(sub_module) - - self.param_coordinator.prefetch_next_sub_modules(sub_module, - numel=self.prefetch_elements) - - self.param_coordinator.increment_step(sub_module) - - def post_sub_module_backward_function(self, sub_module): - see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) - self.param_coordinator.release_sub_module(sub_module) - see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) - - def _release_ipg_buffers(self): - if self.contiguous_gradients: - self.ipg_buffer = None - if not self.offload_optimizer and self.is_gradient_accumulation_boundary: - self.grads_in_partition = None - - self.grads_in_partition_offset = 0 - - def _optimizer_step(self, sub_group_id): - param_group_id = self.sub_group_to_group_id[sub_group_id] - fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] - - self.optimizer.step() - self.optimizer.param_groups[param_group_id]['params'] = [] - - def _swappable_optimizer_subgroup(self, sub_group_id): - if not self.swap_optimizer: - return False - - return self.optimizer_swapper.swappable_tensor( - None, - numel=self.fp16_partitioned_groups_flat_numel[sub_group_id]) - - def _partitioned_params_swap_out(self, i): - offset = 0 - fp32_param = self.fp32_partitioned_groups_flat[i] - assert fp32_param is not None, \ - f'fp32 parameters of sub_group {i} is None' - - swap_fp16_params = [] - swap_fp32_params = [] - for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]): - src = fp32_param.narrow(0, offset, partitioned_param.ds_numel) - if partitioned_param.status == PartitionedParamStatus.AVAILABLE: - partitioned_param.data.copy_(src.data) - else: - swap_fp32_params.append(src) - swap_fp16_params.append(param) - offset += partitioned_param.ds_numel - - if len(swap_fp16_params): - swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params( - dst_fp16_params=swap_fp16_params, - src_fp32_params=swap_fp32_params) - - def initialize_optimizer_states(self): - num_subgroups = len(self.fp16_groups) - - largest_numel = max( - [sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups]) - gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype - gradient_buffer = torch.zeros(int(largest_numel), - dtype=gradient_dtype, - device=self.device) - - timers = self.timers - timer_names = set() - - if self.swap_optimizer: - self.optimizer_swapper.init_timers() - - INIT_OPTIMIZER_TIMER = 'init_optimizer_state' - timer_names.add(INIT_OPTIMIZER_TIMER) - self.start_timers([INIT_OPTIMIZER_TIMER]) - - for i, group in enumerate(self.fp16_groups): - swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i) - swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None - - num_elements = int(self.fp16_partitioned_groups_flat_numel[i]) - - see_memory_usage( - f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', - force=False) - - if swappable_optimizer_subgroup: - self._optimizer_states_and_gradient_swap_in(i, timer_names) - - if self.offload_optimizer and not swappable_optimizer_subgroup: - subgroup_gradient_buffer = torch.zeros(num_elements, - dtype=gradient_dtype, - device=self.device) - if self.offload_optimizer_pin_memory: - subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory() - - self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer - else: - self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow( - 0, - 0, - num_elements) - - self._optimizer_step(i) - - if swappable_param_subgroup: - self._partitioned_params_swap_out(i) - - if swappable_optimizer_subgroup: - self._optimizer_states_and_gradient_swap_out(i, timer_names) - - see_memory_usage( - f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', - force=False) - - self.stop_timers([INIT_OPTIMIZER_TIMER]) - self.log_timers(timer_names) - - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - - if not self.offload_optimizer: - for group in self.fp32_partitioned_groups_flat: - group.grad = None - - # Reset steps - return - - ######################################################################### - #########################ZeRO Partition Gradients######################## - ######################################################################### - - def get_first_param_index(self, group_id, param_group, partition_id): - for index, param in enumerate(param_group): - param_id = self.get_param_id(param) - if partition_id in self.param_to_partition_ids[group_id][param_id]: - return index - return None - - def initialize_gradient_partitioning_data_structures(self): - - total_partitions = dist.get_world_size(group=self.dp_process_group) - - for i, param_group in enumerate(self.fp16_groups): - - self.param_to_partition_ids[i] = {} - self.is_partition_reduced[i] = {} - self.total_grads_in_partition[i] = {} - self.remaining_grads_in_partition[i] = {} - self.is_grad_computed[i] = {} - self.grad_partition_insertion_offset[i] = {} - self.grad_start_offset[i] = {} - self.first_param_index_in_partition[i] = {} - - for partition_id in range(total_partitions): - self.is_grad_computed[i][partition_id] = {} - self.grad_partition_insertion_offset[i][partition_id] = {} - self.grad_start_offset[i][partition_id] = {} - self.initialize_gradient_partition(i, param_group, partition_id) - self.is_partition_reduced[i][partition_id] = False - self.first_param_index_in_partition[i][ - partition_id] = self.get_first_param_index( - i, - param_group, - partition_id) - - def independent_gradient_partition_epilogue(self): - self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) - self.reduce_ipg_grads() - self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) - - if self.overlap_comm: - self.reduction_stream.synchronize() - - with torch.cuda.stream(self.reduction_stream): - self.partition_previous_reduced_grads() - - # if dist.get_rank() == 0: - # logger.info("Params already reduced %s", self.params_already_reduced) - for i in range(len(self.params_already_reduced)): - self.params_already_reduced[i] = False - - #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad - #TODO: use a similar code path for both cpu_offload and non-cpu offload - if not self.offload_optimizer: - for i, sub_group in enumerate(self.fp16_groups): - self.averaged_gradients[i] = [ - torch.zeros_like(param.ds_tensor) if param.grad is None else - param.grad.data.narrow(0, - 0, - param.ds_tensor.numel()) - for param in sub_group - ] - # self.averaged_gradients[i] = self.get_flat_partition( - # self.fp16_groups[i], - # 0, - # self.fp32_partitioned_groups_flat[i].numel(), - # return_tensor_list=True) - - self._release_ipg_buffers() - - see_memory_usage(f"End ipg_epilogue", force=False) - - # resets all partition to no reduced - # sets remaining grads to the total number of grads in each partition - # set is grad computed to false for all grads in partition - def reset_partition_gradient_structures(self): - total_partitions = dist.get_world_size(group=self.dp_process_group) - for i, _ in enumerate(self.fp16_groups): - for partition_id in range(total_partitions): - self.is_partition_reduced[i][partition_id] = False - self.remaining_grads_in_partition[i][ - partition_id] = self.total_grads_in_partition[i][partition_id] - - for param_id in self.is_grad_computed[i][partition_id]: - self.is_grad_computed[i][partition_id][param_id] = False - - def initialize_gradient_partition(self, i, param_group, partition_id): - def set_key_value_list(dictionary, key, value): - if key in dictionary: - dictionary[key].append(value) - else: - dictionary[key] = [value] - - def increment_value(dictionary, key): - if key in dictionary: - dictionary[key] += 1 - else: - dictionary[key] = 1 - - partition_size = self.partition_size[i] - - start_index = partition_size * partition_id - end_index = partition_size * (partition_id + 1) - - current_index = 0 - first_offset = 0 - - for param in param_group: - - param_size = param.numel() - param_id = self.get_param_id(param) - - if (current_index >= start_index and current_index < end_index): - set_key_value_list(self.param_to_partition_ids[i], - param_id, - partition_id) - increment_value(self.total_grads_in_partition[i], partition_id) - - self.is_grad_computed[i][partition_id][param_id] = False - - self.grad_partition_insertion_offset[i][partition_id][ - param_id] = current_index - start_index - self.grad_start_offset[i][partition_id][param_id] = 0 - - elif start_index > current_index and start_index < (current_index + - param_size): - assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" - first_offset = start_index - current_index - - set_key_value_list(self.param_to_partition_ids[i], - param_id, - partition_id) - increment_value(self.total_grads_in_partition[i], partition_id) - - self.is_grad_computed[i][partition_id][param_id] = False - - self.grad_partition_insertion_offset[i][partition_id][param_id] = 0 - self.grad_start_offset[i][partition_id][param_id] = first_offset - - current_index = current_index + param_size - - def overlapping_partition_gradients_reduce_epilogue(self): - self.independent_gradient_partition_epilogue() - self.zero_grad() - - def create_reduce_and_remove_grad_hooks(self): - print_rank_0(f'[Begin] Create gradient reduction hooks') - self.grad_accs = [] - for i, param_group in enumerate(self.fp16_groups): - for param in param_group: - if param.requires_grad: - #print_rank_0(f" Before all gather {param.device}, {param.shape}") - - # The hook must be created in un-partitioned parameter - param.all_gather() - - #print(f"After all gather {param.device}, {param.shape}") - def wrapper(param, i): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - - def reduce_partition_and_remove_grads(*notneeded): - self.reduce_ready_partitions_and_remove_grads(param, i) - - grad_acc.register_hook(reduce_partition_and_remove_grads) - self.grad_accs.append(grad_acc) - - #print(f"param grad fn {param.expand_as(param).grad_fn}") - wrapper(param, i) - - # Partition the parameter after creating the hook - param.partition() - print_rank_0(f'[End] Create gradient reduction hooks') - - def get_param_id(self, param): - unique_id = id(param) - return self.param_id[unique_id] - - def report_ipg_memory_usage(self, tag, param_elems): - elem_count = self.elements_in_ipg_bucket + param_elems - percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size - see_memory_usage( - f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}", - force=False) - - ###############Idependent Partition Gradient ######################## - def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): - #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True) - - # Because the ipg bucket is initialized with a random place holder tensor, we must - # explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket > - # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a - # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be - # empty, while reduction_list will have that garbage data. - if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size: - self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", - param.ds_numel) - - self.reduce_ipg_grads() - - if self.contiguous_gradients and self.overlap_comm: - # Swap ipg_index between 0 and 1 - self.ipg_index = 1 - self.ipg_index - self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", - param.ds_numel) - - param_id = self.get_param_id(param) - assert self.params_already_reduced[param_id] == False, \ - f"The parameter {param_id} has already been reduced. \ - Gradient computed twice for this partition. \ - Multiple gradient reduction is currently not supported" - - # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening - if param.ds_numel > self.reduce_bucket_size: - self.extra_large_param_to_reduce = param - - elif self.contiguous_gradients: - #print_rank_0("before new grad tensor move") - new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow( - 0, - self.elements_in_ipg_bucket, - param.ds_numel) - #print_rank_0("after new grad tensor move") - new_grad_tensor.copy_(param.grad.view(-1)) - param.grad.data = new_grad_tensor.data.view_as(param.grad) - - self.elements_in_ipg_bucket += param.ds_numel - self.grads_in_ipg_bucket.append(param.grad) - self.params_in_ipg_bucket.append((i, param, param_id)) - self.report_ipg_memory_usage("End ipg_remove_grads", 0) - - def gradient_reduction_w_predivide(self, tensor): - dp_world_size = dist.get_world_size(group=self.dp_process_group) - - tensor_to_allreduce = tensor - - if self.communication_data_type != tensor.dtype: - tensor_to_allreduce = tensor.to(self.communication_data_type) - - if self.postscale_gradients: - if self.gradient_predivide_factor != 1.0: - tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor) - - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) - - if self.gradient_predivide_factor != dp_world_size: - tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size) - else: - tensor_to_allreduce.div_(dp_world_size) - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) - - if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: - tensor.copy_(tensor_to_allreduce) - - return tensor - - def average_tensor(self, tensors, params_to_reduce): - with torch.cuda.stream(self.reduction_stream): - if not self.reduce_scatter: - for tensor in tensors: - self.gradient_reduction_w_predivide(tensor) - return - - for tensor in tensors: - tensor.div_(dist.get_world_size(group=self.dp_process_group)) - - # reduction resulting with each rank only holding the gradient partition it owns - # This could either be a reduce scatter or a reduce op depending on how - # parameters are partitionied. The method is implemented by the - # DeepSpeed param extensions to the pytorch parameter, so its up to - # the extension to define what happens here - params_to_reduce[0].reduce_gradients_at_owner( - param_list=params_to_reduce, - hierarchy=self.param_coordinator.hierarchy) - - def set_grad_positions(self): - for i, group in enumerate(self.fp16_groups): - current_offset = 0 - for param in group: - param_id = self.get_param_id(param) - num_elements = param.ds_tensor.ds_numel - - self.grad_position[param_id] = [ - int(i), - int(current_offset), - int(num_elements) - ] - #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") - current_offset += num_elements - - def async_accumulate_grad_in_cpu_via_gpu(self, param, acc_grad_cpu_partition): - - # copy to a preexisiting buffer to avoid memory allocation penalty - dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow( - 0, - 0, - param.ds_tensor.ds_numel) - - if self.micro_step_id > 0: - dest_buffer.copy_(acc_grad_cpu_partition.view(-1), non_blocking=True) - param.grad.data.view(-1).add_(dest_buffer) - - # at the boundary we will send 32bit directly - if not self.is_gradient_accumulation_boundary: - acc_grad_cpu_partition.data.copy_(param.grad.data.view(-1), - non_blocking=True) - - def _constant_buffered_norm2(self, input, buffer_size=250000000): - norm = None - for part in input.view(-1).split(buffer_size): - if norm is None: - norm = part.data.double().norm(2)**2.0 - else: - norm += part.data.double().norm(2)**2.0 - return norm**0.5 - - def set_norm_for_param_grad_in_gpu(self, param): - param_id = self.get_param_id(param) - #self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2) - #Using a more memory efficient version - self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) - - def update_overflow_tracker_for_param_grad(self, param): - #Credit to our user David Minn - if param.grad is not None: - if self.overlap_comm: - self.gpu_sum = self.gpu_sum + param.grad.data.float().sum() - elif self._has_inf_or_nan(param.grad.data): - self.local_overflow = True - - def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): - with torch.cuda.stream(self.copy_grad_stream): - param_id = self.get_param_id(param) - src_tensor = param.grad.view(-1).float() - #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}") - fp32_grad_tensor.copy_(src_tensor, non_blocking=True) - param.grad = None - - def complete_grad_norm_calculation_for_cpu_offload(self, params): - total_norm = 0.0 - norm_type = 2.0 - for p in params: - if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - param_id = self.get_param_id(p) - if param_id in self.norm_for_param_grads.keys(): - param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 - - # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=self.dp_process_group) - - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) - - total_norm = total_norm_cuda[0].item()**(1. / norm_type) - - if total_norm == float( - 'inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 - - return total_norm - - def partition_previous_reduced_grads(self): - if not self.previous_reduced_grads: - return - - if self.offload_optimizer: - allocate_grads_in_partition = self.grads_in_partition is None\ - and self.gradient_accumulation_steps > 1 - else: - allocate_grads_in_partition = self.grads_in_partition is None - - if allocate_grads_in_partition: - self.grads_in_partition = [] - - for i, group in enumerate(self.fp16_groups): - total_size = 0 - for param_in_partition in group: - total_size += param_in_partition.ds_tensor.ds_numel - - see_memory_usage( - f"group {i} before creating {total_size} reduced gradients into partition", - force=False) - if self.offload_param_pin_memory: - self.grads_in_partition.append( - torch.zeros(int(total_size), - dtype=self.dtype, - device=self.device).pin_memory()) - else: - self.grads_in_partition.append( - torch.zeros(int(total_size), - dtype=self.dtype, - device=self.device)) - see_memory_usage( - f"group {i} after creating {total_size} reduced gradients into partition", - force=False) - - if self.offload_optimizer: - offload_fp32_gradients = {} - offload_fp32_offsets = {} - - with torch.cuda.stream(self.copy_grad_stream): - self.reduction_stream.synchronize() - for param in self.previous_reduced_grads: - - [i, - dest_offset, - num_elements] = self.grad_position[self.get_param_id(param)] - - if self.offload_optimizer: - param.partition_gradients( - partition_buffers=self.temp_grad_gpu_buffer) - #with torch.cuda.stream(self.copy_grad_stream): - # self.reduction_stream.synchronize() - - if self.gradient_accumulation_steps > 1: - # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer - fp16_grad_tensor = self.grads_in_partition[i].narrow( - 0, - dest_offset, - num_elements) - self.async_accumulate_grad_in_cpu_via_gpu( - param, - fp16_grad_tensor) - - if self.is_gradient_accumulation_boundary: - - self.set_norm_for_param_grad_in_gpu(param) - - self.update_overflow_tracker_for_param_grad(param) - - if self._swappable_optimizer_subgroup(i): - if not i in offload_fp32_gradients.keys(): - offload_fp32_gradients[i] = [] - offload_fp32_offsets[i] = [] - - offload_fp32_gradients[i].append(param.grad.view(-1).float()) - param.grad = None - offload_fp32_offsets[i].append(dest_offset) - else: - fp32_grad_tensor = self.fp32_partitioned_groups_flat[ - i].grad.narrow(0, - dest_offset, - num_elements) - - self.async_inplace_copy_grad_to_fp32_buffer_from_gpu( - param, - fp32_grad_tensor) - else: - # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer - fp16_grad_tensor = self.grads_in_partition[i].narrow( - 0, - dest_offset, - num_elements) - param.partition_gradients( - partition_buffers=fp16_grad_tensor, - accumulate=True if self.micro_step_id > 0 else False) - - if self.offload_optimizer and self.swap_optimizer: - for i in offload_fp32_gradients.keys(): - self.optimizer_swapper.swap_out_gradients( - parameter=self.fp32_partitioned_groups_flat[i], - gradient_offsets=offload_fp32_offsets[i], - gradient_tensors=offload_fp32_gradients[i]) - - self.previous_reduced_grads = [] - - def reduce_ipg_grads(self, extra_param=None): - if self.overlap_comm: - self.reduction_stream.synchronize() - - with torch.cuda.stream(self.reduction_stream): - self.partition_previous_reduced_grads() - - params_to_reduce = [param for i, param, param_id in self.params_in_ipg_bucket] - #print(f"Params in ipg bucket {self.params_in_ipg_bucket}") - #print(f"Reducing {[(debug_param2name_id_shape(param), param.grad) for param in params_to_reduce]}") - #exit(0) - if self.contiguous_gradients: - reduction_list = [self.ipg_buffer[self.ipg_index]] - if self.extra_large_param_to_reduce is not None: - reduction_list.append(self.extra_large_param_to_reduce.grad) - self.extra_large_param_to_reduce = None - self.average_tensor(reduction_list, params_to_reduce) - else: - self.buffered_reduce_fallback( - None, - self.grads_in_ipg_bucket, - elements_per_buffer=self.elements_in_ipg_bucket) - - for _, param, param_id in self.params_in_ipg_bucket: - self.params_already_reduced[param_id] = True - - self.previous_reduced_grads = params_to_reduce - - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] - self.elements_in_ipg_bucket = 0 - ##################################################################### - - def reduce_ready_partitions_and_remove_grads(self, param, i): - #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) - self.reduce_independent_p_g_buckets_and_remove_grads(param, i) - - def zero_reduced_gradients(self, partition_id, i): - def are_all_related_partitions_reduced(params_id): - for partition_id in self.param_to_partition_ids[i][params_id]: - if not self.is_partition_reduced[i][partition_id]: - return False - return True - - for params_id in self.is_grad_computed[i][partition_id]: - if are_all_related_partitions_reduced(params_id): - self.param_dict[params_id].grad = None - - def flatten_and_print(self, message, tensors, start=0, n=5): - flatten_tensor = self.flatten(tensors) - - def print_func(): - logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) - - self.sequential_execution(print_func, message) - - def get_grads_to_reduce(self, i, partition_id): - def get_reducible_portion(key): - grad = self.param_dict[key].grad - total_elements = grad.numel() - start = self.grad_start_offset[i][partition_id][key] - num_elements = min( - total_elements - start, - self.partition_size[i] - - self.grad_partition_insertion_offset[i][partition_id][key]) - if not pg_correctness_test: - if num_elements == total_elements: - return grad - else: - return grad.contiguous().view(-1).narrow(0, - int(start), - int(num_elements)) - else: - if num_elements == total_elements: - return grad.clone() - else: - return grad.clone().contiguous().view(-1).narrow( - 0, - int(start), - int(num_elements)) - - grads_to_reduce = [] - for key in self.is_grad_computed[i][partition_id]: - grad = get_reducible_portion(key) - grads_to_reduce.append(grad) - return grads_to_reduce - - def sequential_execution(self, function, message, group=None): - if group is None: - group = self.dp_process_group - if dist.get_rank(group=group) == 0: - logger.info(message) - for id in range(dist.get_world_size(group=group)): - if id == dist.get_rank(group=group): - function() - dist.barrier(group=group) - - def set_none_gradients_to_zero(self, i, partition_id): - for param_id in self.is_grad_computed[i][partition_id]: - param = self.param_dict[param_id] - if param.grad is None: - param.grad = torch.zero_like(param) - - ######################Reduction Related Methods############################## - - def allreduce_bucket(self, - bucket, - communication_data_type=torch.float16, - rank=None, - log=None): - rank = None - tensor = self.flatten(bucket) - - tensor_to_allreduce = tensor - - if pg_correctness_test: - communication_data_type = torch.float32 - - if communication_data_type != tensor.dtype: - tensor_to_allreduce = tensor.to(communication_data_type) - - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) - - if rank is None: - # "All Reducing" - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) - else: - global_rank = _get_global_rank(self.dp_process_group, rank) - dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) - - if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: - if rank is None or rank == dist.get_rank(group=self.dp_process_group): - tensor.copy_(tensor_to_allreduce) - - return tensor - - # if rank is specified do a reduction instead of an allreduce - def allreduce_and_copy(self, small_bucket, rank=None, log=None): - with torch.cuda.stream(self.reduction_stream): - allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) - if rank is None or rank == dist.get_rank(group=self.dp_process_group): - for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): - buf.copy_(synced) - - def allreduce_no_retain(self, - bucket, - numel_per_bucket=500000000, - rank=None, - log=None): - small_bucket = [] - numel = 0 - for tensor in bucket: - small_bucket.append(tensor) - numel = numel + tensor.numel() - if numel > numel_per_bucket: - self.allreduce_and_copy(small_bucket, rank=rank, log=None) - small_bucket = [] - if len(small_bucket) > 0: - self.allreduce_and_copy(small_bucket, rank=rank, log=log) - - # allows using reduction of gradients instead of using all_reduce - def buffered_reduce_fallback(self, - rank, - grads, - elements_per_buffer=500000000, - log=None): - split_buckets = split_half_float_double(grads) - - for i, bucket in enumerate(split_buckets): - self.allreduce_no_retain(bucket, - numel_per_bucket=elements_per_buffer, - rank=rank, - log=log) - - ############################################################################# - ############################################################################# - ############################################################################# - - # views the tensor as multiple partitions and returns - # those partitions - def get_data_parallel_partitions(self, tensor): - partitions = [] - - dp = dist.get_world_size(group=self.dp_process_group) - dp_id = dist.get_rank(group=self.dp_process_group) - - total_num_elements = tensor.numel() - - base_size = total_num_elements // dp - remaining = total_num_elements % dp - - start = 0 - for id in range(dp): - partition_size = base_size - if id < remaining: - partition_size = partition_size + 1 - partitions.append(tensor.narrow(0, start, partition_size)) - start = start + partition_size - return partitions - - def get_partition_info(self, tensor_list, partition_size, partition_id): - params_in_partition = [] - params_not_in_partition = [] - - start_index = partition_size * partition_id - end_index = partition_size * (partition_id + 1) - - current_index = 0 - first_offset = 0 - - for tensor in tensor_list: - - tensor_size = tensor.numel() - - if (current_index >= start_index and current_index < end_index): - params_in_partition.append(tensor) - - elif start_index > current_index and start_index < (current_index + - tensor_size): - params_in_partition.append(tensor) - - assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" - first_offset = start_index - current_index - - else: - params_not_in_partition.append(tensor) - - current_index = current_index + tensor_size - - return params_in_partition, params_not_in_partition, first_offset - - def zero_grad(self, set_grads_to_None=True): - """ - Zero FP16 parameter grads. - """ - # FP32 grad should never exist. - # For speed, set model fp16 grad to None by default - for group in self.fp16_groups: - for p in group: - if set_grads_to_None: - p.grad = None - else: - if p.grad is not None: - p.grad.detach_() - p.grad.zero_() - - def _model_parallel_all_reduce(self, tensor, op): - """ Perform all reduce within model parallel group, if any. - """ - if self.model_parallel_group is None: - pass - else: - torch.distributed.all_reduce(tensor=tensor, - op=op, - group=self.model_parallel_group) - - def get_grad_norm_direct(self, gradients, params, norm_type=2): - """Clips gradient norm of an iterable of parameters. - - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. - - Arguments: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - - Returns: - Total norm of the parameters (viewed as a single vector). - """ - norm_type = float(norm_type) - if norm_type == inf: - total_norm = max(g.data.abs().max() for g in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.MAX, - group=self.dp_process_group) - - # Take max across all GPUs. - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.MAX) - total_norm = total_norm_cuda[0].item() - else: - total_norm = 0.0 - # if dist.get_rank() == 0: - # logger.info(f"Total Norm beginning {total_norm}") - for g, p in zip(gradients, params): - if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - param_norm = g.data.double().norm(2) - total_norm += param_norm.item()**2 - # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=self.dp_process_group) - - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) - - total_norm = total_norm_cuda[0].item()**(1. / norm_type) - - if total_norm == float( - 'inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 - - return total_norm - - # creates a flat fused tensor from the tensor list starting at the first_offset - # in the first tensor of the list. If there are not enough elements in the tensor - # list then the flat tensor will be padded with zeros - def get_flat_partition(self, - tensor_list, - first_offset, - partition_size, - return_tensor_list=False): - flat_tensor_list = [] - current_size = 0 - for i, tensor in enumerate(tensor_list): - if tensor.grad is None: - tensor.grad = torch.zeros_like(tensor) - - tensor = tensor.grad - num_elements = tensor.numel() - tensor_offset = 0 - - # we need to offset to get to the right element - if i == 0 and first_offset > 0: - tensor_offset = first_offset - num_elements = num_elements - tensor_offset - - # we dont need all elements of the tensor - if num_elements > (partition_size - current_size): - num_elements = partition_size - current_size - - # we need a narrow view of the tensor based on the tensor offset and number of elements that - # we need from this tensor - if tensor_offset > 0 or num_elements < tensor.numel(): - flat_tensor_list.append(tensor.contiguous().view(-1).narrow( - 0, - int(tensor_offset), - int(num_elements))) - else: - flat_tensor_list.append(tensor) - - current_size = current_size + num_elements - - # this means its the last partition and does not align with the dp boundary. We need to pad before flattening - if current_size < partition_size: - flat_tensor_list.append( - torch.zeros(int(partition_size - current_size), - dtype=tensor_list[0].dtype, - device=tensor_list[0].device)) - - if return_tensor_list: - return flat_tensor_list - - return self.flatten(flat_tensor_list) - - def free_grad_in_param_list(self, param_list): - for p in param_list: - p.grad = None - - def reset_cpu_buffers(self): - self.norm_for_param_grads = {} - self.local_overflow = False - - def log_timers(self, timer_names): - if self.timers is None: - return - - self.timers.log(names=list(timer_names)) - - def start_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).start() - - def stop_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).stop() - - def _pre_step(self): - self.micro_step_id = INITIAL_MICRO_STEP_ID - - print_rank_0(f"Inside Step function") - see_memory_usage(f"In step before checking overflow", force=False) - - print_rank_0("Finished Tracing at Beginning of Step") - self.param_coordinator.hierarchy = 0 - self.param_coordinator.finish_tracing(print_trace=True) - - self.param_coordinator.reset_step() - - print_rank_0("Finished Tracing at Beginning of Step") - - def _get_norm_groups(self): - norm_groups = [] - for i, group in enumerate(self.fp16_groups): - if self.offload_optimizer: - norm_groups.append( - self.complete_grad_norm_calculation_for_cpu_offload( - self.fp16_groups[i])) - else: - norm_groups.append( - self.get_grad_norm_direct(self.averaged_gradients[i], - self.fp16_groups[i])) - return norm_groups - - def _prepare_fp32_grad_for_sub_group(self, sub_group_id): - partition_id = dist.get_rank(group=self.dp_process_group) - - single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to( - self.fp32_partitioned_groups_flat[sub_group_id].dtype) - - assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \ - "averaged gradients have different number of elements that partition size {} {} {} {}".format( - single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id) - - self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition - - # release all the gradient since we have already created a necessary copy in dp_grad_partition - self.zero_grad() - - self.averaged_gradients[sub_group_id] = None - - def _prepare_sub_group(self, sub_group_id, timer_names=set()): - see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', - force=False) - if self._swappable_optimizer_subgroup(sub_group_id): - self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names) - elif not self.offload_optimizer: - self._prepare_fp32_grad_for_sub_group(sub_group_id) - see_memory_usage(f'After prepare optimizer sub group {sub_group_id}', - force=False) - - def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()): - param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] - fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) - assert self._swappable_optimizer_subgroup(sub_group_id), \ - f'Parameter {fp32_param_id} of numel={param_length} is not swappable' - - OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state' - see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}', - force=False) - self.start_timers([OPTIMIZER_SWAP_IN_STATE]) - - self.optimizer_swapper.swap_in_optimizer_state( - parameter=self.fp32_partitioned_groups_flat[sub_group_id], - async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id]) - - self.stop_timers([OPTIMIZER_SWAP_IN_STATE]) - timer_names.add(OPTIMIZER_SWAP_IN_STATE) - see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', - force=False) - - def _release_sub_group(self, sub_group_id, timer_names=set()): - see_memory_usage(f'Before release optimizer sub group {sub_group_id}', - force=False) - # get rid of the fp32 gradients. Not needed anymore - if not self.offload_optimizer: - self.fp32_partitioned_groups_flat[sub_group_id].grad = None - - if self._swappable_optimizer_subgroup(sub_group_id): - self._optimizer_states_and_gradient_swap_out(sub_group_id, timer_names) - see_memory_usage(f'After release optimizer sub group {sub_group_id}', - force=False) - - # create a flat tensor aligned at the alignment boundary - def flatten_dense_tensors_aligned(self, tensor_list, alignment): - num_elements = 0 - for tens in tensor_list: - num_elements = num_elements + tens.numel() - - remaining = num_elements % alignment - - if remaining: - elements_to_add = alignment - remaining - pad_tensor = torch.zeros(elements_to_add, - device=tensor_list[0].device, - dtype=tensor_list[0].dtype) - padded_tensor_list = tensor_list + [pad_tensor] - - num_elements = num_elements + elements_to_add - else: - padded_tensor_list = tensor_list - - return self.flatten(padded_tensor_list) - - def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()): - param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] - fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) - assert self._swappable_optimizer_subgroup(sub_group_id), \ - f'Parameter {fp32_param_id} of numel={param_length} is not swappable' - - OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state' - see_memory_usage( - f'post-step Before swapping out optimizer tensors {sub_group_id}', - force=False) - self.start_timers([OPTIMIZER_SWAP_OUT_STATE]) - - self.optimizer_swapper.swap_out_optimizer_state( - parameter=self.fp32_partitioned_groups_flat[sub_group_id], - async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is - not None) - - self.stop_timers([OPTIMIZER_SWAP_OUT_STATE]) - see_memory_usage( - f'post-step After swapping out optimizer tensors {sub_group_id}', - force=False) - timer_names.add(OPTIMIZER_SWAP_OUT_STATE) - - # get rid of the fp32 gradients. Not needed anymore - self.fp32_partitioned_groups_flat[sub_group_id].grad = None - - def _unflatten_partitioned_parameters(self, sub_group_id): - updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], - self.fp16_partitioned_groups[sub_group_id]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): - partitioned_param.data = q.data - - def _overflow_clean_up(self, prev_scale): - see_memory_usage('After overflow before clearing gradients', force=False) - self.zero_grad() - - if self.offload_optimizer: - self.reset_cpu_buffers() - else: - self.averaged_gradients = {} - - see_memory_usage('After overflow after clearing gradients', force=False) - - if torch.distributed.get_rank() == 0: - logger.info( - "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, " - "reducing to {}".format(dist.get_rank(), - prev_scale, - self.loss_scale)) - - def _overflow_check_and_loss_scale_update(self): - - # First compute norm for all group so we know if there is overflow - self.check_overflow() - - #loss scaling related computation - prev_scale = self.loss_scale - self._update_scale(self.overflow) - - if self.overflow: - self._overflow_clean_up(prev_scale) - - return self.overflow - - def _post_step(self, timer_names=set()): - if self.offload_optimizer: - self.reset_cpu_buffers() - - #Gathering persisting parameters - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - - self.log_timers(timer_names) - - see_memory_usage('After zero_optimizer step', force=False) - print_rank_0(f"------------------Finishing Step-----------------------") - - def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): - if self.fp16_partitioned_groups_flat[sub_group_id] is not None: - self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( - self.fp32_partitioned_groups_flat[sub_group_id].data) - - #unflatten fp16 parameter subgroup - self._unflatten_partitioned_parameters(sub_group_id) - else: - self._partitioned_params_swap_out(sub_group_id) - - def step(self, closure=None): - """ - Not supporting closure. - """ - self._pre_step() - - #checks for overflow, adjust the loss scale accordingly - if self._overflow_check_and_loss_scale_update(): - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - return - - norm_groups = self._get_norm_groups() - self._global_grad_norm = get_global_norm(norm_list=norm_groups) - - timer_names = set() - - timer_names.add('optimizer_step') - self.start_timers(['optimizer_step']) - - #update parameters one sub group at a time - for sub_group_id, group in enumerate(self.fp16_groups): - - #prepare optimizer states, gradients and fp32 parameters for update - self._prepare_sub_group(sub_group_id, timer_names) - - #scale the fp32 gradients - self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm) - - #apply the optimizer step on the sub group and copy fp32 parameters to fp16 - self._optimizer_step(sub_group_id) - - #put fp16 parameters in appropriate location - self._reassign_or_swap_out_partitioned_parameters(sub_group_id) - - #release memory or swap out optimizer states of fp32 parameters - self._release_sub_group(sub_group_id, timer_names) - - self.stop_timers(['optimizer_step']) - - self._post_step(timer_names) - return - - def dump_pre_step_gradients(self, debug_fp32_grads): - # Dump gradient norms for debugging - for i, _ in enumerate(self.fp16_groups): - print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC') - for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]): - param_id = self.get_param_id(fp16_param) - fp16_grad_norm = self.debug_fp16_grads[i][param_id] - - fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad] - norm_list = [fp16_grad_norm, fp32_grad_norm] - print(f'Pre-Step Norms {i} {param_id} = {norm_list}') - - def dump_post_step_gradients(self): - # Dump gradient norms for debugging - for i, group in enumerate(self.fp16_groups): - print( - f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT') - unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i]) - unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i], - self.fp16_groups[i]) - for j, p in enumerate(self.fp16_groups[i]): - param_id = self.get_param_id(p) - param_norm = float(p.data.float().norm(2)) - ds_norm = float(p.ds_tensor.data.float().norm(2)) - - unflat_norm = [ - float(t.data.float().norm(2)) - for t in [unflat_fp16[j], - unflat_fp32[j]] - ] - norm_list = [param_norm, ds_norm] + unflat_norm - print(f'Post-Step Norms {i} {param_id} = {norm_list}') - - def unscale_and_clip_grads(self, sub_group_id, total_norm): - grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] - - # compute combined scale factor for this group - combined_scale = self.loss_scale - if self.clip_grad > 0.: - # norm is in fact norm*scale - clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad - if clip > 1: - combined_scale = clip * self.loss_scale - - for grad in grad_groups_flat: - if isinstance(grad, list): - sub_partitions = grad - for g in sub_partitions: - g.data.mul_(1. / combined_scale) - else: - grad.data.mul_(1. / combined_scale) - - def _check_overflow(self, partition_gradients=True): - self.overflow = self.has_overflow(partition_gradients) - - # `params` is a list / generator of torch.Variable - def has_overflow_serial(self, params, is_grad_list=False): - for p in params: - if p.grad is not None and self._has_inf_or_nan(p.grad.data): - return True - - return False - - def has_overflow_partitioned_grads_serial(self): - for i in range(len(self.fp16_groups)): - for j, grad in enumerate(self.averaged_gradients[i]): - if grad is not None and self._has_inf_or_nan(grad.data, j): - return True - return False - - def has_overflow(self, partition_gradients=True): - if partition_gradients: - if self.overlap_comm: - self.local_overflow = self._has_inf_or_nan(self.gpu_sum) - self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda() - - overflow = self.local_overflow if self.offload_optimizer else self.has_overflow_partitioned_grads_serial( - ) - #overflow = self.has_overflow_partitioned_grads_serial() - overflow_gpu = torch.cuda.ByteTensor([overflow]) - torch.distributed.all_reduce(overflow_gpu, - op=torch.distributed.ReduceOp.MAX, - group=self.dp_process_group) - - else: - params = [] - for group in self.fp16_groups: - for param in group: - params.append(param) - - overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) - overflow_gpu = torch.cuda.ByteTensor([overflow]) - - # Since each model parallel GPU carries only part of the model, - # make sure overflow flag is synced across all the model parallel GPUs - self._model_parallel_all_reduce(tensor=overflow_gpu, - op=torch.distributed.ReduceOp.MAX) - - overflow = overflow_gpu[0].item() - return bool(overflow) - - # `x` is a torch.Tensor - @staticmethod - def _has_inf_or_nan(x, j=None): - try: - # if x is half, the .float() incurs an additional deep copy, but it's necessary if - # Pytorch's .sum() creates a one-element tensor of the same type as x - # (which is true for some recent version of pytorch). - cpu_sum = float(x.float().sum()) - # More efficient version that can be used if .sum() returns a Python scalar - # cpu_sum = float(x.sum()) - except RuntimeError as instance: - # We want to check if inst is actually an overflow exception. - # RuntimeError could come from a different error. - # If so, we still want the exception to propagate. - if "value cannot be converted" not in instance.args[0]: - raise - return True - else: - if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: - return True - return False - - def backward(self, loss, retain_graph=False): - """ - :attr:`backward` performs the following steps: - - 1. fp32_loss = loss.float() - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves - """ - self.micro_step_id += 1 - print_rank_0( - f"Total fully available parameters {self.param_coordinator.total_available_parameter_numel}" - ) - - if self.swap_optimizer: - self.optimizer_swapper.pre_backward() - - see_memory_usage(f"Before backward", force=False) - if self.contiguous_gradients: - self.ipg_buffer = [] - buf_0 = torch.empty(self.reduce_bucket_size, - dtype=self.dtype, - device=torch.cuda.current_device()) - self.ipg_buffer.append(buf_0) - - # Use double buffers to avoid data access conflict when overlap_comm is enabled. - if self.overlap_comm: - buf_1 = torch.empty(self.reduce_bucket_size, - dtype=self.dtype, - device=torch.cuda.current_device()) - self.ipg_buffer.append(buf_1) - self.ipg_index = 0 - - self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - '''Partitioning Parameters that were not partitioned - Usually if parameters of modules whose input parameters do not require - grad computation do not trigger post call and will therefore will remain unpartitioned ''' - self._partition_all_parameters() - - if self.swap_optimizer: - self.optimizer_swapper.post_backward() - - def _partition_all_parameters(self): - for name, param in self.module.named_parameters(recurse=True): - self.param_coordinator.release_and_reset_parameter(param) - - def check_overflow(self, partition_gradients=True): - self._check_overflow(partition_gradients) - - def _update_scale(self, has_overflow=False): - self.loss_scaler.update_scale(has_overflow) - - # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" - def _get_state(self): - return self.optimizer.state - - def _set_state(self, value): - self.optimizer.state = value - - state = property(_get_state, _set_state) - - # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" - # (for example, to adjust the learning rate) - def _get_param_groups(self): - return self.optimizer.param_groups - - def _set_param_groups(self, value): - self.optimizer.param_groups = value - - param_groups = property(_get_param_groups, _set_param_groups) - - # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" - def _get_loss_scale(self): - return self.loss_scaler.loss_scale - - def _set_loss_scale(self, value): - self.loss_scaler.cur_scale = value - - loss_scale = property(_get_loss_scale, _set_loss_scale) - cur_scale = property(_get_loss_scale, _set_loss_scale) - - def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings): - # Remove paddings from flattened tensor - individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors) - lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)] - lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)] - #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}') - return lean_tensors - - #TODO REVISIT this for stage 3 - def get_lean_optimizer_state(self): - # Return optimizer states after removing paddings. - # This method assumes that each param group contains a single flattened tensor. - optimizer_groups_state = [] - - for i, group in enumerate(self.optimizer.param_groups): - p = group['params'][0] - lean_state = {} - for key, value in self.optimizer.state[p].items(): - if torch.is_tensor(value): - padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]] - lean_state[key] = self._get_lean_tensors( - value, - self.fp16_partitioned_groups[i], - self.groups_padding[i]) - lean_flat_len = sum([t.numel() for t in lean_state[key]]) - else: - lean_state[key] = value - - optimizer_groups_state.append(lean_state) - - return optimizer_groups_state - - def get_groups_without_padding(self, groups_with_padding): - # Return group tensor after removing paddings added for alignment to DP world size. - groups_without_padding = [] - for i, group in enumerate(groups_with_padding): - lean_group = self._get_lean_tensors(group, - self.fp16_partitioned_groups[i], - self.groups_padding[i]) - groups_without_padding.append(lean_group) - - return groups_without_padding - - def _set_fp32_optimizer_param_groups(self): - for sub_group_id, _ in enumerate(self.fp16_groups): - param_group_id = self.sub_group_to_group_id[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'].append( - self.fp32_partitioned_groups_flat[sub_group_id]) - - def _clear_fp32_optimizer_param_groups(self): - for param_group in self.optimizer.param_groups: - param_group['params'] = [] - - def _rigid_state_dict(self): - state_dict = {} - state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS - state_dict['loss_scaler'] = self.loss_scaler - state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale - state_dict['overflow'] = self.overflow - state_dict['partition_count'] = self.partition_count - - self._set_fp32_optimizer_param_groups() - state_dict['optimizer_state_dict'] = self.optimizer.state_dict() - state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat - self._clear_fp32_optimizer_param_groups() - - return state_dict - - def state_dict(self): - """ - Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. - This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict - of the contained Pytorch optimizer. - Example:: - checkpoint = {} - checkpoint['model'] = model.state_dict() - checkpoint['optimizer'] = optimizer.state_dict() - torch.save(checkpoint, "saved.pth") - """ - if self.elastic_checkpoint: - raise NotImplementedError( - "ZeRO-3 does not yet support elastic checkpointing, please disable for now." - ) - - if self.swap_optimizer or self.params_in_nvme_and_cpu: - raise NotImplementedError( - "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." - ) - - return self._rigid_state_dict() - - -# Restore base optimizer fp32 weights from checkpoint by: -# 1) Merging fp32 weights from checkpoints of all partitions -# 2) Extracting fp32 weights for current partition from merged weights -# 3) Using extracted weights to update base optimizer weights directly. - - def _restore_from_fp32_weights(self, all_state_dict): - - flat_local_partition = [] - for i in range(len(self.fp32_partitioned_groups_flat)): - merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict] - flat_local_partition.append(self._get_flattened_partition(merged_partitions)) - - for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition): - current.data.copy_(saved.data) - - # Restore base optimizer fp32 weights from ZeRO fp16 weights - def _restore_from_fp16_weights(self): - for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat): - fp32_partition.data.copy_(fp16_partitions.data) - - # Refresh the fp32 master params from the fp16 copies. - def refresh_fp32_params(self): - self._restore_from_fp16_weights() - - # Extract flattened partition for current rank from all partitions - def _get_flattened_partition(self, all_partition_states): - partition_id = dist.get_rank(group=self.dp_process_group) - alignment = dist.get_world_size(group=self.dp_process_group) - - param_partitions = [[] for _ in range(len(all_partition_states[0]))] - for i, partition in enumerate(all_partition_states): - for j, param in enumerate(partition): - param_partitions[j].append(param) - - local_state_partitions = [] - for param_index, param_slices in enumerate(param_partitions): - flattened_merged_tensor = self.flatten_dense_tensors_aligned( - param_slices, - alignment) - new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor) - local_state_partitions.append(new_partitions[partition_id]) - - if torch.is_tensor(local_state_partitions[0]): - return self.flatten_dense_tensors_aligned(local_state_partitions, alignment) - - # Assume non-tensor states are not partitioned and equal across ranks, so return first one - return local_state_partitions[0] - - # Restore base optimizer state from checkpoint by - # 1) Merging optimizer state from checkpoints of all partitions - # 2) Extracting optimizer state for current partition from the merged state - # 3) Using the extracted value to directly update the base optimizer. - def _restore_base_optimizer_state(self, all_state_dict): - base_optimizer_group_states = [] - for i in range(len(self.optimizer.param_groups)): - partition_states = {} - all_partition_group_states = [ - sd['base_optimizer_state'][i] for sd in all_state_dict - ] - for key in all_partition_group_states[0].keys(): - all_partition_states = [ - all_states[key] for all_states in all_partition_group_states - ] - partition_states[key] = self._get_flattened_partition( - all_partition_states) - base_optimizer_group_states.append(partition_states) - - for i, group in enumerate(self.optimizer.param_groups): - p = group['params'][0] - for key, saved in base_optimizer_group_states[i].items(): - if torch.is_tensor(self.optimizer.state[p][key]): - self.optimizer.state[p][key].data.copy_(saved.data) - else: - self.optimizer.state[p][key] = saved - - def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): - # I think it should actually be ok to reload the optimizer before the model. - self.loss_scaler = state_dict['loss_scaler'] - self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] - self.overflow = state_dict['overflow'] - - if load_optimizer_states: - self._set_fp32_optimizer_param_groups() - self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) - self._clear_fp32_optimizer_param_groups() - - # restore fp32 partitions - for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']): - curr_param.data.copy_(saved_param.data) - - # restore fp16 partitions from fp32 - for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): - fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] - fp16_param.data.copy_(fp32_param.data) - - # update fp16 unflattened params - for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): - updated_params = self.unflatten( - self.fp16_partitioned_groups_flat[sub_group_id], - self.fp16_partitioned_groups[sub_group_id]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): - partitioned_param.data = q.data - - # TODO: Support different/changing load/save DP degree. - def load_state_dict(self, - state_dict_list, - load_optimizer_states=True, - load_from_fp32_weights=False): - r"""Loading a ZeRO checkpoint - Arguments: - state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. - Note that the number of saved partitions may differ from number of loading partitions to support - changing GPU count, specifically DP world size, between saving and loading checkpoints. - load_optimizer_states: Boolean indicating whether or not to load base optimizer states - load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 - copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). - """ - """ - Loads a state_dict created by an earlier call to state_dict(). - If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, - whose parameters in turn came from ``model``, it is expected that the user - will call ``model.load_state_dict()`` before - ``fp16_optimizer_instance.load_state_dict()`` is called. - Example:: - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - checkpoint = torch.load("saved.pth") - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - """ - - if self.elastic_checkpoint: - raise NotImplementedError( - "ZeRO-3 does not yet support elastic checkpointing, please disable for now." - ) - - if self.swap_optimizer or self.params_in_nvme_and_cpu: - raise NotImplementedError( - "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." - ) - - self._rigid_load_state_dict( - state_dict_list[dist.get_rank(group=self.dp_process_group)], - load_optimizer_states=load_optimizer_states) - - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].partition(self.persistent_parameters) - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - def save_checkpoint_prologue(self): - self._partition_all_parameters() - - def save_checkpoint_epilogue(self): - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - -def _handle_overflow(cpu_sum, x, i): - import math - rank = torch.distributed.get_rank() - if rank == 0: - t_i = -1 - for v_i, v in enumerate(x.data.contiguous().view(-1)): - if not math.isfinite(float(v)): - t_i = v_i - break - logger.info( - f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}" - ) - - -def estimate_zero3_model_states_mem_needs(total_params, - largest_layer_params, - num_gpus_per_node=1, - num_nodes=1, - cpu_offload=True, - cpu_offload_params=True, - zero_init=True, - additional_buffer_factor=1.5): - - total_gpus = num_nodes * num_gpus_per_node - gpus_factor = 1 / num_nodes - largest_layer_memory = (4 * largest_layer_params) - - if cpu_offload: - if cpu_offload_params: - gpu_mem = largest_layer_memory - - if zero_init: - cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor - else: - cpu_mem = total_params * max(4 * num_gpus_per_node, - 18 * gpus_factor) * additional_buffer_factor - - else: - gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus) - - if zero_init: - cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor - else: - cpu_mem = total_params * max(4 * num_gpus_per_node, - 16 * gpus_factor) * additional_buffer_factor - else: - gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus) - if zero_init: - cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor - else: - cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor - - return int(cpu_mem), int(gpu_mem), largest_layer_memory - - -def model_to_params(model): - # shared params calculated only once - total_params = sum( - dict((p.data_ptr(), - p.numel()) for p in model.parameters()).values()) - - largest_layer_params = 0 - for m in model.modules(): - # assuming no shared params within a single layer - layer_params = sum(p.numel() for p in m.parameters(recurse=False)) - largest_layer_params = max(largest_layer_params, layer_params) - - return total_params, largest_layer_params - - -import math - - -def estimate_zero3_model_states_mem_needs_all_live(model, - num_gpus_per_node=1, - num_nodes=1, - additional_buffer_factor=1.5): - """ - Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients - for a given ``model`` and hardware setup. - - If you have an actual model object, use this function and everything will be derived - automatically. - - If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass - the ``total_params`` and ``largest_layer_params`` explicitly. - - Args: - - ``model``: ``nn.Module`` object - - ``num_gpus_per_node``: how many gpus per node (defaults to 1) - - ``num_nodes``: how many nodes (defaults to 1), - - ``additional_buffer_factor``: estimation factor (defaults to 1.5): - - """ - - total_params, largest_layer_params = model_to_params(model) - - estimate_zero3_model_states_mem_needs_all_cold( - total_params=total_params, - largest_layer_params=largest_layer_params, - num_gpus_per_node=num_gpus_per_node, - num_nodes=num_nodes, - additional_buffer_factor=additional_buffer_factor) - - -def estimate_zero3_model_states_mem_needs_all_cold(total_params, - largest_layer_params, - num_gpus_per_node=1, - num_nodes=1, - additional_buffer_factor=1.5): - """ - Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients - for a given ``model`` and hardware setup. - - If it's a hypothetical model, use this function where you have to pass - the ``total_params`` and ``largest_layer_params`` explicitly. - - If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything - will be derived automatically. - - Args: - - ``total_params``: total model params - - ``largest_layer_params``: largest layer's params - - ``num_gpus_per_node``: how many gpus per node (defaults to 1) - - ``num_nodes``: how many nodes (defaults to 1), - - ``additional_buffer_factor``: estimation factor (defaults to 1.5): - - """ - def format_options(cpu_offload, cpu_offload_params, zero_init): - enabled = [] - padded_cpu_str = f'{OFFLOAD_CPU_DEVICE:4}' - param_device = padded_cpu_str if cpu_offload_params else "none" - enabled.append(f"{OFFLOAD_PARAM}={param_device}") - optimizer_device = padded_cpu_str if cpu_offload else "none" - enabled.append(f"{OFFLOAD_OPTIMIZER}={optimizer_device}") - enabled.append(f"zero_init={1 if zero_init else 0}") - return ", ".join(enabled) - - nodes_str = "nodes" if num_nodes > 1 else "node" - gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" - print( - "Estimated memory needed for params, optim states and gradients for a:\n" - f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" - f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params." - ) - print(" per CPU | per GPU | Options") - for cpu_offload in [True, False]: - for cpu_offload_params in [True, False]: - if not cpu_offload and cpu_offload_params: - continue - for zero_init in [True, False]: - cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs( - total_params=total_params, - largest_layer_params=largest_layer_params, - num_gpus_per_node=num_gpus_per_node, - num_nodes=num_nodes, - cpu_offload=cpu_offload, - cpu_offload_params=cpu_offload_params, - zero_init=zero_init, - additional_buffer_factor=additional_buffer_factor - ) - - options_str = format_options(cpu_offload=cpu_offload, - cpu_offload_params=cpu_offload_params, - zero_init=zero_init) - print( - f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}") +""" +"Copyright 2020 The Microsoft DeepSpeed Team. +Licensed under the MIT license. +""" + +import sys +import os +from collections import defaultdict, OrderedDict +import itertools +import torch +from torch.distributed.distributed_c10d import _get_global_rank +import torch.distributed as dist +import math +from torch._six import inf +from torch.autograd import Variable + +from deepspeed.utils.logging import logger +from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler +from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim +from deepspeed.runtime.zero.partition_parameters import * +from deepspeed.runtime.zero.partition_parameters import _init_external_params +from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS +from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.op_builder import UtilsBuilder +from deepspeed.runtime.zero.offload_constants import * +from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus +from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper +from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper + +# Toggle this to true to enable correctness test +# with gradient partitioning and without +pg_correctness_test = False + +FWD_MODULE_STACK = list() +from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id, debug_param2name_id_numel, debug_param2name_id_shape_device, debug_module2name_class, printflock, log_rank_file + + +def print_rank_0(message, debug=False, force=False): + rank = torch.distributed.get_rank() + if rank == 0 and (debug or force): + print(message) + # other variations + # - print for all ranks w/o interleaving + # printflock(f"[{rank}] {message}") + # - print to log file per rank + # log_rank_file(rank, message) + + +def input(msg): + return + + +def split_half_float_double(tensors): + dtypes = [ + "torch.cuda.HalfTensor", + "torch.cuda.FloatTensor", + "torch.cuda.DoubleTensor" + ] + buckets = [] + for i, dtype in enumerate(dtypes): + bucket = [t for t in tensors if t.type() == dtype] + if bucket: + buckets.append(bucket) + return buckets + + +def isclose(a, b, rtol=1e-09, atol=0.0): + return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) + + +def lcm(x, y): + from fractions import gcd # or can import gcd from `math` in Python 3 + return x * y // gcd(x, y) + + +def move_to_cpu(tensor_list): + for tensor in tensor_list: + tensor.data = tensor.data.cpu() + + +def get_all_parameters(sub_module, recurse=False): + return itertools.chain(sub_module.named_parameters(recurse=recurse), + sub_module.ds_external_parameters()) + + +#apply torch.autograd.Function that calls a backward_function to tensors in output +def _apply_to_tensors_only(module, functional, backward_function, outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_to_tensors_only(module, + functional, + backward_function, + output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + return functional.apply(module, backward_function, outputs) + else: + return outputs + + +#for each tensor in outputs run the forward_function and register backward_function as hook +def _apply_forward_and_backward_to_tensors_only(module, + forward_function, + backward_function, + outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_forward_and_backward_to_tensors_only( + module, + forward_function, + backward_function, + output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + forward_function(outputs) + if outputs.requires_grad: + outputs.register_hook(backward_function) + return outputs + else: + return outputs + + +class ZeROOrderedDict(OrderedDict): + def __init__(self, parent_module, *args, **kwargs): + """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. + + Args: + parent_module (``collections.OrderedDict``): the collection to replace + """ + + super().__init__(*args, **kwargs) + self._parent_module = parent_module + self._in_forward = False + + def __getitem__(self, key): + param = super().__getitem__(key) + + # Params can be registered as None (e.g., bias) + if param is None: + return param + + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if self._parent_module._parameters._in_forward: + print_rank_0(f'Registering external parameter from getter {key}', + force=False) + register_external_parameter(FWD_MODULE_STACK[-1], param) + param.all_gather() + + return param + + +def _inject_parameters(module, cls): + for module in module.modules(): + if cls == ZeROOrderedDict: + new_param = cls(parent_module=module) + else: + new_param = cls() + + for key, param in module._parameters.items(): + new_param[key] = param + module._parameters = new_param + + +# TODO Needs to be implemented +class PrefetchCoordinator(object): + def __init__(self): + # step_id keeps track of the number of sub-modules invoked so far + # the step_id is tracking forward and backward sequence of sub-modules + self.step_id = 0 + + # stores the sequence of sub modules in forward+backward pass + self.sub_module_trace = [] + + # maps sub_module id to submodule objects + self.id_to_sub_module_map = {} + + # stores the total number of parameters in each sub_module + self.id_to_sub_module_size_map = {} + + self.trace_completed = False + + self.most_recent_sub_module_step = {} + + # reuse distances + self.reuse_numel_for_step_id = {} + + def record_trace(self, sub_module): + if not self.trace_completed: + self.sub_module_trace.append(sub_module.id) + self.id_to_sub_module_map[sub_module.id] = sub_module + + def print_trace(self): + print_rank_0( + f"The module trace is : {[self.id_to_sub_module_map[module_id].id for module_id in self.sub_module_trace]}" + ) + + def increment_step(self, sub_module): + self.most_recent_sub_module_step[sub_module.id] = self.step_id + self.step_id += 1 + + def reset_step(self): + self.step_id = 0 + + # returns the next numel parameters that will be used next but are not available or inflight + def get_params_to_prefetch(self, sub_module, numel=2000000): + + # numel_in_sub_module = 0 + # for name, param in sub_module.named_parameters(recurse=False): + # numel_in_sub_module += param.ds_numel + + # #if numel_in_sub_module < (numel // 2): + # return [] + + # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing + if sub_module.id != self.sub_module_trace[self.step_id]: + print_rank_0( + f"Tracing failed. Prefetching is disabled at sub-module: {debug_module2name_id(sub_module)}" + ) + return [] + + params_to_prefetch = [] + total_numel_to_prefetch = 0 + + for i in range(self.step_id, len(self.sub_module_trace)): + module_id = self.sub_module_trace[i] + for _, param in get_all_parameters(self.id_to_sub_module_map[module_id]): + if param.ds_status is ZeroParamStatus.NOT_AVAILABLE and ( + param.ds_id not in [p.ds_id for p in params_to_prefetch]): + params_to_prefetch.append(param) + total_numel_to_prefetch += param.ds_numel + #print_rank_0(f"Total numel to prefetch: {total_numel_to_prefetch}. Param: {param.ds_shape} and numel {param.ds_numel}, numel limit {numel}") + if total_numel_to_prefetch >= numel: # and total_numel_to_prefetch > (numel_in_sub_module // 2): + return params_to_prefetch + + return params_to_prefetch + + # checks if this sub_module will be used again and if so then returns the number of elements + # in the parameters used between this sub_module and the reuse of this sub_module + def get_reuse_distance_in_numel(self, sub_module, sub_module_step_id=None): + #assert is_forward is not None, "is_forward must be set to True for Forward Propagation and False for backward Propagation" + is_there_reuse = False + reuse_distance_in_numel = 1000000000000 + + # set the appropriate trace + trace = self.sub_module_trace + total_steps = len(trace) + if sub_module_step_id is None: + sub_module_step_id = self.most_recent_sub_module_step[sub_module.id] + + # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing + if sub_module.id != trace[sub_module_step_id]: + print_rank_0( + f"Tracing failed. Cannot tell if the sub_module: {sub_module.id} is reused" + ) + return reuse_distance_in_numel + + # return cached value + if sub_module_step_id in self.reuse_numel_for_step_id: + return self.reuse_numel_for_step_id[sub_module_step_id] + + start_step = self.step_id + print_rank_0(f"Step id is {self.step_id} ") + for step_id in range(start_step, total_steps): + print_rank_0(f"Trace id {trace[step_id]} and sub_module id {sub_module.id}") + if sub_module.id == trace[step_id]: + end_step = step_id + + is_there_reuse = True + reuse_distance_in_numel = self._distance_in_numel( + start_step, + end_step, + trace) + break + + self.reuse_numel_for_step_id[sub_module_step_id] = reuse_distance_in_numel + + return reuse_distance_in_numel + + def _distance_in_numel(self, start_step, end_step, trace): + distance_in_numel = 0 + for step_id in range(start_step, end_step): + module_id = trace[step_id] + for _, param in self.id_to_sub_module_map[module_id].named_parameters(recurse=False): + distance_in_numel += param.ds_numel + for _, param in self.id_to_sub_module_map[module_id].ds_external_parameters(): + distance_in_numel += param.ds_numel + return distance_in_numel + + +class PartitionedParameterCoordinator(object): + def __init__(self, + comm_stream=None, + max_reuse_distance_in_numel=500000000, + max_available_parameters_in_numel=700000000): + + self.in_flight_handles = [] + self.params_in_flight = [] + self.comm_stream = comm_stream if comm_stream is not None else torch.cuda.current_stream( + ) + self.prefetch_coordinator = PrefetchCoordinator() + self.hierarchy = 0 + + self.total_available_parameter_numel = 0 + self.max_available_parameters_in_numel = max_available_parameters_in_numel + + # max distance between two use of the module beyond which module is released + self.max_reuse_distance_in_numel = max_reuse_distance_in_numel + + def _increment_available_parameter_numel(self, increment): + self.total_available_parameter_numel += increment + + def _decrement_available_parameter_numel(self, decrement): + self.total_available_parameter_numel -= decrement + + '''-----------------------Tracing and Prefetching ---------------''' + + def record_trace(self, sub_module): + self.prefetch_coordinator.record_trace(sub_module) + + def finish_tracing(self, print_trace=False): + self.prefetch_coordinator.trace_completed = True + + if print_trace: + self.prefetch_coordinator.print_trace() + + #swap in parameter partitions from nvme for those parameters that will be used + # after the ones that are already being prefetched into full parameters + def _prefetch_nvme_param_partitions(self, sub_module, params_in_flight): + numel_in_flight = sum([param.ds_tensor.ds_numel for param in params_in_flight]) + upcoming_param_list = self.prefetch_coordinator.get_params_to_prefetch( + sub_module, + numel=2 * numel_in_flight) + swap_in_params = [] + for param in upcoming_param_list: + if len(swap_in_params) >= param.nvme_swapper.available_swap_in_buffers(): + break + if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: + swap_in_params.append(param) + + if len(swap_in_params) > 0: + swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) + + # Pre fetches the parameters for sub_modules that comes after + # the current sub_module. This call is asynchronous + def prefetch_next_sub_modules(self, sub_module, numel=5000000, nvme=False): + + params_to_prefetch = [] + if not self.prefetch_coordinator.trace_completed: + return params_to_prefetch + + # prefetch if there is no current prefetching in flight + if not self.in_flight_handles and self.total_available_parameter_numel < self.max_available_parameters_in_numel: + params_to_prefetch = self.prefetch_coordinator.get_params_to_prefetch( + sub_module, + numel=numel) + + self._all_gather(params_to_prefetch, async_op=True) + for param in params_to_prefetch: + param.ds_status = ZeroParamStatus.INFLIGHT + + # keeping track of number of elements consumed by available parameters + self._increment_available_parameter_numel(param.ds_numel) + + if nvme: + self._prefetch_nvme_param_partitions(sub_module, params_to_prefetch) + + self._print_prefetch_elements_info(sub_module, params_to_prefetch) + print_rank_0( + f"{'--' * self.hierarchy}--PreFetching parameters {[param.ds_id for param in params_to_prefetch]} and available {self.total_available_parameter_numel}, max limit {self.max_available_parameters_in_numel}", + force=False) + + def _print_prefetch_elements_info(self, sub_module, params_to_prefetch): + sub_module_numel = 0.0 + for name, param in sub_module.named_parameters(recurse=False): + sub_module_numel += param.ds_numel + numel_being_prefetched = 0 + for param in params_to_prefetch: + numel_being_prefetched = param.ds_numel + print_rank_0( + f"{'--' * self.hierarchy}--PreFetching {numel_being_prefetched} numels and number of numel in the next sub module is {sub_module_numel}", + force=False) + + def increment_step(self, sub_module): + self.prefetch_coordinator.increment_step(sub_module) + + def reset_step(self): + self.prefetch_coordinator.reset_step() + + '''----------------------------------------------------------------------''' + + # Fetches the parameters in the sub_module + # This call is blocking + def fetch_sub_module(self, sub_module): + partitioned_params = [] + params_in_flight = False + print_rank_0( + f"{'--' * self.hierarchy}Fetching params in module {debug_module2name_class(sub_module)}" + ) + params_to_fetch = [ + param for _, + param in sub_module.named_parameters(recurse=False) + ] + # print([n for n,p in sub_module.named_parameters(recurse=False)]) + + if hasattr(sub_module, 'ds_external_parameters'): + print_rank_0( + f"{'--' * self.hierarchy}--Fetching external parameters {sub_module.ds_external_parameters()}" + ) + params_to_fetch += [ + param for _, + param in sub_module.ds_external_parameters() + ] + # for _, param in sub_module.named_parameters(recurse=False): + for param in params_to_fetch: + param.ds_active_sub_modules += 1 + print_rank_0( + f"{'--' * self.hierarchy}--Fetching parameters {debug_param2name_id_shape(param)} with active sub modules {param.ds_active_sub_modules}" + ) + + if param.ds_status == ZeroParamStatus.AVAILABLE: + print_rank_0( + f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is already available" + ) + + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + print_rank_0( + f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is being fetched" + ) + partitioned_params.append(param) + + # keeping track of number of elements consumed by available parameters + self._increment_available_parameter_numel(param.ds_numel) + print_rank_0(f"Incrementing with parameter id {param.ds_id}") + + if param.ds_status == ZeroParamStatus.INFLIGHT: + params_in_flight = True + print_rank_0( + f"{'--' * self.hierarchy}--Parameters {debug_param2name_id(param)} is already in flight (prefetched)" + ) + self.hierarchy += 1 + + # parameters are partitioned and need to be allgathered + self._all_gather(partitioned_params, async_op=False) + + # parameters are inflight and communication needs to be completed + if partitioned_params or params_in_flight: + self._synchronize_communication() + + for _, param in sub_module.named_parameters(recurse=False): + param.ds_status = ZeroParamStatus.AVAILABLE + print_rank_0( + f"Param {debug_param2name_id_shape_device(param)} norm={param.norm()}", + force=False) + #print_rank_0(f"After fetching (id, shape, device): {[(param.ds_id, param.shape, param.device) for param in sub_module.named_parameters(recurse=False)]}") + + def release_sub_module(self, sub_module): + self.hierarchy -= 1 + print_rank_0( + f"{'--' * self.hierarchy}Releasing params in module {debug_module2name_class(sub_module)}" + ) + params_to_release = [ + param for _, + param in sub_module.named_parameters(recurse=False) + ] + + if hasattr(sub_module, 'ds_external_parameters'): + #print_rank_0(f"Releasing external parameters {sub_module.ds_external_parameters()}") + params_to_release += [ + param for _, + param in sub_module.ds_external_parameters() + ] + + # for _, param in sub_module.named_parameters(recurse=False): + for param in params_to_release: + param.ds_active_sub_modules -= 1 + if not param.ds_active_sub_modules and not self._keep_for_later( + sub_module) and not param.ds_persist: + print_rank_0( + f"{'--' * self.hierarchy}--Releasing parameter {debug_param2name_id_numel(param)} active sub modules {param.ds_active_sub_modules} and keep for later {self._keep_for_later(sub_module)}", + force=False) + + # Keeping track of number of elements that are consumed by available parameters + self._decrement_available_parameter_numel(param.ds_numel) + see_memory_usage( + f"Before releasing param {debug_param2name_id_numel(param)}", + force=False) + param.partition(hierarchy=self.hierarchy) + see_memory_usage( + f"After releasing param {debug_param2name_id_numel(param)}", + force=False) + + param.ds_status = ZeroParamStatus.NOT_AVAILABLE + else: + + print_rank_0( + f"{'--' * self.hierarchy}--Did not release param {debug_param2name_id_numel(param)} with active sub modules {param.ds_active_sub_modules}, keep for later={self._keep_for_later(sub_module)} and persistence={param.ds_persist}", + force=False) + + def release_and_reset_parameter(self, param): + param.ds_active_sub_modules = 0 + if param.ds_status == ZeroParamStatus.AVAILABLE: + print_rank_0( + f"Releasing unpartitioned param {debug_param2name_id_numel(param)} active sub-modules {param.ds_active_sub_modules} and persistence {param.ds_persist}" + ) + self._decrement_available_parameter_numel(param.ds_numel) + param.partition() + + def _keep_for_later(self, sub_module): + if not self.prefetch_coordinator.trace_completed: + return False + if self.max_reuse_distance_in_numel == 0: + return False + reuse_distance_in_numel = self.prefetch_coordinator.get_reuse_distance_in_numel( + sub_module) + #print_rank_0(f"Reuse distance and numel for sub_module id {sub_module.id} is {reuse_distance_in_numel}") + return reuse_distance_in_numel < self.max_reuse_distance_in_numel + + def _all_gather(self, partitioned_params, async_op=False): + with torch.cuda.stream(self.comm_stream): + handles = partitioned_params[0].all_gather( + param_list=partitioned_params, + async_op=async_op, + hierarchy=self.hierarchy) if partitioned_params else None + + if handles is not None: + self.in_flight_handles.extend(handles) + self.params_in_flight.extend(partitioned_params) + + def _synchronize_communication(self, synchronize_streams=True): + assert len(self.params_in_flight) == len(self.in_flight_handles) + for handle, param in zip(self.in_flight_handles, self.params_in_flight): + if handle is not None: + with torch.cuda.stream(self.comm_stream): + handle.wait() + param.ds_status = ZeroParamStatus.AVAILABLE + self.comm_stream.synchronize() + torch.cuda.synchronize() if synchronize_streams else None + self.in_flight_handles = [] + self.params_in_flight = [] + + +class PreBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, outputs): + ctx.module = module + ctx.pre_backward_function = pre_backward_function + if not hasattr(module, "applied_pre_backward_ref_cnt"): + module.applied_pre_backward_ref_cnt = 0 + module.applied_pre_backward_ref_cnt += 1 + #print(f"After Forward: {ctx.module.__class__.__name__}") + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + #print(f"Before Backward: {ctx.module.__class__.__name__}") + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +class PostBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, output): + ctx.module = module + if output.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.pre_backward_function = pre_backward_function + output = output.detach() + return output + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.pre_backward_function(ctx.module) + #print(f"After Backward: {ctx.module.__class__.__name__}") + return (None, None) + args + + +INITIAL_MICRO_STEP_ID = -1 + + +class FP16_DeepSpeedZeroOptimizer_Stage3(object): + """ + DeepSpeedZeroOptimizer designed to reduce the memory footprint + required for training large deep learning models. + + For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models + https://arxiv.org/abs/1910.02054 + + For usage examples, refer to TODO: DeepSpeed Tutorial + + """ + def __init__(self, + module, + init_optimizer, + timers, + ds_config, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=True, + contiguous_gradients=True, + reduce_bucket_size=500000000, + prefetch_bucket_size=50000000, + max_reuse_distance=1000000000, + max_live_parameters=1000000000, + param_persistence_threshold=100000, + dp_process_group=None, + reduce_scatter=True, + overlap_comm=False, + offload_optimizer_config=None, + offload_param_config=None, + sub_group_size=1000000000000, + mpu=None, + clip_grad=0.0, + communication_data_type=torch.float16, + postscale_gradients=True, + gradient_predivide_factor=1.0, + gradient_accumulation_steps=1, + elastic_checkpoint=False, + aio_config=None): + + see_memory_usage("Stage 3 initialize beginning", force=False) + + if dist.get_rank() == 0: + logger.info(f"Reduce bucket size {reduce_bucket_size}") + logger.info(f"Allgather bucket size {prefetch_bucket_size}") + # The fused optimizer does all the work. We need this layer for two reason: + # 1. maintain same user API from apex.fp16_utils + # 2. keep common stuff here in case we need to add ne552w fused optimizer later + + # differences from apex.fp16_utils: + # - assume all model params in fp16 + # - assume all params requires grad + # - flat by groups, not keeping state. TODO: remove state explicitly? + # - master gard and unflat master weight never exist. TODO: a way to save out unflat master? + if not torch.cuda.is_available: + raise SystemError("Cannot use fp16 without CUDA.") + self.optimizer = init_optimizer + + # Load pre-built or JIT compile (un)flatten ops + util_ops = UtilsBuilder().load() + self.flatten = util_ops.flatten + self.unflatten = util_ops.unflatten + self.dtype = self.optimizer.param_groups[0]['params'][0].dtype + self._global_grad_norm = 0. + + self.optimizer_swapper = None + self.swap_optimizer = False + + self.offload_optimizer = False + self.offload_optimizer_pin_memory = False + self.offload_optimizer_fast_init = False + self.offload_param = False + self.offload_param_pin_memory = False + self.params_in_nvme_and_cpu = False + self.max_params_in_cpu = 0 + + self._configure_offloading(offload_optimizer_config, offload_param_config) + + self._convert_to_zero_parameters(ds_config, module, mpu) + + for m in module.modules(): + _init_external_params(m) + + self.module = module + self.elastic_checkpoint = elastic_checkpoint + self.overlap_comm = overlap_comm + + # Replace ._parameters with a new class to enable auto-registration of + # external parameters + _inject_parameters(module, ZeROOrderedDict) + + if self.overlap_comm: + self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda() + + self.deepspeed_adam_offload = (self.offload_optimizer + and type(init_optimizer) == DeepSpeedCPUAdam) + + self.device = torch.cuda.current_device( + ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE + ############################################################################ + + see_memory_usage("Before Partitioned Parameter Coordinator", force=False) + + fetch_stream = torch.cuda.Stream() if self.overlap_comm else None + self.param_coordinator = PartitionedParameterCoordinator( + comm_stream=fetch_stream, + max_reuse_distance_in_numel=int(max_reuse_distance), + max_available_parameters_in_numel=int(max_live_parameters)) + + see_memory_usage("After Partitioned Parameter Coordinator", force=False) + + #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) + #-------------Stage 3 Setup-------------------# + # parameters smaller than the threshold will be collectively gathered at the + # end of the optimizer step and will be kept till the end of the backward pass + # TODO maybe worth just replicating these parameters and doing all reduce for them + self.persistence_threshold = int(param_persistence_threshold) + + self.persistent_parameters = self.persistent_parameters() + + self.setup_zero_stage3_hooks() + + #resetting ds_tensor just in case parameters have been changed after initialization + #example .half() or .to() + #self.reset_ds_tensor() + #---------------------------------------------# + + self.timers = timers + + self.reduce_scatter = reduce_scatter + + self.dp_process_group = dp_process_group + + self.partition_count = dist.get_world_size(group=self.dp_process_group) + + if mpu is None: + self.model_parallel_group = None + self.model_parallel_rank = 0 + else: + self.model_parallel_group = mpu.get_model_parallel_group() + self.model_parallel_rank = mpu.get_model_parallel_rank() + + self.overflow = False + self.clip_grad = clip_grad + self.communication_data_type = communication_data_type + self.gradient_predivide_factor = gradient_predivide_factor + self.postscale_gradients = postscale_gradients + self.gradient_accumulation_steps = gradient_accumulation_steps + self.micro_step_id = INITIAL_MICRO_STEP_ID + + if self.reduce_scatter: + assert self.communication_data_type in (torch.float16, torch.bfloat16), f"ZeRO-3 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" + assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled" + assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled" + + # Holds the mode parameter + # The param.data may not hold any meaningful data + # when param's status is NOT_AVAILABLE or IN_FLGHT + self.fp16_groups = [] + + # Hold partitioned parameters + self.fp16_partitioned_groups = [] + + # Holds a fused and flattened copy of the parameters + self.fp16_partitioned_groups_flat = [] + self.fp16_partitioned_groups_flat_numel = [] + + #defragmented pinned memory + self.param_groups_fp16_flat_cpu_memory = [] + + #a single 32-bit partition of the parallel partitioned parameters + #that this process will update + self.fp32_partitioned_groups_flat = [] + self.next_swappable_fp32_partitioned_groups = [] + + # number of elements per partition in each group + self.partition_size = [] + + self.all_reduce_print = False + + self.prefetch_elements = int(prefetch_bucket_size) + + # padding on each partition for alignment purposes + self.groups_padding = [] + + self.sub_group_size = sub_group_size + + self.sub_group_to_group_id = {} + see_memory_usage("Before creating fp16 partitions", force=False) + self._create_fp16_partitions_with_defragmentation() + num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) + see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", + force=False) + + # Optimizer tensor swapping + if self.swap_optimizer: + self._configure_tensor_swapping(offload_optimizer_config, aio_config) + + see_memory_usage("Before creating fp32 partitions", force=False) + if not isinstance(self.optimizer, DummyOptim): + self._create_fp32_partitions() + see_memory_usage("After creating fp32 partitions", force=False) + dist.barrier() + + # To support pipelined optimizer swapping + if not isinstance(init_optimizer, DummyOptim): + self._create_next_swappable_fp32_groups() + + see_memory_usage("Before initializing optimizer states", force=False) + if not isinstance(init_optimizer, DummyOptim): + self.initialize_optimizer_states() + see_memory_usage("After initializing optimizer states", force=False) + dist.barrier() + + if dist.get_rank() == 0: + logger.info(f"optimizer state initialized") + + self.reduce_bucket_size = int(reduce_bucket_size) + + self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) + + self.reduction_stream = torch.cuda.Stream( + ) if self.overlap_comm else torch.cuda.current_stream() + self.callback_queued = False + self.copy_grad_stream = torch.cuda.Stream() + + self.param_dict = {} + + # map between param_id and bool to specify if a param is in this partition + self.is_param_in_current_partition = {} + + self.contiguous_gradients = contiguous_gradients + self.extra_large_param_to_reduce = None + self.grads_in_ipg_bucket = [] + self.params_in_ipg_bucket = [] + self.elements_in_ipg_bucket = 0 + self.params_already_reduced = [] + self.is_gradient_accumulation_boundary = True + self._release_ipg_buffers() + self.previous_reduced_grads = None + + # simplified param id + self.param_id = {} + + count = 0 + for i, params_group in enumerate(self.fp16_groups): + for param in params_group: + unique_id = id(param) + self.param_id[unique_id] = count + self.param_dict[count] = param + self.params_already_reduced.append(False) + count = count + 1 + + #Largest partitioned param + largest_partitioned_param_numel = max([ + max([tensor.numel() for tensor in fp16_partitioned_group]) + for fp16_partitioned_group in self.fp16_partitioned_groups + ]) + print_rank_0( + f'Largest partitioned param numel = {largest_partitioned_param_numel}', + force=False) + + see_memory_usage(f"Before Set Grad positions", force=False) + + self.grad_position = {} + self.set_grad_positions() + see_memory_usage(f"Before CPU Offload initialization", force=False) + + self.grads_in_partition = None + + if self.offload_optimizer: + self.accumulated_grads_in_cpu = {} + self.norm_for_param_grads = {} + self.local_overflow = False + self.temp_grad_buffer_for_gpu_offload = torch.zeros( + largest_partitioned_param_numel, + device=torch.cuda.current_device(), + dtype=self.dtype) + self.temp_grad_gpu_buffer = torch.zeros(largest_partitioned_param_numel, + device=torch.cuda.current_device(), + dtype=self.dtype) + see_memory_usage(f"After CPU Offload initialization", force=False) + + # stores if a partition has been reduced in this step + self.is_partition_reduced = {} + + # stores if a grad in a partition has been computed or not + self.is_grad_computed = {} + + # will store the averaged gradients required by this paritition + self.averaged_gradients = {} + + #creates backward hooks for gradient partitioning + self.create_reduce_and_remove_grad_hooks() + + #exit(0) + + # we may have a way of fusing dynamic scale. Do not support for now + if self.dtype == torch.float or not dynamic_loss_scale: + loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale + + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(scale=loss_scale_value) + cur_iter = 0 + else: + if dynamic_loss_args is None: + self.loss_scaler = DynamicLossScaler() + else: + self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) + + self.dynamic_loss_scale = True + + self.debug_fp16_grads = [{} for _ in self.fp16_groups] + + if dist.get_rank(group=self.dp_process_group) == 0: + see_memory_usage(f"After initializing ZeRO optimizer", force=False) + + def _configure_offloading(self, offload_optimizer_config, offload_param_config): + ###################### offload optimizer setup ################################## + if offload_optimizer_config is not None: + self.offload_optimizer = True + self.offload_optimizer_pin_memory = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_PIN_MEMORY] + self.swap_optimizer = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE + self.offload_optimizer_fast_init = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_FAST_INIT] + + ###################### offload param setup ################################## + if offload_param_config is not None: + if not isinstance(self.optimizer, DummyOptim): + assert self.offload_optimizer, "parameter offload is only available with optimizer state offload" + self.offload_param = True + self.offload_param_pin_memory = offload_param_config[ + OFFLOAD_PARAM_PIN_MEMORY] + self.params_in_nvme_and_cpu = offload_param_config[ + OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE + self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU] + print_rank_0( + f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}", + force=False) + + def _convert_to_zero_parameters(self, ds_config, module, mpu): + non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] + if non_zero_params: + zero_params = [p for p in module.parameters() if is_zero_param(p)] + if zero_params: + zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) + else: + group = None + if mpu: + group = mpu.get_data_parallel_group() + + if self.params_in_nvme_and_cpu: + remote_device = OFFLOAD_NVME_DEVICE + elif self.offload_param: + remote_device = OFFLOAD_CPU_DEVICE + else: + remote_device = None + + Init(module=module, + data_parallel_group=group, + dtype=self.dtype, + config_dict_or_path=ds_config, + remote_device=remote_device, + pin_memory=self.offload_param_pin_memory, + mpu=mpu) + + def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): + nvme_swap_folder = os.path.join( + offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH], + 'zero_stage_3') + os.makedirs(nvme_swap_folder, exist_ok=True) + if torch.distributed.get_rank() == 0: + logger.info(f'Tensor Swapping: Adding optimizer tensors') + + swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[ + OFFLOAD_OPTIMIZER_PIPELINE] else PartitionedOptimizerSwapper + + self.optimizer_swapper = swapper_type( + swap_config=offload_optimizer_config, + aio_config=aio_config, + base_folder=nvme_swap_folder, + optimizer=self.optimizer, + largest_numel=max(self.fp16_partitioned_groups_flat_numel), + device=self.device, + dtype=torch.float32, + timers=self.timers) + + def _create_fp16_partitions(self): + dist.barrier() + partition_id = dist.get_rank(group=self.dp_process_group) + + # loop to deal with groups + for j, param_group in enumerate(self.optimizer.param_groups): + + sub_groups = self._create_fp16_sub_groups(param_group['params']) + for sub_group in sub_groups: + i = len(self.fp16_groups) + + # push this group to list before modify + self.fp16_groups.append(sub_group) + self.sub_group_to_group_id[i] = j + + #These are the list of the partitioned parameters + self.fp16_partitioned_groups.append( + [param.ds_tensor for param in self.fp16_groups[i]]) + + print_rank_0( + f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" + ) + + # Record padding required to align group to world size (only applies to last rank) + if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: + padding = [p.padding_size() for p in self.fp16_groups[i]] + else: + padding = [0] * len(self.fp16_groups[i]) + self.groups_padding.append(padding) + + #not sure why apex was cloning the weights before flattening + #removing cloning here + see_memory_usage(f"Before Flattening param group {i}", force=False) + + if not self.offload_param: + see_memory_usage(f"Before moving param group {i} to CPU", + force=False) + #move all the parameters to cpu to free up GPU space for creating flat buffer + move_to_cpu(self.fp16_partitioned_groups[i]) + see_memory_usage(f"After moving param group {i} to CPU", force=False) + + #create flat buffer in CPU and move to GPU + self.fp16_partitioned_groups_flat.append( + self.flatten_dense_tensors_aligned( + self.fp16_partitioned_groups[i], + dist.get_world_size(group=self.dp_process_group)).cuda( + torch.cuda.current_device())) + see_memory_usage( + f"After flattening and moving param group {i} to GPU", + force=False) + else: + #Without the detach, seems like the flattening becomes part of the + #model graph causing errors downstream + self.fp16_partitioned_groups_flat.append( + self.flatten_dense_tensors_aligned( + self.fp16_partitioned_groups[i], + dist.get_world_size( + group=self.dp_process_group)).detach().pin_memory()) + + see_memory_usage(f"After Flattening param group {i}", force=False) + + see_memory_usage(f"After Flattening param group {i}", force=False) + + #set model fp16 weight to slices of flattened buffer + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i], + self.fp16_partitioned_groups[i]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params): + partitioned_param.data = q.data + + def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): + '''If flat buffer is None then the parameters in the param_list are + not copied to the flat buffer. This is because they excede the number of max_params_in_cpu + Some of these parameters may aready be in CPU in unflattened buffers + or they maybe in GPU, or they maybe in NVME. If they are in NVME, then + they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are + needed during training.''' + if flat_buffer is None: + # this dst buffer is on NVMe, so skip this + return + + start = 0 + for param in param_list: + src = param.ds_tensor + dest = flat_buffer.narrow(0, start, src.ds_numel) + start = start + src.ds_numel + '''if the parameter was initialized in nvme then bring it to the destination buffer directly''' + if src.status == PartitionedParamStatus.NOT_AVAILABLE: + print_rank_0( + f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU" + ) + param.nvme_swapper.swap_into_buffer(param, dest) + src.data = dest.data + src.status = PartitionedParamStatus.AVAILABLE + else: + assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here" + if not avoid_copy: + dest.data.copy_(src.data) + src.data = dest.data + + # Final location must be gpu/cpu in this case + param.ds_tensor.final_location = 'not-nvme' + + def _create_param_groups_fp16_flat_cpu_memory(self): + + aggregate_params_count = 0 + + for j, param_group in enumerate(self.optimizer.param_groups): + params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']]) + + flat_buffer_size = params_in_group + + if self.params_in_nvme_and_cpu and \ + aggregate_params_count + params_in_group > self.max_params_in_cpu: + + flat_buffer_size = max(0, + self.max_params_in_cpu - aggregate_params_count) + + aggregate_params_count += params_in_group + + if flat_buffer_size > 0: + print_rank_0(f"group {j} flat buffer size {flat_buffer_size}", + force=False) + self.param_groups_fp16_flat_cpu_memory.append( + torch.empty(int(flat_buffer_size), + dtype=self.dtype, + pin_memory=True)) + else: + print_rank_0( + f"No flat buffer size. Param group size was {params_in_group}", + force=False) + + self.param_groups_fp16_flat_cpu_memory.append( + torch.empty(1, + dtype=self.dtype)) + + def _create_fp16_partitions_with_defragmentation(self): + dist.barrier() + partition_id = dist.get_rank(group=self.dp_process_group) + create_fp16_flat_reuse_buffer = False + largest_partition_numel = [] + max_partition_numel = 0 + + #create a flat CPU memory allocation for each param group + if self.offload_param: + self._create_param_groups_fp16_flat_cpu_memory() + + # loop to deal with groups + for j, param_group in enumerate(self.optimizer.param_groups): + + sub_groups = self._create_fp16_sub_groups(param_group['params']) + print_rank_0(f'fp16 group {j} has {len(sub_groups)} subgroups', force=False) + + flat_offset = 0 + for sub_group in sub_groups: + i = len(self.fp16_groups) + + # push this group to list before modify + self.fp16_groups.append(sub_group) + self.sub_group_to_group_id[i] = j + + # comment out for zero_to_fp32 debug + # if torch.distributed.get_rank() == 0: + # for param in self.fp16_groups[i]: + # print(f"{debug_param2name_id_shape(param)} {param.ds_shape}") + + #These are the list of the partitioned parameters + self.fp16_partitioned_groups.append( + [param.ds_tensor for param in self.fp16_groups[i]]) + + total_elements = sum( + [t.ds_numel for t in self.fp16_partitioned_groups[i]]) + self.fp16_partitioned_groups_flat_numel.append(total_elements) + + if total_elements > max_partition_numel: + largest_partition_numel = [ + t.ds_numel for t in self.fp16_partitioned_groups[i] + ] + max_partition_numel = total_elements + + print_rank_0( + f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" + ) + + # Record padding required to align group to world size (only applies to last rank) + if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: + padding = [p.padding_size() for p in self.fp16_groups[i]] + else: + padding = [0] * len(self.fp16_groups[i]) + self.groups_padding.append(padding) + + #not sure why apex was cloning the weights before flattening + #removing cloning here + see_memory_usage(f"Before Flattening param subgroup {i}", force=False) + + #all partitioned parameters remain in GPU during training + if not self.offload_param: + see_memory_usage(f"Before moving param subgroup group {i} to CPU", + force=False) + #move all the parameters to cpu to free up GPU space for creating flat buffer + move_to_cpu(self.fp16_partitioned_groups[i]) + see_memory_usage(f"After moving param subgroup {i} to CPU", + force=False) + + #create flat buffer in CPU and move to GPU + self.fp16_partitioned_groups_flat.append( + self.flatten_dense_tensors_aligned( + self.fp16_partitioned_groups[i], + 1).cuda(torch.cuda.current_device())) + see_memory_usage( + f"After flattening and moving param subgroup {i} to GPU", + force=False) + + #all partitioned parameters are in CPU during training + else: + print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") + #Flat buffer may not be available for parameters that reside in NVME + if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ + j].numel(): + fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[ + j].narrow(0, + flat_offset, + total_elements) + print_rank_0( + f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", + force=False) + #these parameters reside in NVME and + elif self.params_in_nvme_and_cpu: + fp16_partitioned_group_flat = None + print_rank_0( + f"No flat buffer for sub group {i} of {total_elements} elements", + force=False) + else: + assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs" + + self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) + flat_offset += total_elements + + # move param to flat buffer for both param offload on/off + self._move_to_flat_buffer(self.fp16_groups[i], + self.fp16_partitioned_groups_flat[i], + avoid_copy=not self.offload_param) + + see_memory_usage(f"After Flattening param group {i}", force=False) + + #create a pinned memory to be used for swapping out params to NVME after optimizer step + if self.fp16_partitioned_groups_flat[-1] is None: + create_fp16_flat_reuse_buffer = True + + see_memory_usage(f"After Flattening param subgroup {i}", force=False) + + if create_fp16_flat_reuse_buffer: + assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' + self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space( + largest_partition_numel) + + def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): + offset = 0 + elements_in_sub_group = sum( + [t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]]) + assert (flat_buffer.numel() == elements_in_sub_group) + for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): + dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel) + if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: + print_rank_0( + f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}" + ) + param.nvme_swapper.swap_in([param], async_op=False) + dest.data.copy_(partitioned_param.data) + param.nvme_swapper.remove_partition_and_release_buffers([param]) + print_rank_0(f"Swapping in {param.ds_id} done") + else: + dest.data.copy_(partitioned_param.data) + offset += partitioned_param.ds_numel + + def _create_next_swappable_fp32_groups(self): + reverse_order_indices = [ + i for i in range(len(self.fp32_partitioned_groups_flat)) + ] + reverse_order_indices.reverse() + + next_group = None + for i in reverse_order_indices: + self.next_swappable_fp32_partitioned_groups.append(next_group) + if self._swappable_optimizer_subgroup(i): + next_group = self.fp32_partitioned_groups_flat[i] + + self.next_swappable_fp32_partitioned_groups.reverse() + + def _get_sub_group_partitions(self, sub_group_id): + sub_group_partitions = [] + for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): + if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: + swap_path = param.nvme_swapper.get_path(param, True) + sub_group_partitions.append((partitioned_param, + param.ds_tensor.ds_numel, + swap_path)) + else: + sub_group_partitions.append((partitioned_param, + partitioned_param.ds_numel, + None)) + + return sub_group_partitions + + def _create_fp32_partitions(self): + cpu_memory_usage = 0 + cpu_memory_sub_groups = 0 + nvme_memory_usage = 0 + num_swappable_partitions = 0 + num_swap_from_nvme_partitions = 0 + num_swap_from_cpu_partitions = 0 + swap_from_nvme_memory_usage = 0 + swap_from_cpu_memory_usage = 0 + GIGA_BYTES = (1024**3) + + swappable_fp32_tensors = [] + swappable_fp16_src_tensors = [] + nvme_fp16_partitions_info = [] + nvme_fp16_num_elems = [] + nvme_fp32_dest_tensors = [] + fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() + + for i, tensor in enumerate(self.fp16_partitioned_groups_flat): + num_elements = self.fp16_partitioned_groups_flat_numel[i] + + # a partition of the fp32 master weights that will be updated by this process + if self._swappable_optimizer_subgroup(i): + self.fp32_partitioned_groups_flat.append(torch.Tensor()) + nvme_memory_usage += (fp32_element_size * num_elements) + num_swappable_partitions += 1 + + if self.params_in_nvme_and_cpu and tensor is None: + num_swap_from_nvme_partitions += 1 + swap_from_nvme_memory_usage += (fp32_element_size * num_elements) + if self.offload_optimizer_fast_init: + sub_group_partitions = self._get_sub_group_partitions(i) + nvme_fp16_partitions_info.append(sub_group_partitions) + nvme_fp16_num_elems.append(num_elements) + nvme_fp32_dest_tensors.append( + self.fp32_partitioned_groups_flat[i]) + else: + unpinned_fp32_buffer = torch.empty(num_elements, + device=self.device, + dtype=torch.float) + self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) + self.optimizer_swapper.initialize_parameters( + parameters=[self.fp32_partitioned_groups_flat[i]], + src_tensors=[unpinned_fp32_buffer]) + else: + num_swap_from_cpu_partitions += 1 + swap_from_cpu_memory_usage += (fp32_element_size * num_elements) + swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i]) + swappable_fp16_src_tensors.append( + self.fp16_partitioned_groups_flat[i]) + else: + cpu_memory_usage += (fp32_element_size * num_elements) + cpu_memory_sub_groups += 1 + + if self.params_in_nvme_and_cpu and tensor is None: + unpinned_fp32_buffer = torch.empty(num_elements, + device=self.device, + dtype=torch.float) + self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) + self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) + else: + self.fp32_partitioned_groups_flat.append( + self.fp16_partitioned_groups_flat[i].to( + self.device).clone().float().detach()) + + self.fp32_partitioned_groups_flat[ + i].requires_grad = True # keep this in case internal optimizer uses it + + if len(swappable_fp32_tensors) > 0: + self.optimizer_swapper.initialize_parameters( + parameters=swappable_fp32_tensors, + src_tensors=swappable_fp16_src_tensors) + + if len(nvme_fp32_dest_tensors) > 0: + fp16_pinned_buffers = self.fp16_groups[0][ + 0].nvme_swapper.reserve_available_buffers() + assert len(fp16_pinned_buffers) > 0 + self.optimizer_swapper.initialize_from_swapped_fp16_params( + fp16_partitions_info=nvme_fp16_partitions_info, + fp16_num_elems=nvme_fp16_num_elems, + fp16_pinned_buffers=fp16_pinned_buffers, + fp32_parameters=nvme_fp32_dest_tensors) + self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers() + + nvme_gigabytes = nvme_memory_usage / GIGA_BYTES + print_rank_0( + f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB', + force=False) + if self.params_in_nvme_and_cpu: + print_rank_0( + f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB', + force=False) + print_rank_0( + f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB', + force=False) + + cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES + print_rank_0( + f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB', + force=False) + + # Clear for on-the-fly population before the optimizer step + for param_group in self.optimizer.param_groups: + param_group['params'] = [] + + def _create_fp16_sub_groups(self, params_group): + + params_group_numel = sum([param.partitioned_size() for param in params_group]) + sub_group_size = self.sub_group_size + + if sub_group_size is None or sub_group_size >= params_group_numel: + return [params_group] + + sub_groups = [] + sub_group = [] + local_sub_group_size = 0 + for param in params_group: + + sub_group.append(param) + local_sub_group_size += param.partitioned_size() + + if local_sub_group_size >= sub_group_size or id(param) == id( + params_group[-1]): + + sub_groups.append(sub_group) + + sub_group = [] + local_sub_group_size = 0 + + return sub_groups + + # def reset_ds_tensor(self): + # for name, param in self.module.named_parameters(recurse=True): + # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible" + # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now" + # param.ds_tensor.data = param.data + + def setup_zero_stage3_hooks(self): + self.hierarchy = 0 + self._register_hooks_recursively(self.module) + + #reset step at the beginning of forward + def _pre_forward_hook(module, *args): + self.param_coordinator.reset_step() + + #reset step if in inference mode + def _end_of_forward_hook(module, *args): + + if not torch._C.is_grad_enabled(): + self.param_coordinator.reset_step() + + #likely one of them should be enough but just to be safe + self.module.register_forward_hook(_end_of_forward_hook) + self.module.register_forward_pre_hook(_pre_forward_hook) + + # Add top module to stack trace + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(self.module) + + def persistent_parameters(self): + persistent_params = [] + total_persistent_parameters = 0 + params_count = 0 + for _, param in self.module.named_parameters(recurse=True): + if param.ds_numel < self.persistence_threshold: + params_count += 1 + param.ds_persist = True + persistent_params.append(param) + total_persistent_parameters += param.ds_numel + + print_rank_0( + f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params", + force=False) + return persistent_params + + def _register_hooks_recursively(self, module, count=[0]): + my_count = count[0] + module.id = my_count + + #print(f"{module.__class__} : {module.id}") + + for child in module.children(): + count[0] = count[0] + 1 + self._register_hooks_recursively(child, count=count) + + def _pre_forward_module_hook(module, *args): + self.pre_sub_module_forward_function(module) + + def _post_forward_module_hook(module, input, output): + global FWD_MODULE_STACK + FWD_MODULE_STACK.pop() + if output is None: + output = [] + elif not isinstance(output, (list, tuple)): + if torch.is_tensor(output): + output = [output] + else: + #print(f'got UNKNOWN type {type(output)}') + outputs = [] + output = output if isinstance(output, dict) else vars(output) + for name, val in output.items(): + if not name.startswith('__') and torch.is_tensor(val): + outputs.append(val) + output = outputs + #print(f'convert output to {output}') + + for item in filter(lambda item: is_zero_param(item), output): + if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): + item.ds_active_sub_modules += 1 + module_to_register = FWD_MODULE_STACK[-1] + print_rank_0( + f'Registering dangling parameter for module {module_to_register.__class__.__name__}.', + force=False) + register_external_parameter(module_to_register, item) + + # It's possible that the parameter was already external to the completed module. If so, remove it the + # registration as it will be covered by the outer module instead. + if id(item) in module._external_params: + print_rank_0( + f' Unregistering nested dangling parameter from module {module.__class__.__name__}', + force=False) + unregister_external_parameter(module, item) + + item.all_gather() + + self.post_sub_module_forward_function(module) + + def _pre_backward_module_hook(module, inputs, output): + def _run_before_backward_function(sub_module): + # some models (e.g. Albert) may run multiple forwards on the same layer in a loop + # before doing backwards, so each backward will need a pre-fetch - using reference + # counting to support this scenario + #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") + if sub_module.applied_pre_backward_ref_cnt > 0: + self.pre_sub_module_backward_function(sub_module) + sub_module.applied_pre_backward_ref_cnt -= 1 + #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + + return _apply_to_tensors_only(module, + PreBackwardFunction, + _run_before_backward_function, + output) + + #This is an alternate to doing _post_backward_module_hook + #it uses tensor.register_hook instead of using torch.autograd.Function + def _alternate_post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + #print(f"Before Forward {module.__class__.__name__}") + + def _run_after_backward_hook(*unused): + module.ds_grads_remaining = module.ds_grads_remaining - 1 + if module.ds_grads_remaining == 0: + #print(f"After backward {module.__class__.__name__}") + self.post_sub_module_backward_function(module) + + def _run_before_forward_function(input): + if input.requires_grad: + module.ds_grads_remaining += 1 + + return _apply_forward_and_backward_to_tensors_only( + module, + _run_before_forward_function, + _run_after_backward_hook, + inputs) + + def _post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + def _run_after_backward_function(sub_module): + if sub_module.ds_grads_remaining == 0: + self.post_sub_module_backward_function(sub_module) + + return _apply_to_tensors_only(module, + PostBackwardFunction, + _run_after_backward_function, + inputs) + + # Pre forward hook + module.register_forward_pre_hook(_pre_forward_module_hook) + # Post forward hook + module.register_forward_hook(_post_forward_module_hook) + + # Pre backward hook + module.register_forward_hook(_pre_backward_module_hook) + + # post backward hook + module.register_forward_pre_hook(_post_backward_module_hook) + + def pre_sub_module_forward_function(self, sub_module): + see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", + force=False) + + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(sub_module) + + self.param_coordinator.record_trace(sub_module) + + self.param_coordinator.fetch_sub_module(sub_module) + see_memory_usage( + f"Before sub module function {sub_module.__class__.__name__} after fetch", + force=False) + + self.param_coordinator.prefetch_next_sub_modules( + sub_module, + numel=self.prefetch_elements, + nvme=self.params_in_nvme_and_cpu) + see_memory_usage( + f"Before sub module function {sub_module.__class__.__name__} after prefetch", + force=False) + + self.param_coordinator.increment_step(sub_module) + + def post_sub_module_forward_function(self, sub_module): + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + + self.param_coordinator.release_sub_module(sub_module) + + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) + + def pre_sub_module_backward_function(self, sub_module): + self.param_coordinator.record_trace(sub_module) + + self.param_coordinator.fetch_sub_module(sub_module) + + self.param_coordinator.prefetch_next_sub_modules(sub_module, + numel=self.prefetch_elements) + + self.param_coordinator.increment_step(sub_module) + + def post_sub_module_backward_function(self, sub_module): + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + self.param_coordinator.release_sub_module(sub_module) + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) + + def _release_ipg_buffers(self): + if self.contiguous_gradients: + self.ipg_buffer = None + if not self.offload_optimizer and self.is_gradient_accumulation_boundary: + self.grads_in_partition = None + + self.grads_in_partition_offset = 0 + + def _optimizer_step(self, sub_group_id): + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + + self.optimizer.step() + self.optimizer.param_groups[param_group_id]['params'] = [] + + def _swappable_optimizer_subgroup(self, sub_group_id): + if not self.swap_optimizer: + return False + + return self.optimizer_swapper.swappable_tensor( + None, + numel=self.fp16_partitioned_groups_flat_numel[sub_group_id]) + + def _partitioned_params_swap_out(self, i): + offset = 0 + fp32_param = self.fp32_partitioned_groups_flat[i] + assert fp32_param is not None, \ + f'fp32 parameters of sub_group {i} is None' + + swap_fp16_params = [] + swap_fp32_params = [] + for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]): + src = fp32_param.narrow(0, offset, partitioned_param.ds_numel) + if partitioned_param.status == PartitionedParamStatus.AVAILABLE: + partitioned_param.data.copy_(src.data) + else: + swap_fp32_params.append(src) + swap_fp16_params.append(param) + offset += partitioned_param.ds_numel + + if len(swap_fp16_params): + swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params( + dst_fp16_params=swap_fp16_params, + src_fp32_params=swap_fp32_params) + + def initialize_optimizer_states(self): + num_subgroups = len(self.fp16_groups) + + largest_numel = max( + [sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups]) + gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype + gradient_buffer = torch.zeros(int(largest_numel), + dtype=gradient_dtype, + device=self.device) + + timers = self.timers + timer_names = set() + + if self.swap_optimizer: + self.optimizer_swapper.init_timers() + + INIT_OPTIMIZER_TIMER = 'init_optimizer_state' + timer_names.add(INIT_OPTIMIZER_TIMER) + self.start_timers([INIT_OPTIMIZER_TIMER]) + + for i, group in enumerate(self.fp16_groups): + swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i) + swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None + + num_elements = int(self.fp16_partitioned_groups_flat_numel[i]) + + see_memory_usage( + f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + if swappable_optimizer_subgroup: + self._optimizer_states_and_gradient_swap_in(i, timer_names) + + if self.offload_optimizer and not swappable_optimizer_subgroup: + subgroup_gradient_buffer = torch.zeros(num_elements, + dtype=gradient_dtype, + device=self.device) + if self.offload_optimizer_pin_memory: + subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory() + + self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer + else: + self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow( + 0, + 0, + num_elements) + + self._optimizer_step(i) + + if swappable_param_subgroup: + self._partitioned_params_swap_out(i) + + if swappable_optimizer_subgroup: + self._optimizer_states_and_gradient_swap_out(i, timer_names) + + see_memory_usage( + f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + self.stop_timers([INIT_OPTIMIZER_TIMER]) + self.log_timers(timer_names) + + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + + if not self.offload_optimizer: + for group in self.fp32_partitioned_groups_flat: + group.grad = None + + # Reset steps + return + + ######################################################################### + #########################ZeRO Partition Gradients######################## + ######################################################################### + + def get_first_param_index(self, group_id, param_group, partition_id): + for index, param in enumerate(param_group): + param_id = self.get_param_id(param) + if partition_id in self.param_to_partition_ids[group_id][param_id]: + return index + return None + + def initialize_gradient_partitioning_data_structures(self): + + total_partitions = dist.get_world_size(group=self.dp_process_group) + + for i, param_group in enumerate(self.fp16_groups): + + self.param_to_partition_ids[i] = {} + self.is_partition_reduced[i] = {} + self.total_grads_in_partition[i] = {} + self.remaining_grads_in_partition[i] = {} + self.is_grad_computed[i] = {} + self.grad_partition_insertion_offset[i] = {} + self.grad_start_offset[i] = {} + self.first_param_index_in_partition[i] = {} + + for partition_id in range(total_partitions): + self.is_grad_computed[i][partition_id] = {} + self.grad_partition_insertion_offset[i][partition_id] = {} + self.grad_start_offset[i][partition_id] = {} + self.initialize_gradient_partition(i, param_group, partition_id) + self.is_partition_reduced[i][partition_id] = False + self.first_param_index_in_partition[i][ + partition_id] = self.get_first_param_index( + i, + param_group, + partition_id) + + def independent_gradient_partition_epilogue(self): + self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) + self.reduce_ipg_grads() + self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) + + if self.overlap_comm: + self.reduction_stream.synchronize() + + with torch.cuda.stream(self.reduction_stream): + self.partition_previous_reduced_grads() + + # if dist.get_rank() == 0: + # logger.info("Params already reduced %s", self.params_already_reduced) + for i in range(len(self.params_already_reduced)): + self.params_already_reduced[i] = False + + #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad + #TODO: use a similar code path for both cpu_offload and non-cpu offload + if not self.offload_optimizer: + for i, sub_group in enumerate(self.fp16_groups): + self.averaged_gradients[i] = [ + torch.zeros_like(param.ds_tensor) if param.grad is None else + param.grad.data.narrow(0, + 0, + param.ds_tensor.numel()) + for param in sub_group + ] + # self.averaged_gradients[i] = self.get_flat_partition( + # self.fp16_groups[i], + # 0, + # self.fp32_partitioned_groups_flat[i].numel(), + # return_tensor_list=True) + + self._release_ipg_buffers() + + see_memory_usage(f"End ipg_epilogue", force=False) + + # resets all partition to no reduced + # sets remaining grads to the total number of grads in each partition + # set is grad computed to false for all grads in partition + def reset_partition_gradient_structures(self): + total_partitions = dist.get_world_size(group=self.dp_process_group) + for i, _ in enumerate(self.fp16_groups): + for partition_id in range(total_partitions): + self.is_partition_reduced[i][partition_id] = False + self.remaining_grads_in_partition[i][ + partition_id] = self.total_grads_in_partition[i][partition_id] + + for param_id in self.is_grad_computed[i][partition_id]: + self.is_grad_computed[i][partition_id][param_id] = False + + def initialize_gradient_partition(self, i, param_group, partition_id): + def set_key_value_list(dictionary, key, value): + if key in dictionary: + dictionary[key].append(value) + else: + dictionary[key] = [value] + + def increment_value(dictionary, key): + if key in dictionary: + dictionary[key] += 1 + else: + dictionary[key] = 1 + + partition_size = self.partition_size[i] + + start_index = partition_size * partition_id + end_index = partition_size * (partition_id + 1) + + current_index = 0 + first_offset = 0 + + for param in param_group: + + param_size = param.numel() + param_id = self.get_param_id(param) + + if (current_index >= start_index and current_index < end_index): + set_key_value_list(self.param_to_partition_ids[i], + param_id, + partition_id) + increment_value(self.total_grads_in_partition[i], partition_id) + + self.is_grad_computed[i][partition_id][param_id] = False + + self.grad_partition_insertion_offset[i][partition_id][ + param_id] = current_index - start_index + self.grad_start_offset[i][partition_id][param_id] = 0 + + elif start_index > current_index and start_index < (current_index + + param_size): + assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" + first_offset = start_index - current_index + + set_key_value_list(self.param_to_partition_ids[i], + param_id, + partition_id) + increment_value(self.total_grads_in_partition[i], partition_id) + + self.is_grad_computed[i][partition_id][param_id] = False + + self.grad_partition_insertion_offset[i][partition_id][param_id] = 0 + self.grad_start_offset[i][partition_id][param_id] = first_offset + + current_index = current_index + param_size + + def overlapping_partition_gradients_reduce_epilogue(self): + self.independent_gradient_partition_epilogue() + self.zero_grad() + + def create_reduce_and_remove_grad_hooks(self): + print_rank_0(f'[Begin] Create gradient reduction hooks') + self.grad_accs = [] + for i, param_group in enumerate(self.fp16_groups): + for param in param_group: + if param.requires_grad: + #print_rank_0(f" Before all gather {param.device}, {param.shape}") + + # The hook must be created in un-partitioned parameter + param.all_gather() + + #print(f"After all gather {param.device}, {param.shape}") + def wrapper(param, i): + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + + def reduce_partition_and_remove_grads(*notneeded): + self.reduce_ready_partitions_and_remove_grads(param, i) + + grad_acc.register_hook(reduce_partition_and_remove_grads) + self.grad_accs.append(grad_acc) + + #print(f"param grad fn {param.expand_as(param).grad_fn}") + wrapper(param, i) + + # Partition the parameter after creating the hook + param.partition() + print_rank_0(f'[End] Create gradient reduction hooks') + + def get_param_id(self, param): + unique_id = id(param) + return self.param_id[unique_id] + + def report_ipg_memory_usage(self, tag, param_elems): + elem_count = self.elements_in_ipg_bucket + param_elems + percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size + see_memory_usage( + f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}", + force=False) + + ###############Idependent Partition Gradient ######################## + def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): + #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True) + + # Because the ipg bucket is initialized with a random place holder tensor, we must + # explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket > + # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a + # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be + # empty, while reduction_list will have that garbage data. + if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size: + self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", + param.ds_numel) + + self.reduce_ipg_grads() + + if self.contiguous_gradients and self.overlap_comm: + # Swap ipg_index between 0 and 1 + self.ipg_index = 1 - self.ipg_index + self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", + param.ds_numel) + + param_id = self.get_param_id(param) + assert self.params_already_reduced[param_id] == False, \ + f"The parameter {param_id} has already been reduced. \ + Gradient computed twice for this partition. \ + Multiple gradient reduction is currently not supported" + + # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening + if param.ds_numel > self.reduce_bucket_size: + self.extra_large_param_to_reduce = param + + elif self.contiguous_gradients: + #print_rank_0("before new grad tensor move") + new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow( + 0, + self.elements_in_ipg_bucket, + param.ds_numel) + #print_rank_0("after new grad tensor move") + new_grad_tensor.copy_(param.grad.view(-1)) + param.grad.data = new_grad_tensor.data.view_as(param.grad) + + self.elements_in_ipg_bucket += param.ds_numel + self.grads_in_ipg_bucket.append(param.grad) + self.params_in_ipg_bucket.append((i, param, param_id)) + self.report_ipg_memory_usage("End ipg_remove_grads", 0) + + def gradient_reduction_w_predivide(self, tensor): + dp_world_size = dist.get_world_size(group=self.dp_process_group) + + tensor_to_allreduce = tensor + + if self.communication_data_type != tensor.dtype: + tensor_to_allreduce = tensor.to(self.communication_data_type) + + if self.postscale_gradients: + if self.gradient_predivide_factor != 1.0: + tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor) + + dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + + if self.gradient_predivide_factor != dp_world_size: + tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size) + else: + tensor_to_allreduce.div_(dp_world_size) + dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + + if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: + tensor.copy_(tensor_to_allreduce) + + return tensor + + def average_tensor(self, tensors, params_to_reduce): + with torch.cuda.stream(self.reduction_stream): + if not self.reduce_scatter: + for tensor in tensors: + self.gradient_reduction_w_predivide(tensor) + return + + for tensor in tensors: + tensor.div_(dist.get_world_size(group=self.dp_process_group)) + + # reduction resulting with each rank only holding the gradient partition it owns + # This could either be a reduce scatter or a reduce op depending on how + # parameters are partitionied. The method is implemented by the + # DeepSpeed param extensions to the pytorch parameter, so its up to + # the extension to define what happens here + params_to_reduce[0].reduce_gradients_at_owner( + param_list=params_to_reduce, + hierarchy=self.param_coordinator.hierarchy) + + def set_grad_positions(self): + for i, group in enumerate(self.fp16_groups): + current_offset = 0 + for param in group: + param_id = self.get_param_id(param) + num_elements = param.ds_tensor.ds_numel + + self.grad_position[param_id] = [ + int(i), + int(current_offset), + int(num_elements) + ] + #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") + current_offset += num_elements + + def async_accumulate_grad_in_cpu_via_gpu(self, param, acc_grad_cpu_partition): + + # copy to a preexisiting buffer to avoid memory allocation penalty + dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow( + 0, + 0, + param.ds_tensor.ds_numel) + + if self.micro_step_id > 0: + dest_buffer.copy_(acc_grad_cpu_partition.view(-1), non_blocking=True) + param.grad.data.view(-1).add_(dest_buffer) + + # at the boundary we will send 32bit directly + if not self.is_gradient_accumulation_boundary: + acc_grad_cpu_partition.data.copy_(param.grad.data.view(-1), + non_blocking=True) + + def _constant_buffered_norm2(self, input, buffer_size=250000000): + norm = None + for part in input.view(-1).split(buffer_size): + if norm is None: + norm = part.data.double().norm(2)**2.0 + else: + norm += part.data.double().norm(2)**2.0 + return norm**0.5 + + def set_norm_for_param_grad_in_gpu(self, param): + param_id = self.get_param_id(param) + #self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2) + #Using a more memory efficient version + self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) + + def update_overflow_tracker_for_param_grad(self, param): + #Credit to our user David Minn + if param.grad is not None: + if self.overlap_comm: + self.gpu_sum = self.gpu_sum + param.grad.data.float().sum() + elif self._has_inf_or_nan(param.grad.data): + self.local_overflow = True + + def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): + with torch.cuda.stream(self.copy_grad_stream): + param_id = self.get_param_id(param) + src_tensor = param.grad.view(-1).float() + #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}") + fp32_grad_tensor.copy_(src_tensor, non_blocking=True) + param.grad = None + + def complete_grad_norm_calculation_for_cpu_offload(self, params): + total_norm = 0.0 + norm_type = 2.0 + for p in params: + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + param_id = self.get_param_id(p) + if param_id in self.norm_for_param_grads.keys(): + param_norm = self.norm_for_param_grads[param_id] + total_norm += param_norm.item()**2 + + # Sum across all model parallel GPUs. + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.SUM) + + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float( + 'inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + def partition_previous_reduced_grads(self): + if not self.previous_reduced_grads: + return + + if self.offload_optimizer: + allocate_grads_in_partition = self.grads_in_partition is None\ + and self.gradient_accumulation_steps > 1 + else: + allocate_grads_in_partition = self.grads_in_partition is None + + if allocate_grads_in_partition: + self.grads_in_partition = [] + + for i, group in enumerate(self.fp16_groups): + total_size = 0 + for param_in_partition in group: + total_size += param_in_partition.ds_tensor.ds_numel + + see_memory_usage( + f"group {i} before creating {total_size} reduced gradients into partition", + force=False) + if self.offload_param_pin_memory: + self.grads_in_partition.append( + torch.zeros(int(total_size), + dtype=self.dtype, + device=self.device).pin_memory()) + else: + self.grads_in_partition.append( + torch.zeros(int(total_size), + dtype=self.dtype, + device=self.device)) + see_memory_usage( + f"group {i} after creating {total_size} reduced gradients into partition", + force=False) + + if self.offload_optimizer: + offload_fp32_gradients = {} + offload_fp32_offsets = {} + + with torch.cuda.stream(self.copy_grad_stream): + self.reduction_stream.synchronize() + for param in self.previous_reduced_grads: + + [i, + dest_offset, + num_elements] = self.grad_position[self.get_param_id(param)] + + if self.offload_optimizer: + param.partition_gradients( + partition_buffers=self.temp_grad_gpu_buffer) + #with torch.cuda.stream(self.copy_grad_stream): + # self.reduction_stream.synchronize() + + if self.gradient_accumulation_steps > 1: + # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer + fp16_grad_tensor = self.grads_in_partition[i].narrow( + 0, + dest_offset, + num_elements) + self.async_accumulate_grad_in_cpu_via_gpu( + param, + fp16_grad_tensor) + + if self.is_gradient_accumulation_boundary: + + self.set_norm_for_param_grad_in_gpu(param) + + self.update_overflow_tracker_for_param_grad(param) + + if self._swappable_optimizer_subgroup(i): + if not i in offload_fp32_gradients.keys(): + offload_fp32_gradients[i] = [] + offload_fp32_offsets[i] = [] + + offload_fp32_gradients[i].append(param.grad.view(-1).float()) + param.grad = None + offload_fp32_offsets[i].append(dest_offset) + else: + fp32_grad_tensor = self.fp32_partitioned_groups_flat[ + i].grad.narrow(0, + dest_offset, + num_elements) + + self.async_inplace_copy_grad_to_fp32_buffer_from_gpu( + param, + fp32_grad_tensor) + else: + # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer + fp16_grad_tensor = self.grads_in_partition[i].narrow( + 0, + dest_offset, + num_elements) + param.partition_gradients( + partition_buffers=fp16_grad_tensor, + accumulate=True if self.micro_step_id > 0 else False) + + if self.offload_optimizer and self.swap_optimizer: + for i in offload_fp32_gradients.keys(): + self.optimizer_swapper.swap_out_gradients( + parameter=self.fp32_partitioned_groups_flat[i], + gradient_offsets=offload_fp32_offsets[i], + gradient_tensors=offload_fp32_gradients[i]) + + self.previous_reduced_grads = [] + + def reduce_ipg_grads(self, extra_param=None): + if self.overlap_comm: + self.reduction_stream.synchronize() + + with torch.cuda.stream(self.reduction_stream): + self.partition_previous_reduced_grads() + + params_to_reduce = [param for i, param, param_id in self.params_in_ipg_bucket] + #print(f"Params in ipg bucket {self.params_in_ipg_bucket}") + #print(f"Reducing {[(debug_param2name_id_shape(param), param.grad) for param in params_to_reduce]}") + #exit(0) + if self.contiguous_gradients: + reduction_list = [self.ipg_buffer[self.ipg_index]] + if self.extra_large_param_to_reduce is not None: + reduction_list.append(self.extra_large_param_to_reduce.grad) + self.extra_large_param_to_reduce = None + self.average_tensor(reduction_list, params_to_reduce) + else: + self.buffered_reduce_fallback( + None, + self.grads_in_ipg_bucket, + elements_per_buffer=self.elements_in_ipg_bucket) + + for _, param, param_id in self.params_in_ipg_bucket: + self.params_already_reduced[param_id] = True + + self.previous_reduced_grads = params_to_reduce + + self.grads_in_ipg_bucket = [] + self.params_in_ipg_bucket = [] + self.elements_in_ipg_bucket = 0 + ##################################################################### + + def reduce_ready_partitions_and_remove_grads(self, param, i): + #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) + self.reduce_independent_p_g_buckets_and_remove_grads(param, i) + + def zero_reduced_gradients(self, partition_id, i): + def are_all_related_partitions_reduced(params_id): + for partition_id in self.param_to_partition_ids[i][params_id]: + if not self.is_partition_reduced[i][partition_id]: + return False + return True + + for params_id in self.is_grad_computed[i][partition_id]: + if are_all_related_partitions_reduced(params_id): + self.param_dict[params_id].grad = None + + def flatten_and_print(self, message, tensors, start=0, n=5): + flatten_tensor = self.flatten(tensors) + + def print_func(): + logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) + + self.sequential_execution(print_func, message) + + def get_grads_to_reduce(self, i, partition_id): + def get_reducible_portion(key): + grad = self.param_dict[key].grad + total_elements = grad.numel() + start = self.grad_start_offset[i][partition_id][key] + num_elements = min( + total_elements - start, + self.partition_size[i] - + self.grad_partition_insertion_offset[i][partition_id][key]) + if not pg_correctness_test: + if num_elements == total_elements: + return grad + else: + return grad.contiguous().view(-1).narrow(0, + int(start), + int(num_elements)) + else: + if num_elements == total_elements: + return grad.clone() + else: + return grad.clone().contiguous().view(-1).narrow( + 0, + int(start), + int(num_elements)) + + grads_to_reduce = [] + for key in self.is_grad_computed[i][partition_id]: + grad = get_reducible_portion(key) + grads_to_reduce.append(grad) + return grads_to_reduce + + def sequential_execution(self, function, message, group=None): + if group is None: + group = self.dp_process_group + if dist.get_rank(group=group) == 0: + logger.info(message) + for id in range(dist.get_world_size(group=group)): + if id == dist.get_rank(group=group): + function() + dist.barrier(group=group) + + def set_none_gradients_to_zero(self, i, partition_id): + for param_id in self.is_grad_computed[i][partition_id]: + param = self.param_dict[param_id] + if param.grad is None: + param.grad = torch.zero_like(param) + + ######################Reduction Related Methods############################## + + def allreduce_bucket(self, + bucket, + communication_data_type=torch.float16, + rank=None, + log=None): + rank = None + tensor = self.flatten(bucket) + + tensor_to_allreduce = tensor + + if pg_correctness_test: + communication_data_type = torch.float32 + + if communication_data_type != tensor.dtype: + tensor_to_allreduce = tensor.to(communication_data_type) + + tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) + + if rank is None: + # "All Reducing" + dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + else: + global_rank = _get_global_rank(self.dp_process_group, rank) + dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) + + if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: + if rank is None or rank == dist.get_rank(group=self.dp_process_group): + tensor.copy_(tensor_to_allreduce) + + return tensor + + # if rank is specified do a reduction instead of an allreduce + def allreduce_and_copy(self, small_bucket, rank=None, log=None): + with torch.cuda.stream(self.reduction_stream): + allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) + if rank is None or rank == dist.get_rank(group=self.dp_process_group): + for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): + buf.copy_(synced) + + def allreduce_no_retain(self, + bucket, + numel_per_bucket=500000000, + rank=None, + log=None): + small_bucket = [] + numel = 0 + for tensor in bucket: + small_bucket.append(tensor) + numel = numel + tensor.numel() + if numel > numel_per_bucket: + self.allreduce_and_copy(small_bucket, rank=rank, log=None) + small_bucket = [] + if len(small_bucket) > 0: + self.allreduce_and_copy(small_bucket, rank=rank, log=log) + + # allows using reduction of gradients instead of using all_reduce + def buffered_reduce_fallback(self, + rank, + grads, + elements_per_buffer=500000000, + log=None): + split_buckets = split_half_float_double(grads) + + for i, bucket in enumerate(split_buckets): + self.allreduce_no_retain(bucket, + numel_per_bucket=elements_per_buffer, + rank=rank, + log=log) + + ############################################################################# + ############################################################################# + ############################################################################# + + # views the tensor as multiple partitions and returns + # those partitions + def get_data_parallel_partitions(self, tensor): + partitions = [] + + dp = dist.get_world_size(group=self.dp_process_group) + dp_id = dist.get_rank(group=self.dp_process_group) + + total_num_elements = tensor.numel() + + base_size = total_num_elements // dp + remaining = total_num_elements % dp + + start = 0 + for id in range(dp): + partition_size = base_size + if id < remaining: + partition_size = partition_size + 1 + partitions.append(tensor.narrow(0, start, partition_size)) + start = start + partition_size + return partitions + + def get_partition_info(self, tensor_list, partition_size, partition_id): + params_in_partition = [] + params_not_in_partition = [] + + start_index = partition_size * partition_id + end_index = partition_size * (partition_id + 1) + + current_index = 0 + first_offset = 0 + + for tensor in tensor_list: + + tensor_size = tensor.numel() + + if (current_index >= start_index and current_index < end_index): + params_in_partition.append(tensor) + + elif start_index > current_index and start_index < (current_index + + tensor_size): + params_in_partition.append(tensor) + + assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" + first_offset = start_index - current_index + + else: + params_not_in_partition.append(tensor) + + current_index = current_index + tensor_size + + return params_in_partition, params_not_in_partition, first_offset + + def zero_grad(self, set_grads_to_None=True): + """ + Zero FP16 parameter grads. + """ + # FP32 grad should never exist. + # For speed, set model fp16 grad to None by default + for group in self.fp16_groups: + for p in group: + if set_grads_to_None: + p.grad = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + def _model_parallel_all_reduce(self, tensor, op): + """ Perform all reduce within model parallel group, if any. + """ + if self.model_parallel_group is None: + pass + else: + torch.distributed.all_reduce(tensor=tensor, + op=op, + group=self.model_parallel_group) + + def get_grad_norm_direct(self, gradients, params, norm_type=2): + """Clips gradient norm of an iterable of parameters. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(g.data.abs().max() for g in gradients) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.MAX, + group=self.dp_process_group) + + # Take max across all GPUs. + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.MAX) + total_norm = total_norm_cuda[0].item() + else: + total_norm = 0.0 + # if dist.get_rank() == 0: + # logger.info(f"Total Norm beginning {total_norm}") + for g, p in zip(gradients, params): + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + param_norm = g.data.double().norm(2) + total_norm += param_norm.item()**2 + # Sum across all model parallel GPUs. + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.SUM) + + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float( + 'inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + # creates a flat fused tensor from the tensor list starting at the first_offset + # in the first tensor of the list. If there are not enough elements in the tensor + # list then the flat tensor will be padded with zeros + def get_flat_partition(self, + tensor_list, + first_offset, + partition_size, + return_tensor_list=False): + flat_tensor_list = [] + current_size = 0 + for i, tensor in enumerate(tensor_list): + if tensor.grad is None: + tensor.grad = torch.zeros_like(tensor) + + tensor = tensor.grad + num_elements = tensor.numel() + tensor_offset = 0 + + # we need to offset to get to the right element + if i == 0 and first_offset > 0: + tensor_offset = first_offset + num_elements = num_elements - tensor_offset + + # we dont need all elements of the tensor + if num_elements > (partition_size - current_size): + num_elements = partition_size - current_size + + # we need a narrow view of the tensor based on the tensor offset and number of elements that + # we need from this tensor + if tensor_offset > 0 or num_elements < tensor.numel(): + flat_tensor_list.append(tensor.contiguous().view(-1).narrow( + 0, + int(tensor_offset), + int(num_elements))) + else: + flat_tensor_list.append(tensor) + + current_size = current_size + num_elements + + # this means its the last partition and does not align with the dp boundary. We need to pad before flattening + if current_size < partition_size: + flat_tensor_list.append( + torch.zeros(int(partition_size - current_size), + dtype=tensor_list[0].dtype, + device=tensor_list[0].device)) + + if return_tensor_list: + return flat_tensor_list + + return self.flatten(flat_tensor_list) + + def free_grad_in_param_list(self, param_list): + for p in param_list: + p.grad = None + + def reset_cpu_buffers(self): + self.norm_for_param_grads = {} + self.local_overflow = False + + def log_timers(self, timer_names): + if self.timers is None: + return + + self.timers.log(names=list(timer_names)) + + def start_timers(self, timer_names): + if self.timers is None: + return + + for name in timer_names: + self.timers(name).start() + + def stop_timers(self, timer_names): + if self.timers is None: + return + + for name in timer_names: + self.timers(name).stop() + + def _pre_step(self): + self.micro_step_id = INITIAL_MICRO_STEP_ID + + print_rank_0(f"Inside Step function") + see_memory_usage(f"In step before checking overflow", force=False) + + print_rank_0("Finished Tracing at Beginning of Step") + self.param_coordinator.hierarchy = 0 + self.param_coordinator.finish_tracing(print_trace=True) + + self.param_coordinator.reset_step() + + print_rank_0("Finished Tracing at Beginning of Step") + + def _get_norm_groups(self): + norm_groups = [] + for i, group in enumerate(self.fp16_groups): + if self.offload_optimizer: + norm_groups.append( + self.complete_grad_norm_calculation_for_cpu_offload( + self.fp16_groups[i])) + else: + norm_groups.append( + self.get_grad_norm_direct(self.averaged_gradients[i], + self.fp16_groups[i])) + return norm_groups + + def _prepare_fp32_grad_for_sub_group(self, sub_group_id): + partition_id = dist.get_rank(group=self.dp_process_group) + + single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to( + self.fp32_partitioned_groups_flat[sub_group_id].dtype) + + assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \ + "averaged gradients have different number of elements that partition size {} {} {} {}".format( + single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id) + + self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition + + # release all the gradient since we have already created a necessary copy in dp_grad_partition + self.zero_grad() + + self.averaged_gradients[sub_group_id] = None + + def _prepare_sub_group(self, sub_group_id, timer_names=set()): + see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', + force=False) + if self._swappable_optimizer_subgroup(sub_group_id): + self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names) + elif not self.offload_optimizer: + self._prepare_fp32_grad_for_sub_group(sub_group_id) + see_memory_usage(f'After prepare optimizer sub group {sub_group_id}', + force=False) + + def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()): + param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] + fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) + assert self._swappable_optimizer_subgroup(sub_group_id), \ + f'Parameter {fp32_param_id} of numel={param_length} is not swappable' + + OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state' + see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}', + force=False) + self.start_timers([OPTIMIZER_SWAP_IN_STATE]) + + self.optimizer_swapper.swap_in_optimizer_state( + parameter=self.fp32_partitioned_groups_flat[sub_group_id], + async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id]) + + self.stop_timers([OPTIMIZER_SWAP_IN_STATE]) + timer_names.add(OPTIMIZER_SWAP_IN_STATE) + see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', + force=False) + + def _release_sub_group(self, sub_group_id, timer_names=set()): + see_memory_usage(f'Before release optimizer sub group {sub_group_id}', + force=False) + # get rid of the fp32 gradients. Not needed anymore + if not self.offload_optimizer: + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + + if self._swappable_optimizer_subgroup(sub_group_id): + self._optimizer_states_and_gradient_swap_out(sub_group_id, timer_names) + see_memory_usage(f'After release optimizer sub group {sub_group_id}', + force=False) + + # create a flat tensor aligned at the alignment boundary + def flatten_dense_tensors_aligned(self, tensor_list, alignment): + num_elements = 0 + for tens in tensor_list: + num_elements = num_elements + tens.numel() + + remaining = num_elements % alignment + + if remaining: + elements_to_add = alignment - remaining + pad_tensor = torch.zeros(elements_to_add, + device=tensor_list[0].device, + dtype=tensor_list[0].dtype) + padded_tensor_list = tensor_list + [pad_tensor] + + num_elements = num_elements + elements_to_add + else: + padded_tensor_list = tensor_list + + return self.flatten(padded_tensor_list) + + def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()): + param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] + fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) + assert self._swappable_optimizer_subgroup(sub_group_id), \ + f'Parameter {fp32_param_id} of numel={param_length} is not swappable' + + OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state' + see_memory_usage( + f'post-step Before swapping out optimizer tensors {sub_group_id}', + force=False) + self.start_timers([OPTIMIZER_SWAP_OUT_STATE]) + + self.optimizer_swapper.swap_out_optimizer_state( + parameter=self.fp32_partitioned_groups_flat[sub_group_id], + async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is + not None) + + self.stop_timers([OPTIMIZER_SWAP_OUT_STATE]) + see_memory_usage( + f'post-step After swapping out optimizer tensors {sub_group_id}', + force=False) + timer_names.add(OPTIMIZER_SWAP_OUT_STATE) + + # get rid of the fp32 gradients. Not needed anymore + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + + def _unflatten_partitioned_parameters(self, sub_group_id): + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + def _overflow_clean_up(self, prev_scale): + see_memory_usage('After overflow before clearing gradients', force=False) + self.zero_grad() + + if self.offload_optimizer: + self.reset_cpu_buffers() + else: + self.averaged_gradients = {} + + see_memory_usage('After overflow after clearing gradients', force=False) + + if torch.distributed.get_rank() == 0: + logger.info( + "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, " + "reducing to {}".format(dist.get_rank(), + prev_scale, + self.loss_scale)) + + def _overflow_check_and_loss_scale_update(self): + + # First compute norm for all group so we know if there is overflow + self.check_overflow() + + #loss scaling related computation + prev_scale = self.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + self._overflow_clean_up(prev_scale) + + return self.overflow + + def _post_step(self, timer_names=set()): + if self.offload_optimizer: + self.reset_cpu_buffers() + + #Gathering persisting parameters + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + + self.log_timers(timer_names) + + see_memory_usage('After zero_optimizer step', force=False) + print_rank_0(f"------------------Finishing Step-----------------------") + + def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): + if self.fp16_partitioned_groups_flat[sub_group_id] is not None: + self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( + self.fp32_partitioned_groups_flat[sub_group_id].data) + + #unflatten fp16 parameter subgroup + self._unflatten_partitioned_parameters(sub_group_id) + else: + self._partitioned_params_swap_out(sub_group_id) + + def step(self, closure=None): + """ + Not supporting closure. + """ + self._pre_step() + + #checks for overflow, adjust the loss scale accordingly + if self._overflow_check_and_loss_scale_update(): + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + return + + norm_groups = self._get_norm_groups() + self._global_grad_norm = get_global_norm(norm_list=norm_groups) + + timer_names = set() + + timer_names.add('optimizer_step') + self.start_timers(['optimizer_step']) + + #update parameters one sub group at a time + for sub_group_id, group in enumerate(self.fp16_groups): + + #prepare optimizer states, gradients and fp32 parameters for update + self._prepare_sub_group(sub_group_id, timer_names) + + #scale the fp32 gradients + self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm) + + #apply the optimizer step on the sub group and copy fp32 parameters to fp16 + self._optimizer_step(sub_group_id) + + #put fp16 parameters in appropriate location + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + + #release memory or swap out optimizer states of fp32 parameters + self._release_sub_group(sub_group_id, timer_names) + + self.stop_timers(['optimizer_step']) + + self._post_step(timer_names) + return + + def dump_pre_step_gradients(self, debug_fp32_grads): + # Dump gradient norms for debugging + for i, _ in enumerate(self.fp16_groups): + print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC') + for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]): + param_id = self.get_param_id(fp16_param) + fp16_grad_norm = self.debug_fp16_grads[i][param_id] + + fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad] + norm_list = [fp16_grad_norm, fp32_grad_norm] + print(f'Pre-Step Norms {i} {param_id} = {norm_list}') + + def dump_post_step_gradients(self): + # Dump gradient norms for debugging + for i, group in enumerate(self.fp16_groups): + print( + f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT') + unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i]) + unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i], + self.fp16_groups[i]) + for j, p in enumerate(self.fp16_groups[i]): + param_id = self.get_param_id(p) + param_norm = float(p.data.float().norm(2)) + ds_norm = float(p.ds_tensor.data.float().norm(2)) + + unflat_norm = [ + float(t.data.float().norm(2)) + for t in [unflat_fp16[j], + unflat_fp32[j]] + ] + norm_list = [param_norm, ds_norm] + unflat_norm + print(f'Post-Step Norms {i} {param_id} = {norm_list}') + + def unscale_and_clip_grads(self, sub_group_id, total_norm): + grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] + + # compute combined scale factor for this group + combined_scale = self.loss_scale + if self.clip_grad > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad + if clip > 1: + combined_scale = clip * self.loss_scale + + for grad in grad_groups_flat: + if isinstance(grad, list): + sub_partitions = grad + for g in sub_partitions: + g.data.mul_(1. / combined_scale) + else: + grad.data.mul_(1. / combined_scale) + + def _check_overflow(self, partition_gradients=True): + self.overflow = self.has_overflow(partition_gradients) + + # `params` is a list / generator of torch.Variable + def has_overflow_serial(self, params, is_grad_list=False): + for p in params: + if p.grad is not None and self._has_inf_or_nan(p.grad.data): + return True + + return False + + def has_overflow_partitioned_grads_serial(self): + for i in range(len(self.fp16_groups)): + for j, grad in enumerate(self.averaged_gradients[i]): + if grad is not None and self._has_inf_or_nan(grad.data, j): + return True + return False + + def has_overflow(self, partition_gradients=True): + if partition_gradients: + if self.overlap_comm: + self.local_overflow = self._has_inf_or_nan(self.gpu_sum) + self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda() + + overflow = self.local_overflow if self.offload_optimizer else self.has_overflow_partitioned_grads_serial( + ) + #overflow = self.has_overflow_partitioned_grads_serial() + overflow_gpu = torch.cuda.ByteTensor([overflow]) + torch.distributed.all_reduce(overflow_gpu, + op=torch.distributed.ReduceOp.MAX, + group=self.dp_process_group) + + else: + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + + overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) + overflow_gpu = torch.cuda.ByteTensor([overflow]) + + # Since each model parallel GPU carries only part of the model, + # make sure overflow flag is synced across all the model parallel GPUs + self._model_parallel_all_reduce(tensor=overflow_gpu, + op=torch.distributed.ReduceOp.MAX) + + overflow = overflow_gpu[0].item() + return bool(overflow) + + # `x` is a torch.Tensor + @staticmethod + def _has_inf_or_nan(x, j=None): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + def backward(self, loss, retain_graph=False): + """ + :attr:`backward` performs the following steps: + + 1. fp32_loss = loss.float() + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves + """ + self.micro_step_id += 1 + print_rank_0( + f"Total fully available parameters {self.param_coordinator.total_available_parameter_numel}" + ) + + if self.swap_optimizer: + self.optimizer_swapper.pre_backward() + + see_memory_usage(f"Before backward", force=False) + if self.contiguous_gradients: + self.ipg_buffer = [] + buf_0 = torch.empty(self.reduce_bucket_size, + dtype=self.dtype, + device=torch.cuda.current_device()) + self.ipg_buffer.append(buf_0) + + # Use double buffers to avoid data access conflict when overlap_comm is enabled. + if self.overlap_comm: + buf_1 = torch.empty(self.reduce_bucket_size, + dtype=self.dtype, + device=torch.cuda.current_device()) + self.ipg_buffer.append(buf_1) + self.ipg_index = 0 + + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + '''Partitioning Parameters that were not partitioned + Usually if parameters of modules whose input parameters do not require + grad computation do not trigger post call and will therefore will remain unpartitioned ''' + self._partition_all_parameters() + + if self.swap_optimizer: + self.optimizer_swapper.post_backward() + + def _partition_all_parameters(self): + for name, param in self.module.named_parameters(recurse=True): + self.param_coordinator.release_and_reset_parameter(param) + + def check_overflow(self, partition_gradients=True): + self._check_overflow(partition_gradients) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + return self.loss_scaler.loss_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + cur_scale = property(_get_loss_scale, _set_loss_scale) + + def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings): + # Remove paddings from flattened tensor + individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors) + lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)] + lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)] + #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}') + return lean_tensors + + #TODO REVISIT this for stage 3 + def get_lean_optimizer_state(self): + # Return optimizer states after removing paddings. + # This method assumes that each param group contains a single flattened tensor. + optimizer_groups_state = [] + + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + lean_state = {} + for key, value in self.optimizer.state[p].items(): + if torch.is_tensor(value): + padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]] + lean_state[key] = self._get_lean_tensors( + value, + self.fp16_partitioned_groups[i], + self.groups_padding[i]) + lean_flat_len = sum([t.numel() for t in lean_state[key]]) + else: + lean_state[key] = value + + optimizer_groups_state.append(lean_state) + + return optimizer_groups_state + + def get_groups_without_padding(self, groups_with_padding): + # Return group tensor after removing paddings added for alignment to DP world size. + groups_without_padding = [] + for i, group in enumerate(groups_with_padding): + lean_group = self._get_lean_tensors(group, + self.fp16_partitioned_groups[i], + self.groups_padding[i]) + groups_without_padding.append(lean_group) + + return groups_without_padding + + def _set_fp32_optimizer_param_groups(self): + for sub_group_id, _ in enumerate(self.fp16_groups): + param_group_id = self.sub_group_to_group_id[sub_group_id] + self.optimizer.param_groups[param_group_id]['params'].append( + self.fp32_partitioned_groups_flat[sub_group_id]) + + def _clear_fp32_optimizer_param_groups(self): + for param_group in self.optimizer.param_groups: + param_group['params'] = [] + + def _rigid_state_dict(self): + state_dict = {} + state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict['partition_count'] = self.partition_count + + self._set_fp32_optimizer_param_groups() + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat + self._clear_fp32_optimizer_param_groups() + + return state_dict + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + if self.elastic_checkpoint: + raise NotImplementedError( + "ZeRO-3 does not yet support elastic checkpointing, please disable for now." + ) + + if self.swap_optimizer or self.params_in_nvme_and_cpu: + raise NotImplementedError( + "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." + ) + + return self._rigid_state_dict() + + +# Restore base optimizer fp32 weights from checkpoint by: +# 1) Merging fp32 weights from checkpoints of all partitions +# 2) Extracting fp32 weights for current partition from merged weights +# 3) Using extracted weights to update base optimizer weights directly. + + def _restore_from_fp32_weights(self, all_state_dict): + + flat_local_partition = [] + for i in range(len(self.fp32_partitioned_groups_flat)): + merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict] + flat_local_partition.append(self._get_flattened_partition(merged_partitions)) + + for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition): + current.data.copy_(saved.data) + + # Restore base optimizer fp32 weights from ZeRO fp16 weights + def _restore_from_fp16_weights(self): + for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat): + fp32_partition.data.copy_(fp16_partitions.data) + + # Refresh the fp32 master params from the fp16 copies. + def refresh_fp32_params(self): + self._restore_from_fp16_weights() + + # Extract flattened partition for current rank from all partitions + def _get_flattened_partition(self, all_partition_states): + partition_id = dist.get_rank(group=self.dp_process_group) + alignment = dist.get_world_size(group=self.dp_process_group) + + param_partitions = [[] for _ in range(len(all_partition_states[0]))] + for i, partition in enumerate(all_partition_states): + for j, param in enumerate(partition): + param_partitions[j].append(param) + + local_state_partitions = [] + for param_index, param_slices in enumerate(param_partitions): + flattened_merged_tensor = self.flatten_dense_tensors_aligned( + param_slices, + alignment) + new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor) + local_state_partitions.append(new_partitions[partition_id]) + + if torch.is_tensor(local_state_partitions[0]): + return self.flatten_dense_tensors_aligned(local_state_partitions, alignment) + + # Assume non-tensor states are not partitioned and equal across ranks, so return first one + return local_state_partitions[0] + + # Restore base optimizer state from checkpoint by + # 1) Merging optimizer state from checkpoints of all partitions + # 2) Extracting optimizer state for current partition from the merged state + # 3) Using the extracted value to directly update the base optimizer. + def _restore_base_optimizer_state(self, all_state_dict): + base_optimizer_group_states = [] + for i in range(len(self.optimizer.param_groups)): + partition_states = {} + all_partition_group_states = [ + sd['base_optimizer_state'][i] for sd in all_state_dict + ] + for key in all_partition_group_states[0].keys(): + all_partition_states = [ + all_states[key] for all_states in all_partition_group_states + ] + partition_states[key] = self._get_flattened_partition( + all_partition_states) + base_optimizer_group_states.append(partition_states) + + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + for key, saved in base_optimizer_group_states[i].items(): + if torch.is_tensor(self.optimizer.state[p][key]): + self.optimizer.state[p][key].data.copy_(saved.data) + else: + self.optimizer.state[p][key] = saved + + def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + + if load_optimizer_states: + self._set_fp32_optimizer_param_groups() + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + self._clear_fp32_optimizer_param_groups() + + # restore fp32 partitions + for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']): + curr_param.data.copy_(saved_param.data) + + # restore fp16 partitions from fp32 + for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + fp16_param.data.copy_(fp32_param.data) + + # update fp16 unflattened params + for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): + updated_params = self.unflatten( + self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + # TODO: Support different/changing load/save DP degree. + def load_state_dict(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False): + r"""Loading a ZeRO checkpoint + Arguments: + state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. + Note that the number of saved partitions may differ from number of loading partitions to support + changing GPU count, specifically DP world size, between saving and loading checkpoints. + load_optimizer_states: Boolean indicating whether or not to load base optimizer states + load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 + copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). + """ + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + Example:: + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + + if self.elastic_checkpoint: + raise NotImplementedError( + "ZeRO-3 does not yet support elastic checkpointing, please disable for now." + ) + + if self.swap_optimizer or self.params_in_nvme_and_cpu: + raise NotImplementedError( + "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." + ) + + self._rigid_load_state_dict( + state_dict_list[dist.get_rank(group=self.dp_process_group)], + load_optimizer_states=load_optimizer_states) + + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].partition(self.persistent_parameters) + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + def save_checkpoint_prologue(self): + self._partition_all_parameters() + + def save_checkpoint_epilogue(self): + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + +def _handle_overflow(cpu_sum, x, i): + import math + rank = torch.distributed.get_rank() + if rank == 0: + t_i = -1 + for v_i, v in enumerate(x.data.contiguous().view(-1)): + if not math.isfinite(float(v)): + t_i = v_i + break + logger.info( + f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}" + ) + + +def estimate_zero3_model_states_mem_needs(total_params, + largest_layer_params, + num_gpus_per_node=1, + num_nodes=1, + cpu_offload=True, + cpu_offload_params=True, + zero_init=True, + additional_buffer_factor=1.5): + + total_gpus = num_nodes * num_gpus_per_node + gpus_factor = 1 / num_nodes + largest_layer_memory = (4 * largest_layer_params) + + if cpu_offload: + if cpu_offload_params: + gpu_mem = largest_layer_memory + + if zero_init: + cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor + else: + cpu_mem = total_params * max(4 * num_gpus_per_node, + 18 * gpus_factor) * additional_buffer_factor + + else: + gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus) + + if zero_init: + cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor + else: + cpu_mem = total_params * max(4 * num_gpus_per_node, + 16 * gpus_factor) * additional_buffer_factor + else: + gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus) + if zero_init: + cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor + else: + cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor + + return int(cpu_mem), int(gpu_mem), largest_layer_memory + + +def model_to_params(model): + # shared params calculated only once + total_params = sum( + dict((p.data_ptr(), + p.numel()) for p in model.parameters()).values()) + + largest_layer_params = 0 + for m in model.modules(): + # assuming no shared params within a single layer + layer_params = sum(p.numel() for p in m.parameters(recurse=False)) + largest_layer_params = max(largest_layer_params, layer_params) + + return total_params, largest_layer_params + + +import math + + +def estimate_zero3_model_states_mem_needs_all_live(model, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients + for a given ``model`` and hardware setup. + + If you have an actual model object, use this function and everything will be derived + automatically. + + If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass + the ``total_params`` and ``largest_layer_params`` explicitly. + + Args: + - ``model``: ``nn.Module`` object + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + + total_params, largest_layer_params = model_to_params(model) + + estimate_zero3_model_states_mem_needs_all_cold( + total_params=total_params, + largest_layer_params=largest_layer_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + additional_buffer_factor=additional_buffer_factor) + + +def estimate_zero3_model_states_mem_needs_all_cold(total_params, + largest_layer_params, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients + for a given ``model`` and hardware setup. + + If it's a hypothetical model, use this function where you have to pass + the ``total_params`` and ``largest_layer_params`` explicitly. + + If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything + will be derived automatically. + + Args: + - ``total_params``: total model params + - ``largest_layer_params``: largest layer's params + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + def format_options(cpu_offload, cpu_offload_params, zero_init): + enabled = [] + padded_cpu_str = f'{OFFLOAD_CPU_DEVICE:4}' + param_device = padded_cpu_str if cpu_offload_params else "none" + enabled.append(f"{OFFLOAD_PARAM}={param_device}") + optimizer_device = padded_cpu_str if cpu_offload else "none" + enabled.append(f"{OFFLOAD_OPTIMIZER}={optimizer_device}") + enabled.append(f"zero_init={1 if zero_init else 0}") + return ", ".join(enabled) + + nodes_str = "nodes" if num_nodes > 1 else "node" + gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" + print( + "Estimated memory needed for params, optim states and gradients for a:\n" + f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" + f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params." + ) + print(" per CPU | per GPU | Options") + for cpu_offload in [True, False]: + for cpu_offload_params in [True, False]: + if not cpu_offload and cpu_offload_params: + continue + for zero_init in [True, False]: + cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs( + total_params=total_params, + largest_layer_params=largest_layer_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + cpu_offload=cpu_offload, + cpu_offload_params=cpu_offload_params, + zero_init=zero_init, + additional_buffer_factor=additional_buffer_factor + ) + + options_str = format_options(cpu_offload=cpu_offload, + cpu_offload_params=cpu_offload_params, + zero_init=zero_init) + print( + f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}") diff --git a/docs/README.md b/docs/README.md index 0ac7783f3..4b80f6bd4 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,49 +1,49 @@ -# DeepSpeed Documentation - -This directory includes the source code for the website and documentation of DeepSpeed. The `code-docs/` directory is used to build [deepspeed.readthedocs.io](https://deepspeed.readthedocs.io/en/latest/). - -[deepspeed.ai](https://www.deepspeed.ai/) is the recommended way to read all DeepSpeed documentation. Directly viewing the Markdown files in this directory will not include images and other features. - -## Building the documentation locally -You can serve the DeepSpeed website locally. This is especially useful for development. - -### Prerequisites -The DeepSpeed website relies on [Jekyll](https://jekyllrb.com/). There are several [guides for installation](https://jekyllrb.com/docs/installation/). The instructions below assume you are in an Ubuntu environment and have been tested on WSL. - -First ensure that you have the necessary packages (e.g., `make` and `zlib`). -``` -sudo apt-get install build-essential zlib1g-dev ruby-full -``` - -Add these lines to your `.bashrc` or equivalent to ensure you have permissions to install Ruby packages without `sudo`. -``` -export GEM_HOME="$HOME/gems" -export PATH="$HOME/gems/bin:$PATH" -``` -Don't forget to `source ~/.bashrc` afterwards 😊. - - -Now we can install Jekyll and [Bundler](https://bundler.io/): -``` -gem install jekyll bundler -``` - -### Start a local webserver -We now need to install the required Ruby packages for the website. - -**NOTE**: you should change to this folder (i.e., docs) before running the installation command to avoid this [error](https://stackoverflow.com/questions/10012181/bundle-install-returns-could-not-locate-gemfile/35157872): - -> Could not locate Gemfile - -**NOTE**: this step frequently hangs when connected to a VPN (including MSVPN). Simply disconnect for the package installation. - - -``` -bundle install -``` - -You can now start a local webserver via: -``` -bundle exec jekyll serve -``` -The website should now be accessible at [http://localhost:4000](http://localhost:4000) +# DeepSpeed Documentation + +This directory includes the source code for the website and documentation of DeepSpeed. The `code-docs/` directory is used to build [deepspeed.readthedocs.io](https://deepspeed.readthedocs.io/en/latest/). + +[deepspeed.ai](https://www.deepspeed.ai/) is the recommended way to read all DeepSpeed documentation. Directly viewing the Markdown files in this directory will not include images and other features. + +## Building the documentation locally +You can serve the DeepSpeed website locally. This is especially useful for development. + +### Prerequisites +The DeepSpeed website relies on [Jekyll](https://jekyllrb.com/). There are several [guides for installation](https://jekyllrb.com/docs/installation/). The instructions below assume you are in an Ubuntu environment and have been tested on WSL. + +First ensure that you have the necessary packages (e.g., `make` and `zlib`). +``` +sudo apt-get install build-essential zlib1g-dev ruby-full +``` + +Add these lines to your `.bashrc` or equivalent to ensure you have permissions to install Ruby packages without `sudo`. +``` +export GEM_HOME="$HOME/gems" +export PATH="$HOME/gems/bin:$PATH" +``` +Don't forget to `source ~/.bashrc` afterwards 😊. + + +Now we can install Jekyll and [Bundler](https://bundler.io/): +``` +gem install jekyll bundler +``` + +### Start a local webserver +We now need to install the required Ruby packages for the website. + +**NOTE**: you should change to this folder (i.e., docs) before running the installation command to avoid this [error](https://stackoverflow.com/questions/10012181/bundle-install-returns-could-not-locate-gemfile/35157872): + +> Could not locate Gemfile + +**NOTE**: this step frequently hangs when connected to a VPN (including MSVPN). Simply disconnect for the package installation. + + +``` +bundle install +``` + +You can now start a local webserver via: +``` +bundle exec jekyll serve +``` +The website should now be accessible at [http://localhost:4000](http://localhost:4000) diff --git a/docs/_posts/2021-03-08-zero3-offload.md b/docs/_posts/2021-03-08-zero3-offload.md index fa12ab5b2..3fba666ea 100644 --- a/docs/_posts/2021-03-08-zero3-offload.md +++ b/docs/_posts/2021-03-08-zero3-offload.md @@ -1,100 +1,100 @@ ---- -layout: single -title: "DeepSpeed ZeRO-3 Offload" -excerpt: "" -categories: news -new_post: true -date: 2021-03-08 00:00:00 ---- -Today we are announcing the release of ZeRO-3 Offload, a highly efficient and easy to use implementation of ZeRO Stage 3 and ZeRO Offload combined, geared towards our continued goal of democratizing AI by making efficient large-scale DL training available to everyone. The key benefits of ZeRO-3 Offload are: - -* Unprecedented memory efficiency to run very large models on a limited number of GPU resources - e.g., fine-tune models with over 40B parameters on a single GPU and over 2 Trillion parameters on 512 GPUs! -* Extremely Easy to use: - * Scale to over a trillion parameters without the need to combine multiple parallelism techniques in complicated ways. - * For existing DeepSpeed users, turn on ZeRO-3 Offload with just a few flags in DeepSpeed Config file. -* High-performance per-GPU throughput and super-linear scalability across GPUs for distributed training. - * With 1 Trillion parameters, ZeRO-3 Offload sustains 25 PetaFlops in compute performance on 512 NVIDIA V100 GPUs, achieving 49 TFlops/GPU. - * Up to 2x improvement in throughput compared to ZeRO- 2 Offload on single GPU - - -

Overview of ZeRO family of technology

- -The Zero Redundancy Optimizer (abbreviated ZeRO) is a family of memory optimization technologies for large-scale distributed deep learning. Unlike data parallelism (that is efficient but can only support a limited model size) or model parallelism (that can support larger model sizes but requires significant code refactoring while adding communication overhead that limits efficiency), ZeRO allows fitting larger models in memory without requiring code refactoring while remaining very efficient. ZeRO does so by eliminating the memory redundancy that is inherent in data parallelism while limiting the communication overhead to a minimum. -ZeRO removes the memory redundancies across data-parallel processes by partitioning the three model states (optimizer states, gradients, and parameters) across data-parallel processes instead of replicating them. By doing this, it boosts memory efficiency compared to classic data-parallelism while retaining its computational granularity and communication efficiency. -There are three stages in ZeRO corresponding to three model states, as shown in the Figure 1: the first stage (ZeRO-1) partitions only the optimizer states, the second stage (ZeRO-2) partitions both the optimizer states and the gradients and the final stage (ZeRO-3) partitions all three model states (for more details see the ZeRO [paper](https://arxiv.org/abs/1910.02054v3)). - - - - -Figure 1. Overview of ZeRO memory savings - -In addition to these three stages, ZeRO family of technology also consists of ZeRO-2 Offload. ZeRO-2 Offload is a heterogenous DL training technology that works in conjunction with ZeRO-2 to offload partitioned optimizer states and gradients to CPU memory. ZeRO-2 Offload offers the full memory advantage of ZeRO-2 even on a single GPU, while at the same time offering great scalability of ZeRO-2 on multi-GPU setup. DeepSpeed library has been offering ZeRO-2 Offload since Sept 2020. For details, please see below: - -* ZeRO: [Stage 1 blog](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/), [Stage 2 blog](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/), [Tutorial](/tutorials/zero) -* ZeRO-Offload: [Blog](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-3), [Tutorials](/tutorials/zero-offload), [Paper link](https://arxiv.org/abs/2101.06840) - -

ZeRO-3 Offload

-With today’s release of ZeRO-3 Offload, we are adding support for partitioning and offloading parameters in addition to optimizer states and gradients partitioning already supported by ZeRO-2 Offload in DeepSpeed. With parameter partitioning ZeRO-3 Offload implements the full set of features in the three stages of ZeRO, that allows for a linear growth in model size with the number of GPUs. In addition, ZeRO-3 Offload can also optionally offload all these model states to CPU to further reduce GPU memory consumption, leveraging both CPU and GPU to maximize memory and compute efficiency of the entire system. - -We believe ZeRO-3 Offload offers a massive leap for large model training, in three regards: - -i) Unprecedented model scale, - -ii) Ease of supporting very-large models, and - -iii) Achieving excellent training efficiency. - - -

Unprecedented model scale

-Unlike ZeRO-2 and ZeRO-Offload where the parameters have to fit in the memory of a single GPU, ZeRO-3 Offload can partition the parameters across GPUs, and offload them to CPU, supporting model sizes that are much larger than the memory on a single GPU. Furthermore, ZeRO-3 Offload goes beyond the state-of-the-art hybrid 3D-parallelism (data, model and pipeline parallelism combined). While 3D Parallelism is limited by the aggregate GPU memory, ZeRO-3 Offload can exploit both GPU and CPU memory, the latter of which is much larger and cheaper compared to GPU memory. This allows ZeRO-3 Offload to train larger model sizes with the given GPU and CPU resources than any other currently available technology. - -Model Scale on Single GPU: ZeRO-3 Offload can train models with over 40B parameters efficiently on a single GPU (e.g., 32GB V100 GPU + 1.5TB CPU memory). This is 3x larger than what is possible with ZeRO-2 Offload, the current state-of-the art. - -Model Scale on Multi-GPUs: With ZeRO-3 Offload you can train a trillion and two trillion parameter models on NVIDIA 32GB V100 DGX-2 cluster with 256 GPUs and 512 GPUs, respectively. In contrast, the state-of-art 3D Parallelism requires 800 GPUs, and 1600 GPUs, respectively, to fit the same sized models. This represents a 3x reduction in GPUs required to fit models with over a trillion parameters. - -

Ease of supporting very large models

-From a system perspective, training models with hundreds of billions and trillions of parameters is extremely challenging. Data parallelism cannot scale the model size much further beyond a billion parameters, model parallelism (with tensor slicing) cannot be used to scale model size efficiently beyond a single node boundary due to massive communication overheads, and pipeline parallelism cannot scale beyond the number of layers available in a model, which limits both the model size and the number of GPUs that it can scale to. - -The only existing parallel technology available that can scale to over a trillion parameters on massively parallel GPU clusters is the [3D parallelism](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-0) that combines data, model and pipeline parallelism in complex ways. While such a system can be very efficient, it requires major model code refactoring from data scientists to split the model into load balanced pipeline stages. This also makes 3D parallelism inflexible in the type of models that it can support, since models with complex dependency graphs cannot be easily converted into a load balanced pipeline. - -ZeRO-3 Offload address these challenges in two ways: - -i) With ground-breaking memory efficiency, ZeRO-3 and ZeRO-3 Offload are the only DL parallel technology that can efficiently scale to over a trillion parameters by itself, without requiring a hybrid parallelism strategy, greatly simplifying the system stack for DL training. - -ii) ZeRO-3 Offload requires virtually no model refactoring from model scientists, liberating data scientists to scale up complex models to hundreds of billions to trillions of parameters. - -

Excellent training efficiency

-High-performance per-GPU throughput on multiple nodes: ZeRO-3 Offload offers excellent training efficiency for multi-billion and trillion parameter models on multiple nodes. It achieves a sustained throughput of up to 50 Tflops per GPU running on 32 DGX2 nodes comprising 512 NVIDIA V100 GPUs (see Figure 2). In comparison, the standard data parallel training with PyTorch can only achieve 30 TFlops per GPU for a 1.2B parameter model, the largest model that can be trained using data parallelism alone. - - - - -Figure 2. ZeRO-3 Offload: Multi-billion and trillion parameter model throughput on 512 V100 GPUs - -ZeRO-3 Offload obtains high efficiency despite the 50% communication overhead of ZeRO Stage 3 compared to standard data parallel training for a fixed batch size. This is made possible through a communication overlap centric design and implementation, which allows ZeRO-3 Offload to hide nearly all of the communication volume with computation, while taking advantage of a larger batch size for improved efficiency resulting from better GPU memory efficiency. - - -Efficient multi-billion parameter model training on a single GPU: ZeRO-3 Offload further democratizes AI by enabling efficient training of multi-billion parameter models on a single GPU. For single GPU training, ZeRO-3 Offload provides benefits over ZeRO-2 Offload along two dimensions. First, ZeRO-3 Offload increases the size of models trainable on a single V100 from 13B to 40B. Second, for ZeRO-3 Offload provides speedups (e.g., 2.3X for 13B) compared to ZeRO-2 Offload for model sizes trainable by both solutions. These results are summarized in Figure 3. - - - - -Figure 3. Multi-billion parameter model training on one V100 GPU - -Super-Linear scalability across GPUs: Additionally, ZeRO-3 Offload also preserves the super-linear scalability characteristics that we have demonstrated with all our previous ZeRO technologies (ZeRO Stage 1, ZeRO Stage 2 and ZeRO Offload). ZeRO-3 Offload can exploit the aggregate PCI-E bandwidth between GPU and CPU across all the GPUs in multi-GPU training configuration, and at the same time, it can also exploit the aggregate CPU compute across all the nodes. As a result, the CPU-GPU-CPU communication time as well as the optimizer update time decreases linearly with number of GPUs and nodes, respectively, allowing ZeRO-3 Offload to exhibit super-linear scaling (see Figure 4). - - - - -Figure 4. ZeRO-3 Offload Superlinear Scalability for a 200B parameter model. - -

How to use ZeRO-3 Offload

-As with many other existing DeepSpeed features, once the user model has been converted to use DeepSpeed, enabling ZeRO-3 Offload is as easy as turning on a couple of flags in DeepSpeed Config file. Supporting advanced features like weight sharing, or enabling extremely large models that requires to be partitioned across GPUs/nodes to fit in GPU/CPU memory, can be done with just a couple of additional lines of code change using the ZeRO-3 Offload API. - -If you are already a DeepSpeed user, you can find our detailed tutorial on ZeRO-3 Offload below. If you are new to DeepSpeed, we recommend that you start at the getting started page before trying out our ZeRO-3 Offload Tutorial. - -* DeepSpeed: [Getting Started Page](/getting-started/) - -* ZeRO-3 Offload [Documentation](https://deepspeed.readthedocs.io/en/latest/zero3.html), [Tutorial](/tutorials/zero/#training-trillion-scale-models-with-zero-3-offload) - -The DeepSpeed Team is very excited to share ZeRO-3 Offload with the DL community. +--- +layout: single +title: "DeepSpeed ZeRO-3 Offload" +excerpt: "" +categories: news +new_post: true +date: 2021-03-08 00:00:00 +--- +Today we are announcing the release of ZeRO-3 Offload, a highly efficient and easy to use implementation of ZeRO Stage 3 and ZeRO Offload combined, geared towards our continued goal of democratizing AI by making efficient large-scale DL training available to everyone. The key benefits of ZeRO-3 Offload are: + +* Unprecedented memory efficiency to run very large models on a limited number of GPU resources - e.g., fine-tune models with over 40B parameters on a single GPU and over 2 Trillion parameters on 512 GPUs! +* Extremely Easy to use: + * Scale to over a trillion parameters without the need to combine multiple parallelism techniques in complicated ways. + * For existing DeepSpeed users, turn on ZeRO-3 Offload with just a few flags in DeepSpeed Config file. +* High-performance per-GPU throughput and super-linear scalability across GPUs for distributed training. + * With 1 Trillion parameters, ZeRO-3 Offload sustains 25 PetaFlops in compute performance on 512 NVIDIA V100 GPUs, achieving 49 TFlops/GPU. + * Up to 2x improvement in throughput compared to ZeRO- 2 Offload on single GPU + + +

Overview of ZeRO family of technology

+ +The Zero Redundancy Optimizer (abbreviated ZeRO) is a family of memory optimization technologies for large-scale distributed deep learning. Unlike data parallelism (that is efficient but can only support a limited model size) or model parallelism (that can support larger model sizes but requires significant code refactoring while adding communication overhead that limits efficiency), ZeRO allows fitting larger models in memory without requiring code refactoring while remaining very efficient. ZeRO does so by eliminating the memory redundancy that is inherent in data parallelism while limiting the communication overhead to a minimum. +ZeRO removes the memory redundancies across data-parallel processes by partitioning the three model states (optimizer states, gradients, and parameters) across data-parallel processes instead of replicating them. By doing this, it boosts memory efficiency compared to classic data-parallelism while retaining its computational granularity and communication efficiency. +There are three stages in ZeRO corresponding to three model states, as shown in the Figure 1: the first stage (ZeRO-1) partitions only the optimizer states, the second stage (ZeRO-2) partitions both the optimizer states and the gradients and the final stage (ZeRO-3) partitions all three model states (for more details see the ZeRO [paper](https://arxiv.org/abs/1910.02054v3)). + + + + +Figure 1. Overview of ZeRO memory savings + +In addition to these three stages, ZeRO family of technology also consists of ZeRO-2 Offload. ZeRO-2 Offload is a heterogenous DL training technology that works in conjunction with ZeRO-2 to offload partitioned optimizer states and gradients to CPU memory. ZeRO-2 Offload offers the full memory advantage of ZeRO-2 even on a single GPU, while at the same time offering great scalability of ZeRO-2 on multi-GPU setup. DeepSpeed library has been offering ZeRO-2 Offload since Sept 2020. For details, please see below: + +* ZeRO: [Stage 1 blog](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/), [Stage 2 blog](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/), [Tutorial](/tutorials/zero) +* ZeRO-Offload: [Blog](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-3), [Tutorials](/tutorials/zero-offload), [Paper link](https://arxiv.org/abs/2101.06840) + +

ZeRO-3 Offload

+With today’s release of ZeRO-3 Offload, we are adding support for partitioning and offloading parameters in addition to optimizer states and gradients partitioning already supported by ZeRO-2 Offload in DeepSpeed. With parameter partitioning ZeRO-3 Offload implements the full set of features in the three stages of ZeRO, that allows for a linear growth in model size with the number of GPUs. In addition, ZeRO-3 Offload can also optionally offload all these model states to CPU to further reduce GPU memory consumption, leveraging both CPU and GPU to maximize memory and compute efficiency of the entire system. + +We believe ZeRO-3 Offload offers a massive leap for large model training, in three regards: + +i) Unprecedented model scale, + +ii) Ease of supporting very-large models, and + +iii) Achieving excellent training efficiency. + + +

Unprecedented model scale

+Unlike ZeRO-2 and ZeRO-Offload where the parameters have to fit in the memory of a single GPU, ZeRO-3 Offload can partition the parameters across GPUs, and offload them to CPU, supporting model sizes that are much larger than the memory on a single GPU. Furthermore, ZeRO-3 Offload goes beyond the state-of-the-art hybrid 3D-parallelism (data, model and pipeline parallelism combined). While 3D Parallelism is limited by the aggregate GPU memory, ZeRO-3 Offload can exploit both GPU and CPU memory, the latter of which is much larger and cheaper compared to GPU memory. This allows ZeRO-3 Offload to train larger model sizes with the given GPU and CPU resources than any other currently available technology. + +Model Scale on Single GPU: ZeRO-3 Offload can train models with over 40B parameters efficiently on a single GPU (e.g., 32GB V100 GPU + 1.5TB CPU memory). This is 3x larger than what is possible with ZeRO-2 Offload, the current state-of-the art. + +Model Scale on Multi-GPUs: With ZeRO-3 Offload you can train a trillion and two trillion parameter models on NVIDIA 32GB V100 DGX-2 cluster with 256 GPUs and 512 GPUs, respectively. In contrast, the state-of-art 3D Parallelism requires 800 GPUs, and 1600 GPUs, respectively, to fit the same sized models. This represents a 3x reduction in GPUs required to fit models with over a trillion parameters. + +

Ease of supporting very large models

+From a system perspective, training models with hundreds of billions and trillions of parameters is extremely challenging. Data parallelism cannot scale the model size much further beyond a billion parameters, model parallelism (with tensor slicing) cannot be used to scale model size efficiently beyond a single node boundary due to massive communication overheads, and pipeline parallelism cannot scale beyond the number of layers available in a model, which limits both the model size and the number of GPUs that it can scale to. + +The only existing parallel technology available that can scale to over a trillion parameters on massively parallel GPU clusters is the [3D parallelism](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-0) that combines data, model and pipeline parallelism in complex ways. While such a system can be very efficient, it requires major model code refactoring from data scientists to split the model into load balanced pipeline stages. This also makes 3D parallelism inflexible in the type of models that it can support, since models with complex dependency graphs cannot be easily converted into a load balanced pipeline. + +ZeRO-3 Offload address these challenges in two ways: + +i) With ground-breaking memory efficiency, ZeRO-3 and ZeRO-3 Offload are the only DL parallel technology that can efficiently scale to over a trillion parameters by itself, without requiring a hybrid parallelism strategy, greatly simplifying the system stack for DL training. + +ii) ZeRO-3 Offload requires virtually no model refactoring from model scientists, liberating data scientists to scale up complex models to hundreds of billions to trillions of parameters. + +

Excellent training efficiency

+High-performance per-GPU throughput on multiple nodes: ZeRO-3 Offload offers excellent training efficiency for multi-billion and trillion parameter models on multiple nodes. It achieves a sustained throughput of up to 50 Tflops per GPU running on 32 DGX2 nodes comprising 512 NVIDIA V100 GPUs (see Figure 2). In comparison, the standard data parallel training with PyTorch can only achieve 30 TFlops per GPU for a 1.2B parameter model, the largest model that can be trained using data parallelism alone. + + + + +Figure 2. ZeRO-3 Offload: Multi-billion and trillion parameter model throughput on 512 V100 GPUs + +ZeRO-3 Offload obtains high efficiency despite the 50% communication overhead of ZeRO Stage 3 compared to standard data parallel training for a fixed batch size. This is made possible through a communication overlap centric design and implementation, which allows ZeRO-3 Offload to hide nearly all of the communication volume with computation, while taking advantage of a larger batch size for improved efficiency resulting from better GPU memory efficiency. + + +Efficient multi-billion parameter model training on a single GPU: ZeRO-3 Offload further democratizes AI by enabling efficient training of multi-billion parameter models on a single GPU. For single GPU training, ZeRO-3 Offload provides benefits over ZeRO-2 Offload along two dimensions. First, ZeRO-3 Offload increases the size of models trainable on a single V100 from 13B to 40B. Second, for ZeRO-3 Offload provides speedups (e.g., 2.3X for 13B) compared to ZeRO-2 Offload for model sizes trainable by both solutions. These results are summarized in Figure 3. + + + + +Figure 3. Multi-billion parameter model training on one V100 GPU + +Super-Linear scalability across GPUs: Additionally, ZeRO-3 Offload also preserves the super-linear scalability characteristics that we have demonstrated with all our previous ZeRO technologies (ZeRO Stage 1, ZeRO Stage 2 and ZeRO Offload). ZeRO-3 Offload can exploit the aggregate PCI-E bandwidth between GPU and CPU across all the GPUs in multi-GPU training configuration, and at the same time, it can also exploit the aggregate CPU compute across all the nodes. As a result, the CPU-GPU-CPU communication time as well as the optimizer update time decreases linearly with number of GPUs and nodes, respectively, allowing ZeRO-3 Offload to exhibit super-linear scaling (see Figure 4). + + + + +Figure 4. ZeRO-3 Offload Superlinear Scalability for a 200B parameter model. + +

How to use ZeRO-3 Offload

+As with many other existing DeepSpeed features, once the user model has been converted to use DeepSpeed, enabling ZeRO-3 Offload is as easy as turning on a couple of flags in DeepSpeed Config file. Supporting advanced features like weight sharing, or enabling extremely large models that requires to be partitioned across GPUs/nodes to fit in GPU/CPU memory, can be done with just a couple of additional lines of code change using the ZeRO-3 Offload API. + +If you are already a DeepSpeed user, you can find our detailed tutorial on ZeRO-3 Offload below. If you are new to DeepSpeed, we recommend that you start at the getting started page before trying out our ZeRO-3 Offload Tutorial. + +* DeepSpeed: [Getting Started Page](/getting-started/) + +* ZeRO-3 Offload [Documentation](https://deepspeed.readthedocs.io/en/latest/zero3.html), [Tutorial](/tutorials/zero/#training-trillion-scale-models-with-zero-3-offload) + +The DeepSpeed Team is very excited to share ZeRO-3 Offload with the DL community. diff --git a/docs/_posts/2021-05-05-inference-kernel-optimization.md b/docs/_posts/2021-05-05-inference-kernel-optimization.md index 18ab7c321..9b9e747a2 100644 --- a/docs/_posts/2021-05-05-inference-kernel-optimization.md +++ b/docs/_posts/2021-05-05-inference-kernel-optimization.md @@ -1,73 +1,73 @@ ---- -layout: single -title: "DeepSpeed Inference: Multi-GPU inference with customized inference kernels and quantization support" -excerpt: "" -categories: news -new_post: false -date: 2021-03-16 00:00:00 ---- -While DeepSpeed supports training advanced large-scale models, using these trained models in the desired application scenarios is still challenging due to three major limitations in existing inference solutions: 1) lack of support for multi-GPU inference to fit large models and meet latency requirements, 2) limited GPU kernel performance when running inference with small batch sizes, and 3) difficulties in exploiting quantization, which includes both quantizing the model to reduce the model size and latency as well as supporting high-performance inference of quantized models without specialized hardware. - -To handle these challenges, we introduce DeepSpeed Inference, which seamlessly adds high-performance inference support to large models trained in DeepSpeed with three key features: inference-adapted parallelism for multi-GPU inference, inference-optimized kernels tuned for small batch sizes, and flexible support for quantize-aware training and inference kernels for quantized models. - -## Multi-GPU Inference with Adaptive Parallelism - -Parallelism is an effective approach to fit large models and reduce per-device memory consumption for both training and inference. However, simply applying training parallelism choices and degree to inference does not work well. The MP and PP configuration is normally set during the model training, apart from the data parallelism (DP), based on the memory footprint and computation style, and resource budget. On one hand, inference computation intrinsically requires less memory, so it can afford a larger partition per device. It helps reduce the degree of parallelism needed for model deployment. On the other hand, optimizing latency or meeting latency requirements is often a first-class citizen in inference while training optimizes throughput. - -To obtain desired latency, DeepSpeed Inference automatically adapts MP as an effective approach to reduce model latency, and its parallelism degree is often determined first. With MP, we can split the mode and parallelize computational operations across multiple devices (GPUs) to reduce latency, but it reduces computation granularity and increases communication that may hurt throughput. Once the latency target has been met, DeepSpeed can apply pipeline parallelism to maximize the throughput. Overall, DeepSpeed Inference supports flexible adaptation of both parallelism approach and degree choices from training to inference, minimizing latency while saving deployment costs. - - -## Customized Inference Kernels for Boosted Compute Efficiency of Transformer Blocks - -To achieve high compute efficiency, DeepSpeed-inference offers inference kernels tailored for Transformer blocks through operator fusion, taking model-parallelism for multi-GPU into account. The main difference between our kernel-fusion scheme and similar approaches is that we not only fuse element-wise operations (such as bias-add, residual, and activation function), but also merge the General matrix multiply (GeMM) operations with other operations. To do this, we design an efficient implementation for the vector-matrix or skinny matrix-matrix multiplication that allows us to fuse more operations at the reduction boundary of GeMM operations. - -# Kernel-Fusion - -We take two main policies for fusing operations: 1) keeping the access-pattern of inputs and outputs intact throughout the sequence of operations fused together; 2) fusing operations at each all-reduce boundary. The first policy ensures that different thread-blocks won’t encounter transferring data between Streaming-Multiprocessors (SMs). This is due to no straight-forward communication among SMs other than using the main memory which adds the block-synching overhead because of non-deterministic behavior of memory access. The reason behind the second policy is that we cannot continue the execution unless the partial results are reduced among the model-parallel GPUs. - -![Inference-Kernel-Fusion](/assets/images/inference-kernel-fusion.png){: .align-center} - -Figure 1: Transformer Layer with Megatron-style model-parallelism all-reduce components. The figure illustrates the parts of layer fused together with broken lines (width of line shows the fusion depth). - -Figure 1 shows the different components of a Transformer layer, and the groups of operations considered for fusion in our inference optimization. We also consider the NVIDIA Megatron-LM style of parallelism that partitions attention (Attn) and feed-forward (FF) blocks across multiple GPUs. Thus, we include the two all-reduce operations that reduce the results among parallel GPUs after Attn and FF blocks. As Figure 1 shows, we fuse the operations inside a Transformer layer at four main regions: -1. Input Layer-Norm plus Query, Key, and Value GeMMs and their bias adds. -2. Transform plus Attention. -3. Intermediate FF, Layer-Norm, Bias-add, Residual, and Gaussian Error Linear Unit (GELU). -4. Bias-add plus Residual. - -To fuse these operations, we exploit shared-memory as an intermediate cache for transferring data between reduction operations used in layer-norm and GeMM, and the element-wise operations. Moreover, we use the warp-level instructions to communicate data between threads when reducing partial computations. In addition, we use a new schedule for GeMM operations, which allows for fusing as many operations as needed for the third kernel-fusion. We also combine the GeMMs for the attention computation in the second kernel-fusion, by using an implicit matrix transformation in order to reduce the memory pressure. Compared to the unfused computation style using cuBLAS GeMM, we improve the performance by 1.5x, 2.9x. 3x, and 1.2x for all these kernel-fusions, respectively. - -## Seamless pipeline from training to inference with automatic kernel-injection - -To run the model in Inference mode, DeepSpeed simply requires the location of the model checkpoints and the desired parallelism configuration, i.e., MP/PP degree. DeepSpeed Inference kernels can also be enabled for many well-known model architectures such as HuggingFace (Bert and GPT-2) or Megatron GPT-based models using a pre-defined policy map that maps the original parameters to the parameters in the inference kernels. For other transformer-based models, user can specify their own policy map. Note that DS-Inference can run independent of the training pipeline as long as it receives all model checkpoints, and the DeepSpeed Transformer kernels for inference can be injected into any Transformer model if the right mapping policy is defined. For more information on how to enable Transformer inference kernel as well as specifying parallelism, please refer to out [inference tutorial](https://www.deepspeed.ai/tutorials/inference-tutorial/). - - -## Flexible quantization support - -To further reduce the inference cost for large-scale models, we created the DeepSpeed Quantization Toolkit, supporting flexible quantize-aware training and high-performance kernels for quantized inference. - -For training, we introduce a novel approach called Mixture of Quantization (MoQ), which is inspired by mixed-precision training while seamlessly applying quantization. With MoQ, we can control the precision of the model by simulating the impact of quantization when updating the parameters at each step of training. Moreover, it supports flexible quantization policies and schedules—we find that by dynamically adjusting the number of quantization bits during training, the final quantized model provides higher accuracy under the same compression ratio. To adapt to different tasks, MoQ can also leverage the second order information of models to detect their sensitivity to precision and adjust the quantization schedule and target accordingly. - -To maximize the performance gains from the quantization model, we provide inference kernels tailored for quantized models that reduce latency through optimizing data movement but do not require specialized hardware. Finally, our toolkit does not require any code changes on the client side, making it easy to use. - -## Performance results - -Boosting throughput and reducing inference cost. Figure 3 shows the inference throughput per GPU for the three model sizes corresponding to the three Transformer networks, GPT-2, Turing-NLG, and GPT-3. DeepSpeed Inference increases in per-GPU throughput by 2 to 4 times when using the same precision of FP16 as the baseline. By enabling quantization, we boost throughput further. We reach a throughput improvement of 3x for GPT-2, 5x for Turing-NLG, and 3x for a model that is similar in characteristics and size to GPT-3, which directly translates to 3–5x inference cost reduction on serving these large models. In addition, we achieve these throughput and cost improvements without compromising latency as shown in Figure 5. - -![Inference-Throughput](/assets/images/inference-throughput.png){: .align-center} - -Figure 3: Inference throughput for different model sizes. DeepSpeed Inference achieves 3x to 5x higher throughput than baseline. - -One source of inference cost reduction is through reducing the number of GPUs for hosting large models as shown in Figure 4. The optimized GPU resources comes from 1) using inference-adapted parallelism, allowing users to adjust the model and pipeline parallelism degree from the trained model checkpoints, and 2) shrinking model memory footprint by half with INT8 quantization. As shown in this figure, we use 2x less GPUs to run inference for the 17B model size by adapting the parallelism. Together with INT8 quantization through DeepSpeed MoQ, we use 4x and 2x fewer GPUs for 17B and 175B sizes respectively. - -![Inference-Throughput](/assets/images/gpu-numbers.png){: .align-center} - -Figure 4: Number of GPUs used for running inference on the different model sizes shown in Figure 4. - -Reducing inference latency. For the application scenarios where inference latency is critical, we can increase model parallelism degree in DeepSpeed Inference to reduce inference latency further. As Figure 5 depicts, we can reduce the latency by 2.3x compared to PyTorch as we increase the model-parallelism size to 4. Furthermore, we can still have high latency improvement with a fewer number of GPUs by adapting the parallelism at inference and using MoQ to quantize the model. We obtain 1.3x and 1.9x speedups while using 4x and 2x lower resources than baseline, respectively. - -For the application scenarios where inference latency is critical, we can increase model parallelism degree in DeepSpeed Inference to reduce inference latency further. As Figure 5 depicts, we can reduce the latency by 2.3x compared to PyTorch as we increase the model-parallelism size to 4. Furthermore, we can still have high latency improvement with a fewer number of GPUs by adapting the parallelism at inference and using MoQ to quantize the model. We obtain 1.3x and 1.9x speedups while using 4x and 2x lower resources than baseline, respectively. - -![Inference-Throughput](/assets/images/inference-latency.png){: .align-center} - -Figure 5. Inference latency for the 17B model using different parallelism configuration to optimize latency. +--- +layout: single +title: "DeepSpeed Inference: Multi-GPU inference with customized inference kernels and quantization support" +excerpt: "" +categories: news +new_post: false +date: 2021-03-16 00:00:00 +--- +While DeepSpeed supports training advanced large-scale models, using these trained models in the desired application scenarios is still challenging due to three major limitations in existing inference solutions: 1) lack of support for multi-GPU inference to fit large models and meet latency requirements, 2) limited GPU kernel performance when running inference with small batch sizes, and 3) difficulties in exploiting quantization, which includes both quantizing the model to reduce the model size and latency as well as supporting high-performance inference of quantized models without specialized hardware. + +To handle these challenges, we introduce DeepSpeed Inference, which seamlessly adds high-performance inference support to large models trained in DeepSpeed with three key features: inference-adapted parallelism for multi-GPU inference, inference-optimized kernels tuned for small batch sizes, and flexible support for quantize-aware training and inference kernels for quantized models. + +## Multi-GPU Inference with Adaptive Parallelism + +Parallelism is an effective approach to fit large models and reduce per-device memory consumption for both training and inference. However, simply applying training parallelism choices and degree to inference does not work well. The MP and PP configuration is normally set during the model training, apart from the data parallelism (DP), based on the memory footprint and computation style, and resource budget. On one hand, inference computation intrinsically requires less memory, so it can afford a larger partition per device. It helps reduce the degree of parallelism needed for model deployment. On the other hand, optimizing latency or meeting latency requirements is often a first-class citizen in inference while training optimizes throughput. + +To obtain desired latency, DeepSpeed Inference automatically adapts MP as an effective approach to reduce model latency, and its parallelism degree is often determined first. With MP, we can split the mode and parallelize computational operations across multiple devices (GPUs) to reduce latency, but it reduces computation granularity and increases communication that may hurt throughput. Once the latency target has been met, DeepSpeed can apply pipeline parallelism to maximize the throughput. Overall, DeepSpeed Inference supports flexible adaptation of both parallelism approach and degree choices from training to inference, minimizing latency while saving deployment costs. + + +## Customized Inference Kernels for Boosted Compute Efficiency of Transformer Blocks + +To achieve high compute efficiency, DeepSpeed-inference offers inference kernels tailored for Transformer blocks through operator fusion, taking model-parallelism for multi-GPU into account. The main difference between our kernel-fusion scheme and similar approaches is that we not only fuse element-wise operations (such as bias-add, residual, and activation function), but also merge the General matrix multiply (GeMM) operations with other operations. To do this, we design an efficient implementation for the vector-matrix or skinny matrix-matrix multiplication that allows us to fuse more operations at the reduction boundary of GeMM operations. + +# Kernel-Fusion + +We take two main policies for fusing operations: 1) keeping the access-pattern of inputs and outputs intact throughout the sequence of operations fused together; 2) fusing operations at each all-reduce boundary. The first policy ensures that different thread-blocks won’t encounter transferring data between Streaming-Multiprocessors (SMs). This is due to no straight-forward communication among SMs other than using the main memory which adds the block-synching overhead because of non-deterministic behavior of memory access. The reason behind the second policy is that we cannot continue the execution unless the partial results are reduced among the model-parallel GPUs. + +![Inference-Kernel-Fusion](/assets/images/inference-kernel-fusion.png){: .align-center} + +Figure 1: Transformer Layer with Megatron-style model-parallelism all-reduce components. The figure illustrates the parts of layer fused together with broken lines (width of line shows the fusion depth). + +Figure 1 shows the different components of a Transformer layer, and the groups of operations considered for fusion in our inference optimization. We also consider the NVIDIA Megatron-LM style of parallelism that partitions attention (Attn) and feed-forward (FF) blocks across multiple GPUs. Thus, we include the two all-reduce operations that reduce the results among parallel GPUs after Attn and FF blocks. As Figure 1 shows, we fuse the operations inside a Transformer layer at four main regions: +1. Input Layer-Norm plus Query, Key, and Value GeMMs and their bias adds. +2. Transform plus Attention. +3. Intermediate FF, Layer-Norm, Bias-add, Residual, and Gaussian Error Linear Unit (GELU). +4. Bias-add plus Residual. + +To fuse these operations, we exploit shared-memory as an intermediate cache for transferring data between reduction operations used in layer-norm and GeMM, and the element-wise operations. Moreover, we use the warp-level instructions to communicate data between threads when reducing partial computations. In addition, we use a new schedule for GeMM operations, which allows for fusing as many operations as needed for the third kernel-fusion. We also combine the GeMMs for the attention computation in the second kernel-fusion, by using an implicit matrix transformation in order to reduce the memory pressure. Compared to the unfused computation style using cuBLAS GeMM, we improve the performance by 1.5x, 2.9x. 3x, and 1.2x for all these kernel-fusions, respectively. + +## Seamless pipeline from training to inference with automatic kernel-injection + +To run the model in Inference mode, DeepSpeed simply requires the location of the model checkpoints and the desired parallelism configuration, i.e., MP/PP degree. DeepSpeed Inference kernels can also be enabled for many well-known model architectures such as HuggingFace (Bert and GPT-2) or Megatron GPT-based models using a pre-defined policy map that maps the original parameters to the parameters in the inference kernels. For other transformer-based models, user can specify their own policy map. Note that DS-Inference can run independent of the training pipeline as long as it receives all model checkpoints, and the DeepSpeed Transformer kernels for inference can be injected into any Transformer model if the right mapping policy is defined. For more information on how to enable Transformer inference kernel as well as specifying parallelism, please refer to out [inference tutorial](https://www.deepspeed.ai/tutorials/inference-tutorial/). + + +## Flexible quantization support + +To further reduce the inference cost for large-scale models, we created the DeepSpeed Quantization Toolkit, supporting flexible quantize-aware training and high-performance kernels for quantized inference. + +For training, we introduce a novel approach called Mixture of Quantization (MoQ), which is inspired by mixed-precision training while seamlessly applying quantization. With MoQ, we can control the precision of the model by simulating the impact of quantization when updating the parameters at each step of training. Moreover, it supports flexible quantization policies and schedules—we find that by dynamically adjusting the number of quantization bits during training, the final quantized model provides higher accuracy under the same compression ratio. To adapt to different tasks, MoQ can also leverage the second order information of models to detect their sensitivity to precision and adjust the quantization schedule and target accordingly. + +To maximize the performance gains from the quantization model, we provide inference kernels tailored for quantized models that reduce latency through optimizing data movement but do not require specialized hardware. Finally, our toolkit does not require any code changes on the client side, making it easy to use. + +## Performance results + +Boosting throughput and reducing inference cost. Figure 3 shows the inference throughput per GPU for the three model sizes corresponding to the three Transformer networks, GPT-2, Turing-NLG, and GPT-3. DeepSpeed Inference increases in per-GPU throughput by 2 to 4 times when using the same precision of FP16 as the baseline. By enabling quantization, we boost throughput further. We reach a throughput improvement of 3x for GPT-2, 5x for Turing-NLG, and 3x for a model that is similar in characteristics and size to GPT-3, which directly translates to 3–5x inference cost reduction on serving these large models. In addition, we achieve these throughput and cost improvements without compromising latency as shown in Figure 5. + +![Inference-Throughput](/assets/images/inference-throughput.png){: .align-center} + +Figure 3: Inference throughput for different model sizes. DeepSpeed Inference achieves 3x to 5x higher throughput than baseline. + +One source of inference cost reduction is through reducing the number of GPUs for hosting large models as shown in Figure 4. The optimized GPU resources comes from 1) using inference-adapted parallelism, allowing users to adjust the model and pipeline parallelism degree from the trained model checkpoints, and 2) shrinking model memory footprint by half with INT8 quantization. As shown in this figure, we use 2x less GPUs to run inference for the 17B model size by adapting the parallelism. Together with INT8 quantization through DeepSpeed MoQ, we use 4x and 2x fewer GPUs for 17B and 175B sizes respectively. + +![Inference-Throughput](/assets/images/gpu-numbers.png){: .align-center} + +Figure 4: Number of GPUs used for running inference on the different model sizes shown in Figure 4. + +Reducing inference latency. For the application scenarios where inference latency is critical, we can increase model parallelism degree in DeepSpeed Inference to reduce inference latency further. As Figure 5 depicts, we can reduce the latency by 2.3x compared to PyTorch as we increase the model-parallelism size to 4. Furthermore, we can still have high latency improvement with a fewer number of GPUs by adapting the parallelism at inference and using MoQ to quantize the model. We obtain 1.3x and 1.9x speedups while using 4x and 2x lower resources than baseline, respectively. + +For the application scenarios where inference latency is critical, we can increase model parallelism degree in DeepSpeed Inference to reduce inference latency further. As Figure 5 depicts, we can reduce the latency by 2.3x compared to PyTorch as we increase the model-parallelism size to 4. Furthermore, we can still have high latency improvement with a fewer number of GPUs by adapting the parallelism at inference and using MoQ to quantize the model. We obtain 1.3x and 1.9x speedups while using 4x and 2x lower resources than baseline, respectively. + +![Inference-Throughput](/assets/images/inference-latency.png){: .align-center} + +Figure 5. Inference latency for the 17B model using different parallelism configuration to optimize latency. diff --git a/docs/_tutorials/mixture-of-experts.md b/docs/_tutorials/mixture-of-experts.md index 39c85ebdb..ef8ca1756 100644 --- a/docs/_tutorials/mixture-of-experts.md +++ b/docs/_tutorials/mixture-of-experts.md @@ -1,197 +1,197 @@ ---- -title: "Mixture of Experts" ---- - -DeepSpeed v0.5 introduces new support for training Mixture of Experts (MoE) models. MoE models are an emerging class of sparsely activated models that have sublinear compute costs with respect to their parameters. For example, the [Switch Transformer](https://arxiv.org/abs/2101.03961) consists of over 1.6 trillion parameters, while the compute required to train it is approximately equal to that of a 10 billion-parameter dense model. This increase in model size offers tremendous accuracy gains for a constant compute budget. - -For more details on results and further discussion, please see our press release: [DeepSpeed powers 8x larger MoE model training with high performance]({{ site.press_release_v5 }}). - -## Getting started with a simple MoE example - -**Note:** DeepSpeed MoE requires Pytorch 1.8 or above. -{: .notice--info} - -As a simple starting point we will show how to apply DeepSpeed MoE to a cifar10 example. Please refer to -our [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) going forward. - -If you are adding MoE to an existing model you can use the snippet below to help guide you: - - -### Expert groups initialization - -DeepSpeed MoE supports five different forms of parallelism, and it exploits both GPU and CPU memory. Its flexible design enables users to mix different types of prevalent parallelism techniques, as shown in the table below. - -| Short Name | Flexible Parallelism Configurations | Benefit | -| ---------------- | ------------------------------------| --------------------------------------------------------------------------- | -| E | Expert | Scales the model size by increasing the number of experts | -| E + D | Expert + Data | Accelerates training throughput by scaling to multiple data parallel groups | -| E + Z | Expert + ZeRO-powered data | Partitions the nonexpert parameters to support larger base models | -| E + D + M | Expert + Data + Model | Supports massive hidden sizes and even larger base models than E+Z | -| E + D + Z | Expert + Data + ZeRO-powered data | Supports massive hidden sizes and even larger base models than E+Z | -| E + Z-Off + M | Expert + ZeRO-Offload + Model | Leverages both GPU and CPU memory for large MoE models on limited # of GPUs | - -To support different forms of parallelism, we create a notion of DeepSpeed process groups that resides in ```deepspeed.utils.groups.py``` - -For most cases, the model training code needs to initialize these groups by calling -```python -deepspeed.utils.groups.initialize(ep_size="desired expert-parallel world size") -``` - -The GPUs (or ranks) participating in an expert-parallel group will distribute the total number of experts specified by the model training code argument num_experts. - -### MoE layer API - -The hidden_size is the input dimension of a particular layer and the output dimension is the same as that. This could lead to some changes to your model definition, especially for vision/convolutional models because the input/output dimensions don't match in certain cases. E.g. in the CIFAR-10 example, we modify the third fully connected layer to add the MoE layer. To cater for this, we need to add an additional fully-connected layer, whose input dimension is equal to the output dimension of the MoE layer. - -Original model config - -```python - self.fc3 = nn.Linear(84, 10) -``` - -Updated with MoE Layers - -```python - self.fc3 = nn.Linear(84, 84) - self.fc3 = deepspeed.moe.layer.MoE(hidden_size=84, expert=self.fc3, num_experts=args.num_experts, ...) - self.fc4 = nn.Linear(84, 10) -``` - -### An Example Scenario - -Given a total number of GPUs in our world size and a subset of GPUs in our expert-parallel world as follows. - -```python -WORLD_SIZE = 4 -EP_WORLD_SIZE = 2 -EXPERTS = 8 -``` - -The user code needs to initialize the groups as follows. - -```python -groups.initialize (ep_size=EP_WORLD_SIZE) -``` - -After that, the model code needs to use the deepspeed.moe.layer.MoE API as follows. - -```python -self.experts = deepspeed.moe.layer.MoE(hidden_size=input_dim, expert=ExpertModule(), num_experts=EXPERTS) -``` -With the above two commands, the DeepSpeed runtime will be set to train an MoE model with a total of 8 experts on 4 GPUs in 4 experts/GPU mode. We call this the E + D mode as described earlier in the table. - -For more advanced use case of the groups API including the inter-operability with Megatron style mpu object, watch this space! - - -```python -import torch -import deepspeed -import deepspeed.utils.groups as groups -from deepspeed.moe.layer import MoE - -WORLD_SIZE = 4 -EP_WORLD_SIZE = 2 -EXPERTS = 8 - -groups.initialize(ep_size=EP_WORLD_SIZE) - -fc3 = torch.nn.Linear(84, 84) -fc3 = MoE(hidden_size=84, expert=self.fc3, num_experts=EXPERTS, k=1) -fc4 = torch.nn.Linear(84, 10) - -``` - -For a runnable end-to-end example, please look at [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) - -### Combining ZeRO-Offload and DeepSpeed MoE for very large models - -To use MoE Layers in DeepSpeed, we rely on two parameter groups that are passed to an optimizer. A concrete example to create such groups is available from the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar). - -The relevant function that creates these param groups is as follows. - -```python -def create_moe_param_groups(model): - from deepspeed.moe.utils import is_moe_param - - params_with_weight_decay = {'params': [], 'name': 'weight_decay_params'} - moe_params_with_weight_decay = { - 'params': [], - 'moe': True, - 'name': 'weight_decay_moe_params' - } - - for module_ in model.modules(): - moe_params_with_weight_decay['params'].extend([ - p for n, p in list(module_._parameters.items()) - if p is not None and is_moe_param(p) - ]) - params_with_weight_decay['params'].extend([ - p for n, p in list(module_._parameters.items()) - if p is not None and not is_moe_param(p) - ]) - - return params_with_weight_decay, moe_params_with_weight_decay -``` - -The above param groups can then be fed to the ZeRO stage-2 optimizer as follows. - -```python - -net = Net() - -parameters = create_moe_param_groups(net) - -model_engine, optimizer, trainloader, __ = deepspeed.initialize( - args=args, model=net, model_parameters=parameters, training_data=trainset) -``` - -We are working on automating this functionality in the DeepSpeed ZeRO optimizer so the model training code can be simplified further. - -To run the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) with ZeRO-Offload (stage 2) and MoE, please set the ds_config flags - -```json -"zero_optimization": { - "stage": 2, - "allgather_partitions": true, - "reduce_scatter": true, - "allgather_bucket_size": 50000000, - "reduce_bucket_size": 50000000, - "overlap_comm": true, - "contiguous_gradients": true, - "cpu_offload": true - } -``` - -An additional optimization to save memory for extremely large model training on limited number of GPUs has also been introduced. Please enable that using the following config flag to the fp16 optimizer in ds_config. - - ```json - "fp16": { - "enabled": true, - "fp16_master_weights_and_grads": true, - } - ``` - - - - - - -## Random Token Selection - -We have devised a new technique called “Random Token Selection” that greatly improves convergence. Random token selection addresses the limitation of biased selection problem in MoE model training. Our upcoming paper describes this technique and its results in detail. This feature is already part of the DeepSpeed runtime and is enabled by default so users can take advantage without any config flags or command-line arguments. - -## Advanced MoE usage - -Watch this space! We plan to add more interesting and detailed examples of using DeepSpeed MoE in the coming weeks. +--- +title: "Mixture of Experts" +--- + +DeepSpeed v0.5 introduces new support for training Mixture of Experts (MoE) models. MoE models are an emerging class of sparsely activated models that have sublinear compute costs with respect to their parameters. For example, the [Switch Transformer](https://arxiv.org/abs/2101.03961) consists of over 1.6 trillion parameters, while the compute required to train it is approximately equal to that of a 10 billion-parameter dense model. This increase in model size offers tremendous accuracy gains for a constant compute budget. + +For more details on results and further discussion, please see our press release: [DeepSpeed powers 8x larger MoE model training with high performance]({{ site.press_release_v5 }}). + +## Getting started with a simple MoE example + +**Note:** DeepSpeed MoE requires Pytorch 1.8 or above. +{: .notice--info} + +As a simple starting point we will show how to apply DeepSpeed MoE to a cifar10 example. Please refer to +our [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) going forward. + +If you are adding MoE to an existing model you can use the snippet below to help guide you: + + +### Expert groups initialization + +DeepSpeed MoE supports five different forms of parallelism, and it exploits both GPU and CPU memory. Its flexible design enables users to mix different types of prevalent parallelism techniques, as shown in the table below. + +| Short Name | Flexible Parallelism Configurations | Benefit | +| ---------------- | ------------------------------------| --------------------------------------------------------------------------- | +| E | Expert | Scales the model size by increasing the number of experts | +| E + D | Expert + Data | Accelerates training throughput by scaling to multiple data parallel groups | +| E + Z | Expert + ZeRO-powered data | Partitions the nonexpert parameters to support larger base models | +| E + D + M | Expert + Data + Model | Supports massive hidden sizes and even larger base models than E+Z | +| E + D + Z | Expert + Data + ZeRO-powered data | Supports massive hidden sizes and even larger base models than E+Z | +| E + Z-Off + M | Expert + ZeRO-Offload + Model | Leverages both GPU and CPU memory for large MoE models on limited # of GPUs | + +To support different forms of parallelism, we create a notion of DeepSpeed process groups that resides in ```deepspeed.utils.groups.py``` + +For most cases, the model training code needs to initialize these groups by calling +```python +deepspeed.utils.groups.initialize(ep_size="desired expert-parallel world size") +``` + +The GPUs (or ranks) participating in an expert-parallel group will distribute the total number of experts specified by the model training code argument num_experts. + +### MoE layer API + +The hidden_size is the input dimension of a particular layer and the output dimension is the same as that. This could lead to some changes to your model definition, especially for vision/convolutional models because the input/output dimensions don't match in certain cases. E.g. in the CIFAR-10 example, we modify the third fully connected layer to add the MoE layer. To cater for this, we need to add an additional fully-connected layer, whose input dimension is equal to the output dimension of the MoE layer. + +Original model config + +```python + self.fc3 = nn.Linear(84, 10) +``` + +Updated with MoE Layers + +```python + self.fc3 = nn.Linear(84, 84) + self.fc3 = deepspeed.moe.layer.MoE(hidden_size=84, expert=self.fc3, num_experts=args.num_experts, ...) + self.fc4 = nn.Linear(84, 10) +``` + +### An Example Scenario + +Given a total number of GPUs in our world size and a subset of GPUs in our expert-parallel world as follows. + +```python +WORLD_SIZE = 4 +EP_WORLD_SIZE = 2 +EXPERTS = 8 +``` + +The user code needs to initialize the groups as follows. + +```python +groups.initialize (ep_size=EP_WORLD_SIZE) +``` + +After that, the model code needs to use the deepspeed.moe.layer.MoE API as follows. + +```python +self.experts = deepspeed.moe.layer.MoE(hidden_size=input_dim, expert=ExpertModule(), num_experts=EXPERTS) +``` +With the above two commands, the DeepSpeed runtime will be set to train an MoE model with a total of 8 experts on 4 GPUs in 4 experts/GPU mode. We call this the E + D mode as described earlier in the table. + +For more advanced use case of the groups API including the inter-operability with Megatron style mpu object, watch this space! + + +```python +import torch +import deepspeed +import deepspeed.utils.groups as groups +from deepspeed.moe.layer import MoE + +WORLD_SIZE = 4 +EP_WORLD_SIZE = 2 +EXPERTS = 8 + +groups.initialize(ep_size=EP_WORLD_SIZE) + +fc3 = torch.nn.Linear(84, 84) +fc3 = MoE(hidden_size=84, expert=self.fc3, num_experts=EXPERTS, k=1) +fc4 = torch.nn.Linear(84, 10) + +``` + +For a runnable end-to-end example, please look at [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) + +### Combining ZeRO-Offload and DeepSpeed MoE for very large models + +To use MoE Layers in DeepSpeed, we rely on two parameter groups that are passed to an optimizer. A concrete example to create such groups is available from the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar). + +The relevant function that creates these param groups is as follows. + +```python +def create_moe_param_groups(model): + from deepspeed.moe.utils import is_moe_param + + params_with_weight_decay = {'params': [], 'name': 'weight_decay_params'} + moe_params_with_weight_decay = { + 'params': [], + 'moe': True, + 'name': 'weight_decay_moe_params' + } + + for module_ in model.modules(): + moe_params_with_weight_decay['params'].extend([ + p for n, p in list(module_._parameters.items()) + if p is not None and is_moe_param(p) + ]) + params_with_weight_decay['params'].extend([ + p for n, p in list(module_._parameters.items()) + if p is not None and not is_moe_param(p) + ]) + + return params_with_weight_decay, moe_params_with_weight_decay +``` + +The above param groups can then be fed to the ZeRO stage-2 optimizer as follows. + +```python + +net = Net() + +parameters = create_moe_param_groups(net) + +model_engine, optimizer, trainloader, __ = deepspeed.initialize( + args=args, model=net, model_parameters=parameters, training_data=trainset) +``` + +We are working on automating this functionality in the DeepSpeed ZeRO optimizer so the model training code can be simplified further. + +To run the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) with ZeRO-Offload (stage 2) and MoE, please set the ds_config flags + +```json +"zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "reduce_scatter": true, + "allgather_bucket_size": 50000000, + "reduce_bucket_size": 50000000, + "overlap_comm": true, + "contiguous_gradients": true, + "cpu_offload": true + } +``` + +An additional optimization to save memory for extremely large model training on limited number of GPUs has also been introduced. Please enable that using the following config flag to the fp16 optimizer in ds_config. + + ```json + "fp16": { + "enabled": true, + "fp16_master_weights_and_grads": true, + } + ``` + + + + + + +## Random Token Selection + +We have devised a new technique called “Random Token Selection” that greatly improves convergence. Random token selection addresses the limitation of biased selection problem in MoE model training. Our upcoming paper describes this technique and its results in detail. This feature is already part of the DeepSpeed runtime and is enabled by default so users can take advantage without any config flags or command-line arguments. + +## Advanced MoE usage + +Watch this space! We plan to add more interesting and detailed examples of using DeepSpeed MoE in the coming weeks. diff --git a/docs/_tutorials/progressive_layer_dropping.md b/docs/_tutorials/progressive_layer_dropping.md index 8a447e97c..8c184dfc6 100755 --- a/docs/_tutorials/progressive_layer_dropping.md +++ b/docs/_tutorials/progressive_layer_dropping.md @@ -1,155 +1,155 @@ ---- -title: "Accelerating Training of Transformer-Based Language Models with Progressive Layer Dropping" - ---- - -In this tutorial, we are going to introduce the progressive layer dropping (PLD) in DeepSpeed and provide examples on how to use PLD. PLD allows to train Transformer networks such as BERT 24% faster under the same number of samples and 2.5 times faster to get similar accuracy on downstream tasks. Detailed description of PLD and the experimental results are available in our [technical report](https://arxiv.org/pdf/2010.13369.pdf). - -To illustrate how to use PLD in DeepSpeed, we show how to enable PLD to pre-train a BERT model and fine-tune the pre-trained model on the GLUE datasets. - -## Running Pre-training with DeepSpeed and PLD - -To perform pre-training, one needs to first prepare the datasets. For this part, please refer our [BERT Pre-training](/tutorials/bert-pretraining/) post, which contains detailed information on how to do data downloading and pre-processing. For the below experiment, we use Wikipedia text and Bookcorpus, similar as [Devlin et. al.](https://arxiv.org/abs/1810.04805). - -The main part of pre-training is done in `deepspeed_train.py`, which has -already been modified to use DeepSpeed. The `ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh` is the shell script that launches the pre-training with DeepSpeed and PLD. - -```shell -bash ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh -``` - -Most of the flags in the above script should be familiar if you have stepped through the BERT pre-training [tutorial](/tutorials/bert-pretraining/). To enable training with PLD, one needs to enable PLD in both the client script and in the DeepSpeed engine. To enable PLD in the client script, one needs to add the following command line flag to enable progressive layer dropping on Transformer blocks. - -```bash ---progressive_layer_drop -``` - -To enable PLD in DeepSpeed, one needs to update the json configuration file with an appropriate PLD configuration dictionary like below: - -```json -{ - ... - "progressive_layer_drop": { - "enabled": true, - "theta": 0.5, - "gamma": 0.001 - } -} -``` - -we recommend a PLD theta value of 0.5 and gamma of 0.001 because these have worked well in our experiments. - -With these configuration changes, the DeepSpeed engine should print a runtime message as below: - - [INFO] [logging.py:60:log_dist] [Rank 0] Enabled progressive layer dropping (theta = 0.5) - -The `deepspeed_bsz4k_progressive_layer_drop_config_seq128.json` file allows users to specify DeepSpeed options in terms of batch size, micro batch size, optimizer, learning rate, sequence length, and other parameters. Below is the DeepSpeed configuration file we use for running BERT and PLD. - -```json -{ - "train_batch_size": 4096, - "train_micro_batch_size_per_gpu": 16, - "steps_per_print": 1000, - "prescale_gradients": true, - "gradient_predivide_factor": 8, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-3, - "weight_decay": 0.01, - "bias_correction": false - } - }, - "gradient_clipping": 1.0, - - "wall_clock_breakdown": false, - - "fp16": { - "enabled": true, - "loss_scale": 0 - }, - - "progressive_layer_drop": { - "enabled": true, - "theta": 0.5, - "gamma": 0.001 - } -} -``` - -Note that the above configuration assumes training on 64 X 32GB V100 GPUs. Each GPU uses a micro batch size of 16 and accumulates gradients until the effective batch size reaches 4096. If you have GPUs with less memory, you may need to reduce "train_micro_batch_size_per_gpu". Alternatively, if you have more GPUs, you can increase the "train_batch_size" to increase training speed. We use the following hyperparameters for pre-training BERT with PLD enabled. - -| Parameters | Value | -| ------------------------------ | ----------------------- | -| Effective batch size | 4K | -| Train micro batch size per GPU | 16 | -| Optimizer | Adam | -| Peak learning rate | 1e-3 | -| Sequence-length | 128 | -| Learning rate scheduler | Warmup linear decay exp | -| Warmup ratio | 0.02 | -| Decay rate | 0.99 | -| Decay step | 1000 | -| Weight decay | 0.01 | -| Gradient clipping | 1.0 | - -Table 1. Pre-training hyperparameters - -**Note:** DeepSpeed now supports PreLayerNorm as the default way for training BERT, because of its ability to avoid vanishing gradient, stabilize optimization, and performance gains, as described in our fastest BERT training [blog post](https://www.deepspeed.ai/news/2020/05/27/fastest-bert-training.html). We therefore support the switchable Transformer block directly on the the BERT with PreLayerNorm. The implementation can be found at "example\bing_bert\nvidia\modelingpreln_layerdrop.py". - -## Fine-tuning with DeepSpeed on GLUE Tasks - -We use GLUE for fine-tuning tasks. GLUE (General Language Understanding Evaluation benchmark) (https://gluebenchmark.com/) is a collection of sentence or sentence-pair natural language understanding tasks including question answering, sentiment analysis, and textual entailment. It is designed to favor sample-efficient learning and knowledge-transfer across a range of different linguistic tasks in different domains. - -One can download all GLUE data using the provided helper [script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e). Once the data has been downloaded, one can set up the data and move the data to "/data/GlueData", which is the default location for hosting GLUE data. We then can use the PLD pre-trained BERT model checkpoint to run the fine-tuning. - -The main part of fine-tuning is done in `run_glue_classifier_bert_base.py`, which has -already been modified to use DeepSpeed. Before the fine-tuning, one needs to specify the BERT model configuration through the following config in `run_glue_classifier_bert_base.py`. In this case, it has already been modified to be the same as the configuration of the pre-trained model. - -```json - bert_model_config = { - "vocab_size_or_config_json_file": 119547, - "hidden_size": 768, - "num_hidden_layers": 12, - "num_attention_heads": 12, - "intermediate_size": 3072, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "attention_probs_dropout_prob": 0.1, - "max_position_embeddings": 512, - "type_vocab_size": 2, - "initializer_range": 0.02 - } -``` - -Next, one can load a DeepSpeed style checkpoint with the following command, which has also already been added in the script. - -```shell -model.load_state_dict(checkpoint_state_dict['module'], strict=False) -``` - -Finally, the `run_glue_classifier_bert_base.sh` script invokes pre-training and setups several hyperparameters relevant to fine-tuning. - -```shell -bash run_glue_bert_base_finetune.sh [task] [batch size] [learning rate] [number of epochs] [job name] [checkpoint path] -``` - -An example would be: - -```shell -bash run_glue_bert_base_finetune.sh MNLI 32 3e-5 5 "fine_tune_MNLI" deepspeed_checkpoint.pt -``` - - - -### Expected Results - -The fine-tuning results can be found under the "logs" directory, and below are expected results for PLD on GLUE tasks. The "Lr" row indicates the learning rate we use for getting the corresponding accuracy result for each task. - -| | RTE | MRPC | STS-B | CoLA | SST-2 | QNLI | QQP | MNLI-m/mm | GLUE | -| ---------------------- | :--: | --------- | --------- | ---- | ----- | ---- | --------- | --------- | ---- | -| Metrics | Acc. | F1/Acc. | PCC/SCC | Acc. | Acc. | Acc. | F1/Acc. | Acc. | | -| Bert_{base} (original) | 66.4 | 88.9/84.8 | 87.1/89.2 | 52.1 | 93.5 | 90.5 | 71.2/89.2 | 84.6/83.4 | 80.7 | -| Bert_{base} (Our impl) | 67.8 | 88.0/86.0 | 89.5/89.2 | 52.5 | 91.2 | 87.1 | 89.0/90.6 | 82.5/83.4 | 82.1 | -| PLD | 69.3 | 86.6/84.3 | 90.0/89.6 | 55.8 | 91.6 | 90.7 | 89.6/91.2 | 84.1/83.8 | 82.9 | -| Lr | 7e-5 | 9e-5 | 7e-5 | 5e-5 | 7e-5 | 9e-5 | 2e-4 | 3e-5 | | +--- +title: "Accelerating Training of Transformer-Based Language Models with Progressive Layer Dropping" + +--- + +In this tutorial, we are going to introduce the progressive layer dropping (PLD) in DeepSpeed and provide examples on how to use PLD. PLD allows to train Transformer networks such as BERT 24% faster under the same number of samples and 2.5 times faster to get similar accuracy on downstream tasks. Detailed description of PLD and the experimental results are available in our [technical report](https://arxiv.org/pdf/2010.13369.pdf). + +To illustrate how to use PLD in DeepSpeed, we show how to enable PLD to pre-train a BERT model and fine-tune the pre-trained model on the GLUE datasets. + +## Running Pre-training with DeepSpeed and PLD + +To perform pre-training, one needs to first prepare the datasets. For this part, please refer our [BERT Pre-training](/tutorials/bert-pretraining/) post, which contains detailed information on how to do data downloading and pre-processing. For the below experiment, we use Wikipedia text and Bookcorpus, similar as [Devlin et. al.](https://arxiv.org/abs/1810.04805). + +The main part of pre-training is done in `deepspeed_train.py`, which has +already been modified to use DeepSpeed. The `ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh` is the shell script that launches the pre-training with DeepSpeed and PLD. + +```shell +bash ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh +``` + +Most of the flags in the above script should be familiar if you have stepped through the BERT pre-training [tutorial](/tutorials/bert-pretraining/). To enable training with PLD, one needs to enable PLD in both the client script and in the DeepSpeed engine. To enable PLD in the client script, one needs to add the following command line flag to enable progressive layer dropping on Transformer blocks. + +```bash +--progressive_layer_drop +``` + +To enable PLD in DeepSpeed, one needs to update the json configuration file with an appropriate PLD configuration dictionary like below: + +```json +{ + ... + "progressive_layer_drop": { + "enabled": true, + "theta": 0.5, + "gamma": 0.001 + } +} +``` + +we recommend a PLD theta value of 0.5 and gamma of 0.001 because these have worked well in our experiments. + +With these configuration changes, the DeepSpeed engine should print a runtime message as below: + + [INFO] [logging.py:60:log_dist] [Rank 0] Enabled progressive layer dropping (theta = 0.5) + +The `deepspeed_bsz4k_progressive_layer_drop_config_seq128.json` file allows users to specify DeepSpeed options in terms of batch size, micro batch size, optimizer, learning rate, sequence length, and other parameters. Below is the DeepSpeed configuration file we use for running BERT and PLD. + +```json +{ + "train_batch_size": 4096, + "train_micro_batch_size_per_gpu": 16, + "steps_per_print": 1000, + "prescale_gradients": true, + "gradient_predivide_factor": 8, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3, + "weight_decay": 0.01, + "bias_correction": false + } + }, + "gradient_clipping": 1.0, + + "wall_clock_breakdown": false, + + "fp16": { + "enabled": true, + "loss_scale": 0 + }, + + "progressive_layer_drop": { + "enabled": true, + "theta": 0.5, + "gamma": 0.001 + } +} +``` + +Note that the above configuration assumes training on 64 X 32GB V100 GPUs. Each GPU uses a micro batch size of 16 and accumulates gradients until the effective batch size reaches 4096. If you have GPUs with less memory, you may need to reduce "train_micro_batch_size_per_gpu". Alternatively, if you have more GPUs, you can increase the "train_batch_size" to increase training speed. We use the following hyperparameters for pre-training BERT with PLD enabled. + +| Parameters | Value | +| ------------------------------ | ----------------------- | +| Effective batch size | 4K | +| Train micro batch size per GPU | 16 | +| Optimizer | Adam | +| Peak learning rate | 1e-3 | +| Sequence-length | 128 | +| Learning rate scheduler | Warmup linear decay exp | +| Warmup ratio | 0.02 | +| Decay rate | 0.99 | +| Decay step | 1000 | +| Weight decay | 0.01 | +| Gradient clipping | 1.0 | + +Table 1. Pre-training hyperparameters + +**Note:** DeepSpeed now supports PreLayerNorm as the default way for training BERT, because of its ability to avoid vanishing gradient, stabilize optimization, and performance gains, as described in our fastest BERT training [blog post](https://www.deepspeed.ai/news/2020/05/27/fastest-bert-training.html). We therefore support the switchable Transformer block directly on the the BERT with PreLayerNorm. The implementation can be found at "example\bing_bert\nvidia\modelingpreln_layerdrop.py". + +## Fine-tuning with DeepSpeed on GLUE Tasks + +We use GLUE for fine-tuning tasks. GLUE (General Language Understanding Evaluation benchmark) (https://gluebenchmark.com/) is a collection of sentence or sentence-pair natural language understanding tasks including question answering, sentiment analysis, and textual entailment. It is designed to favor sample-efficient learning and knowledge-transfer across a range of different linguistic tasks in different domains. + +One can download all GLUE data using the provided helper [script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e). Once the data has been downloaded, one can set up the data and move the data to "/data/GlueData", which is the default location for hosting GLUE data. We then can use the PLD pre-trained BERT model checkpoint to run the fine-tuning. + +The main part of fine-tuning is done in `run_glue_classifier_bert_base.py`, which has +already been modified to use DeepSpeed. Before the fine-tuning, one needs to specify the BERT model configuration through the following config in `run_glue_classifier_bert_base.py`. In this case, it has already been modified to be the same as the configuration of the pre-trained model. + +```json + bert_model_config = { + "vocab_size_or_config_json_file": 119547, + "hidden_size": 768, + "num_hidden_layers": 12, + "num_attention_heads": 12, + "intermediate_size": 3072, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 2, + "initializer_range": 0.02 + } +``` + +Next, one can load a DeepSpeed style checkpoint with the following command, which has also already been added in the script. + +```shell +model.load_state_dict(checkpoint_state_dict['module'], strict=False) +``` + +Finally, the `run_glue_classifier_bert_base.sh` script invokes pre-training and setups several hyperparameters relevant to fine-tuning. + +```shell +bash run_glue_bert_base_finetune.sh [task] [batch size] [learning rate] [number of epochs] [job name] [checkpoint path] +``` + +An example would be: + +```shell +bash run_glue_bert_base_finetune.sh MNLI 32 3e-5 5 "fine_tune_MNLI" deepspeed_checkpoint.pt +``` + + + +### Expected Results + +The fine-tuning results can be found under the "logs" directory, and below are expected results for PLD on GLUE tasks. The "Lr" row indicates the learning rate we use for getting the corresponding accuracy result for each task. + +| | RTE | MRPC | STS-B | CoLA | SST-2 | QNLI | QQP | MNLI-m/mm | GLUE | +| ---------------------- | :--: | --------- | --------- | ---- | ----- | ---- | --------- | --------- | ---- | +| Metrics | Acc. | F1/Acc. | PCC/SCC | Acc. | Acc. | Acc. | F1/Acc. | Acc. | | +| Bert_{base} (original) | 66.4 | 88.9/84.8 | 87.1/89.2 | 52.1 | 93.5 | 90.5 | 71.2/89.2 | 84.6/83.4 | 80.7 | +| Bert_{base} (Our impl) | 67.8 | 88.0/86.0 | 89.5/89.2 | 52.5 | 91.2 | 87.1 | 89.0/90.6 | 82.5/83.4 | 82.1 | +| PLD | 69.3 | 86.6/84.3 | 90.0/89.6 | 55.8 | 91.6 | 90.7 | 89.6/91.2 | 84.1/83.8 | 82.9 | +| Lr | 7e-5 | 9e-5 | 7e-5 | 5e-5 | 7e-5 | 9e-5 | 2e-4 | 3e-5 | | diff --git a/docs/_tutorials/zero-offload.md b/docs/_tutorials/zero-offload.md index 8b0f56ec5..404355090 100644 --- a/docs/_tutorials/zero-offload.md +++ b/docs/_tutorials/zero-offload.md @@ -1,75 +1,74 @@ ---- -title: "ZeRO-Offload" ---- -ZeRO-3 Offload consists of a subset of features in our newly released ZeRO-Infinity. Read our [ZeRO-Infinity blog](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/) to learn more! - -We recommend that you read the tutorials on [Getting Started](/getting-started/) and [ZeRO](/tutorials/zero/) before stepping through this tutorial. - -ZeRO-Offload is a ZeRO optimization that offloads the optimizer memory and computation from the GPU to the host CPU. ZeRO-Offload enables large models with up to 13 billion parameters to be efficiently trained on a single GPU. In this tutorial we will use ZeRO-Offload to train a 10-billion parameter GPT-2 model in DeepSpeed. Furthermore, *using ZeRO-Offload in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration json*. No code changes are needed. - -## ZeRO-Offload Overview -For large model training, optimizers such as [Adam](https://arxiv.org/abs/1412.6980), can consume a significant amount of GPU compute and memory. ZeRO-Offload reduces the GPU compute and memory requirements of such models by leveraging compute and memory resources on the host CPU to execute the optimizer. Furthermore, to prevent the optimizer from becoming a bottleneck, ZeRO-Offload uses DeepSpeed's highly optimized CPU implementation of Adam called [DeeSpeedCPUAdam](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/ops/adam). DeepSpeedCPUAdam is 5X--7X faster than the standard PyTorch implementation. To deep dive into the design and performance of ZeRO-Offload, please see our [blog post](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-3). - -## Training Environment -For this tutorial, we will configure a 10 billion parameter GPT-2 model using the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3) GPT-2 code. We advise stepping through the Megatron-LM [tutorial](/tutorials/megatron/) if you have not previously done so. We will use a single [NVIDIA Tesla V100-SXM3 Tensor Core GPU](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM for this exercise. - -## Training a 10B parameter GPT-2 on 1 V100 GPU -We need to make changes to the Megatron-LM launch script and to the DeepSpeed configuration json. - -### Megatron-LM GPT-2 launch script changes -We need to apply two changes to the launch script for the DeepSpeed Megatron-LM GPT-2 model. The first change is to configure a 10B parameter GPT-2 model with activation checkpointing enabled, which can be achieved by the following set of changes: - -```bash - --model-parallel-size 1 \ - --num-layers 50 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --batch-size 10 \ - --deepspeed_config ds_zero_offload.config \ - --checkpoint-activations -``` - -Most of the flags in the changes above should be familiar if you have stepped through the Megatron-LM [tutorial](/tutorials/megatron/). - -Second, we need to apply the following changes to ensure that only one GPU is used for training. -```bash - deepspeed --num_nodes 1 --num_gpus 1 ... -``` - -### DeepSpeed Configuration Changes -ZeRO-Offload leverages much for ZeRO stage 2 mechanisms, and so the configuration changes to enable ZeRO-Offload is an extension of those required to enable ZeRO stage 2. The **zero_optimization** key to enable ZeRO-Offload is shown below: - -```json -{ - "zero_optimization": { - "stage": 2, - "cpu_offload": true, - "contiguous_gradients": true, - "overlap_comm": true - } -} -``` - -As seen above, in addition to setting the _stage_ field to **2** (to enable ZeRO stage 2), we also need to set _cpu_offload_ flag to **true** to enable ZeRO-Offload optimizations. In addition, we can set other ZeRO stage 2 optimization flags, such as _overlap_comm_ to tune ZeRO-Offload performance. With these changes we can now run the model. We share some screenshots of the training below. - -Here is a screenshot of the training log: - - - - - - -Here is a screenshot of nvidia-smi showing that only GPU 0 is active during training: - - - - - -Finally, here is a screenshot of htop showing host CPU and memory activity during optimizer computation: - - - - - -Congratulations! You have completed the ZeRO-Offload tutorial. - +--- +title: "ZeRO-Offload" +--- +ZeRO-3 Offload consists of a subset of features in our newly released ZeRO-Infinity. Read our [ZeRO-Infinity blog](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/) to learn more! + +We recommend that you read the tutorials on [Getting Started](/getting-started/) and [ZeRO](/tutorials/zero/) before stepping through this tutorial. + +ZeRO-Offload is a ZeRO optimization that offloads the optimizer memory and computation from the GPU to the host CPU. ZeRO-Offload enables large models with up to 13 billion parameters to be efficiently trained on a single GPU. In this tutorial we will use ZeRO-Offload to train a 10-billion parameter GPT-2 model in DeepSpeed. Furthermore, *using ZeRO-Offload in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration json*. No code changes are needed. + +## ZeRO-Offload Overview +For large model training, optimizers such as [Adam](https://arxiv.org/abs/1412.6980), can consume a significant amount of GPU compute and memory. ZeRO-Offload reduces the GPU compute and memory requirements of such models by leveraging compute and memory resources on the host CPU to execute the optimizer. Furthermore, to prevent the optimizer from becoming a bottleneck, ZeRO-Offload uses DeepSpeed's highly optimized CPU implementation of Adam called [DeeSpeedCPUAdam](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/ops/adam). DeepSpeedCPUAdam is 5X--7X faster than the standard PyTorch implementation. To deep dive into the design and performance of ZeRO-Offload, please see our [blog post](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-3). + +## Training Environment +For this tutorial, we will configure a 10 billion parameter GPT-2 model using the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3) GPT-2 code. We advise stepping through the Megatron-LM [tutorial](/tutorials/megatron/) if you have not previously done so. We will use a single [NVIDIA Tesla V100-SXM3 Tensor Core GPU](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM for this exercise. + +## Training a 10B parameter GPT-2 on 1 V100 GPU +We need to make changes to the Megatron-LM launch script and to the DeepSpeed configuration json. + +### Megatron-LM GPT-2 launch script changes +We need to apply two changes to the launch script for the DeepSpeed Megatron-LM GPT-2 model. The first change is to configure a 10B parameter GPT-2 model with activation checkpointing enabled, which can be achieved by the following set of changes: + +```bash + --model-parallel-size 1 \ + --num-layers 50 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --batch-size 10 \ + --deepspeed_config ds_zero_offload.config \ + --checkpoint-activations +``` + +Most of the flags in the changes above should be familiar if you have stepped through the Megatron-LM [tutorial](/tutorials/megatron/). + +Second, we need to apply the following changes to ensure that only one GPU is used for training. +```bash + deepspeed --num_nodes 1 --num_gpus 1 ... +``` + +### DeepSpeed Configuration Changes +ZeRO-Offload leverages much for ZeRO stage 2 mechanisms, and so the configuration changes to enable ZeRO-Offload is an extension of those required to enable ZeRO stage 2. The **zero_optimization** key to enable ZeRO-Offload is shown below: + +```json +{ + "zero_optimization": { + "stage": 2, + "cpu_offload": true, + "contiguous_gradients": true, + "overlap_comm": true + } +} +``` + +As seen above, in addition to setting the _stage_ field to **2** (to enable ZeRO stage 2), we also need to set _cpu_offload_ flag to **true** to enable ZeRO-Offload optimizations. In addition, we can set other ZeRO stage 2 optimization flags, such as _overlap_comm_ to tune ZeRO-Offload performance. With these changes we can now run the model. We share some screenshots of the training below. + +Here is a screenshot of the training log: + + + + + + +Here is a screenshot of nvidia-smi showing that only GPU 0 is active during training: + + + + + +Finally, here is a screenshot of htop showing host CPU and memory activity during optimizer computation: + + + + + +Congratulations! You have completed the ZeRO-Offload tutorial. diff --git a/docs/_tutorials/zero.md b/docs/_tutorials/zero.md index 01595c113..adc0bbb40 100644 --- a/docs/_tutorials/zero.md +++ b/docs/_tutorials/zero.md @@ -1,301 +1,301 @@ ---- -title: "Zero Redundancy Optimizer (ZeRO)" ---- -If you have not done so already, we advise that you read the DeepSpeed tutorials on [Getting Started](/getting-started/) and [Megatron-LM GPT-2](/tutorials/megatron/) before stepping through this tutorial. - -In this tutorial, we will apply the ZeRO optimizer to the [Megatron-LM GPT-2](https://github.com/NVIDIA/Megatron-LM) model. ZeRO is a powerful set of memory optimization techniques that enable effective FP16 training of large models with trillions of parameters, such as [GPT-2](https://openai.com/blog/better-language-models/) and [Turing-NLG 17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/). Compared to the alternative model parallelism approaches for training large models, a key appeal of ZeRO is that no model code modifications are required. As this tutorial will demonstrate, *using ZeRO in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration JSON*. No code changes are needed. - -## ZeRO Overview -ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each device (GPU) used for model training. ZeRO reduces the memory consumption of each GPU by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices (GPUs and CPUs) in the distributed training hardware. Concretely, ZeRO is being implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. To deep dive into ZeRO, please see our [paper](https://arxiv.org/abs/1910.02054v3). - -* **Stage 1**: The optimizer states (e.g., for [Adam optimizer](https://arxiv.org/abs/1412.6980), 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition. - -* **Stage 2**: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states. - -* **Stage 3**: The 16-bit model parameters are partitioned across the processes. ZeRO-3 will automatically collect and partition them during the forward and backward passes. - -In addition, ZeRO-3 includes the *infinity offload engine* to form ZeRO-Infinity ([paper](https://arxiv.org/abs/2104.07857)), which can offload to both CPU and NVMe memory for huge memory savings. - -## Training environment -We use the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3) GPT-2 code for this exercise. You can step through the Megatron-LM [tutorial](/tutorials/megatron/) to familiarize yourself with the code. We will train the models in this tutorial on [NVIDIA Tesla V100-SXM3 Tensor Core GPUs](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM. - -## Enabling ZeRO Optimization -To enable ZeRO optimizations for a DeepSpeed model, we simply add the **_zero_optimization_** key to the DeepSpeed JSON configuration. A full description of configuration knobs of the **zero_optimization** key is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). - -### Training a 1.5B Parameter GPT-2 model -We demonstrate the benefits of ZeRO stage 1 by showing that it enables data parallel training of a 1.5 billion parameter GPT-2 model on eight V100 GPUs. We configure training to use a batch size of 1 per device to ensure that the memory consumption is primarily due to model parameters and optimizer states. We create this training scenario by applying the following modifications to the deepspeed launch script: - -```bash - --model-parallel-size 1 \ - --num-layers 48 \ - --hidden-size 1600 \ - --num-attention-heads 16 \ - --batch-size 1 \ - --deepspeed_config ds_zero_stage_1.config \ -``` - -Training this model without ZeRO fails with an out-of-memory (OOM) error as shown below: - - - - - -A key reason why this model does not fit in GPU memory is that the Adam optimizer states for the model consume 18GB; a significant portion of the 32GB RAM. By using ZeRO stage 1 to partition the optimizer state among eight data parallel ranks, the per-device memory consumption can be reduced to 2.25GB, thus making the model trainable. To enable ZeRO stage 1, we simply update the DeepSpeed JSON config file as below: - -```json -{ - "zero_optimization": { - "stage":1, - "reduce_bucket_size": 5e8 - } -} -``` -As seen above, we set two fields in the **zero_optimization** key. Specifically we set the _stage_ field to 1, and the optional _reduce_bucket_size_ for gradient reduction to 500M. With ZeRO stage 1 enabled, the model can now train smoothly on 8 GPUs without running out of memory. Below we provide some screenshots of the model training: - - - - - - - - - - - -From the nvidia-smi screenshot above we can see that only GPUs 6-7 are being used for training the model. With ZeRO stage 1 we can further reduce the per-device memory consumption by increasing the data parallelism degree. These memory savings can be leveraged to either increase model size and/or batch size. In contrast, such benefits are not possible with data parallelism alone. - -### Training a 10B Parameter GPT-2 model -ZeRO stage 2 optimizations further increases the size of models that can be trained using data parallelism. We show this by training a model with 10B parameters using 32 V100 GPUs. - -First, we need to configure a 10B parameter model with activation checkpointing enabled. This can be done by applying the following GPT-2 model configuration changes to the DeepSpeed launch script. - -```bash - --model-parallel-size 1 \ - --num-layers 50 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --batch-size 1 \ - --deepspeed_config ds_zero_stage_2.config \ - --checkpoint-activations -``` - -Next, we need to update the DeepSpeed JSON configuration, as shown below, to enable ZeRO stage 2 optimizations: - -```json -{ - "zero_optimization": { - "stage":2, - "contiguous_gradients": true, - "overlap_comm": true, - "reduce_scatter": true, - "reduce_bucket_size": 5e8, - "allgather_bucket_size": 5e8 - } -} -``` - -In the above changes, we have set the _stage_ field to 2, and configured other optimization knobs that are available in ZeRO stage 2. For example, we have enabled _contiguous_gradients_ to reduce memory fragmentation during backward pass. A full description of these optimization knobs is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). With these changes, we can now launch the training run. - -Here is a screenshot of the training log: - - - - - -Here is a screenshot of nvidia-smi showing GPU activity during training: - - - - - -### Training trillion-scale models with ZeRO-Infinity - -ZeRO-3, the third stage of ZeRO, partitions the full model state (i.e., -weights, gradients, and optimizer states) to scale memory savings linearly -with the degree of data parallelism. ZeRO-3 can be enabled in the JSON -configuration. A full description of these configurations is available -[here](/docs/config-json/#zero-optimizations-for-fp16-training). - - -#### Offloading to CPU and NVMe with ZeRO-Infinity - -ZeRO-Infinity uses DeepSpeed's infinity offload engine to offload the full -model state to CPU or NVMe memory, allowing for even larger model sizes. Offloading -can be enabled inside the DeepSpeed configuration: - -```diff -@@ -6,5 +6,11 @@ - "zero_optimization": { - "stage": 3, - "contiguous_gradients": true, - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_prefetch_bucket_size": 1e7, - "stage3_param_persistence_threshold": 1e5, - "reduce_bucket_size": 1e7, -- "sub_group_size": 1e9 -+ "sub_group_size": 1e9, -+ "offload_optimizer": { -+ "device": "cpu" -+ }, -+ "offload_param": { -+ "device": "cpu" -+ } - } -``` - -**ZeRO-Infinity vs ZeRO-Offload:** -DeepSpeed first included offloading capabilities with ZeRO-Offload, -a system for offloading optimizer and gradient states to CPU memory -within ZeRO-2. ZeRO-Infinity is the next generation of offloading -capabilities accessible to ZeRO-3. ZeRO-Infinity is able to offload -more data than ZeRO-Offload and has more effective bandwidth utilization -and overlapping of computation and communication. -{: .notice--info} - - - - -#### Allocating Massive Megatron-LM Models - -We make two further changes to model initialization in order to support models -that exceed *local* system memory, but not *total* system memory. - -1. Allocate the model in a memory-scalable fashion. The model parameters will -be allocated and immediately partitioned across the data parallel group. If -`remote_device` is `"cpu"` or `"nvme"`, the model will also be allocated in CPU/NVMe memory -instead of GPU memory. Please see the full -[ZeRO-3 Init docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.Init) -for more details. - - ```python - with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), - remote_device=get_args().remote_device, - enabled=get_args().zero_stage==3): - model = GPT2Model(num_tokentypes=0, parallel_output=True) - ``` - -2. Gather the embeddings weight for initialization. DeepSpeed will automatically -gather a module's parameters during its constructor and for its forward and backward pass. -However, additional accesses must coordinate with DeepSpeed to ensure that parameter data -is gathered and subsequently partitioned. If the tensor is modified, the `modifier_rank` -argument should also be used to ensure all ranks have a consistent view of -the data. Please see the full -[GatheredParameters docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.GatheredParameters) -for more details. - - ```python - self.position_embeddings = torch.nn.Embedding(...) - with deepspeed.zero.GatheredParameters(self.position_embeddings.weight, - modifier_rank=0): - # Initialize the position embeddings. - self.init_method(self.position_embeddings.weight) - - ... - - self.tokentype_embeddings = torch.nn.Embedding(...) - with deepspeed.zero.GatheredParameters(self.tokentype_embeddings.weight, - modifier_rank=0): - # Initialize the token-type embeddings. - self.init_method(self.tokentype_embeddings.weight) - ``` - -#### Memory-centric tiling -ZeRO-Infinity includes a replacement for `Linear` layers that further reduces memory. -We optionally tile the model parallel linear layers found in each Transformer layer. Note -that model parallelism and tiling can be combined by specifying the corresponding -base class when building the layer. -The `deepspeed.zero.TiledLinear` module exploits the data fetch and release -pattern of ZeRO-3 to reduce the working memory requirements by breaking down -a large operator into smaller tiles that can be executed sequentially. - -We include the changes for one example from Megatron-LM's [ParallelMLP](https://github.com/microsoft/DeepSpeedExamples/blob/bdf8e59aede8c8e0577e8d4d557298ca8515268f/Megatron-LM-v1.1.5-ZeRO3/megatron/model/transformer.py#L82). Three more -model-parallel layers in `transformer.py` proceed similarly. - -The model parallel layers of Megatron-LM have a special form in which the -additive `bias` of the layer is delayed and instead returned from `forward()` -to be fused with a later operator. DeepSpeed's -`deepspeed.zero.TiledLinearReturnBias` subclass of `TiledLinear` simply also -forwards the returned `bias` parameter without accumulating. - -```diff -@@ -1,6 +1,9 @@ --self.dense_h_to_4h = mpu.ColumnParallelLinear( -+self.dense_h_to_4h = deepspeed.zero.TiledLinearReturnBias( - args.hidden_size, - 4 * args.hidden_size, -+ in_splits=args.tile_factor, -+ out_splits=4*args.tile_factor, -+ linear_cls=mpu.ColumnParallelLinear, - gather_output=False, - init_method=init_method, - skip_bias_add=True) -``` - -Note that we scale `in_splits` and `out_splits` proportionally with `input_size` and `output_size`. This -results in tiles of fixed size `[hidden/tile_factor, hidden/tile_factor]`. - -#### Registering external parameters - -**Deprecated:** -DeepSpeed version `0.3.15` introduced automatic external parameter -registration and this step is no longer needed. -{: .notice--info} - - -## Extracting weights - -If you need to take the pretrained weights out of Deepspeed here is what you can do for getting fp16 weights: - -- under ZeRO-2 `state_dict` contains the fp16 model weights and these can be saved normally with `torch.save`. -- under ZeRO-3 `state_dict` contains just the placeholders since the model weights are partitioned across multiple GPUs. If you want to get to these weights enable: - -```json - "zero_optimization": { - "stage3_gather_fp16_weights_on_model_save": true - }, -``` -And then save the model using: - -```python - if self.deepspeed: - self.deepspeed.save_fp16_model(output_dir, output_file) -``` - -Because it requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed. - -Note that if `stage3_gather_fp16_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them. -You can use this method to save ZeRO-2 weights as well. - -If you'd like to get the fp32 weights, we supply a special script that can do offline consolidation. It requires no configuration files or GPUs. Here is an example of its usage: - -``` bash -$ cd /path/to/checkpoint_dir -$ ./zero_to_fp32.py . pytorch_model.bin -Processing zero checkpoint at global_step1 -Detected checkpoint of type zero stage 3, world_size: 2 -Saving fp32 state dict to pytorch_model.bin (total_numel=60506624) -``` - -The `zero_to_fp32.py` gets created automatically when you save a checkpoint. - -Note: currently this script uses 2x memory (general RAM) of the size of the final checkpoint. - -Alternatively, if you have plenty of spare CPU memory and instead of getting the file you want your model to be updated to its fp32 weights, you can do the following at the end of the training: - -``` python - from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint - fp32_model = load_state_dict_from_zero_checkpoint(deepspeed.module, checkpoint_dir) -``` - -Beware, that the model will be good for saving, but no longer good for continuing the training and will require a `deepspeed.initialize()` anew. - -If you just want the `state_dict`, you can do: - -``` python - from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) -``` - - -Congratulations! You have completed the ZeRO tutorial. +--- +title: "Zero Redundancy Optimizer (ZeRO)" +--- +If you have not done so already, we advise that you read the DeepSpeed tutorials on [Getting Started](/getting-started/) and [Megatron-LM GPT-2](/tutorials/megatron/) before stepping through this tutorial. + +In this tutorial, we will apply the ZeRO optimizer to the [Megatron-LM GPT-2](https://github.com/NVIDIA/Megatron-LM) model. ZeRO is a powerful set of memory optimization techniques that enable effective FP16 training of large models with trillions of parameters, such as [GPT-2](https://openai.com/blog/better-language-models/) and [Turing-NLG 17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/). Compared to the alternative model parallelism approaches for training large models, a key appeal of ZeRO is that no model code modifications are required. As this tutorial will demonstrate, *using ZeRO in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration JSON*. No code changes are needed. + +## ZeRO Overview +ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each device (GPU) used for model training. ZeRO reduces the memory consumption of each GPU by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices (GPUs and CPUs) in the distributed training hardware. Concretely, ZeRO is being implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. To deep dive into ZeRO, please see our [paper](https://arxiv.org/abs/1910.02054v3). + +* **Stage 1**: The optimizer states (e.g., for [Adam optimizer](https://arxiv.org/abs/1412.6980), 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition. + +* **Stage 2**: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states. + +* **Stage 3**: The 16-bit model parameters are partitioned across the processes. ZeRO-3 will automatically collect and partition them during the forward and backward passes. + +In addition, ZeRO-3 includes the *infinity offload engine* to form ZeRO-Infinity ([paper](https://arxiv.org/abs/2104.07857)), which can offload to both CPU and NVMe memory for huge memory savings. + +## Training environment +We use the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3) GPT-2 code for this exercise. You can step through the Megatron-LM [tutorial](/tutorials/megatron/) to familiarize yourself with the code. We will train the models in this tutorial on [NVIDIA Tesla V100-SXM3 Tensor Core GPUs](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM. + +## Enabling ZeRO Optimization +To enable ZeRO optimizations for a DeepSpeed model, we simply add the **_zero_optimization_** key to the DeepSpeed JSON configuration. A full description of configuration knobs of the **zero_optimization** key is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). + +### Training a 1.5B Parameter GPT-2 model +We demonstrate the benefits of ZeRO stage 1 by showing that it enables data parallel training of a 1.5 billion parameter GPT-2 model on eight V100 GPUs. We configure training to use a batch size of 1 per device to ensure that the memory consumption is primarily due to model parameters and optimizer states. We create this training scenario by applying the following modifications to the deepspeed launch script: + +```bash + --model-parallel-size 1 \ + --num-layers 48 \ + --hidden-size 1600 \ + --num-attention-heads 16 \ + --batch-size 1 \ + --deepspeed_config ds_zero_stage_1.config \ +``` + +Training this model without ZeRO fails with an out-of-memory (OOM) error as shown below: + + + + + +A key reason why this model does not fit in GPU memory is that the Adam optimizer states for the model consume 18GB; a significant portion of the 32GB RAM. By using ZeRO stage 1 to partition the optimizer state among eight data parallel ranks, the per-device memory consumption can be reduced to 2.25GB, thus making the model trainable. To enable ZeRO stage 1, we simply update the DeepSpeed JSON config file as below: + +```json +{ + "zero_optimization": { + "stage":1, + "reduce_bucket_size": 5e8 + } +} +``` +As seen above, we set two fields in the **zero_optimization** key. Specifically we set the _stage_ field to 1, and the optional _reduce_bucket_size_ for gradient reduction to 500M. With ZeRO stage 1 enabled, the model can now train smoothly on 8 GPUs without running out of memory. Below we provide some screenshots of the model training: + + + + + + + + + + + +From the nvidia-smi screenshot above we can see that only GPUs 6-7 are being used for training the model. With ZeRO stage 1 we can further reduce the per-device memory consumption by increasing the data parallelism degree. These memory savings can be leveraged to either increase model size and/or batch size. In contrast, such benefits are not possible with data parallelism alone. + +### Training a 10B Parameter GPT-2 model +ZeRO stage 2 optimizations further increases the size of models that can be trained using data parallelism. We show this by training a model with 10B parameters using 32 V100 GPUs. + +First, we need to configure a 10B parameter model with activation checkpointing enabled. This can be done by applying the following GPT-2 model configuration changes to the DeepSpeed launch script. + +```bash + --model-parallel-size 1 \ + --num-layers 50 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --batch-size 1 \ + --deepspeed_config ds_zero_stage_2.config \ + --checkpoint-activations +``` + +Next, we need to update the DeepSpeed JSON configuration, as shown below, to enable ZeRO stage 2 optimizations: + +```json +{ + "zero_optimization": { + "stage":2, + "contiguous_gradients": true, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "allgather_bucket_size": 5e8 + } +} +``` + +In the above changes, we have set the _stage_ field to 2, and configured other optimization knobs that are available in ZeRO stage 2. For example, we have enabled _contiguous_gradients_ to reduce memory fragmentation during backward pass. A full description of these optimization knobs is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). With these changes, we can now launch the training run. + +Here is a screenshot of the training log: + + + + + +Here is a screenshot of nvidia-smi showing GPU activity during training: + + + + + +### Training trillion-scale models with ZeRO-Infinity + +ZeRO-3, the third stage of ZeRO, partitions the full model state (i.e., +weights, gradients, and optimizer states) to scale memory savings linearly +with the degree of data parallelism. ZeRO-3 can be enabled in the JSON +configuration. A full description of these configurations is available +[here](/docs/config-json/#zero-optimizations-for-fp16-training). + + +#### Offloading to CPU and NVMe with ZeRO-Infinity + +ZeRO-Infinity uses DeepSpeed's infinity offload engine to offload the full +model state to CPU or NVMe memory, allowing for even larger model sizes. Offloading +can be enabled inside the DeepSpeed configuration: + +```diff +@@ -6,5 +6,11 @@ + "zero_optimization": { + "stage": 3, + "contiguous_gradients": true, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_prefetch_bucket_size": 1e7, + "stage3_param_persistence_threshold": 1e5, + "reduce_bucket_size": 1e7, +- "sub_group_size": 1e9 ++ "sub_group_size": 1e9, ++ "offload_optimizer": { ++ "device": "cpu" ++ }, ++ "offload_param": { ++ "device": "cpu" ++ } + } +``` + +**ZeRO-Infinity vs ZeRO-Offload:** +DeepSpeed first included offloading capabilities with ZeRO-Offload, +a system for offloading optimizer and gradient states to CPU memory +within ZeRO-2. ZeRO-Infinity is the next generation of offloading +capabilities accessible to ZeRO-3. ZeRO-Infinity is able to offload +more data than ZeRO-Offload and has more effective bandwidth utilization +and overlapping of computation and communication. +{: .notice--info} + + + + +#### Allocating Massive Megatron-LM Models + +We make two further changes to model initialization in order to support models +that exceed *local* system memory, but not *total* system memory. + +1. Allocate the model in a memory-scalable fashion. The model parameters will +be allocated and immediately partitioned across the data parallel group. If +`remote_device` is `"cpu"` or `"nvme"`, the model will also be allocated in CPU/NVMe memory +instead of GPU memory. Please see the full +[ZeRO-3 Init docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.Init) +for more details. + + ```python + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device=get_args().remote_device, + enabled=get_args().zero_stage==3): + model = GPT2Model(num_tokentypes=0, parallel_output=True) + ``` + +2. Gather the embeddings weight for initialization. DeepSpeed will automatically +gather a module's parameters during its constructor and for its forward and backward pass. +However, additional accesses must coordinate with DeepSpeed to ensure that parameter data +is gathered and subsequently partitioned. If the tensor is modified, the `modifier_rank` +argument should also be used to ensure all ranks have a consistent view of +the data. Please see the full +[GatheredParameters docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.GatheredParameters) +for more details. + + ```python + self.position_embeddings = torch.nn.Embedding(...) + with deepspeed.zero.GatheredParameters(self.position_embeddings.weight, + modifier_rank=0): + # Initialize the position embeddings. + self.init_method(self.position_embeddings.weight) + + ... + + self.tokentype_embeddings = torch.nn.Embedding(...) + with deepspeed.zero.GatheredParameters(self.tokentype_embeddings.weight, + modifier_rank=0): + # Initialize the token-type embeddings. + self.init_method(self.tokentype_embeddings.weight) + ``` + +#### Memory-centric tiling +ZeRO-Infinity includes a replacement for `Linear` layers that further reduces memory. +We optionally tile the model parallel linear layers found in each Transformer layer. Note +that model parallelism and tiling can be combined by specifying the corresponding +base class when building the layer. +The `deepspeed.zero.TiledLinear` module exploits the data fetch and release +pattern of ZeRO-3 to reduce the working memory requirements by breaking down +a large operator into smaller tiles that can be executed sequentially. + +We include the changes for one example from Megatron-LM's [ParallelMLP](https://github.com/microsoft/DeepSpeedExamples/blob/bdf8e59aede8c8e0577e8d4d557298ca8515268f/Megatron-LM-v1.1.5-ZeRO3/megatron/model/transformer.py#L82). Three more +model-parallel layers in `transformer.py` proceed similarly. + +The model parallel layers of Megatron-LM have a special form in which the +additive `bias` of the layer is delayed and instead returned from `forward()` +to be fused with a later operator. DeepSpeed's +`deepspeed.zero.TiledLinearReturnBias` subclass of `TiledLinear` simply also +forwards the returned `bias` parameter without accumulating. + +```diff +@@ -1,6 +1,9 @@ +-self.dense_h_to_4h = mpu.ColumnParallelLinear( ++self.dense_h_to_4h = deepspeed.zero.TiledLinearReturnBias( + args.hidden_size, + 4 * args.hidden_size, ++ in_splits=args.tile_factor, ++ out_splits=4*args.tile_factor, ++ linear_cls=mpu.ColumnParallelLinear, + gather_output=False, + init_method=init_method, + skip_bias_add=True) +``` + +Note that we scale `in_splits` and `out_splits` proportionally with `input_size` and `output_size`. This +results in tiles of fixed size `[hidden/tile_factor, hidden/tile_factor]`. + +#### Registering external parameters + +**Deprecated:** +DeepSpeed version `0.3.15` introduced automatic external parameter +registration and this step is no longer needed. +{: .notice--info} + + +## Extracting weights + +If you need to take the pretrained weights out of Deepspeed here is what you can do for getting fp16 weights: + +- under ZeRO-2 `state_dict` contains the fp16 model weights and these can be saved normally with `torch.save`. +- under ZeRO-3 `state_dict` contains just the placeholders since the model weights are partitioned across multiple GPUs. If you want to get to these weights enable: + +```json + "zero_optimization": { + "stage3_gather_fp16_weights_on_model_save": true + }, +``` +And then save the model using: + +```python + if self.deepspeed: + self.deepspeed.save_fp16_model(output_dir, output_file) +``` + +Because it requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed. + +Note that if `stage3_gather_fp16_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them. +You can use this method to save ZeRO-2 weights as well. + +If you'd like to get the fp32 weights, we supply a special script that can do offline consolidation. It requires no configuration files or GPUs. Here is an example of its usage: + +``` bash +$ cd /path/to/checkpoint_dir +$ ./zero_to_fp32.py . pytorch_model.bin +Processing zero checkpoint at global_step1 +Detected checkpoint of type zero stage 3, world_size: 2 +Saving fp32 state dict to pytorch_model.bin (total_numel=60506624) +``` + +The `zero_to_fp32.py` gets created automatically when you save a checkpoint. + +Note: currently this script uses 2x memory (general RAM) of the size of the final checkpoint. + +Alternatively, if you have plenty of spare CPU memory and instead of getting the file you want your model to be updated to its fp32 weights, you can do the following at the end of the training: + +``` python + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + fp32_model = load_state_dict_from_zero_checkpoint(deepspeed.module, checkpoint_dir) +``` + +Beware, that the model will be good for saving, but no longer good for continuing the training and will require a `deepspeed.initialize()` anew. + +If you just want the `state_dict`, you can do: + +``` python + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) +``` + + +Congratulations! You have completed the ZeRO tutorial. diff --git a/docs/code-docs/source/schedulers.rst b/docs/code-docs/source/schedulers.rst index c7b67cbb2..5bc23ffb0 100755 --- a/docs/code-docs/source/schedulers.rst +++ b/docs/code-docs/source/schedulers.rst @@ -1,25 +1,25 @@ -Learning Rate Schedulers -=================== - -DeepSpeed offers implementations of ``LRRangeTest``, ``OneCycle``, ``WarmupLR``, ``WarmupDecayLR`` learning rate schedulers. When using a DeepSpeed's learning rate scheduler (specified in the `ds_config.json` file), DeepSpeed calls the `step()` method of the scheduler at every training step (when `model_engine.step()` is executed). When not using a DeepSpeed's learning rate scheduler: - * if the schedule is supposed to execute at every training step, then the user can pass the scheduler to `deepspeed.initialize` when initializing the DeepSpeed engine and let DeepSpeed manage it for update or save/restore. - * if the schedule is supposed to execute at any other interval (e.g., training epochs), then the user should NOT pass the scheduler to DeepSpeed during initialization and must manage it explicitly. - -LRRangeTest ---------------------------- -.. autoclass:: deepspeed.runtime.lr_schedules.LRRangeTest - - -OneCycle ---------------------------- -.. autoclass:: deepspeed.runtime.lr_schedules.OneCycle - - -WarmupLR ---------------------------- -.. autoclass:: deepspeed.runtime.lr_schedules.WarmupLR - - -WarmupDecayLR ---------------------------- -.. autoclass:: deepspeed.runtime.lr_schedules.WarmupDecayLR +Learning Rate Schedulers +=================== + +DeepSpeed offers implementations of ``LRRangeTest``, ``OneCycle``, ``WarmupLR``, ``WarmupDecayLR`` learning rate schedulers. When using a DeepSpeed's learning rate scheduler (specified in the `ds_config.json` file), DeepSpeed calls the `step()` method of the scheduler at every training step (when `model_engine.step()` is executed). When not using a DeepSpeed's learning rate scheduler: + * if the schedule is supposed to execute at every training step, then the user can pass the scheduler to `deepspeed.initialize` when initializing the DeepSpeed engine and let DeepSpeed manage it for update or save/restore. + * if the schedule is supposed to execute at any other interval (e.g., training epochs), then the user should NOT pass the scheduler to DeepSpeed during initialization and must manage it explicitly. + +LRRangeTest +--------------------------- +.. autoclass:: deepspeed.runtime.lr_schedules.LRRangeTest + + +OneCycle +--------------------------- +.. autoclass:: deepspeed.runtime.lr_schedules.OneCycle + + +WarmupLR +--------------------------- +.. autoclass:: deepspeed.runtime.lr_schedules.WarmupLR + + +WarmupDecayLR +--------------------------- +.. autoclass:: deepspeed.runtime.lr_schedules.WarmupDecayLR diff --git a/tests/perf/adam_test.py b/tests/perf/adam_test.py index 0f29cab46..1ddcd44bb 100755 --- a/tests/perf/adam_test.py +++ b/tests/perf/adam_test.py @@ -1,24 +1,24 @@ -import torch -from deepspeed.ops.adam import DeepSpeedCPUAdam -import time - -device = 'cpu' -model_size = 1 * 1024**3 -group_size = [model_size, 274432] - -param = [torch.nn.Parameter(torch.ones(size, device=device)) for size in group_size] -optimizer = DeepSpeedCPUAdam(param) -#torch.set_num_threads(128) -for i, p in enumerate(param): - p.grad = torch.ones(group_size[i], device=device) -#param.grad = torch.ones(model_size, device=device) -avg = 0 -for i in range(100): - start = time.time() - optimizer.step() - stop = time.time() - avg += (stop - start) - for i, p in enumerate(param): - p.grad = torch.ones(group_size[i], device=device) * 2 - #param.grad = torch.ones(model_size, device=device) * 2 -print("Elapsed Time is ", avg / 100) +import torch +from deepspeed.ops.adam import DeepSpeedCPUAdam +import time + +device = 'cpu' +model_size = 1 * 1024**3 +group_size = [model_size, 274432] + +param = [torch.nn.Parameter(torch.ones(size, device=device)) for size in group_size] +optimizer = DeepSpeedCPUAdam(param) +#torch.set_num_threads(128) +for i, p in enumerate(param): + p.grad = torch.ones(group_size[i], device=device) +#param.grad = torch.ones(model_size, device=device) +avg = 0 +for i in range(100): + start = time.time() + optimizer.step() + stop = time.time() + avg += (stop - start) + for i, p in enumerate(param): + p.grad = torch.ones(group_size[i], device=device) * 2 + #param.grad = torch.ones(model_size, device=device) * 2 +print("Elapsed Time is ", avg / 100) diff --git a/tests/perf/adam_test1.py b/tests/perf/adam_test1.py index b0aba0fcd..88f1a1c59 100755 --- a/tests/perf/adam_test1.py +++ b/tests/perf/adam_test1.py @@ -1,22 +1,22 @@ -import torch -from deepspeed.ops.adam import DeepSpeedCPUAdam -import time - -device = 'cpu' -model_size = 1 * 1024**3 -param = torch.nn.Parameter(torch.ones(model_size, device=device)) -param_fp16 = torch.nn.Parameter(torch.ones(model_size, - dtype=torch.half, - device='cuda:0')) - -optimizer = DeepSpeedCPUAdam([param]) -#torch.set_num_threads(128) -param.grad = torch.ones(model_size, device=device) -avg = 0 -for i in range(100): - start = time.time() - optimizer.step(fp16_param_groups=[param_fp16]) - stop = time.time() - avg += (stop - start) - param.grad = torch.ones(model_size, device=device) * 2 -print("Elapsed Time is ", avg / 100) +import torch +from deepspeed.ops.adam import DeepSpeedCPUAdam +import time + +device = 'cpu' +model_size = 1 * 1024**3 +param = torch.nn.Parameter(torch.ones(model_size, device=device)) +param_fp16 = torch.nn.Parameter(torch.ones(model_size, + dtype=torch.half, + device='cuda:0')) + +optimizer = DeepSpeedCPUAdam([param]) +#torch.set_num_threads(128) +param.grad = torch.ones(model_size, device=device) +avg = 0 +for i in range(100): + start = time.time() + optimizer.step(fp16_param_groups=[param_fp16]) + stop = time.time() + avg += (stop - start) + param.grad = torch.ones(model_size, device=device) * 2 +print("Elapsed Time is ", avg / 100) diff --git a/tests/unit/ds_batch_config.json b/tests/unit/ds_batch_config.json index 2558a5b9d..2e86c1929 100755 --- a/tests/unit/ds_batch_config.json +++ b/tests/unit/ds_batch_config.json @@ -1,15 +1,15 @@ -{ - "train_batch_size": 2, - "gradient_accumulation_steps": 1, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": true, - "loss_scale": 0 - } - } +{ + "train_batch_size": 2, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0 + } + } diff --git a/tests/unit/modelingpreln.py b/tests/unit/modelingpreln.py index 43f210ec9..7661303a4 100755 --- a/tests/unit/modelingpreln.py +++ b/tests/unit/modelingpreln.py @@ -1,1692 +1,1692 @@ -# DeepSpeed note, code taken from commit 3d59216cec89a363649b4fe3d15295ba936ced0f -# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/modeling.py - -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch BERT model.""" - -from __future__ import absolute_import, division, print_function, unicode_literals - -import copy -import json -import logging -import math -import os -import shutil -import tarfile -import tempfile -import sys -from io import open - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss -from torch.utils import checkpoint -import torch.distributed as dist - -from torch.nn import Module -from torch.nn.parameter import Parameter -import torch.nn.functional as F -import torch.nn.init as init -import time - -#from numba import cuda - -#from deepspeed_cuda import DeepSpeedSoftmaxConfig, DeepSpeedSoftmax - -logger = logging.getLogger(__name__) - -PRETRAINED_MODEL_ARCHIVE_MAP = { - 'bert-base-uncased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", - 'bert-large-uncased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", - 'bert-base-cased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", - 'bert-large-cased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", - 'bert-base-multilingual-uncased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", - 'bert-base-multilingual-cased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", - 'bert-base-chinese': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", -} -CONFIG_NAME = 'bert_config.json' -WEIGHTS_NAME = 'pytorch_model.bin' -TF_WEIGHTS_NAME = 'model.ckpt' - - -def load_tf_weights_in_bert(model, tf_checkpoint_path): - """ Load tf checkpoints in a pytorch model - """ - try: - import re - import numpy as np - import tensorflow as tf - except ImportError: - print( - "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions.") - raise - tf_path = os.path.abspath(tf_checkpoint_path) - print("Converting TensorFlow checkpoint from {}".format(tf_path)) - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - print("Loading TF weight {} with shape {}".format(name, shape)) - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split('/') - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any(n in ["adam_v", "adam_m"] for n in name): - print("Skipping {}".format("/".join(name))) - continue - pointer = model - for m_name in name: - if re.fullmatch(r'[A-Za-z]+_\d+', m_name): - l = re.split(r'_(\d+)', m_name) - else: - l = [m_name] - if l[0] == 'kernel' or l[0] == 'gamma': - pointer = getattr(pointer, 'weight') - elif l[0] == 'output_bias' or l[0] == 'beta': - pointer = getattr(pointer, 'bias') - elif l[0] == 'output_weights': - pointer = getattr(pointer, 'weight') - else: - pointer = getattr(pointer, l[0]) - if len(l) >= 2: - num = int(l[1]) - pointer = pointer[num] - if m_name[-11:] == '_embeddings': - pointer = getattr(pointer, 'weight') - elif m_name == 'kernel': - array = np.transpose(array) - try: - assert pointer.shape == array.shape - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - print("Initialize PyTorch weight {}".format(name)) - pointer.data = torch.from_numpy(array) - return model - - -""" -@torch.jit.script -def f_gelu(x): - return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) -@torch.jit.script -def bias_gelu(bias, y): - x = bias + y - return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) -@torch.jit.script -def bias_tanh(bias, y): - x = bias + y - return torch.tanh(x) - """ - - -def f_gelu(x): - x_type = x.dtype - x = x.float() - x = x * 0.5 * (1.0 + torch.erf(x / 1.41421)) - return x.to(x_type) - - -def bias_gelu(bias, y): - y_type = y.dtype - x = bias.float() + y.float() - x = x * 0.5 * (1.0 + torch.erf(x / 1.41421)) - return x.to(y_type) - - -def bias_tanh(bias, y): - y_type = y.dtype - x = bias.float() + y.float() - x = torch.tanh(x) - return x.to(y_type) - - -def gelu(x): - """Implementation of the gelu activation function. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - Also see https://arxiv.org/abs/1606.08415 - """ - return f_gelu(x) - - -def swish(x): - return x * torch.sigmoid(x) - - -ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} - - -class GPUTimer: - def __init__(self): - super().__init__() - self.start = cuda.event() - self.stop = cuda.event() - - def record(self): - self.start.record() - - def elapsed(self): - self.stop.record() - self.stop.synchronize() - return self.start.elapsed_time(self.stop) / 1000.0 - - -class LinearActivation(Module): - r"""Fused Linear and activation Module. - """ - __constants__ = ['bias'] - - def __init__(self, - in_features, - out_features, - weights, - biases, - act='gelu', - bias=True): - super(LinearActivation, self).__init__() - self.in_features = in_features - self.out_features = out_features - self.fused_gelu = False - self.fused_tanh = False - if isinstance(act, - str) or (sys.version_info[0] == 2 and isinstance(act, - unicode)): - if bias and act == 'gelu': - self.fused_gelu = True - elif bias and act == 'tanh': - self.fused_tanh = True - else: - self.act_fn = ACT2FN[act] - else: - self.act_fn = act - #self.weight = Parameter(torch.Tensor(out_features, in_features)) - self.weight = weights[5] - self.bias = biases[5] - #if bias: - # self.bias = Parameter(torch.Tensor(out_features)) - #else: - # self.register_parameter('bias', None) - #self.reset_parameters() - - def reset_parameters(self): - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) - init.uniform_(self.bias, -bound, bound) - - def forward(self, input): - if self.fused_gelu: - #timing = [] - #t1 = GPUTimer() - #t1.record() - y = F.linear(input, self.weight, None) - #timing.append(t1.elapsed()) - #t1.record() - bg = bias_gelu(self.bias, y) - #timing.append(t1.elapsed()) - return bg - elif self.fused_tanh: - return bias_tanh(self.bias, F.linear(input, self.weight, None)) - else: - return self.act_fn(F.linear(input, self.weight, self.bias)) - - def extra_repr(self): - return 'in_features={}, out_features={}, bias={}'.format( - self.in_features, - self.out_features, - self.bias is not None) - - -class BertConfig(object): - """Configuration class to store the configuration of a `BertModel`. - """ - def __init__(self, - vocab_size_or_config_json_file, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - batch_size=8, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - fp16=False): - """Constructs BertConfig. - - Args: - vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. - hidden_size: Size of the encoder layers and the pooler layer. - num_hidden_layers: Number of hidden layers in the Transformer encoder. - num_attention_heads: Number of attention heads for each attention layer in - the Transformer encoder. - intermediate_size: The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder. - hidden_act: The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported. - hidden_dropout_prob: The dropout probability for all fully connected - layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob: The dropout ratio for the attention - probabilities. - max_position_embeddings: The maximum sequence length that this model might - ever be used with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048). - type_vocab_size: The vocabulary size of the `token_type_ids` passed into - `BertModel`. - initializer_range: The sttdev of the truncated_normal_initializer for - initializing all weight matrices. - """ - if isinstance(vocab_size_or_config_json_file, - str) or (sys.version_info[0] == 2 - and isinstance(vocab_size_or_config_json_file, - unicode)): - with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - elif isinstance(vocab_size_or_config_json_file, int): - self.vocab_size = vocab_size_or_config_json_file - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.batch_size = batch_size - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - self.fp16 = fp16 - else: - raise ValueError("First argument must be either a vocabulary size (int)" - "or the path to a pretrained model config file (str)") - - @classmethod - def from_dict(cls, json_object): - """Constructs a `BertConfig` from a Python dictionary of parameters.""" - config = BertConfig(vocab_size_or_config_json_file=-1) - for key, value in json_object.items(): - config.__dict__[key] = value - return config - - @classmethod - def from_json_file(cls, json_file): - """Constructs a `BertConfig` from a json file of parameters.""" - with open(json_file, "r", encoding='utf-8') as reader: - text = reader.read() - return cls.from_dict(json.loads(text)) - - def __repr__(self): - return str(self.to_json_string()) - - def to_dict(self): - """Serializes this instance to a Python dictionary.""" - output = copy.deepcopy(self.__dict__) - return output - - def to_json_string(self): - """Serializes this instance to a JSON string.""" - return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" - - -try: - import apex - #apex.amp.register_half_function(apex.normalization.fused_layer_norm, 'FusedLayerNorm') - import apex.normalization - #apex.amp.register_float_function(apex.normalization.FusedLayerNorm, 'forward') - BertLayerNorm = apex.normalization.FusedLayerNorm -except ImportError: - print( - "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex." - ) - - class BertLayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-12): - """Construct a layernorm module in the TF style (epsilon inside the square root). - """ - super(BertLayerNorm, self).__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = nn.Parameter(torch.zeros(hidden_size)) - self.variance_epsilon = eps - - def forward(self, x): - pdtype = x.dtype - x = x.float() - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.weight * x.to(pdtype) + self.bias - - #def forward(self, x): - # u = x.mean(-1, keepdim=True) - # s = (x - u).pow(2).mean(-1, keepdim=True) - # x = (x - u) / torch.sqrt(s + self.variance_epsilon) - # return self.weight * x + self.bias - - -class BertEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings. - """ - def __init__(self, config): - super(BertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, - config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, - config.hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, input_ids, token_type_ids=None): - seq_length = input_ids.size(1) - position_ids = torch.arange(seq_length, - dtype=torch.long, - device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = words_embeddings + position_embeddings + token_type_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BertSelfAttention(nn.Module): - def __init__(self, i, config, weights, biases): - super(BertSelfAttention, self).__init__() - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, - config.num_attention_heads)) - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.query.weight = weights[0] - self.query.bias = biases[0] - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.key.weight = weights[1] - self.key.bias = biases[1] - self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.value.weight = weights[2] - self.value.bias = biases[2] - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.softmax = nn.Softmax(dim=-1) - #self.softmax_config = DeepSpeedSoftmaxConfig() - #self.softmax_config.batch_size = config.batch_size - #self.softmax_config.max_seq_length = config.max_position_embeddings - #self.softmax_config.hidden_size = config.hidden_size - #self.softmax_config.heads = config.num_attention_heads - #self.softmax_config.softmax_id = i - #self.softmax_config.fp16 = config.fp16 - #self.softmax_config.prob_drop_out = 0.0 - #self.softmax = DeepSpeedSoftmax(i, self.softmax_config) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, - self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def transpose_key_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, - self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 3, 1) - - def forward(self, hidden_states, attention_mask, grads=None): - #timing = [] - #t1 = GPUTimer() - #t1.record() - mixed_query_layer = self.query(hidden_states) - - #timing.append(t1.elapsed()) - #print("Query elapsed: %s" % (time.clock() - start)) - #t1.record() - mixed_key_layer = self.key(hidden_states) - - #timing.append(t1.elapsed()) - #print("Key elapsed: %s" % (time.clock() - start)) - #t1.record() - mixed_value_layer = self.value(hidden_states) - #timing.append(t1.elapsed()) - #print("Value elapsed: %s" % (time.clock() - start)) - - #t1.record() - query_layer = self.transpose_for_scores(mixed_query_layer) - # print(query_layer) - #timing.append(t1.elapsed()) - #print("Query-Transform elapsed: %s" % (time.clock() - start)) - #t1.record() - key_layer = self.transpose_key_for_scores(mixed_key_layer) - # print(key_layer) - #timing.append(t1.elapsed()) - #print("Key-Transform elapsed: %s" % (time.clock() - start)) - #t1.record() - value_layer = self.transpose_for_scores(mixed_value_layer) - #print(value_layer) - #timing.append(t1.elapsed()) - #print("Value-Transform elapsed: %s" % (time.clock() - start)) - - # Take the dot product between "query" and "key" to get the raw attention scores. - #t1.record() - #print(query_layer.shape) - #print(key_layer.shape) - attention_scores = torch.matmul(query_layer, key_layer) - #print(attention_scores.shape) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - #print("Pytorch: ", attention_scores) - #timing.append(t1.elapsed()) - #print("Attention-Score elapsed: %s" % (time.clock() - start)) - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - #t1.record() - - # context_layer = self.softmax(query_layer, key_layer, value_layer, attention_mask) - #print("context shape is :", context_layer.shape) - #print("Cuda-ext:, ", attention_scores1) - # Normalize the attention scores to probabilities. - ####attention_probs = self.softmax(attention_scores) - #timing.append(t1.elapsed()) - #print("Softmax elapsed: %s" % (time.clock() - start)) - #t1 = GPUTimer() - #t1.record() - attention_scores = attention_scores + attention_mask - attention_probs = self.softmax(attention_scores) - #attention_scores = self.softmax(attention_scores, attention_mask) - #print("Softmax elapse {0:8.2f} ms", t1.elapsed() * 1000) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - #t1.record() - context_layer = torch.matmul(attention_probs, value_layer) - #timing.append(t1.elapsed()) - #print("Context elapsed: %s" % (time.clock() - start)) - #t1.record() - #context_layer1 = context_layer.permute( - # 0, 1, 3, 2, 4).contiguous() - #if grads is not None: - # context_layer.register_hook(lambda x, self = self : grads.append([x, "Context"])) - context_layer1 = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer1.size()[:-2] + (self.all_head_size, ) - context_layer1 = context_layer1.view(*new_context_layer_shape) - #timing.append(t1.elapsed()) - #print("Context-Transform elapsed: %s" % (time.clock() - start)) - - if grads is not None: - query_layer.register_hook(lambda x, self=self: grads.append([x, "Query"])) - key_layer.register_hook(lambda x, self=self: grads.append([x, "Key"])) - value_layer.register_hook(lambda x, self=self: grads.append([x, "Value"])) - - return context_layer1 - - -class BertSelfOutput(nn.Module): - def __init__(self, config, weights, biases): - super(BertSelfOutput, self).__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.dense.weight = weights[3] - self.dense.bias = biases[3] - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - #timing = [] - #t1 = GPUTimer() - #t1.record() - hidden_states = self.dense(hidden_states) - #timing.append(t1.elapsed()) - #print("Attention Output elapsed: %s" % (time.clock() - start)) - hidden_states = self.dropout(hidden_states) - #t1.record() - #hidden_states = self.LayerNorm(hidden_states + input_tensor) - #timing.append(t1.elapsed()) - #print("LayerNorm elapsed: %s" % (time.clock() - start)) - return hidden_states - - def get_w(self): - return self.dense.weight - - -class BertAttention(nn.Module): - def __init__(self, i, config, weights, biases): - super(BertAttention, self).__init__() - self.self = BertSelfAttention(i, config, weights, biases) - self.output = BertSelfOutput(config, weights, biases) - - def forward(self, input_tensor, attention_mask): - self_output = self.self(input_tensor, attention_mask) - attention_output = self.output(self_output, input_tensor) - return attention_output - - def get_w(self): - return self.output.get_w() - - -class BertIntermediate(nn.Module): - def __init__(self, config, weights, biases): - super(BertIntermediate, self).__init__() - self.dense_act = LinearActivation(config.hidden_size, - config.intermediate_size, - weights, - biases, - act=config.hidden_act) - - def forward(self, hidden_states): - hidden_states = self.dense_act(hidden_states) - return hidden_states - - -class BertOutput(nn.Module): - def __init__(self, config, weights, biases): - super(BertOutput, self).__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.dense.weight = weights[6] - self.dense.bias = biases[6] - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - #timing = [] - #t1 = GPUTimer() - #t1.record() - #print (hidden_states) - #print (self.dense.weight) - hidden_states = self.dense(hidden_states) - #timing.append(t1.elapsed()) - #print("FF2 elapsed: %s" % (time.clock() - start)) - hidden_states = self.dropout(hidden_states) - #t1.record() - #hidden_states = self.LayerNorm(hidden_states + input_tensor) - #timing.append(t1.elapsed()) - #print("LayerNorm elapsed: %s" % (time.clock() - start)) - return hidden_states - - -class BertLayer(nn.Module): - def __init__(self, i, config, weights, biases): - super(BertLayer, self).__init__() - self.attention = BertAttention(i, config, weights, biases) - self.PreAttentionLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - self.PostAttentionLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - self.intermediate = BertIntermediate(config, weights, biases) - self.output = BertOutput(config, weights, biases) - self.weight = weights - self.biases = biases - - def forward(self, hidden_states, attention_mask, grads, collect_all_grads=False): - input_layer_norm = self.PreAttentionLayerNorm(hidden_states) - attention_output = self.attention(input_layer_norm, attention_mask) - #print ("hidden shape is :", hidden_states.shape) - intermediate_input = hidden_states + attention_output - - intermediate_layer_norm = self.PostAttentionLayerNorm(intermediate_input) - intermediate_output = self.intermediate(intermediate_layer_norm) - layer_output = self.output(intermediate_output, attention_output) - - #attention_output = self.attention(hidden_states, attention_mask) - #intermediate_output = self.intermediate(attention_output) - #layer_output = self.output(intermediate_output, attention_output) - - if collect_all_grads: - # self.weight[0].register_hook(lambda x, self=self: grads.append([x,"Q_W"])) - # self.biases[0].register_hook(lambda x, self=self: grads.append([x,"Q_B"])) - # self.weight[1].register_hook(lambda x, self=self: grads.append([x,"K_W"])) - # self.biases[1].register_hook(lambda x, self=self: grads.append([x,"K_B"])) - self.weight[2].register_hook(lambda x, self=self: grads.append([x, "V_W"])) - self.biases[2].register_hook(lambda x, self=self: grads.append([x, "V_B"])) - self.weight[3].register_hook(lambda x, self=self: grads.append([x, "O_W"])) - self.biases[3].register_hook(lambda x, self=self: grads.append([x, "O_B"])) - self.PostAttentionLayerNorm.weight.register_hook( - lambda x, - self=self: grads.append([x, - "N2_W"])) - self.PostAttentionLayerNorm.bias.register_hook( - lambda x, - self=self: grads.append([x, - "N2_B"])) - self.weight[5].register_hook(lambda x, self=self: grads.append([x, "int_W"])) - self.biases[5].register_hook(lambda x, self=self: grads.append([x, "int_B"])) - self.weight[6].register_hook(lambda x, self=self: grads.append([x, "out_W"])) - self.biases[6].register_hook(lambda x, self=self: grads.append([x, "out_B"])) - self.PreAttentionLayerNorm.weight.register_hook( - lambda x, - self=self: grads.append([x, - "norm_W"])) - self.PreAttentionLayerNorm.bias.register_hook( - lambda x, - self=self: grads.append([x, - "norm_B"])) - - return layer_output + intermediate_input - - def get_w(self): - return self.attention.get_w() - - -class BertEncoder(nn.Module): - def __init__(self, config, weights, biases): - super(BertEncoder, self).__init__() - #layer = BertLayer(config, weights, biases) - self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - - self.layer = nn.ModuleList([ - copy.deepcopy(BertLayer(i, - config, - weights, - biases)) for i in range(config.num_hidden_layers) - ]) - self.grads = [] - self.graph = [] - - def get_grads(self): - return self.grads - - # def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): - # all_encoder_layers = [] - # for layer_module in self.layer: - # hidden_states = layer_module(hidden_states, attention_mask) - # if output_all_encoded_layers: - # all_encoder_layers.append(hidden_states) - # if not output_all_encoded_layers: - # all_encoder_layers.append(hidden_states) - # return all_encoder_layers - - def get_modules(self, big_node, input): - for mdl in big_node.named_children(): - graph.append(mdl) - get_modules(self, mdl, input) - - def forward(self, - hidden_states, - attention_mask, - output_all_encoded_layers=True, - checkpoint_activations=False): - all_encoder_layers = [] - - def custom(start, end): - def custom_forward(*inputs): - layers = self.layer[start:end] - x_ = inputs[0] - for layer in layers: - x_ = layer(x_, inputs[1]) - return x_ - - return custom_forward - - if checkpoint_activations: - l = 0 - num_layers = len(self.layer) - chunk_length = math.ceil(math.sqrt(num_layers)) - while l < num_layers: - hidden_states = checkpoint.checkpoint(custom(l, - l + chunk_length), - hidden_states, - attention_mask * 1) - l += chunk_length - # decoder layers - else: - for i, layer_module in enumerate(self.layer): - hidden_states = layer_module(hidden_states, - attention_mask, - self.grads, - collect_all_grads=True) - hidden_states.register_hook( - lambda x, - i=i, - self=self: self.grads.append([x, - "hidden_state"])) - #print("pytorch weight is: ", layer_module.get_w()) - - if output_all_encoded_layers: - all_encoder_layers.append((hidden_states)) - - if not output_all_encoded_layers or checkpoint_activations: - hidden_states = self.FinalLayerNorm(hidden_states) - all_encoder_layers.append((hidden_states)) - return all_encoder_layers - - -#class BertEncoder(nn.Module): -# def __init__(self, config): -# super(BertEncoder, self).__init__() -# layer = BertLayer(config) -# self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) -# -# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): -# all_encoder_layers = [] -# for layer_module in self.layer: -# hidden_states = layer_module(hidden_states, attention_mask) -# if output_all_encoded_layers: -# all_encoder_layers.append(hidden_states) -# if not output_all_encoded_layers: -# all_encoder_layers.append(hidden_states) -# return all_encoder_layers - - -class BertPooler(nn.Module): - def __init__(self, config): - super(BertPooler, self).__init__() - self.dense_act = LinearActivation(config.hidden_size, - config.hidden_size, - act="tanh") - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense_act(first_token_tensor) - return pooled_output - - -class BertPredictionHeadTransform(nn.Module): - def __init__(self, config): - super(BertPredictionHeadTransform, self).__init__() - self.dense_act = LinearActivation(config.hidden_size, - config.hidden_size, - act=config.hidden_act) - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - - def forward(self, hidden_states): - hidden_states = self.dense_act(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class BertLMPredictionHead(nn.Module): - def __init__(self, config, bert_model_embedding_weights): - super(BertLMPredictionHead, self).__init__() - self.transform = BertPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(bert_model_embedding_weights.size(1), - bert_model_embedding_weights.size(0), - bias=False) - self.decoder.weight = bert_model_embedding_weights - self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - torch.cuda.nvtx.range_push( - "decoder input.size() = {}, weight.size() = {}".format( - hidden_states.size(), - self.decoder.weight.size())) - hidden_states = self.decoder(hidden_states) + self.bias - torch.cuda.nvtx.range_pop() - return hidden_states - - -class BertOnlyMLMHead(nn.Module): - def __init__(self, config, bert_model_embedding_weights): - super(BertOnlyMLMHead, self).__init__() - self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class BertOnlyNSPHead(nn.Module): - def __init__(self, config): - super(BertOnlyNSPHead, self).__init__() - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, pooled_output): - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - -class BertPreTrainingHeads(nn.Module): - def __init__(self, config, bert_model_embedding_weights): - super(BertPreTrainingHeads, self).__init__() - self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, sequence_output, pooled_output): - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class BertPreTrainedModel(nn.Module): - """ An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - def __init__(self, config, *inputs, **kwargs): - super(BertPreTrainedModel, self).__init__() - if not isinstance(config, BertConfig): - raise ValueError( - "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " - "To create a model from a Google pretrained model use " - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( - self.__class__.__name__, - self.__class__.__name__)) - self.config = config - - def init_bert_weights(self, module): - """ Initialize the weights. - """ - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, BertLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - @classmethod - def from_pretrained(cls, - pretrained_model_name_or_path, - state_dict=None, - cache_dir=None, - from_tf=False, - *inputs, - **kwargs): - """ - Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. - Download and cache the pre-trained model file if needed. - - Params: - pretrained_model_name_or_path: either: - - a str with the name of a pre-trained model to load selected in the list of: - . `bert-base-uncased` - . `bert-large-uncased` - . `bert-base-cased` - . `bert-large-cased` - . `bert-base-multilingual-uncased` - . `bert-base-multilingual-cased` - . `bert-base-chinese` - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `model.chkpt` a TensorFlow checkpoint - from_tf: should we load the weights from a locally saved TensorFlow checkpoint - cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of Google pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) - """ - if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: - archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] - else: - archive_file = pretrained_model_name_or_path - if resolved_archive_file == archive_file: - logger.info("loading archive file {}".format(archive_file)) - else: - logger.info("loading archive file {} from cache at {}".format( - archive_file, - resolved_archive_file)) - tempdir = None - if os.path.isdir(resolved_archive_file) or from_tf: - serialization_dir = resolved_archive_file - else: - # Extract archive to temp dir - tempdir = tempfile.mkdtemp() - logger.info("extracting archive file {} to temp dir {}".format( - resolved_archive_file, - tempdir)) - with tarfile.open(resolved_archive_file, 'r:gz') as archive: - archive.extractall(tempdir) - serialization_dir = tempdir - # Load config - config_file = os.path.join(serialization_dir, CONFIG_NAME) - config = BertConfig.from_json_file(config_file) - logger.info("Model config {}".format(config)) - # Instantiate model. - model = cls(config, *inputs, **kwargs) - if state_dict is None and not from_tf: - weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load( - weights_path, - map_location='cpu' if not torch.cuda.is_available() else None) - if tempdir: - # Clean up temp dir - shutil.rmtree(tempdir) - if from_tf: - # Directly load from a TensorFlow checkpoint - weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) - return load_tf_weights_in_bert(model, weights_path) - # Load from a PyTorch state_dict - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - def load(module, prefix=''): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - module._load_from_state_dict(state_dict, - prefix, - local_metadata, - True, - missing_keys, - unexpected_keys, - error_msgs) - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + '.') - - start_prefix = '' - if not hasattr(model, - 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): - start_prefix = 'bert.' - load(model, prefix=start_prefix) - if len(missing_keys) > 0: - logger.info("Weights of {} not initialized from pretrained model: {}".format( - model.__class__.__name__, - missing_keys)) - if len(unexpected_keys) > 0: - logger.info("Weights from pretrained model not used in {}: {}".format( - model.__class__.__name__, - unexpected_keys)) - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, - "\n\t".join(error_msgs))) - return model - - -class BertModel(BertPreTrainedModel): - """BERT model ("Bidirectional Embedding Representations from a Transformer"). - - Params: - config: a BertConfig class instance with the configuration to build a new model - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. - - Outputs: Tuple of (encoded_layers, pooled_output) - `encoded_layers`: controlled by `output_all_encoded_layers` argument: - - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end - of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each - encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - to the last attention block of shape [batch_size, sequence_length, hidden_size], - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a - classifier pretrained on top of the hidden state associated to the first character of the - input (`CLS`) to train on the Next-Sentence task (see BERT's paper). - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = modeling.BertModel(config=config) - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config): - super(BertModel, self).__init__(config) - self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config) - self.pooler = BertPooler(config) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - output_all_encoded_layers=True, - checkpoint_activations=False): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next( - self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - embedding_output = self.embeddings(input_ids, token_type_ids) - encoded_layers = self.encoder( - embedding_output, - extended_attention_mask, - output_all_encoded_layers=output_all_encoded_layers, - checkpoint_activations=checkpoint_activations) - sequence_output = encoded_layers[-1] - pooled_output = self.pooler(sequence_output) - if not output_all_encoded_layers: - encoded_layers = encoded_layers[-1] - return encoded_layers, pooled_output - - -class BertForPreTraining(BertPreTrainedModel): - """BERT model with pre-training heads. - This module comprises the BERT model followed by the two pre-training heads: - - the masked language modeling head, and - - the next sentence classification head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss - is only computed for the labels set in [0, ..., vocab_size] - `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] - with indices selected in [0, 1]. - 0 => next sentence is the continuation, 1 => next sentence is a random sentence. - - Outputs: - if `masked_lm_labels` and `next_sentence_label` are not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `masked_lm_labels` or `next_sentence_label` is `None`: - Outputs a tuple comprising - - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and - - the next sentence classification logits of shape [batch_size, 2]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForPreTraining(config) - masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config, args): - super(BertForPreTraining, self).__init__(config) - self.summary_writer = None - if dist.get_rank() == 0: - self.summary_writer = args.summary_writer - self.samples_per_step = dist.get_world_size() * args.train_batch_size - self.sample_count = self.samples_per_step - self.bert = BertModel(config) - self.cls = BertPreTrainingHeads(config, - self.bert.embeddings.word_embeddings.weight) - self.apply(self.init_bert_weights) - - def log_summary_writer(self, logs: dict, base='Train'): - if dist.get_rank() == 0: - module_name = "Samples" #self._batch_module_name.get(batch_type, self._get_batch_type_error(batch_type)) - for key, log in logs.items(): - self.summary_writer.add_scalar(f'{base}/{module_name}/{key}', - log, - self.sample_count) - self.sample_count += self.samples_per_step - - def forward(self, batch, log=True): - #input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, checkpoint_activations=False): - input_ids = batch[1] - token_type_ids = batch[3] - attention_mask = batch[2] - masked_lm_labels = batch[5] - next_sentence_label = batch[4] - checkpoint_activations = False - - sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations) - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) - - if masked_lm_labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.view(-1, - self.config.vocab_size), - masked_lm_labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, - 2), - next_sentence_label.view(-1)) - #print("loss is {} {}".format(masked_lm_loss, next_sentence_loss)) - total_loss = masked_lm_loss + next_sentence_loss - # if log: - # self.log_summary_writer(logs={'train_loss': total_loss.item()}) - return total_loss - else: - return prediction_scores, seq_relationship_score - - -class BertForMaskedLM(BertPreTrainedModel): - """BERT model with the masked language modeling head. - This module comprises the BERT model followed by the masked language modeling head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss - is only computed for the labels set in [0, ..., vocab_size] - - Outputs: - if `masked_lm_labels` is not `None`: - Outputs the masked language modeling loss. - if `masked_lm_labels` is `None`: - Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForMaskedLM(config) - masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config): - super(BertForMaskedLM, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - masked_lm_labels=None, - checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False) - prediction_scores = self.cls(sequence_output) - - if masked_lm_labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.view(-1, - self.config.vocab_size), - masked_lm_labels.view(-1)) - return masked_lm_loss - else: - return prediction_scores - - -class BertForNextSentencePrediction(BertPreTrainedModel): - """BERT model with next sentence prediction head. - This module comprises the BERT model followed by the next sentence classification head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] - with indices selected in [0, 1]. - 0 => next sentence is the continuation, 1 => next sentence is a random sentence. - - Outputs: - if `next_sentence_label` is not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `next_sentence_label` is `None`: - Outputs the next sentence classification logits of shape [batch_size, 2]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForNextSentencePrediction(config) - seq_relationship_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config): - super(BertForNextSentencePrediction, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertOnlyNSPHead(config) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - next_sentence_label=None, - checkpoint_activations=False): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False) - seq_relationship_score = self.cls(pooled_output) - - if next_sentence_label is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, - 2), - next_sentence_label.view(-1)) - return next_sentence_loss - else: - return seq_relationship_score - - -class BertForSequenceClassification(BertPreTrainedModel): - """BERT model for classification. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_labels]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_labels = 2 - - model = BertForSequenceClassification(config, num_labels) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config, num_labels): - super(BertForSequenceClassification, self).__init__(config) - self.num_labels = num_labels - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - labels=None, - checkpoint_activations=False): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss - else: - return logits - - -class BertForMultipleChoice(BertPreTrainedModel): - """BERT model for multiple choice tasks. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_choices`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` - and type 1 corresponds to a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_choices]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) - input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) - token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_choices = 2 - - model = BertForMultipleChoice(config, num_choices) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config, num_choices): - super(BertForMultipleChoice, self).__init__(config) - self.num_choices = num_choices - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, 1) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - labels=None, - checkpoint_activations=False): - flat_input_ids = input_ids.view(-1, input_ids.size(-1)) - flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) - flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) - _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, self.num_choices) - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - return loss - else: - return reshaped_logits - - -class BertForTokenClassification(BertPreTrainedModel): - """BERT model for token-level classification. - This module is composed of the BERT model with a linear layer on top of - the full hidden state of the last layer. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [0, ..., num_labels]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_labels = 2 - - model = BertForTokenClassification(config, num_labels) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config, num_labels): - super(BertForTokenClassification, self).__init__(config) - self.num_labels = num_labels - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - labels=None, - checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels)[active_loss] - active_labels = labels.view(-1)[active_loss] - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss - else: - return logits - - -class BertForQuestionAnswering(BertPreTrainedModel): - """BERT model for Question Answering (span extraction). - This module is composed of the BERT model with a linear layer on top of - the sequence output that computes start_logits and end_logits - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - - Outputs: - if `start_positions` and `end_positions` are not `None`: - Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. - if `start_positions` or `end_positions` is `None`: - Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end - position tokens of shape [batch_size, sequence_length]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForQuestionAnswering(config) - start_logits, end_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config): - super(BertForQuestionAnswering, self).__init__(config) - self.bert = BertModel(config) - # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version - # self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - start_positions=None, - end_positions=None, - checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions.clamp_(0, ignored_index) - end_positions.clamp_(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - return total_loss - else: - return start_logits, end_logits +# DeepSpeed note, code taken from commit 3d59216cec89a363649b4fe3d15295ba936ced0f +# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/modeling.py + +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import copy +import json +import logging +import math +import os +import shutil +import tarfile +import tempfile +import sys +from io import open + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils import checkpoint +import torch.distributed as dist + +from torch.nn import Module +from torch.nn.parameter import Parameter +import torch.nn.functional as F +import torch.nn.init as init +import time + +#from numba import cuda + +#from deepspeed_cuda import DeepSpeedSoftmaxConfig, DeepSpeedSoftmax + +logger = logging.getLogger(__name__) + +PRETRAINED_MODEL_ARCHIVE_MAP = { + 'bert-base-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", + 'bert-large-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", + 'bert-base-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", + 'bert-large-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", + 'bert-base-multilingual-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", + 'bert-base-multilingual-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", + 'bert-base-chinese': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", +} +CONFIG_NAME = 'bert_config.json' +WEIGHTS_NAME = 'pytorch_model.bin' +TF_WEIGHTS_NAME = 'model.ckpt' + + +def load_tf_weights_in_bert(model, tf_checkpoint_path): + """ Load tf checkpoints in a pytorch model + """ + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + print( + "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions.") + raise + tf_path = os.path.abspath(tf_checkpoint_path) + print("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + print("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ["adam_v", "adam_m"] for n in name): + print("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + l = re.split(r'_(\d+)', m_name) + else: + l = [m_name] + if l[0] == 'kernel' or l[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif l[0] == 'output_bias' or l[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif l[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + else: + pointer = getattr(pointer, l[0]) + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +""" +@torch.jit.script +def f_gelu(x): + return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) +@torch.jit.script +def bias_tanh(bias, y): + x = bias + y + return torch.tanh(x) + """ + + +def f_gelu(x): + x_type = x.dtype + x = x.float() + x = x * 0.5 * (1.0 + torch.erf(x / 1.41421)) + return x.to(x_type) + + +def bias_gelu(bias, y): + y_type = y.dtype + x = bias.float() + y.float() + x = x * 0.5 * (1.0 + torch.erf(x / 1.41421)) + return x.to(y_type) + + +def bias_tanh(bias, y): + y_type = y.dtype + x = bias.float() + y.float() + x = torch.tanh(x) + return x.to(y_type) + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return f_gelu(x) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class GPUTimer: + def __init__(self): + super().__init__() + self.start = cuda.event() + self.stop = cuda.event() + + def record(self): + self.start.record() + + def elapsed(self): + self.stop.record() + self.stop.synchronize() + return self.start.elapsed_time(self.stop) / 1000.0 + + +class LinearActivation(Module): + r"""Fused Linear and activation Module. + """ + __constants__ = ['bias'] + + def __init__(self, + in_features, + out_features, + weights, + biases, + act='gelu', + bias=True): + super(LinearActivation, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.fused_gelu = False + self.fused_tanh = False + if isinstance(act, + str) or (sys.version_info[0] == 2 and isinstance(act, + unicode)): + if bias and act == 'gelu': + self.fused_gelu = True + elif bias and act == 'tanh': + self.fused_tanh = True + else: + self.act_fn = ACT2FN[act] + else: + self.act_fn = act + #self.weight = Parameter(torch.Tensor(out_features, in_features)) + self.weight = weights[5] + self.bias = biases[5] + #if bias: + # self.bias = Parameter(torch.Tensor(out_features)) + #else: + # self.register_parameter('bias', None) + #self.reset_parameters() + + def reset_parameters(self): + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + if self.fused_gelu: + #timing = [] + #t1 = GPUTimer() + #t1.record() + y = F.linear(input, self.weight, None) + #timing.append(t1.elapsed()) + #t1.record() + bg = bias_gelu(self.bias, y) + #timing.append(t1.elapsed()) + return bg + elif self.fused_tanh: + return bias_tanh(self.bias, F.linear(input, self.weight, None)) + else: + return self.act_fn(F.linear(input, self.weight, self.bias)) + + def extra_repr(self): + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, + self.out_features, + self.bias is not None) + + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + batch_size=8, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + fp16=False): + """Constructs BertConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probability for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + if isinstance(vocab_size_or_config_json_file, + str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, + unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.batch_size = batch_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.fp16 = fp16 + else: + raise ValueError("First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)") + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + +try: + import apex + #apex.amp.register_half_function(apex.normalization.fused_layer_norm, 'FusedLayerNorm') + import apex.normalization + #apex.amp.register_float_function(apex.normalization.FusedLayerNorm, 'forward') + BertLayerNorm = apex.normalization.FusedLayerNorm +except ImportError: + print( + "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex." + ) + + class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + pdtype = x.dtype + x = x.float() + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x.to(pdtype) + self.bias + + #def forward(self, x): + # u = x.mean(-1, keepdim=True) + # s = (x - u).pow(2).mean(-1, keepdim=True) + # x = (x - u) / torch.sqrt(s + self.variance_epsilon) + # return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, + dtype=torch.long, + device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, i, config, weights, biases): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, + config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.query.weight = weights[0] + self.query.bias = biases[0] + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.key.weight = weights[1] + self.key.bias = biases[1] + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.value.weight = weights[2] + self.value.bias = biases[2] + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.softmax = nn.Softmax(dim=-1) + #self.softmax_config = DeepSpeedSoftmaxConfig() + #self.softmax_config.batch_size = config.batch_size + #self.softmax_config.max_seq_length = config.max_position_embeddings + #self.softmax_config.hidden_size = config.hidden_size + #self.softmax_config.heads = config.num_attention_heads + #self.softmax_config.softmax_id = i + #self.softmax_config.fp16 = config.fp16 + #self.softmax_config.prob_drop_out = 0.0 + #self.softmax = DeepSpeedSoftmax(i, self.softmax_config) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def transpose_key_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 3, 1) + + def forward(self, hidden_states, attention_mask, grads=None): + #timing = [] + #t1 = GPUTimer() + #t1.record() + mixed_query_layer = self.query(hidden_states) + + #timing.append(t1.elapsed()) + #print("Query elapsed: %s" % (time.clock() - start)) + #t1.record() + mixed_key_layer = self.key(hidden_states) + + #timing.append(t1.elapsed()) + #print("Key elapsed: %s" % (time.clock() - start)) + #t1.record() + mixed_value_layer = self.value(hidden_states) + #timing.append(t1.elapsed()) + #print("Value elapsed: %s" % (time.clock() - start)) + + #t1.record() + query_layer = self.transpose_for_scores(mixed_query_layer) + # print(query_layer) + #timing.append(t1.elapsed()) + #print("Query-Transform elapsed: %s" % (time.clock() - start)) + #t1.record() + key_layer = self.transpose_key_for_scores(mixed_key_layer) + # print(key_layer) + #timing.append(t1.elapsed()) + #print("Key-Transform elapsed: %s" % (time.clock() - start)) + #t1.record() + value_layer = self.transpose_for_scores(mixed_value_layer) + #print(value_layer) + #timing.append(t1.elapsed()) + #print("Value-Transform elapsed: %s" % (time.clock() - start)) + + # Take the dot product between "query" and "key" to get the raw attention scores. + #t1.record() + #print(query_layer.shape) + #print(key_layer.shape) + attention_scores = torch.matmul(query_layer, key_layer) + #print(attention_scores.shape) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + #print("Pytorch: ", attention_scores) + #timing.append(t1.elapsed()) + #print("Attention-Score elapsed: %s" % (time.clock() - start)) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + #t1.record() + + # context_layer = self.softmax(query_layer, key_layer, value_layer, attention_mask) + #print("context shape is :", context_layer.shape) + #print("Cuda-ext:, ", attention_scores1) + # Normalize the attention scores to probabilities. + ####attention_probs = self.softmax(attention_scores) + #timing.append(t1.elapsed()) + #print("Softmax elapsed: %s" % (time.clock() - start)) + #t1 = GPUTimer() + #t1.record() + attention_scores = attention_scores + attention_mask + attention_probs = self.softmax(attention_scores) + #attention_scores = self.softmax(attention_scores, attention_mask) + #print("Softmax elapse {0:8.2f} ms", t1.elapsed() * 1000) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + #t1.record() + context_layer = torch.matmul(attention_probs, value_layer) + #timing.append(t1.elapsed()) + #print("Context elapsed: %s" % (time.clock() - start)) + #t1.record() + #context_layer1 = context_layer.permute( + # 0, 1, 3, 2, 4).contiguous() + #if grads is not None: + # context_layer.register_hook(lambda x, self = self : grads.append([x, "Context"])) + context_layer1 = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer1.size()[:-2] + (self.all_head_size, ) + context_layer1 = context_layer1.view(*new_context_layer_shape) + #timing.append(t1.elapsed()) + #print("Context-Transform elapsed: %s" % (time.clock() - start)) + + if grads is not None: + query_layer.register_hook(lambda x, self=self: grads.append([x, "Query"])) + key_layer.register_hook(lambda x, self=self: grads.append([x, "Key"])) + value_layer.register_hook(lambda x, self=self: grads.append([x, "Value"])) + + return context_layer1 + + +class BertSelfOutput(nn.Module): + def __init__(self, config, weights, biases): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dense.weight = weights[3] + self.dense.bias = biases[3] + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + #timing = [] + #t1 = GPUTimer() + #t1.record() + hidden_states = self.dense(hidden_states) + #timing.append(t1.elapsed()) + #print("Attention Output elapsed: %s" % (time.clock() - start)) + hidden_states = self.dropout(hidden_states) + #t1.record() + #hidden_states = self.LayerNorm(hidden_states + input_tensor) + #timing.append(t1.elapsed()) + #print("LayerNorm elapsed: %s" % (time.clock() - start)) + return hidden_states + + def get_w(self): + return self.dense.weight + + +class BertAttention(nn.Module): + def __init__(self, i, config, weights, biases): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(i, config, weights, biases) + self.output = BertSelfOutput(config, weights, biases) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + def get_w(self): + return self.output.get_w() + + +class BertIntermediate(nn.Module): + def __init__(self, config, weights, biases): + super(BertIntermediate, self).__init__() + self.dense_act = LinearActivation(config.hidden_size, + config.intermediate_size, + weights, + biases, + act=config.hidden_act) + + def forward(self, hidden_states): + hidden_states = self.dense_act(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config, weights, biases): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dense.weight = weights[6] + self.dense.bias = biases[6] + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + #timing = [] + #t1 = GPUTimer() + #t1.record() + #print (hidden_states) + #print (self.dense.weight) + hidden_states = self.dense(hidden_states) + #timing.append(t1.elapsed()) + #print("FF2 elapsed: %s" % (time.clock() - start)) + hidden_states = self.dropout(hidden_states) + #t1.record() + #hidden_states = self.LayerNorm(hidden_states + input_tensor) + #timing.append(t1.elapsed()) + #print("LayerNorm elapsed: %s" % (time.clock() - start)) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, i, config, weights, biases): + super(BertLayer, self).__init__() + self.attention = BertAttention(i, config, weights, biases) + self.PreAttentionLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.PostAttentionLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.intermediate = BertIntermediate(config, weights, biases) + self.output = BertOutput(config, weights, biases) + self.weight = weights + self.biases = biases + + def forward(self, hidden_states, attention_mask, grads, collect_all_grads=False): + input_layer_norm = self.PreAttentionLayerNorm(hidden_states) + attention_output = self.attention(input_layer_norm, attention_mask) + #print ("hidden shape is :", hidden_states.shape) + intermediate_input = hidden_states + attention_output + + intermediate_layer_norm = self.PostAttentionLayerNorm(intermediate_input) + intermediate_output = self.intermediate(intermediate_layer_norm) + layer_output = self.output(intermediate_output, attention_output) + + #attention_output = self.attention(hidden_states, attention_mask) + #intermediate_output = self.intermediate(attention_output) + #layer_output = self.output(intermediate_output, attention_output) + + if collect_all_grads: + # self.weight[0].register_hook(lambda x, self=self: grads.append([x,"Q_W"])) + # self.biases[0].register_hook(lambda x, self=self: grads.append([x,"Q_B"])) + # self.weight[1].register_hook(lambda x, self=self: grads.append([x,"K_W"])) + # self.biases[1].register_hook(lambda x, self=self: grads.append([x,"K_B"])) + self.weight[2].register_hook(lambda x, self=self: grads.append([x, "V_W"])) + self.biases[2].register_hook(lambda x, self=self: grads.append([x, "V_B"])) + self.weight[3].register_hook(lambda x, self=self: grads.append([x, "O_W"])) + self.biases[3].register_hook(lambda x, self=self: grads.append([x, "O_B"])) + self.PostAttentionLayerNorm.weight.register_hook( + lambda x, + self=self: grads.append([x, + "N2_W"])) + self.PostAttentionLayerNorm.bias.register_hook( + lambda x, + self=self: grads.append([x, + "N2_B"])) + self.weight[5].register_hook(lambda x, self=self: grads.append([x, "int_W"])) + self.biases[5].register_hook(lambda x, self=self: grads.append([x, "int_B"])) + self.weight[6].register_hook(lambda x, self=self: grads.append([x, "out_W"])) + self.biases[6].register_hook(lambda x, self=self: grads.append([x, "out_B"])) + self.PreAttentionLayerNorm.weight.register_hook( + lambda x, + self=self: grads.append([x, + "norm_W"])) + self.PreAttentionLayerNorm.bias.register_hook( + lambda x, + self=self: grads.append([x, + "norm_B"])) + + return layer_output + intermediate_input + + def get_w(self): + return self.attention.get_w() + + +class BertEncoder(nn.Module): + def __init__(self, config, weights, biases): + super(BertEncoder, self).__init__() + #layer = BertLayer(config, weights, biases) + self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + + self.layer = nn.ModuleList([ + copy.deepcopy(BertLayer(i, + config, + weights, + biases)) for i in range(config.num_hidden_layers) + ]) + self.grads = [] + self.graph = [] + + def get_grads(self): + return self.grads + + # def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + # all_encoder_layers = [] + # for layer_module in self.layer: + # hidden_states = layer_module(hidden_states, attention_mask) + # if output_all_encoded_layers: + # all_encoder_layers.append(hidden_states) + # if not output_all_encoded_layers: + # all_encoder_layers.append(hidden_states) + # return all_encoder_layers + + def get_modules(self, big_node, input): + for mdl in big_node.named_children(): + graph.append(mdl) + get_modules(self, mdl, input) + + def forward(self, + hidden_states, + attention_mask, + output_all_encoded_layers=True, + checkpoint_activations=False): + all_encoder_layers = [] + + def custom(start, end): + def custom_forward(*inputs): + layers = self.layer[start:end] + x_ = inputs[0] + for layer in layers: + x_ = layer(x_, inputs[1]) + return x_ + + return custom_forward + + if checkpoint_activations: + l = 0 + num_layers = len(self.layer) + chunk_length = math.ceil(math.sqrt(num_layers)) + while l < num_layers: + hidden_states = checkpoint.checkpoint(custom(l, + l + chunk_length), + hidden_states, + attention_mask * 1) + l += chunk_length + # decoder layers + else: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, + attention_mask, + self.grads, + collect_all_grads=True) + hidden_states.register_hook( + lambda x, + i=i, + self=self: self.grads.append([x, + "hidden_state"])) + #print("pytorch weight is: ", layer_module.get_w()) + + if output_all_encoded_layers: + all_encoder_layers.append((hidden_states)) + + if not output_all_encoded_layers or checkpoint_activations: + hidden_states = self.FinalLayerNorm(hidden_states) + all_encoder_layers.append((hidden_states)) + return all_encoder_layers + + +#class BertEncoder(nn.Module): +# def __init__(self, config): +# super(BertEncoder, self).__init__() +# layer = BertLayer(config) +# self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) +# +# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): +# all_encoder_layers = [] +# for layer_module in self.layer: +# hidden_states = layer_module(hidden_states, attention_mask) +# if output_all_encoded_layers: +# all_encoder_layers.append(hidden_states) +# if not output_all_encoded_layers: +# all_encoder_layers.append(hidden_states) +# return all_encoder_layers + + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense_act = LinearActivation(config.hidden_size, + config.hidden_size, + act="tanh") + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense_act(first_token_tensor) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense_act = LinearActivation(config.hidden_size, + config.hidden_size, + act=config.hidden_act) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense_act(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(bert_model_embedding_weights.size(1), + bert_model_embedding_weights.size(0), + bias=False) + self.decoder.weight = bert_model_embedding_weights + self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + torch.cuda.nvtx.range_push( + "decoder input.size() = {}, weight.size() = {}".format( + hidden_states.size(), + self.decoder.weight.size())) + hidden_states = self.decoder(hidden_states) + self.bias + torch.cuda.nvtx.range_pop() + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super(BertOnlyNSPHead, self).__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super(BertPreTrainedModel, self).__init__() + if not isinstance(config, BertConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, + self.__class__.__name__)) + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path, + state_dict=None, + cache_dir=None, + from_tf=False, + *inputs, + **kwargs): + """ + Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name_or_path: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-large-cased` + . `bert-base-multilingual-uncased` + . `bert-base-multilingual-cased` + . `bert-base-chinese` + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + from_tf: should we load the weights from a locally saved TensorFlow checkpoint + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of Google pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + archive_file = pretrained_model_name_or_path + if resolved_archive_file == archive_file: + logger.info("loading archive file {}".format(archive_file)) + else: + logger.info("loading archive file {} from cache at {}".format( + archive_file, + resolved_archive_file)) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + else: + # Extract archive to temp dir + tempdir = tempfile.mkdtemp() + logger.info("extracting archive file {} to temp dir {}".format( + resolved_archive_file, + tempdir)) + with tarfile.open(resolved_archive_file, 'r:gz') as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + # Load config + config_file = os.path.join(serialization_dir, CONFIG_NAME) + config = BertConfig.from_json_file(config_file) + logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load( + weights_path, + map_location='cpu' if not torch.cuda.is_available() else None) + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + if from_tf: + # Directly load from a TensorFlow checkpoint + weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) + return load_tf_weights_in_bert(model, weights_path) + # Load from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict(state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + start_prefix = '' + if not hasattr(model, + 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): + start_prefix = 'bert.' + load(model, prefix=start_prefix) + if len(missing_keys) > 0: + logger.info("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, + missing_keys)) + if len(unexpected_keys) > 0: + logger.info("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, + unexpected_keys)) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, + "\n\t".join(error_msgs))) + return model + + +class BertModel(BertPreTrainedModel): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + Params: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controlled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLS`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = modeling.BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertModel, self).__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + output_all_encoded_layers=True, + checkpoint_activations=False): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next( + self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder( + embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers, + checkpoint_activations=checkpoint_activations) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + +class BertForPreTraining(BertPreTrainedModel): + """BERT model with pre-training heads. + This module comprises the BERT model followed by the two pre-training heads: + - the masked language modeling head, and + - the next sentence classification head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., vocab_size] + `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] + with indices selected in [0, 1]. + 0 => next sentence is the continuation, 1 => next sentence is a random sentence. + + Outputs: + if `masked_lm_labels` and `next_sentence_label` are not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `masked_lm_labels` or `next_sentence_label` is `None`: + Outputs a tuple comprising + - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and + - the next sentence classification logits of shape [batch_size, 2]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForPreTraining(config) + masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, args): + super(BertForPreTraining, self).__init__(config) + self.summary_writer = None + if dist.get_rank() == 0: + self.summary_writer = args.summary_writer + self.samples_per_step = dist.get_world_size() * args.train_batch_size + self.sample_count = self.samples_per_step + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config, + self.bert.embeddings.word_embeddings.weight) + self.apply(self.init_bert_weights) + + def log_summary_writer(self, logs: dict, base='Train'): + if dist.get_rank() == 0: + module_name = "Samples" #self._batch_module_name.get(batch_type, self._get_batch_type_error(batch_type)) + for key, log in logs.items(): + self.summary_writer.add_scalar(f'{base}/{module_name}/{key}', + log, + self.sample_count) + self.sample_count += self.samples_per_step + + def forward(self, batch, log=True): + #input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, checkpoint_activations=False): + input_ids = batch[1] + token_type_ids = batch[3] + attention_mask = batch[2] + masked_lm_labels = batch[5] + next_sentence_label = batch[4] + checkpoint_activations = False + + sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations) + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + if masked_lm_labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct(prediction_scores.view(-1, + self.config.vocab_size), + masked_lm_labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, + 2), + next_sentence_label.view(-1)) + #print("loss is {} {}".format(masked_lm_loss, next_sentence_loss)) + total_loss = masked_lm_loss + next_sentence_loss + # if log: + # self.log_summary_writer(logs={'train_loss': total_loss.item()}) + return total_loss + else: + return prediction_scores, seq_relationship_score + + +class BertForMaskedLM(BertPreTrainedModel): + """BERT model with the masked language modeling head. + This module comprises the BERT model followed by the masked language modeling head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., vocab_size] + + Outputs: + if `masked_lm_labels` is not `None`: + Outputs the masked language modeling loss. + if `masked_lm_labels` is `None`: + Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForMaskedLM(config) + masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForMaskedLM, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + checkpoint_activations=False): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False) + prediction_scores = self.cls(sequence_output) + + if masked_lm_labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct(prediction_scores.view(-1, + self.config.vocab_size), + masked_lm_labels.view(-1)) + return masked_lm_loss + else: + return prediction_scores + + +class BertForNextSentencePrediction(BertPreTrainedModel): + """BERT model with next sentence prediction head. + This module comprises the BERT model followed by the next sentence classification head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] + with indices selected in [0, 1]. + 0 => next sentence is the continuation, 1 => next sentence is a random sentence. + + Outputs: + if `next_sentence_label` is not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `next_sentence_label` is `None`: + Outputs the next sentence classification logits of shape [batch_size, 2]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForNextSentencePrediction(config) + seq_relationship_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForNextSentencePrediction, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + next_sentence_label=None, + checkpoint_activations=False): + _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False) + seq_relationship_score = self.cls(pooled_output) + + if next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, + 2), + next_sentence_label.view(-1)) + return next_sentence_loss + else: + return seq_relationship_score + + +class BertForSequenceClassification(BertPreTrainedModel): + """BERT model for classification. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_labels]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_labels = 2 + + model = BertForSequenceClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_labels): + super(BertForSequenceClassification, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, num_labels) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + checkpoint_activations=False): + _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return loss + else: + return logits + + +class BertForMultipleChoice(BertPreTrainedModel): + """BERT model for multiple choice tasks. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_choices`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` + and type 1 corresponds to a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_choices]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) + input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) + token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_choices = 2 + + model = BertForMultipleChoice(config, num_choices) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_choices): + super(BertForMultipleChoice, self).__init__(config) + self.num_choices = num_choices + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + checkpoint_activations=False): + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) + _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, self.num_choices) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + return loss + else: + return reshaped_logits + + +class BertForTokenClassification(BertPreTrainedModel): + """BERT model for token-level classification. + This module is composed of the BERT model with a linear layer on top of + the full hidden state of the last layer. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [0, ..., num_labels]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_labels = 2 + + model = BertForTokenClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_labels): + super(BertForTokenClassification, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, num_labels) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + checkpoint_activations=False): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels)[active_loss] + active_labels = labels.view(-1)[active_loss] + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return loss + else: + return logits + + +class BertForQuestionAnswering(BertPreTrainedModel): + """BERT model for Question Answering (span extraction). + This module is composed of the BERT model with a linear layer on top of + the sequence output that computes start_logits and end_logits + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + + Outputs: + if `start_positions` and `end_positions` are not `None`: + Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. + if `start_positions` or `end_positions` is `None`: + Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end + position tokens of shape [batch_size, sequence_length]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForQuestionAnswering(config) + start_logits, end_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForQuestionAnswering, self).__init__(config) + self.bert = BertModel(config) + # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version + # self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + start_positions=None, + end_positions=None, + checkpoint_activations=False): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + return total_loss + else: + return start_logits, end_logits diff --git a/tests/unit/test_aio.py b/tests/unit/test_aio.py index daa633fe5..48272bade 100755 --- a/tests/unit/test_aio.py +++ b/tests/unit/test_aio.py @@ -1,335 +1,335 @@ -import pytest -import os -import filecmp -import torch -import deepspeed -import torch.distributed as dist -from common import distributed_test -from deepspeed.ops.aio import AsyncIOBuilder - -MEGA_BYTE = 1024**2 -BLOCK_SIZE = MEGA_BYTE -QUEUE_DEPTH = 2 -IO_SIZE = 16 * MEGA_BYTE -IO_PARALLEL = 2 - - -def _skip_if_no_aio(): - if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: - pytest.skip('Skip tests since async-io is not compatible') - - -def _do_ref_write(tmpdir, index=0): - file_suffix = f'{dist.get_rank()}_{index}' - ref_file = os.path.join(tmpdir, f'_py_random_{file_suffix}.pt') - ref_buffer = os.urandom(IO_SIZE) - with open(ref_file, 'wb') as f: - f.write(ref_buffer) - - return ref_file, ref_buffer - - -def _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device, index=0): - file_suffix = f'{dist.get_rank()}_{index}' - test_file = os.path.join(tmpdir, f'_aio_write_random_{file_suffix}.pt') - if cuda_device: - test_buffer = torch.cuda.ByteTensor(list(ref_buffer)) - else: - test_buffer = torch.ByteTensor(list(ref_buffer)).pin_memory() - - return test_file, test_buffer - - -def _validate_handle_state(handle, single_submit, overlap_events): - assert handle.get_single_submit() == single_submit - assert handle.get_overlap_events() == overlap_events - assert handle.get_thread_count() == IO_PARALLEL - assert handle.get_block_size() == BLOCK_SIZE - assert handle.get_queue_depth() == QUEUE_DEPTH - - -@pytest.mark.parametrize('single_submit, overlap_events', - [(False, - False), - (False, - True), - (True, - False), - (True, - True)]) -def test_parallel_read(tmpdir, single_submit, overlap_events): - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_parallel_read(single_submit, overlap_events): - ref_file, _ = _do_ref_write(tmpdir) - - aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory() - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - read_status = h.sync_pread(aio_buffer, ref_file) - assert read_status == 1 - - with open(ref_file, 'rb') as f: - ref_buffer = list(f.read()) - assert ref_buffer == aio_buffer.tolist() - - _test_parallel_read(single_submit, overlap_events) - - -@pytest.mark.parametrize('single_submit, overlap_events, cuda_device', - [(False, - False, - False), - (False, - True, - False), - (True, - False, - False), - (True, - True, - False), - (False, - False, - True), - (True, - True, - True)]) -def test_async_read(tmpdir, single_submit, overlap_events, cuda_device): - - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_async_read(single_submit, overlap_events, cuda_device): - ref_file, _ = _do_ref_write(tmpdir) - - if cuda_device: - aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda') - else: - aio_buffer = torch.empty(IO_SIZE, - dtype=torch.uint8, - device='cpu').pin_memory() - - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - read_status = h.async_pread(aio_buffer, ref_file) - assert read_status == 0 - - wait_status = h.wait() - assert wait_status == 1 - - with open(ref_file, 'rb') as f: - ref_buffer = list(f.read()) - assert ref_buffer == aio_buffer.tolist() - - _test_async_read(single_submit, overlap_events, cuda_device) - - -@pytest.mark.parametrize('single_submit, overlap_events', - [(False, - False), - (False, - True), - (True, - False), - (True, - True)]) -def test_parallel_write(tmpdir, single_submit, overlap_events): - - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_parallel_write(single_submit, overlap_events): - ref_file, ref_buffer = _do_ref_write(tmpdir) - - aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, False) - - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - write_status = h.sync_pwrite(aio_buffer, aio_file) - assert write_status == 1 - - assert os.path.isfile(aio_file) - - filecmp.clear_cache() - assert filecmp.cmp(ref_file, aio_file, shallow=False) - - _test_parallel_write(single_submit, overlap_events) - - -@pytest.mark.parametrize('single_submit, overlap_events, cuda_device', - [(False, - False, - False), - (False, - True, - False), - (True, - False, - False), - (True, - True, - False), - (False, - False, - True), - (True, - True, - True)]) -def test_async_write(tmpdir, single_submit, overlap_events, cuda_device): - - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_async_write(single_submit, overlap_events, cuda_device): - ref_file, ref_buffer = _do_ref_write(tmpdir) - - aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device) - - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - write_status = h.async_pwrite(aio_buffer, aio_file) - assert write_status == 0 - - wait_status = h.wait() - assert wait_status == 1 - - assert os.path.isfile(aio_file) - - filecmp.clear_cache() - assert filecmp.cmp(ref_file, aio_file, shallow=False) - - _test_async_write(single_submit, overlap_events, cuda_device) - - -@pytest.mark.parametrize('async_queue, cuda_device', - [(2, - False), - (4, - False), - (2, - True), - (4, - True)]) -def test_async_queue_read(tmpdir, async_queue, cuda_device): - - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_async_queue_read(async_queue, cuda_device): - ref_files = [] - for i in range(async_queue): - f, _ = _do_ref_write(tmpdir, i) - ref_files.append(f) - - aio_buffers = [] - for i in range(async_queue): - if cuda_device: - buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda') - else: - buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory() - aio_buffers.append(buf) - - single_submit = True - overlap_events = True - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - for i in range(async_queue): - read_status = h.async_pread(aio_buffers[i], ref_files[i]) - assert read_status == 0 - - wait_status = h.wait() - assert wait_status == async_queue - - for i in range(async_queue): - with open(ref_files[i], 'rb') as f: - ref_buffer = list(f.read()) - assert ref_buffer == aio_buffers[i].tolist() - - _test_async_queue_read(async_queue, cuda_device) - - -@pytest.mark.parametrize('async_queue, cuda_device', - [(2, - False), - (7, - False), - (2, - True), - (7, - True)]) -def test_async_queue_write(tmpdir, async_queue, cuda_device): - - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_async_queue_write(async_queue, cuda_device): - ref_files = [] - ref_buffers = [] - for i in range(async_queue): - f, buf = _do_ref_write(tmpdir, i) - ref_files.append(f) - ref_buffers.append(buf) - - aio_files = [] - aio_buffers = [] - for i in range(async_queue): - f, buf = _get_test_file_and_buffer(tmpdir, ref_buffers[i], cuda_device, i) - aio_files.append(f) - aio_buffers.append(buf) - - single_submit = True - overlap_events = True - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - for i in range(async_queue): - read_status = h.async_pwrite(aio_buffers[i], aio_files[i]) - assert read_status == 0 - - wait_status = h.wait() - assert wait_status == async_queue - - for i in range(async_queue): - assert os.path.isfile(aio_files[i]) - - filecmp.clear_cache() - assert filecmp.cmp(ref_files[i], aio_files[i], shallow=False) - - _test_async_queue_write(async_queue, cuda_device) +import pytest +import os +import filecmp +import torch +import deepspeed +import torch.distributed as dist +from common import distributed_test +from deepspeed.ops.aio import AsyncIOBuilder + +MEGA_BYTE = 1024**2 +BLOCK_SIZE = MEGA_BYTE +QUEUE_DEPTH = 2 +IO_SIZE = 16 * MEGA_BYTE +IO_PARALLEL = 2 + + +def _skip_if_no_aio(): + if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + + +def _do_ref_write(tmpdir, index=0): + file_suffix = f'{dist.get_rank()}_{index}' + ref_file = os.path.join(tmpdir, f'_py_random_{file_suffix}.pt') + ref_buffer = os.urandom(IO_SIZE) + with open(ref_file, 'wb') as f: + f.write(ref_buffer) + + return ref_file, ref_buffer + + +def _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device, index=0): + file_suffix = f'{dist.get_rank()}_{index}' + test_file = os.path.join(tmpdir, f'_aio_write_random_{file_suffix}.pt') + if cuda_device: + test_buffer = torch.cuda.ByteTensor(list(ref_buffer)) + else: + test_buffer = torch.ByteTensor(list(ref_buffer)).pin_memory() + + return test_file, test_buffer + + +def _validate_handle_state(handle, single_submit, overlap_events): + assert handle.get_single_submit() == single_submit + assert handle.get_overlap_events() == overlap_events + assert handle.get_thread_count() == IO_PARALLEL + assert handle.get_block_size() == BLOCK_SIZE + assert handle.get_queue_depth() == QUEUE_DEPTH + + +@pytest.mark.parametrize('single_submit, overlap_events', + [(False, + False), + (False, + True), + (True, + False), + (True, + True)]) +def test_parallel_read(tmpdir, single_submit, overlap_events): + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_parallel_read(single_submit, overlap_events): + ref_file, _ = _do_ref_write(tmpdir) + + aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory() + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + read_status = h.sync_pread(aio_buffer, ref_file) + assert read_status == 1 + + with open(ref_file, 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == aio_buffer.tolist() + + _test_parallel_read(single_submit, overlap_events) + + +@pytest.mark.parametrize('single_submit, overlap_events, cuda_device', + [(False, + False, + False), + (False, + True, + False), + (True, + False, + False), + (True, + True, + False), + (False, + False, + True), + (True, + True, + True)]) +def test_async_read(tmpdir, single_submit, overlap_events, cuda_device): + + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_async_read(single_submit, overlap_events, cuda_device): + ref_file, _ = _do_ref_write(tmpdir) + + if cuda_device: + aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda') + else: + aio_buffer = torch.empty(IO_SIZE, + dtype=torch.uint8, + device='cpu').pin_memory() + + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + read_status = h.async_pread(aio_buffer, ref_file) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == 1 + + with open(ref_file, 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == aio_buffer.tolist() + + _test_async_read(single_submit, overlap_events, cuda_device) + + +@pytest.mark.parametrize('single_submit, overlap_events', + [(False, + False), + (False, + True), + (True, + False), + (True, + True)]) +def test_parallel_write(tmpdir, single_submit, overlap_events): + + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_parallel_write(single_submit, overlap_events): + ref_file, ref_buffer = _do_ref_write(tmpdir) + + aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, False) + + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + write_status = h.sync_pwrite(aio_buffer, aio_file) + assert write_status == 1 + + assert os.path.isfile(aio_file) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, aio_file, shallow=False) + + _test_parallel_write(single_submit, overlap_events) + + +@pytest.mark.parametrize('single_submit, overlap_events, cuda_device', + [(False, + False, + False), + (False, + True, + False), + (True, + False, + False), + (True, + True, + False), + (False, + False, + True), + (True, + True, + True)]) +def test_async_write(tmpdir, single_submit, overlap_events, cuda_device): + + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_async_write(single_submit, overlap_events, cuda_device): + ref_file, ref_buffer = _do_ref_write(tmpdir) + + aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device) + + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + write_status = h.async_pwrite(aio_buffer, aio_file) + assert write_status == 0 + + wait_status = h.wait() + assert wait_status == 1 + + assert os.path.isfile(aio_file) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, aio_file, shallow=False) + + _test_async_write(single_submit, overlap_events, cuda_device) + + +@pytest.mark.parametrize('async_queue, cuda_device', + [(2, + False), + (4, + False), + (2, + True), + (4, + True)]) +def test_async_queue_read(tmpdir, async_queue, cuda_device): + + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_async_queue_read(async_queue, cuda_device): + ref_files = [] + for i in range(async_queue): + f, _ = _do_ref_write(tmpdir, i) + ref_files.append(f) + + aio_buffers = [] + for i in range(async_queue): + if cuda_device: + buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda') + else: + buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory() + aio_buffers.append(buf) + + single_submit = True + overlap_events = True + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + for i in range(async_queue): + read_status = h.async_pread(aio_buffers[i], ref_files[i]) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == async_queue + + for i in range(async_queue): + with open(ref_files[i], 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == aio_buffers[i].tolist() + + _test_async_queue_read(async_queue, cuda_device) + + +@pytest.mark.parametrize('async_queue, cuda_device', + [(2, + False), + (7, + False), + (2, + True), + (7, + True)]) +def test_async_queue_write(tmpdir, async_queue, cuda_device): + + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_async_queue_write(async_queue, cuda_device): + ref_files = [] + ref_buffers = [] + for i in range(async_queue): + f, buf = _do_ref_write(tmpdir, i) + ref_files.append(f) + ref_buffers.append(buf) + + aio_files = [] + aio_buffers = [] + for i in range(async_queue): + f, buf = _get_test_file_and_buffer(tmpdir, ref_buffers[i], cuda_device, i) + aio_files.append(f) + aio_buffers.append(buf) + + single_submit = True + overlap_events = True + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + for i in range(async_queue): + read_status = h.async_pwrite(aio_buffers[i], aio_files[i]) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == async_queue + + for i in range(async_queue): + assert os.path.isfile(aio_files[i]) + + filecmp.clear_cache() + assert filecmp.cmp(ref_files[i], aio_files[i], shallow=False) + + _test_async_queue_write(async_queue, cuda_device) diff --git a/tests/unit/test_cpu_adagrad.py b/tests/unit/test_cpu_adagrad.py index b8a025fe0..f2ba26255 100755 --- a/tests/unit/test_cpu_adagrad.py +++ b/tests/unit/test_cpu_adagrad.py @@ -1,125 +1,125 @@ -import torch -import numpy as np -import pytest - -import deepspeed -from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad -from deepspeed.ops.op_builder import CPUAdagradBuilder - -if not deepspeed.ops.__compatible_ops__[CPUAdagradBuilder.NAME]: - pytest.skip("cpu-adagrad is not compatible") - - -def check_equal(first, second, atol=1e-2, verbose=False): - x = first.detach().numpy() - y = second.detach().numpy() - if verbose: - print("x = {}".format(x.flatten())) - print("y = {}".format(y.flatten())) - print('-' * 80) - np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) - - -@pytest.mark.parametrize('model_size', - [ - (64), - (22), - (55), - (127), - (1024), - (1048576), - (30000000), - ]) # yapf: disable -def test_cpu_adagrad_opt(model_size): - device = 'cpu' - rng_state = torch.get_rng_state() - param = torch.nn.Parameter(torch.randn(model_size, device=device)) - torch.set_rng_state(rng_state) - param1 = torch.nn.Parameter(torch.randn(model_size, device=device)) - torch.set_rng_state(rng_state) - - optimizer = DeepSpeedCPUAdagrad([param]) - optimizer1 = torch.optim.Adagrad([param1]) - - for i in range(10): - rng_state = torch.get_rng_state() - param.grad = torch.randn(model_size, device=device) - torch.set_rng_state(rng_state) - param1.grad = torch.randn(model_size, device=device) - optimizer.step() - optimizer1.step() - - check_equal(param, param1, atol=1e-2, verbose=True) - - -@pytest.mark.parametrize('model_size,vocabulary_size,dim', - [ - (16 * 2, 16 * 4, 16), - (16 * 32, 16 * 256, 16), - (16 * 256, 16 * 16384, 16), - ]) # yapf: disable -def test_cpu_adagrad_opt_sparse_embedding(model_size, vocabulary_size, dim): - device = 'cpu' - rng_state = torch.get_rng_state() - - def gen_sparse_grad(vocabulary_size, dim, num_indices, dtype, device): - i = torch.randint(vocabulary_size, - size=(1, - num_indices), - dtype=torch.int64, - device=device) - v = torch.randn(num_indices, dim, dtype=dtype, device=device) - t = torch.sparse_coo_tensor(i, v, (vocabulary_size, dim), device=device) - t = t.coalesce() - new_i = (t.indices().view(-1, - 1).repeat(1, - dim) * dim + - torch.tensor(range(dim))).flatten().unsqueeze(0) - new_v = t.values().flatten() - new_t = torch.sparse_coo_tensor(new_i, - new_v, - (vocabulary_size * dim, - ), - device=device) - new_t = new_t.coalesce() - new_t.requires_grad = False - return new_t - - voc_size = vocabulary_size - dim = dim - num_indices = int(model_size // dim) - dtype = torch.float32 - - param = torch.nn.Parameter(torch.randn((voc_size * dim, - ), - dtype=dtype, - device=device), - requires_grad=True) - torch.set_rng_state(rng_state) - param1 = torch.nn.Parameter(torch.randn((voc_size * dim, - ), - dtype=dtype, - device=device), - requires_grad=True) - torch.set_rng_state(rng_state) - - optimizer = DeepSpeedCPUAdagrad([param]) - optimizer1 = torch.optim.Adagrad([param1]) - - for i in range(10): - torch.set_rng_state(rng_state) - param.grad = gen_sparse_grad(voc_size, - dim, - num_indices, - dtype=dtype, - device=device) - torch.set_rng_state(rng_state) - param1.grad = gen_sparse_grad(voc_size, - dim, - num_indices, - dtype=dtype, - device=device) - optimizer.step() - optimizer1.step() - - check_equal(param, param1, atol=1e-2, verbose=True) +import torch +import numpy as np +import pytest + +import deepspeed +from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad +from deepspeed.ops.op_builder import CPUAdagradBuilder + +if not deepspeed.ops.__compatible_ops__[CPUAdagradBuilder.NAME]: + pytest.skip("cpu-adagrad is not compatible") + + +def check_equal(first, second, atol=1e-2, verbose=False): + x = first.detach().numpy() + y = second.detach().numpy() + if verbose: + print("x = {}".format(x.flatten())) + print("y = {}".format(y.flatten())) + print('-' * 80) + np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) + + +@pytest.mark.parametrize('model_size', + [ + (64), + (22), + (55), + (127), + (1024), + (1048576), + (30000000), + ]) # yapf: disable +def test_cpu_adagrad_opt(model_size): + device = 'cpu' + rng_state = torch.get_rng_state() + param = torch.nn.Parameter(torch.randn(model_size, device=device)) + torch.set_rng_state(rng_state) + param1 = torch.nn.Parameter(torch.randn(model_size, device=device)) + torch.set_rng_state(rng_state) + + optimizer = DeepSpeedCPUAdagrad([param]) + optimizer1 = torch.optim.Adagrad([param1]) + + for i in range(10): + rng_state = torch.get_rng_state() + param.grad = torch.randn(model_size, device=device) + torch.set_rng_state(rng_state) + param1.grad = torch.randn(model_size, device=device) + optimizer.step() + optimizer1.step() + + check_equal(param, param1, atol=1e-2, verbose=True) + + +@pytest.mark.parametrize('model_size,vocabulary_size,dim', + [ + (16 * 2, 16 * 4, 16), + (16 * 32, 16 * 256, 16), + (16 * 256, 16 * 16384, 16), + ]) # yapf: disable +def test_cpu_adagrad_opt_sparse_embedding(model_size, vocabulary_size, dim): + device = 'cpu' + rng_state = torch.get_rng_state() + + def gen_sparse_grad(vocabulary_size, dim, num_indices, dtype, device): + i = torch.randint(vocabulary_size, + size=(1, + num_indices), + dtype=torch.int64, + device=device) + v = torch.randn(num_indices, dim, dtype=dtype, device=device) + t = torch.sparse_coo_tensor(i, v, (vocabulary_size, dim), device=device) + t = t.coalesce() + new_i = (t.indices().view(-1, + 1).repeat(1, + dim) * dim + + torch.tensor(range(dim))).flatten().unsqueeze(0) + new_v = t.values().flatten() + new_t = torch.sparse_coo_tensor(new_i, + new_v, + (vocabulary_size * dim, + ), + device=device) + new_t = new_t.coalesce() + new_t.requires_grad = False + return new_t + + voc_size = vocabulary_size + dim = dim + num_indices = int(model_size // dim) + dtype = torch.float32 + + param = torch.nn.Parameter(torch.randn((voc_size * dim, + ), + dtype=dtype, + device=device), + requires_grad=True) + torch.set_rng_state(rng_state) + param1 = torch.nn.Parameter(torch.randn((voc_size * dim, + ), + dtype=dtype, + device=device), + requires_grad=True) + torch.set_rng_state(rng_state) + + optimizer = DeepSpeedCPUAdagrad([param]) + optimizer1 = torch.optim.Adagrad([param1]) + + for i in range(10): + torch.set_rng_state(rng_state) + param.grad = gen_sparse_grad(voc_size, + dim, + num_indices, + dtype=dtype, + device=device) + torch.set_rng_state(rng_state) + param1.grad = gen_sparse_grad(voc_size, + dim, + num_indices, + dtype=dtype, + device=device) + optimizer.step() + optimizer1.step() + + check_equal(param, param1, atol=1e-2, verbose=True) diff --git a/tests/unit/test_cpu_adam.py b/tests/unit/test_cpu_adam.py index dd5527b01..94453c7e8 100755 --- a/tests/unit/test_cpu_adam.py +++ b/tests/unit/test_cpu_adam.py @@ -1,62 +1,62 @@ -import argparse -import torch -import time -import numpy as np -import pytest -import copy - -import deepspeed -from deepspeed.ops.adam import FusedAdam -from deepspeed.ops.op_builder import CPUAdamBuilder - -if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: - pytest.skip("cpu-adam is not compatible") - - -def check_equal(first, second, atol=1e-2, verbose=False): - x = first.detach().numpy() - y = second.detach().numpy() - if verbose: - print("x = {}".format(x.flatten())) - print("y = {}".format(y.flatten())) - print('-' * 80) - np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) - -@pytest.mark.parametrize('model_size', - [ - (64), - (22), - (55), - (127), - (1024), - (1048576), - ]) # yapf: disable -def test_cpu_adam_opt(model_size): - from deepspeed.ops.adam import DeepSpeedCPUAdam - device = 'cpu' - rng_state = torch.get_rng_state() - param = torch.nn.Parameter(torch.randn(model_size, device=device)) - torch.set_rng_state(rng_state) - param1 = torch.nn.Parameter(torch.randn(model_size, device=device)) - torch.set_rng_state(rng_state) - param2_data = torch.randn(model_size, device=device).cuda() - param2 = torch.nn.Parameter(param2_data) - - optimizer1 = torch.optim.AdamW([param1]) - optimizer2 = FusedAdam([param2]) - optimizer = DeepSpeedCPUAdam([param]) - - for i in range(10): - rng_state = torch.get_rng_state() - param.grad = torch.randn(model_size, device=device) - torch.set_rng_state(rng_state) - param1.grad = torch.randn(model_size, device=device) - torch.set_rng_state(rng_state) - param2.grad = torch.randn(model_size, device=device).cuda() - - optimizer.step() - optimizer2.step() - optimizer1.step() - - check_equal(param, param1, atol=1e-2, verbose=True) - check_equal(param, param2.cpu(), atol=1e-2, verbose=True) +import argparse +import torch +import time +import numpy as np +import pytest +import copy + +import deepspeed +from deepspeed.ops.adam import FusedAdam +from deepspeed.ops.op_builder import CPUAdamBuilder + +if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") + + +def check_equal(first, second, atol=1e-2, verbose=False): + x = first.detach().numpy() + y = second.detach().numpy() + if verbose: + print("x = {}".format(x.flatten())) + print("y = {}".format(y.flatten())) + print('-' * 80) + np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) + +@pytest.mark.parametrize('model_size', + [ + (64), + (22), + (55), + (127), + (1024), + (1048576), + ]) # yapf: disable +def test_cpu_adam_opt(model_size): + from deepspeed.ops.adam import DeepSpeedCPUAdam + device = 'cpu' + rng_state = torch.get_rng_state() + param = torch.nn.Parameter(torch.randn(model_size, device=device)) + torch.set_rng_state(rng_state) + param1 = torch.nn.Parameter(torch.randn(model_size, device=device)) + torch.set_rng_state(rng_state) + param2_data = torch.randn(model_size, device=device).cuda() + param2 = torch.nn.Parameter(param2_data) + + optimizer1 = torch.optim.AdamW([param1]) + optimizer2 = FusedAdam([param2]) + optimizer = DeepSpeedCPUAdam([param]) + + for i in range(10): + rng_state = torch.get_rng_state() + param.grad = torch.randn(model_size, device=device) + torch.set_rng_state(rng_state) + param1.grad = torch.randn(model_size, device=device) + torch.set_rng_state(rng_state) + param2.grad = torch.randn(model_size, device=device).cuda() + + optimizer.step() + optimizer2.step() + optimizer1.step() + + check_equal(param, param1, atol=1e-2, verbose=True) + check_equal(param, param2.cpu(), atol=1e-2, verbose=True) diff --git a/tests/unit/test_onebit.py b/tests/unit/test_onebit.py index 9796a7095..18a0244c3 100644 --- a/tests/unit/test_onebit.py +++ b/tests/unit/test_onebit.py @@ -1,920 +1,920 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.distributed as dist -import deepspeed -import argparse -import pytest -import copy -import json -import os -import numpy as np -import time - -from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology -PipeTopo = PipeDataParallelTopology -from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec -from common import distributed_test -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args -from test_pipe import AlexNetPipe, train_cifar - -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) -if TORCH_MAJOR < 1 or TORCH_MINOR < 8: - pytest.skip("NCCL-based 1-bit compression requires torch 1.8 or higher", - allow_module_level=True) - - -def test_onebitadam_fp16_basic(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[1, 2]) - def _test_onebitadam_fp16_basic(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - _test_onebitadam_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitadam_fp32_basic(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[1, 2]) - def _test_onebitadam_fp32_basic(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device, - dtype=torch.float) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - _test_onebitadam_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitadam_exp_avg_mask(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - param_optimizer = list(model.named_parameters()) - mask1 = torch.zeros_like(param_optimizer[0][1].data) - for col in range(mask1.size()[1]): - mask1[0][col] += 1 - mask1 = torch.flatten(mask1) - optimizer_grouped_parameters = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask1 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - @distributed_test(world_size=[2]) - def _test_onebitadam_exp_avg_mask(args, model, hidden_dim): - model, optimizer, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - # Test whether the momentum mask works - for v in optimizer.state.values(): - if v['exp_avg'].size() == mask1.size(): - assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" - - _test_onebitadam_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitadam_checkpointing(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - param_optimizer = list(model.named_parameters()) - mask1 = torch.zeros_like(param_optimizer[0][1].data) - mask2 = torch.zeros_like(param_optimizer[0][1].data) - for col in range(mask1.size()[1]): - mask1[0][col] += 1 - mask2[1][col] += 1 - mask1 = torch.flatten(mask1) - mask2 = torch.flatten(mask2) - - optimizer_grouped_parameters_1 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask1 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - optimizer_grouped_parameters_2 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask2 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - optimizer_grouped_parameters_3 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - @distributed_test(world_size=[2]) - def _test_onebitadam_checkpointing(mask1, mask2, args, model, hidden_dim): - model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_1) - data_loader = random_dataloader(model=model_1, - total_samples=10, - hidden_dim=hidden_dim, - device=model_1.device) - for n, batch in enumerate(data_loader): - loss = model_1(batch[0], batch[1]) - model_1.backward(loss) - model_1.step() - # Test whether momentum mask still exist after saving checkpoint - assert optimizer_1.optimizer.adam_freeze_key is True - mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) - assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" - save_folder = os.path.join(tmpdir, 'saved_checkpoint') - model_1.save_checkpoint(save_folder, tag=None) - assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" - - - model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_2) - # Test whether momentum mask stays the same after loading checkpoint - mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) - assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" - model_2.load_checkpoint(save_folder, - tag=None, - load_optimizer_states=True, - load_lr_scheduler_states=True) - assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" - # Test whether worker&server error is resetted - for v in optimizer_2.state.values(): - assert 'worker_error' not in v, f"Incorrect worker error" - assert 'server_error' not in v, f"Incorrect server error" - assert optimizer_2.optimizer.adam_freeze_key is True - - model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_3) - optimizer_3.optimizer.freeze_step = 20 - data_loader = random_dataloader(model=model_3, - total_samples=50, - hidden_dim=hidden_dim, - device=model_3.device) - for n, batch in enumerate(data_loader): - loss = model_3(batch[0], batch[1]) - model_3.backward(loss) - model_3.step() - assert optimizer_3.optimizer.adam_freeze_key is True - # Test whether momentum mask stays the same after loading checkpoint - assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" - model_3.load_checkpoint(save_folder, - tag=None, - load_optimizer_states=True, - load_lr_scheduler_states=True) - assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" - # Test whether worker&server error is resetted - for v in optimizer_3.state.values(): - assert 'worker_error' not in v, f"Incorrect worker error" - assert 'server_error' not in v, f"Incorrect server error" - assert optimizer_3.optimizer.adam_freeze_key is False - - _test_onebitadam_checkpointing(mask1, - mask2, - args=args, - model=model, - hidden_dim=hidden_dim) - - -def test_onebitadam_checkpointing_overflow(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[2]) - def _test_onebitadam_checkpointing_overflow(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=100, - hidden_dim=hidden_dim, - device=model.device) - save_folder = os.path.join(tmpdir, 'saved_checkpoint') - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - if dist.get_rank() == 0 and n >= 10: - loss = loss * 1000000.0 - model.backward(loss) - dist.barrier() - model.step() - dist.barrier() - model.save_checkpoint(save_folder, tag=None) - - _test_onebitadam_checkpointing_overflow(args=args, - model=model, - hidden_dim=hidden_dim) - - -@pytest.mark.parametrize('topo', - [ - PipeTopo(num_pp=1, - num_dp=4), - PipeTopo(num_pp=2, - num_dp=2), - PipeTopo(num_pp=4, - num_dp=1), - ]) -def test_onebitadam_fp16_pipeline(topo, tmpdir): - config_dict = { - "train_batch_size": 16, - "train_micro_batch_size_per_gpu": 4, - "steps_per_print": 20, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00001, - "betas": [0.9, - 0.999], - "eps": 1e-8, - "weight_decay": 3e-7, - "freeze_step": 200, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "zero_optimization": { - "stage": 0 - }, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - }, - "pipeline": { - "seed_layers": True, - "activation_checkpoint_interval": 1 - } - } - args = args_from_dict(tmpdir, config_dict) - - # Allocate model for consistent initial weights. - init_net = AlexNetPipe() - - @distributed_test(world_size=4) - def _helper(topo, tmpdir, steps=500): - assert steps >= 100 - - test_net = copy.deepcopy(init_net) - test_model = PipelineModule(layers=test_net.to_layers(), - topology=topo, - loss_fn=nn.CrossEntropyLoss()) - - test_losses = train_cifar(test_model, - args, - num_steps=steps, - fp16=config_dict['fp16']['enabled']) - - _helper(topo, tmpdir) - - -def test_onebitlamb_fp16_basic(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "max_coeff": 0.3, - "min_coeff": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl", - "coeff_beta": 0.9, - "factor_max": 1.0, - "factor_min": 0.5, - "factor_threshold": 0.1 - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[1, 2]) - def _test_onebitlamb_fp16_basic(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - _test_onebitlamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitlamb_fp32_basic(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "max_coeff": 0.3, - "min_coeff": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl", - "coeff_beta": 0.9, - "factor_max": 1.0, - "factor_min": 0.5, - "factor_threshold": 0.1 - } - }, - "gradient_clipping": 1.0, - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[1, 2]) - def _test_onebitlamb_fp32_basic(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device, - dtype=torch.float) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - _test_onebitlamb_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitlamb_exp_avg_mask(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "max_coeff": 0.3, - "min_coeff": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl", - "coeff_beta": 0.9, - "factor_max": 1.0, - "factor_min": 0.5, - "factor_threshold": 0.1 - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - param_optimizer = list(model.named_parameters()) - mask1 = torch.zeros_like(param_optimizer[0][1].data) - for col in range(mask1.size()[1]): - mask1[0][col] += 1 - optimizer_grouped_parameters = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask1 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - @distributed_test(world_size=[2]) - def _test_onebitlamb_exp_avg_mask(args, model, hidden_dim): - model, optimizer, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - # Test whether the momentum mask works - for v in optimizer.state.values(): - if v['exp_avg'].size() == mask1.size(): - assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" - - _test_onebitlamb_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitlamb_checkpointing(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "max_coeff": 0.3, - "min_coeff": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl", - "coeff_beta": 0.9, - "factor_max": 1.0, - "factor_min": 0.5, - "factor_threshold": 0.1 - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - param_optimizer = list(model.named_parameters()) - mask1 = torch.zeros_like(param_optimizer[0][1].data) - mask2 = torch.zeros_like(param_optimizer[0][1].data) - for col in range(mask1.size()[1]): - mask1[0][col] += 1 - mask2[1][col] += 1 - - optimizer_grouped_parameters_1 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask1 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - optimizer_grouped_parameters_2 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask2 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - optimizer_grouped_parameters_3 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - @distributed_test(world_size=[2]) - def _test_onebitlamb_checkpointing(mask1, mask2, args, model, hidden_dim): - model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_1) - data_loader = random_dataloader(model=model_1, - total_samples=10, - hidden_dim=hidden_dim, - device=model_1.device) - for n, batch in enumerate(data_loader): - loss = model_1(batch[0], batch[1]) - model_1.backward(loss) - model_1.step() - # Test whether momentum mask still exist after saving checkpoint - assert optimizer_1.optimizer.lamb_freeze_key is True - mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) - assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" - scaling_coeff_1 = [] - for v in optimizer_1.state.values(): - assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" - scaling_coeff_1.append(v['scaling_coeff']) - save_folder = os.path.join(tmpdir, 'saved_checkpoint') - model_1.save_checkpoint(save_folder, tag=None) - assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" - - - model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_2) - # Test whether momentum mask stays the same after loading checkpoint - mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) - assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" - model_2.load_checkpoint(save_folder, - tag=None, - load_optimizer_states=True, - load_lr_scheduler_states=True) - assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" - # Test whether worker&server error is resetted - assert len(optimizer_2.optimizer.worker_errors) == 0, f"Incorrect worker error" - assert len(optimizer_2.optimizer.server_errors) == 0, f"Incorrect server error" - # Test whether scaling_coeffs is loaded correctly - scaling_coeff_2 = [] - for v in optimizer_2.state.values(): - assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" - scaling_coeff_2.append(v['scaling_coeff']) - assert list(sorted(scaling_coeff_2)) == list(sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs" - assert optimizer_2.optimizer.lamb_freeze_key is True - - model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_3) - optimizer_3.optimizer.freeze_step = 20 - data_loader = random_dataloader(model=model_3, - total_samples=50, - hidden_dim=hidden_dim, - device=model_3.device) - for n, batch in enumerate(data_loader): - loss = model_3(batch[0], batch[1]) - model_3.backward(loss) - model_3.step() - assert optimizer_3.optimizer.lamb_freeze_key is True - # Test whether momentum mask stays the same after loading checkpoint - assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" - model_3.load_checkpoint(save_folder, - tag=None, - load_optimizer_states=True, - load_lr_scheduler_states=True) - assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" - # Test whether worker&server error is resetted - assert len(optimizer_3.optimizer.worker_errors) == 0, f"Incorrect worker error" - assert len(optimizer_3.optimizer.server_errors) == 0, f"Incorrect server error" - # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are resetted - for v in optimizer_3.state.values(): - assert v['lamb_coeff_freeze'] == 0.0, f"Incorrect lamb_coeff_freeze" - assert v['last_factor'] == 1.0, f"Incorrect last_factor" - assert 'scaling_coeff' not in v, f"Incorrect scaling_coeff" - assert optimizer_3.optimizer.lamb_freeze_key is False - - _test_onebitlamb_checkpointing(mask1, - mask2, - args=args, - model=model, - hidden_dim=hidden_dim) - - -def test_onebitlamb_checkpointing_overflow(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "max_coeff": 0.3, - "min_coeff": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl", - "coeff_beta": 0.9, - "factor_max": 1.0, - "factor_min": 0.5, - "factor_threshold": 0.1 - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[2]) - def _test_onebitlamb_checkpointing_overflow(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=100, - hidden_dim=hidden_dim, - device=model.device) - save_folder = os.path.join(tmpdir, 'saved_checkpoint') - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - if dist.get_rank() == 0 and n >= 10: - loss = loss * 1000000.0 - model.backward(loss) - dist.barrier() - model.step() - dist.barrier() - model.save_checkpoint(save_folder, tag=None) - - _test_onebitlamb_checkpointing_overflow(args=args, - model=model, - hidden_dim=hidden_dim) - - -@pytest.mark.parametrize('topo', - [ - PipeTopo(num_pp=1, - num_dp=4), - PipeTopo(num_pp=2, - num_dp=2), - PipeTopo(num_pp=4, - num_dp=1), - ]) -def test_onebitlamb_fp16_pipeline(topo, tmpdir): - config_dict = { - "train_batch_size": 16, - "train_micro_batch_size_per_gpu": 4, - "steps_per_print": 20, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00001, - "betas": [0.9, - 0.999], - "eps": 1e-8, - "weight_decay": 3e-7, - "freeze_step": 200, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "zero_optimization": { - "stage": 0 - }, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - }, - "pipeline": { - "seed_layers": True, - "activation_checkpoint_interval": 1 - } - } - args = args_from_dict(tmpdir, config_dict) - - # Allocate model for consistent initial weights. - init_net = AlexNetPipe() - - @distributed_test(world_size=4) - def _helper(topo, tmpdir, steps=500): - assert steps >= 100 - - test_net = copy.deepcopy(init_net) - test_model = PipelineModule(layers=test_net.to_layers(), - topology=topo, - loss_fn=nn.CrossEntropyLoss()) - - test_losses = train_cifar(test_model, - args, - num_steps=steps, - fp16=config_dict['fp16']['enabled']) - - _helper(topo, tmpdir) - - -def test_compressed_allreduce_basic(tmpdir): - @distributed_test(world_size=[1, 2]) - def _test_compressed_allreduce_basic(): - from deepspeed.runtime.comm.nccl import NcclBackend - size = dist.get_world_size() - rank = dist.get_rank() - backend = NcclBackend() - local_rank = dist.get_rank() - device = torch.device("cuda", dist.get_rank()) - - # A simulated compression function using torch.distributed - def torch_sim(a): - a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) - scale = a.norm() / np.sqrt(a.numel()) - a_compressed = scale * a_sign - a_sign = None - worker_error = a - a_compressed - dist.all_reduce(a_compressed) - a_compressed.mul_(1 / dist.get_world_size()) - a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_( - 2.0) - a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) - server_scale = [ - chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list - ] - a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) - a_server_compressed = torch.cat( - [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) - rank = dist.get_rank() - server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] - torch.cuda.synchronize() - torch.distributed.barrier() - return a_server_compressed, worker_error, server_error - - tensor_size = 300 * 2**20 - server_size = int(tensor_size / size) - if tensor_size % (8 * size) != 0: - right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) - else: - right_tensor_size = tensor_size - right_server_size = right_tensor_size // size - - # Adding bias to the initialization of the gradient we are communicating - # In order to get rid of the case where some elements in the gradient are too small - a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank - - worker_error = torch.zeros(right_tensor_size, device=device) - server_error = torch.zeros(right_server_size, device=device) - - a_torch, worker_error_torch, server_error_torch = torch_sim(a) - torch.cuda.empty_cache() - - a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) - - threshold = 1e-6 - magnitude_threshold = 1e-6 - diff_mask = (a_after - a_torch) > threshold - diff_server_mask = torch.chunk(diff_mask, size)[rank] - mpi_server = torch.chunk(a_after, size)[rank] + server_error - torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch - - # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic - # The test would skip those numbers that are too small in compensated_server_m - check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold - if torch.sum(check_mag_mask) != 0: - print('Fails at {} of positions'.format(torch.sum(check_mag_mask))) - assert torch.sum(diff_server_mask) == 0 or torch.sum(check_mag_mask) == 0 - - _test_compressed_allreduce_basic() +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +import deepspeed +import argparse +import pytest +import copy +import json +import os +import numpy as np +import time + +from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology +PipeTopo = PipeDataParallelTopology +from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec +from common import distributed_test +from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args +from test_pipe import AlexNetPipe, train_cifar + +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) +if TORCH_MAJOR < 1 or TORCH_MINOR < 8: + pytest.skip("NCCL-based 1-bit compression requires torch 1.8 or higher", + allow_module_level=True) + + +def test_onebitadam_fp16_basic(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_onebitadam_fp16_basic(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_onebitadam_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitadam_fp32_basic(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_onebitadam_fp32_basic(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_onebitadam_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitadam_exp_avg_mask(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + mask1 = torch.flatten(mask1) + optimizer_grouped_parameters = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitadam_exp_avg_mask(args, model, hidden_dim): + model, optimizer, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + # Test whether the momentum mask works + for v in optimizer.state.values(): + if v['exp_avg'].size() == mask1.size(): + assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" + + _test_onebitadam_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitadam_checkpointing(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + mask2 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + mask2[1][col] += 1 + mask1 = torch.flatten(mask1) + mask2 = torch.flatten(mask2) + + optimizer_grouped_parameters_1 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_2 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask2 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_3 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitadam_checkpointing(mask1, mask2, args, model, hidden_dim): + model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_1) + data_loader = random_dataloader(model=model_1, + total_samples=10, + hidden_dim=hidden_dim, + device=model_1.device) + for n, batch in enumerate(data_loader): + loss = model_1(batch[0], batch[1]) + model_1.backward(loss) + model_1.step() + # Test whether momentum mask still exist after saving checkpoint + assert optimizer_1.optimizer.adam_freeze_key is True + mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + model_1.save_checkpoint(save_folder, tag=None) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" + + + model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_2) + # Test whether momentum mask stays the same after loading checkpoint + mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" + model_2.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + for v in optimizer_2.state.values(): + assert 'worker_error' not in v, f"Incorrect worker error" + assert 'server_error' not in v, f"Incorrect server error" + assert optimizer_2.optimizer.adam_freeze_key is True + + model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_3) + optimizer_3.optimizer.freeze_step = 20 + data_loader = random_dataloader(model=model_3, + total_samples=50, + hidden_dim=hidden_dim, + device=model_3.device) + for n, batch in enumerate(data_loader): + loss = model_3(batch[0], batch[1]) + model_3.backward(loss) + model_3.step() + assert optimizer_3.optimizer.adam_freeze_key is True + # Test whether momentum mask stays the same after loading checkpoint + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" + model_3.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + for v in optimizer_3.state.values(): + assert 'worker_error' not in v, f"Incorrect worker error" + assert 'server_error' not in v, f"Incorrect server error" + assert optimizer_3.optimizer.adam_freeze_key is False + + _test_onebitadam_checkpointing(mask1, + mask2, + args=args, + model=model, + hidden_dim=hidden_dim) + + +def test_onebitadam_checkpointing_overflow(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[2]) + def _test_onebitadam_checkpointing_overflow(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=100, + hidden_dim=hidden_dim, + device=model.device) + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + if dist.get_rank() == 0 and n >= 10: + loss = loss * 1000000.0 + model.backward(loss) + dist.barrier() + model.step() + dist.barrier() + model.save_checkpoint(save_folder, tag=None) + + _test_onebitadam_checkpointing_overflow(args=args, + model=model, + hidden_dim=hidden_dim) + + +@pytest.mark.parametrize('topo', + [ + PipeTopo(num_pp=1, + num_dp=4), + PipeTopo(num_pp=2, + num_dp=2), + PipeTopo(num_pp=4, + num_dp=1), + ]) +def test_onebitadam_fp16_pipeline(topo, tmpdir): + config_dict = { + "train_batch_size": 16, + "train_micro_batch_size_per_gpu": 4, + "steps_per_print": 20, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00001, + "betas": [0.9, + 0.999], + "eps": 1e-8, + "weight_decay": 3e-7, + "freeze_step": 200, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 0 + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + }, + "pipeline": { + "seed_layers": True, + "activation_checkpoint_interval": 1 + } + } + args = args_from_dict(tmpdir, config_dict) + + # Allocate model for consistent initial weights. + init_net = AlexNetPipe() + + @distributed_test(world_size=4) + def _helper(topo, tmpdir, steps=500): + assert steps >= 100 + + test_net = copy.deepcopy(init_net) + test_model = PipelineModule(layers=test_net.to_layers(), + topology=topo, + loss_fn=nn.CrossEntropyLoss()) + + test_losses = train_cifar(test_model, + args, + num_steps=steps, + fp16=config_dict['fp16']['enabled']) + + _helper(topo, tmpdir) + + +def test_onebitlamb_fp16_basic(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_onebitlamb_fp16_basic(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_onebitlamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitlamb_fp32_basic(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_onebitlamb_fp32_basic(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_onebitlamb_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitlamb_exp_avg_mask(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + optimizer_grouped_parameters = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitlamb_exp_avg_mask(args, model, hidden_dim): + model, optimizer, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + # Test whether the momentum mask works + for v in optimizer.state.values(): + if v['exp_avg'].size() == mask1.size(): + assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" + + _test_onebitlamb_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitlamb_checkpointing(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + mask2 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + mask2[1][col] += 1 + + optimizer_grouped_parameters_1 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_2 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask2 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_3 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitlamb_checkpointing(mask1, mask2, args, model, hidden_dim): + model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_1) + data_loader = random_dataloader(model=model_1, + total_samples=10, + hidden_dim=hidden_dim, + device=model_1.device) + for n, batch in enumerate(data_loader): + loss = model_1(batch[0], batch[1]) + model_1.backward(loss) + model_1.step() + # Test whether momentum mask still exist after saving checkpoint + assert optimizer_1.optimizer.lamb_freeze_key is True + mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" + scaling_coeff_1 = [] + for v in optimizer_1.state.values(): + assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" + scaling_coeff_1.append(v['scaling_coeff']) + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + model_1.save_checkpoint(save_folder, tag=None) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" + + + model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_2) + # Test whether momentum mask stays the same after loading checkpoint + mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" + model_2.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + assert len(optimizer_2.optimizer.worker_errors) == 0, f"Incorrect worker error" + assert len(optimizer_2.optimizer.server_errors) == 0, f"Incorrect server error" + # Test whether scaling_coeffs is loaded correctly + scaling_coeff_2 = [] + for v in optimizer_2.state.values(): + assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" + scaling_coeff_2.append(v['scaling_coeff']) + assert list(sorted(scaling_coeff_2)) == list(sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs" + assert optimizer_2.optimizer.lamb_freeze_key is True + + model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_3) + optimizer_3.optimizer.freeze_step = 20 + data_loader = random_dataloader(model=model_3, + total_samples=50, + hidden_dim=hidden_dim, + device=model_3.device) + for n, batch in enumerate(data_loader): + loss = model_3(batch[0], batch[1]) + model_3.backward(loss) + model_3.step() + assert optimizer_3.optimizer.lamb_freeze_key is True + # Test whether momentum mask stays the same after loading checkpoint + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" + model_3.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + assert len(optimizer_3.optimizer.worker_errors) == 0, f"Incorrect worker error" + assert len(optimizer_3.optimizer.server_errors) == 0, f"Incorrect server error" + # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are resetted + for v in optimizer_3.state.values(): + assert v['lamb_coeff_freeze'] == 0.0, f"Incorrect lamb_coeff_freeze" + assert v['last_factor'] == 1.0, f"Incorrect last_factor" + assert 'scaling_coeff' not in v, f"Incorrect scaling_coeff" + assert optimizer_3.optimizer.lamb_freeze_key is False + + _test_onebitlamb_checkpointing(mask1, + mask2, + args=args, + model=model, + hidden_dim=hidden_dim) + + +def test_onebitlamb_checkpointing_overflow(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[2]) + def _test_onebitlamb_checkpointing_overflow(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=100, + hidden_dim=hidden_dim, + device=model.device) + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + if dist.get_rank() == 0 and n >= 10: + loss = loss * 1000000.0 + model.backward(loss) + dist.barrier() + model.step() + dist.barrier() + model.save_checkpoint(save_folder, tag=None) + + _test_onebitlamb_checkpointing_overflow(args=args, + model=model, + hidden_dim=hidden_dim) + + +@pytest.mark.parametrize('topo', + [ + PipeTopo(num_pp=1, + num_dp=4), + PipeTopo(num_pp=2, + num_dp=2), + PipeTopo(num_pp=4, + num_dp=1), + ]) +def test_onebitlamb_fp16_pipeline(topo, tmpdir): + config_dict = { + "train_batch_size": 16, + "train_micro_batch_size_per_gpu": 4, + "steps_per_print": 20, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00001, + "betas": [0.9, + 0.999], + "eps": 1e-8, + "weight_decay": 3e-7, + "freeze_step": 200, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 0 + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + }, + "pipeline": { + "seed_layers": True, + "activation_checkpoint_interval": 1 + } + } + args = args_from_dict(tmpdir, config_dict) + + # Allocate model for consistent initial weights. + init_net = AlexNetPipe() + + @distributed_test(world_size=4) + def _helper(topo, tmpdir, steps=500): + assert steps >= 100 + + test_net = copy.deepcopy(init_net) + test_model = PipelineModule(layers=test_net.to_layers(), + topology=topo, + loss_fn=nn.CrossEntropyLoss()) + + test_losses = train_cifar(test_model, + args, + num_steps=steps, + fp16=config_dict['fp16']['enabled']) + + _helper(topo, tmpdir) + + +def test_compressed_allreduce_basic(tmpdir): + @distributed_test(world_size=[1, 2]) + def _test_compressed_allreduce_basic(): + from deepspeed.runtime.comm.nccl import NcclBackend + size = dist.get_world_size() + rank = dist.get_rank() + backend = NcclBackend() + local_rank = dist.get_rank() + device = torch.device("cuda", dist.get_rank()) + + # A simulated compression function using torch.distributed + def torch_sim(a): + a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + scale = a.norm() / np.sqrt(a.numel()) + a_compressed = scale * a_sign + a_sign = None + worker_error = a - a_compressed + dist.all_reduce(a_compressed) + a_compressed.mul_(1 / dist.get_world_size()) + a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_( + 2.0) + a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) + server_scale = [ + chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list + ] + a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) + a_server_compressed = torch.cat( + [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) + rank = dist.get_rank() + server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] + torch.cuda.synchronize() + torch.distributed.barrier() + return a_server_compressed, worker_error, server_error + + tensor_size = 300 * 2**20 + server_size = int(tensor_size / size) + if tensor_size % (8 * size) != 0: + right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) + else: + right_tensor_size = tensor_size + right_server_size = right_tensor_size // size + + # Adding bias to the initialization of the gradient we are communicating + # In order to get rid of the case where some elements in the gradient are too small + a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank + + worker_error = torch.zeros(right_tensor_size, device=device) + server_error = torch.zeros(right_server_size, device=device) + + a_torch, worker_error_torch, server_error_torch = torch_sim(a) + torch.cuda.empty_cache() + + a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) + + threshold = 1e-6 + magnitude_threshold = 1e-6 + diff_mask = (a_after - a_torch) > threshold + diff_server_mask = torch.chunk(diff_mask, size)[rank] + mpi_server = torch.chunk(a_after, size)[rank] + server_error + torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch + + # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic + # The test would skip those numbers that are too small in compensated_server_m + check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold + if torch.sum(check_mag_mask) != 0: + print('Fails at {} of positions'.format(torch.sum(check_mag_mask))) + assert torch.sum(diff_server_mask) == 0 or torch.sum(check_mag_mask) == 0 + + _test_compressed_allreduce_basic() diff --git a/tests/unit/test_pld.py b/tests/unit/test_pld.py index 784aeff03..d8fa8488f 100755 --- a/tests/unit/test_pld.py +++ b/tests/unit/test_pld.py @@ -1,117 +1,117 @@ -import numpy as np -import deepspeed -import pytest -from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop -from common import distributed_test -from simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict - - -@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) -def test_pld_schedule(tmpdir, theta): - gamma = 0.001 - - pld_scheduler = ProgressiveLayerDrop(theta, gamma) - for i in range(10): - pld_scheduler.update_state(i) - expected_theta = (1. - theta) * np.exp(-gamma * i) + theta - actual_theta = pld_scheduler.get_theta() - assert expected_theta == actual_theta - - -@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) -def test_pld_model(tmpdir, theta): - gamma = 0.001 - config_dict = { - "train_batch_size": 1, - "steps_per_print": 1, - "optimizer": { - "type": 'Adam', - "params": { - "lr": 0.0001 - } - }, - "fp16": { - "enabled": True - }, - "progressive_layer_drop": { - "enabled": True, - "theta": theta, - "gamma": gamma - } - } - - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = PLD_SimpleModel(hidden_dim, empty_grad=False) - - @distributed_test(world_size=[1]) - def _test_pld_model(args, model, hidden_dim, theta, gamma): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device) - - for i, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - expected_theta = (1. - theta) * np.exp(-gamma * i) + theta - actual_theta = model.get_pld_theta() - assert expected_theta == actual_theta - - _test_pld_model(args=args, - model=model, - hidden_dim=hidden_dim, - theta=theta, - gamma=gamma) - - -def test_non_pld_model(tmpdir): - gamma = 0.001 - theta = 0.5 - config_dict = { - "train_batch_size": 1, - "steps_per_print": 1, - "optimizer": { - "type": 'Adam', - "params": { - "lr": 0.0001 - } - }, - "fp16": { - "enabled": True - }, - "progressive_layer_drop": { - "enabled": True, - "theta": theta, - "gamma": gamma - } - } - - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim, empty_grad=False) - - @distributed_test(world_size=[1]) - def _test_non_pld_model(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - - data_loader = random_dataloader(model=model, - total_samples=1, - hidden_dim=hidden_dim, - device=model.device) - - for i, batch in enumerate(data_loader): - with pytest.raises(TypeError): - loss = model(batch[0], batch[1]) - - _test_non_pld_model(args=args, model=model, hidden_dim=hidden_dim) +import numpy as np +import deepspeed +import pytest +from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop +from common import distributed_test +from simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict + + +@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) +def test_pld_schedule(tmpdir, theta): + gamma = 0.001 + + pld_scheduler = ProgressiveLayerDrop(theta, gamma) + for i in range(10): + pld_scheduler.update_state(i) + expected_theta = (1. - theta) * np.exp(-gamma * i) + theta + actual_theta = pld_scheduler.get_theta() + assert expected_theta == actual_theta + + +@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) +def test_pld_model(tmpdir, theta): + gamma = 0.001 + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "optimizer": { + "type": 'Adam', + "params": { + "lr": 0.0001 + } + }, + "fp16": { + "enabled": True + }, + "progressive_layer_drop": { + "enabled": True, + "theta": theta, + "gamma": gamma + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = PLD_SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1]) + def _test_pld_model(args, model, hidden_dim, theta, gamma): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + + for i, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + expected_theta = (1. - theta) * np.exp(-gamma * i) + theta + actual_theta = model.get_pld_theta() + assert expected_theta == actual_theta + + _test_pld_model(args=args, + model=model, + hidden_dim=hidden_dim, + theta=theta, + gamma=gamma) + + +def test_non_pld_model(tmpdir): + gamma = 0.001 + theta = 0.5 + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "optimizer": { + "type": 'Adam', + "params": { + "lr": 0.0001 + } + }, + "fp16": { + "enabled": True + }, + "progressive_layer_drop": { + "enabled": True, + "theta": theta, + "gamma": gamma + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1]) + def _test_non_pld_model(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + data_loader = random_dataloader(model=model, + total_samples=1, + hidden_dim=hidden_dim, + device=model.device) + + for i, batch in enumerate(data_loader): + with pytest.raises(TypeError): + loss = model(batch[0], batch[1]) + + _test_non_pld_model(args=args, model=model, hidden_dim=hidden_dim)