mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Extend linter to detect DOS newlines (#86973)
Fix DOS newlines in `onednn/decompose_silu.[cpp|h]` introduced by https://github.com/pytorch/pytorch/pull/85591 as well as one in `.github/PULL_REQUEST_TEMPLATE.md` Pull Request resolved: https://github.com/pytorch/pytorch/pull/86973 Approved by: https://github.com/huydhn, https://github.com/izaitsevfb
This commit is contained in:
committed by
PyTorch MergeBot
parent
b8aa1767cd
commit
3924aa75b1
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -1 +1 @@
|
|||||||
Fixes #ISSUE_NUMBER
|
Fixes #ISSUE_NUMBER
|
||||||
|
@ -288,8 +288,10 @@ include_patterns=['**']
|
|||||||
exclude_patterns=[
|
exclude_patterns=[
|
||||||
'**/contrib/**',
|
'**/contrib/**',
|
||||||
'third_party/**',
|
'third_party/**',
|
||||||
|
'**/*.bat',
|
||||||
'**/*.expect',
|
'**/*.expect',
|
||||||
'**/*.ipynb',
|
'**/*.ipynb',
|
||||||
|
'**/*.ps1',
|
||||||
'**/*.ptl',
|
'**/*.ptl',
|
||||||
'tools/clang_format_hash/**',
|
'tools/clang_format_hash/**',
|
||||||
'test/cpp/jit/upgrader_models/*.ptl',
|
'test/cpp/jit/upgrader_models/*.ptl',
|
||||||
|
@ -4,13 +4,13 @@ NEWLINE: Checks files to make sure there are no trailing newlines.
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import NamedTuple, Optional
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
NEWLINE = 10 # ASCII "\n"
|
NEWLINE = 10 # ASCII "\n"
|
||||||
|
CARRIAGE_RETURN = 13 # ASCII "\r"
|
||||||
LINTER_CODE = "NEWLINE"
|
LINTER_CODE = "NEWLINE"
|
||||||
|
|
||||||
|
|
||||||
@ -33,78 +33,96 @@ class LintMessage(NamedTuple):
|
|||||||
description: Optional[str]
|
description: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
def correct_trailing_newlines(filename: str) -> bool:
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
a = len(f.read(2))
|
|
||||||
if a == 0:
|
|
||||||
return True
|
|
||||||
elif a == 1:
|
|
||||||
# file is wrong whether or not the only byte is a newline
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
f.seek(-2, os.SEEK_END)
|
|
||||||
b, c = f.read(2)
|
|
||||||
# no ASCII byte is part of any non-ASCII character in UTF-8
|
|
||||||
return b != NEWLINE and c == NEWLINE
|
|
||||||
|
|
||||||
|
|
||||||
def check_file(filename: str) -> Optional[LintMessage]:
|
def check_file(filename: str) -> Optional[LintMessage]:
|
||||||
logging.debug("Checking file %s", filename)
|
logging.debug("Checking file %s", filename)
|
||||||
|
|
||||||
with open(filename, "rb") as f:
|
with open(filename, "rb") as f:
|
||||||
a = len(f.read(2))
|
lines = f.readlines()
|
||||||
if a == 0:
|
|
||||||
# File is empty, just leave it alone.
|
if len(lines) == 0:
|
||||||
return None
|
# File is empty, just leave it alone.
|
||||||
elif a == 1:
|
return None
|
||||||
# file is wrong whether or not the only byte is a newline
|
|
||||||
|
if len(lines) == 1 and len(lines[0]) == 1:
|
||||||
|
# file is wrong whether or not the only byte is a newline
|
||||||
|
return LintMessage(
|
||||||
|
path=filename,
|
||||||
|
line=None,
|
||||||
|
char=None,
|
||||||
|
code=LINTER_CODE,
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="testestTrailing newline",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(lines[-1]) == 1 and lines[-1][0] == NEWLINE:
|
||||||
|
try:
|
||||||
|
original = b"".join(lines).decode("utf-8")
|
||||||
|
except Exception as err:
|
||||||
return LintMessage(
|
return LintMessage(
|
||||||
path=filename,
|
path=filename,
|
||||||
line=None,
|
line=None,
|
||||||
char=None,
|
char=None,
|
||||||
code=LINTER_CODE,
|
code=LINTER_CODE,
|
||||||
severity=LintSeverity.ERROR,
|
severity=LintSeverity.ERROR,
|
||||||
name="testestTrailing newline",
|
name="Decoding failure",
|
||||||
original=None,
|
original=None,
|
||||||
replacement=None,
|
replacement=None,
|
||||||
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
return LintMessage(
|
||||||
# Read the last two bytes
|
path=filename,
|
||||||
f.seek(-2, os.SEEK_END)
|
line=None,
|
||||||
b, c = f.read(2)
|
char=None,
|
||||||
# no ASCII byte is part of any non-ASCII character in UTF-8
|
code=LINTER_CODE,
|
||||||
if b != NEWLINE and c == NEWLINE:
|
severity=LintSeverity.ERROR,
|
||||||
return None
|
name="Trailing newline",
|
||||||
else:
|
original=original,
|
||||||
f.seek(0)
|
replacement=original.rstrip("\n") + "\n",
|
||||||
try:
|
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||||
original = f.read().decode("utf-8")
|
)
|
||||||
except Exception as err:
|
has_changes = False
|
||||||
return LintMessage(
|
original_lines: Optional[List[bytes]] = None
|
||||||
path=filename,
|
for idx, line in enumerate(lines):
|
||||||
line=None,
|
if len(line) >= 2 and line[-1] == NEWLINE and line[-2] == CARRIAGE_RETURN:
|
||||||
char=None,
|
if not has_changes:
|
||||||
code=LINTER_CODE,
|
original_lines = list(lines)
|
||||||
severity=LintSeverity.ERROR,
|
has_changes = True
|
||||||
name="Decoding failure",
|
lines[idx] = line[:-2] + b"\n"
|
||||||
original=None,
|
|
||||||
replacement=None,
|
|
||||||
description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
|
|
||||||
)
|
|
||||||
|
|
||||||
return LintMessage(
|
if has_changes:
|
||||||
path=filename,
|
try:
|
||||||
line=None,
|
assert original_lines is not None
|
||||||
char=None,
|
original = b"".join(original_lines).decode("utf-8")
|
||||||
code=LINTER_CODE,
|
replacement = b"".join(lines).decode("utf-8")
|
||||||
severity=LintSeverity.ERROR,
|
except Exception as err:
|
||||||
name="Trailing newline",
|
return LintMessage(
|
||||||
original=original,
|
path=filename,
|
||||||
replacement=original.rstrip("\n") + "\n",
|
line=None,
|
||||||
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
char=None,
|
||||||
)
|
code=LINTER_CODE,
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="Decoding failure",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
|
||||||
|
)
|
||||||
|
return LintMessage(
|
||||||
|
path=filename,
|
||||||
|
line=None,
|
||||||
|
char=None,
|
||||||
|
code=LINTER_CODE,
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="DOS newline",
|
||||||
|
original=original,
|
||||||
|
replacement=replacement,
|
||||||
|
description="DOS newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,65 +1,65 @@
|
|||||||
#include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
|
#include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
|
||||||
#include <torch/csrc/jit/codegen/onednn/operator.h>
|
#include <torch/csrc/jit/codegen/onednn/operator.h>
|
||||||
|
|
||||||
#include <ATen/code_template.h>
|
#include <ATen/code_template.h>
|
||||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace fuser {
|
namespace fuser {
|
||||||
namespace onednn {
|
namespace onednn {
|
||||||
|
|
||||||
bool shouldDecomposeSilu(Node* node) {
|
bool shouldDecomposeSilu(Node* node) {
|
||||||
if (node->kind() != aten::silu) {
|
if (node->kind() != aten::silu) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto inputToSilu = node->input(0)->node();
|
auto inputToSilu = node->input(0)->node();
|
||||||
if (inputToSilu->kind() == aten::_convolution) {
|
if (inputToSilu->kind() == aten::_convolution) {
|
||||||
// TODO: remove transpose check once the bridge supported ConvTranspose
|
// TODO: remove transpose check once the bridge supported ConvTranspose
|
||||||
bool transposed = Operator::Bool(inputToSilu, 6);
|
bool transposed = Operator::Bool(inputToSilu, 6);
|
||||||
return !transposed;
|
return !transposed;
|
||||||
}
|
}
|
||||||
if (inputToSilu->kind() == aten::linear) {
|
if (inputToSilu->kind() == aten::linear) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DecomposeSilu(Node* node) {
|
void DecomposeSilu(Node* node) {
|
||||||
if (shouldDecomposeSilu(node)) {
|
if (shouldDecomposeSilu(node)) {
|
||||||
auto dtype = node->input(0)->type()->expect<TensorType>();
|
auto dtype = node->input(0)->type()->expect<TensorType>();
|
||||||
|
|
||||||
WithInsertPoint guard(node);
|
WithInsertPoint guard(node);
|
||||||
auto g = node->owningGraph();
|
auto g = node->owningGraph();
|
||||||
auto sigmoid = g->insert(aten::sigmoid, {node->input(0)});
|
auto sigmoid = g->insert(aten::sigmoid, {node->input(0)});
|
||||||
sigmoid->setType(dtype);
|
sigmoid->setType(dtype);
|
||||||
|
|
||||||
auto mul = g->insert(aten::mul, {sigmoid, node->input(0)});
|
auto mul = g->insert(aten::mul, {sigmoid, node->input(0)});
|
||||||
mul->setType(dtype);
|
mul->setType(dtype);
|
||||||
|
|
||||||
node->output()->replaceAllUsesWith(mul);
|
node->output()->replaceAllUsesWith(mul);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void DecomposeSilu(Block* block) {
|
static void DecomposeSilu(Block* block) {
|
||||||
for (auto node : block->nodes()) {
|
for (auto node : block->nodes()) {
|
||||||
for (auto sub : node->blocks()) {
|
for (auto sub : node->blocks()) {
|
||||||
DecomposeSilu(sub);
|
DecomposeSilu(sub);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->kind() == aten::silu) {
|
if (node->kind() == aten::silu) {
|
||||||
DecomposeSilu(node);
|
DecomposeSilu(node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph) {
|
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph) {
|
||||||
DecomposeSilu(graph->block());
|
DecomposeSilu(graph->block());
|
||||||
EliminateDeadCode(graph);
|
EliminateDeadCode(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onednn
|
} // namespace onednn
|
||||||
} // namespace fuser
|
} // namespace fuser
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace fuser {
|
namespace fuser {
|
||||||
namespace onednn {
|
namespace onednn {
|
||||||
|
|
||||||
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph);
|
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph);
|
||||||
|
|
||||||
} // namespace onednn
|
} // namespace onednn
|
||||||
} // namespace fuser
|
} // namespace fuser
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
Reference in New Issue
Block a user