[BE] Extend linter to detect DOS newlines (#86973)

Fix DOS newlines in `onednn/decompose_silu.[cpp|h]` introduced by https://github.com/pytorch/pytorch/pull/85591 as well as one in `.github/PULL_REQUEST_TEMPLATE.md`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86973
Approved by: https://github.com/huydhn, https://github.com/izaitsevfb
This commit is contained in:
Nikita Shulga
2022-10-15 00:20:42 +00:00
committed by PyTorch MergeBot
parent b8aa1767cd
commit 3924aa75b1
5 changed files with 160 additions and 140 deletions

View File

@ -1 +1 @@
Fixes #ISSUE_NUMBER
Fixes #ISSUE_NUMBER

View File

@ -288,8 +288,10 @@ include_patterns=['**']
exclude_patterns=[
'**/contrib/**',
'third_party/**',
'**/*.bat',
'**/*.expect',
'**/*.ipynb',
'**/*.ps1',
'**/*.ptl',
'tools/clang_format_hash/**',
'test/cpp/jit/upgrader_models/*.ptl',

View File

@ -4,13 +4,13 @@ NEWLINE: Checks files to make sure there are no trailing newlines.
import argparse
import json
import logging
import os
import sys
from enum import Enum
from typing import NamedTuple, Optional
from typing import List, NamedTuple, Optional
NEWLINE = 10 # ASCII "\n"
CARRIAGE_RETURN = 13 # ASCII "\r"
LINTER_CODE = "NEWLINE"
@ -33,78 +33,96 @@ class LintMessage(NamedTuple):
description: Optional[str]
def correct_trailing_newlines(filename: str) -> bool:
with open(filename, "rb") as f:
a = len(f.read(2))
if a == 0:
return True
elif a == 1:
# file is wrong whether or not the only byte is a newline
return False
else:
f.seek(-2, os.SEEK_END)
b, c = f.read(2)
# no ASCII byte is part of any non-ASCII character in UTF-8
return b != NEWLINE and c == NEWLINE
def check_file(filename: str) -> Optional[LintMessage]:
logging.debug("Checking file %s", filename)
with open(filename, "rb") as f:
a = len(f.read(2))
if a == 0:
# File is empty, just leave it alone.
return None
elif a == 1:
# file is wrong whether or not the only byte is a newline
lines = f.readlines()
if len(lines) == 0:
# File is empty, just leave it alone.
return None
if len(lines) == 1 and len(lines[0]) == 1:
# file is wrong whether or not the only byte is a newline
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="testestTrailing newline",
original=None,
replacement=None,
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
)
if len(lines[-1]) == 1 and lines[-1][0] == NEWLINE:
try:
original = b"".join(lines).decode("utf-8")
except Exception as err:
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="testestTrailing newline",
name="Decoding failure",
original=None,
replacement=None,
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
)
else:
# Read the last two bytes
f.seek(-2, os.SEEK_END)
b, c = f.read(2)
# no ASCII byte is part of any non-ASCII character in UTF-8
if b != NEWLINE and c == NEWLINE:
return None
else:
f.seek(0)
try:
original = f.read().decode("utf-8")
except Exception as err:
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="Decoding failure",
original=None,
replacement=None,
description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
)
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="Trailing newline",
original=original,
replacement=original.rstrip("\n") + "\n",
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
)
has_changes = False
original_lines: Optional[List[bytes]] = None
for idx, line in enumerate(lines):
if len(line) >= 2 and line[-1] == NEWLINE and line[-2] == CARRIAGE_RETURN:
if not has_changes:
original_lines = list(lines)
has_changes = True
lines[idx] = line[:-2] + b"\n"
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="Trailing newline",
original=original,
replacement=original.rstrip("\n") + "\n",
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
)
if has_changes:
try:
assert original_lines is not None
original = b"".join(original_lines).decode("utf-8")
replacement = b"".join(lines).decode("utf-8")
except Exception as err:
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="Decoding failure",
original=None,
replacement=None,
description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
)
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="DOS newline",
original=original,
replacement=replacement,
description="DOS newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
)
return None
if __name__ == "__main__":

View File

@ -1,65 +1,65 @@
#include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
#include <torch/csrc/jit/codegen/onednn/operator.h>
#include <ATen/code_template.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
bool shouldDecomposeSilu(Node* node) {
if (node->kind() != aten::silu) {
return false;
}
auto inputToSilu = node->input(0)->node();
if (inputToSilu->kind() == aten::_convolution) {
// TODO: remove transpose check once the bridge supported ConvTranspose
bool transposed = Operator::Bool(inputToSilu, 6);
return !transposed;
}
if (inputToSilu->kind() == aten::linear) {
return true;
}
return false;
}
void DecomposeSilu(Node* node) {
if (shouldDecomposeSilu(node)) {
auto dtype = node->input(0)->type()->expect<TensorType>();
WithInsertPoint guard(node);
auto g = node->owningGraph();
auto sigmoid = g->insert(aten::sigmoid, {node->input(0)});
sigmoid->setType(dtype);
auto mul = g->insert(aten::mul, {sigmoid, node->input(0)});
mul->setType(dtype);
node->output()->replaceAllUsesWith(mul);
}
}
static void DecomposeSilu(Block* block) {
for (auto node : block->nodes()) {
for (auto sub : node->blocks()) {
DecomposeSilu(sub);
}
if (node->kind() == aten::silu) {
DecomposeSilu(node);
}
}
}
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph) {
DecomposeSilu(graph->block());
EliminateDeadCode(graph);
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch
#include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
#include <torch/csrc/jit/codegen/onednn/operator.h>
#include <ATen/code_template.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
bool shouldDecomposeSilu(Node* node) {
if (node->kind() != aten::silu) {
return false;
}
auto inputToSilu = node->input(0)->node();
if (inputToSilu->kind() == aten::_convolution) {
// TODO: remove transpose check once the bridge supported ConvTranspose
bool transposed = Operator::Bool(inputToSilu, 6);
return !transposed;
}
if (inputToSilu->kind() == aten::linear) {
return true;
}
return false;
}
void DecomposeSilu(Node* node) {
if (shouldDecomposeSilu(node)) {
auto dtype = node->input(0)->type()->expect<TensorType>();
WithInsertPoint guard(node);
auto g = node->owningGraph();
auto sigmoid = g->insert(aten::sigmoid, {node->input(0)});
sigmoid->setType(dtype);
auto mul = g->insert(aten::mul, {sigmoid, node->input(0)});
mul->setType(dtype);
node->output()->replaceAllUsesWith(mul);
}
}
static void DecomposeSilu(Block* block) {
for (auto node : block->nodes()) {
for (auto sub : node->blocks()) {
DecomposeSilu(sub);
}
if (node->kind() == aten::silu) {
DecomposeSilu(node);
}
}
}
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph) {
DecomposeSilu(graph->block());
EliminateDeadCode(graph);
}
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -1,15 +1,15 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph);
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch
#pragma once
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph);
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch