[BE] Extend linter to detect DOS newlines (#86973)

Fix DOS newlines in `onednn/decompose_silu.[cpp|h]` introduced by https://github.com/pytorch/pytorch/pull/85591 as well as one in `.github/PULL_REQUEST_TEMPLATE.md`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86973
Approved by: https://github.com/huydhn, https://github.com/izaitsevfb
This commit is contained in:
Nikita Shulga
2022-10-15 00:20:42 +00:00
committed by PyTorch MergeBot
parent b8aa1767cd
commit 3924aa75b1
5 changed files with 160 additions and 140 deletions

View File

@ -1 +1 @@
Fixes #ISSUE_NUMBER Fixes #ISSUE_NUMBER

View File

@ -288,8 +288,10 @@ include_patterns=['**']
exclude_patterns=[ exclude_patterns=[
'**/contrib/**', '**/contrib/**',
'third_party/**', 'third_party/**',
'**/*.bat',
'**/*.expect', '**/*.expect',
'**/*.ipynb', '**/*.ipynb',
'**/*.ps1',
'**/*.ptl', '**/*.ptl',
'tools/clang_format_hash/**', 'tools/clang_format_hash/**',
'test/cpp/jit/upgrader_models/*.ptl', 'test/cpp/jit/upgrader_models/*.ptl',

View File

@ -4,13 +4,13 @@ NEWLINE: Checks files to make sure there are no trailing newlines.
import argparse import argparse
import json import json
import logging import logging
import os
import sys import sys
from enum import Enum from enum import Enum
from typing import NamedTuple, Optional from typing import List, NamedTuple, Optional
NEWLINE = 10 # ASCII "\n" NEWLINE = 10 # ASCII "\n"
CARRIAGE_RETURN = 13 # ASCII "\r"
LINTER_CODE = "NEWLINE" LINTER_CODE = "NEWLINE"
@ -33,78 +33,96 @@ class LintMessage(NamedTuple):
description: Optional[str] description: Optional[str]
def correct_trailing_newlines(filename: str) -> bool:
with open(filename, "rb") as f:
a = len(f.read(2))
if a == 0:
return True
elif a == 1:
# file is wrong whether or not the only byte is a newline
return False
else:
f.seek(-2, os.SEEK_END)
b, c = f.read(2)
# no ASCII byte is part of any non-ASCII character in UTF-8
return b != NEWLINE and c == NEWLINE
def check_file(filename: str) -> Optional[LintMessage]: def check_file(filename: str) -> Optional[LintMessage]:
logging.debug("Checking file %s", filename) logging.debug("Checking file %s", filename)
with open(filename, "rb") as f: with open(filename, "rb") as f:
a = len(f.read(2)) lines = f.readlines()
if a == 0:
# File is empty, just leave it alone. if len(lines) == 0:
return None # File is empty, just leave it alone.
elif a == 1: return None
# file is wrong whether or not the only byte is a newline
if len(lines) == 1 and len(lines[0]) == 1:
# file is wrong whether or not the only byte is a newline
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="testestTrailing newline",
original=None,
replacement=None,
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
)
if len(lines[-1]) == 1 and lines[-1][0] == NEWLINE:
try:
original = b"".join(lines).decode("utf-8")
except Exception as err:
return LintMessage( return LintMessage(
path=filename, path=filename,
line=None, line=None,
char=None, char=None,
code=LINTER_CODE, code=LINTER_CODE,
severity=LintSeverity.ERROR, severity=LintSeverity.ERROR,
name="testestTrailing newline", name="Decoding failure",
original=None, original=None,
replacement=None, replacement=None,
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.", description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
) )
else: return LintMessage(
# Read the last two bytes path=filename,
f.seek(-2, os.SEEK_END) line=None,
b, c = f.read(2) char=None,
# no ASCII byte is part of any non-ASCII character in UTF-8 code=LINTER_CODE,
if b != NEWLINE and c == NEWLINE: severity=LintSeverity.ERROR,
return None name="Trailing newline",
else: original=original,
f.seek(0) replacement=original.rstrip("\n") + "\n",
try: description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
original = f.read().decode("utf-8") )
except Exception as err: has_changes = False
return LintMessage( original_lines: Optional[List[bytes]] = None
path=filename, for idx, line in enumerate(lines):
line=None, if len(line) >= 2 and line[-1] == NEWLINE and line[-2] == CARRIAGE_RETURN:
char=None, if not has_changes:
code=LINTER_CODE, original_lines = list(lines)
severity=LintSeverity.ERROR, has_changes = True
name="Decoding failure", lines[idx] = line[:-2] + b"\n"
original=None,
replacement=None,
description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
)
return LintMessage( if has_changes:
path=filename, try:
line=None, assert original_lines is not None
char=None, original = b"".join(original_lines).decode("utf-8")
code=LINTER_CODE, replacement = b"".join(lines).decode("utf-8")
severity=LintSeverity.ERROR, except Exception as err:
name="Trailing newline", return LintMessage(
original=original, path=filename,
replacement=original.rstrip("\n") + "\n", line=None,
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.", char=None,
) code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="Decoding failure",
original=None,
replacement=None,
description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
)
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="DOS newline",
original=original,
replacement=replacement,
description="DOS newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
)
return None
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,65 +1,65 @@
#include <torch/csrc/jit/codegen/onednn/decompose_silu.h> #include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
#include <torch/csrc/jit/codegen/onednn/operator.h> #include <torch/csrc/jit/codegen/onednn/operator.h>
#include <ATen/code_template.h> #include <ATen/code_template.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h> #include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h> #include <torch/csrc/jit/passes/subgraph_rewrite.h>
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace fuser { namespace fuser {
namespace onednn { namespace onednn {
bool shouldDecomposeSilu(Node* node) { bool shouldDecomposeSilu(Node* node) {
if (node->kind() != aten::silu) { if (node->kind() != aten::silu) {
return false; return false;
} }
auto inputToSilu = node->input(0)->node(); auto inputToSilu = node->input(0)->node();
if (inputToSilu->kind() == aten::_convolution) { if (inputToSilu->kind() == aten::_convolution) {
// TODO: remove transpose check once the bridge supported ConvTranspose // TODO: remove transpose check once the bridge supported ConvTranspose
bool transposed = Operator::Bool(inputToSilu, 6); bool transposed = Operator::Bool(inputToSilu, 6);
return !transposed; return !transposed;
} }
if (inputToSilu->kind() == aten::linear) { if (inputToSilu->kind() == aten::linear) {
return true; return true;
} }
return false; return false;
} }
void DecomposeSilu(Node* node) { void DecomposeSilu(Node* node) {
if (shouldDecomposeSilu(node)) { if (shouldDecomposeSilu(node)) {
auto dtype = node->input(0)->type()->expect<TensorType>(); auto dtype = node->input(0)->type()->expect<TensorType>();
WithInsertPoint guard(node); WithInsertPoint guard(node);
auto g = node->owningGraph(); auto g = node->owningGraph();
auto sigmoid = g->insert(aten::sigmoid, {node->input(0)}); auto sigmoid = g->insert(aten::sigmoid, {node->input(0)});
sigmoid->setType(dtype); sigmoid->setType(dtype);
auto mul = g->insert(aten::mul, {sigmoid, node->input(0)}); auto mul = g->insert(aten::mul, {sigmoid, node->input(0)});
mul->setType(dtype); mul->setType(dtype);
node->output()->replaceAllUsesWith(mul); node->output()->replaceAllUsesWith(mul);
} }
} }
static void DecomposeSilu(Block* block) { static void DecomposeSilu(Block* block) {
for (auto node : block->nodes()) { for (auto node : block->nodes()) {
for (auto sub : node->blocks()) { for (auto sub : node->blocks()) {
DecomposeSilu(sub); DecomposeSilu(sub);
} }
if (node->kind() == aten::silu) { if (node->kind() == aten::silu) {
DecomposeSilu(node); DecomposeSilu(node);
} }
} }
} }
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph) { void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph) {
DecomposeSilu(graph->block()); DecomposeSilu(graph->block());
EliminateDeadCode(graph); EliminateDeadCode(graph);
} }
} // namespace onednn } // namespace onednn
} // namespace fuser } // namespace fuser
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -1,15 +1,15 @@
#pragma once #pragma once
#include <torch/csrc/jit/ir/ir.h> #include <torch/csrc/jit/ir/ir.h>
namespace torch { namespace torch {
namespace jit { namespace jit {
namespace fuser { namespace fuser {
namespace onednn { namespace onednn {
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph); void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph);
} // namespace onednn } // namespace onednn
} // namespace fuser } // namespace fuser
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch