mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE] Extend linter to detect DOS newlines (#86973)
Fix DOS newlines in `onednn/decompose_silu.[cpp|h]` introduced by https://github.com/pytorch/pytorch/pull/85591 as well as one in `.github/PULL_REQUEST_TEMPLATE.md` Pull Request resolved: https://github.com/pytorch/pytorch/pull/86973 Approved by: https://github.com/huydhn, https://github.com/izaitsevfb
This commit is contained in:
committed by
PyTorch MergeBot
parent
b8aa1767cd
commit
3924aa75b1
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -1 +1 @@
|
||||
Fixes #ISSUE_NUMBER
|
||||
Fixes #ISSUE_NUMBER
|
||||
|
@ -288,8 +288,10 @@ include_patterns=['**']
|
||||
exclude_patterns=[
|
||||
'**/contrib/**',
|
||||
'third_party/**',
|
||||
'**/*.bat',
|
||||
'**/*.expect',
|
||||
'**/*.ipynb',
|
||||
'**/*.ps1',
|
||||
'**/*.ptl',
|
||||
'tools/clang_format_hash/**',
|
||||
'test/cpp/jit/upgrader_models/*.ptl',
|
||||
|
@ -4,13 +4,13 @@ NEWLINE: Checks files to make sure there are no trailing newlines.
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
NEWLINE = 10 # ASCII "\n"
|
||||
CARRIAGE_RETURN = 13 # ASCII "\r"
|
||||
LINTER_CODE = "NEWLINE"
|
||||
|
||||
|
||||
@ -33,78 +33,96 @@ class LintMessage(NamedTuple):
|
||||
description: Optional[str]
|
||||
|
||||
|
||||
def correct_trailing_newlines(filename: str) -> bool:
|
||||
with open(filename, "rb") as f:
|
||||
a = len(f.read(2))
|
||||
if a == 0:
|
||||
return True
|
||||
elif a == 1:
|
||||
# file is wrong whether or not the only byte is a newline
|
||||
return False
|
||||
else:
|
||||
f.seek(-2, os.SEEK_END)
|
||||
b, c = f.read(2)
|
||||
# no ASCII byte is part of any non-ASCII character in UTF-8
|
||||
return b != NEWLINE and c == NEWLINE
|
||||
|
||||
|
||||
def check_file(filename: str) -> Optional[LintMessage]:
|
||||
logging.debug("Checking file %s", filename)
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
a = len(f.read(2))
|
||||
if a == 0:
|
||||
# File is empty, just leave it alone.
|
||||
return None
|
||||
elif a == 1:
|
||||
# file is wrong whether or not the only byte is a newline
|
||||
lines = f.readlines()
|
||||
|
||||
if len(lines) == 0:
|
||||
# File is empty, just leave it alone.
|
||||
return None
|
||||
|
||||
if len(lines) == 1 and len(lines[0]) == 1:
|
||||
# file is wrong whether or not the only byte is a newline
|
||||
return LintMessage(
|
||||
path=filename,
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="testestTrailing newline",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||
)
|
||||
|
||||
if len(lines[-1]) == 1 and lines[-1][0] == NEWLINE:
|
||||
try:
|
||||
original = b"".join(lines).decode("utf-8")
|
||||
except Exception as err:
|
||||
return LintMessage(
|
||||
path=filename,
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="testestTrailing newline",
|
||||
name="Decoding failure",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||
description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
|
||||
)
|
||||
|
||||
else:
|
||||
# Read the last two bytes
|
||||
f.seek(-2, os.SEEK_END)
|
||||
b, c = f.read(2)
|
||||
# no ASCII byte is part of any non-ASCII character in UTF-8
|
||||
if b != NEWLINE and c == NEWLINE:
|
||||
return None
|
||||
else:
|
||||
f.seek(0)
|
||||
try:
|
||||
original = f.read().decode("utf-8")
|
||||
except Exception as err:
|
||||
return LintMessage(
|
||||
path=filename,
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="Decoding failure",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
|
||||
)
|
||||
return LintMessage(
|
||||
path=filename,
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="Trailing newline",
|
||||
original=original,
|
||||
replacement=original.rstrip("\n") + "\n",
|
||||
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||
)
|
||||
has_changes = False
|
||||
original_lines: Optional[List[bytes]] = None
|
||||
for idx, line in enumerate(lines):
|
||||
if len(line) >= 2 and line[-1] == NEWLINE and line[-2] == CARRIAGE_RETURN:
|
||||
if not has_changes:
|
||||
original_lines = list(lines)
|
||||
has_changes = True
|
||||
lines[idx] = line[:-2] + b"\n"
|
||||
|
||||
return LintMessage(
|
||||
path=filename,
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="Trailing newline",
|
||||
original=original,
|
||||
replacement=original.rstrip("\n") + "\n",
|
||||
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||
)
|
||||
if has_changes:
|
||||
try:
|
||||
assert original_lines is not None
|
||||
original = b"".join(original_lines).decode("utf-8")
|
||||
replacement = b"".join(lines).decode("utf-8")
|
||||
except Exception as err:
|
||||
return LintMessage(
|
||||
path=filename,
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="Decoding failure",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=f"utf-8 decoding failed due to {err.__class__.__name__}:\n{err}",
|
||||
)
|
||||
return LintMessage(
|
||||
path=filename,
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="DOS newline",
|
||||
original=original,
|
||||
replacement=replacement,
|
||||
description="DOS newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,65 +1,65 @@
|
||||
#include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/operator.h>
|
||||
|
||||
#include <ATen/code_template.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
bool shouldDecomposeSilu(Node* node) {
|
||||
if (node->kind() != aten::silu) {
|
||||
return false;
|
||||
}
|
||||
auto inputToSilu = node->input(0)->node();
|
||||
if (inputToSilu->kind() == aten::_convolution) {
|
||||
// TODO: remove transpose check once the bridge supported ConvTranspose
|
||||
bool transposed = Operator::Bool(inputToSilu, 6);
|
||||
return !transposed;
|
||||
}
|
||||
if (inputToSilu->kind() == aten::linear) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void DecomposeSilu(Node* node) {
|
||||
if (shouldDecomposeSilu(node)) {
|
||||
auto dtype = node->input(0)->type()->expect<TensorType>();
|
||||
|
||||
WithInsertPoint guard(node);
|
||||
auto g = node->owningGraph();
|
||||
auto sigmoid = g->insert(aten::sigmoid, {node->input(0)});
|
||||
sigmoid->setType(dtype);
|
||||
|
||||
auto mul = g->insert(aten::mul, {sigmoid, node->input(0)});
|
||||
mul->setType(dtype);
|
||||
|
||||
node->output()->replaceAllUsesWith(mul);
|
||||
}
|
||||
}
|
||||
|
||||
static void DecomposeSilu(Block* block) {
|
||||
for (auto node : block->nodes()) {
|
||||
for (auto sub : node->blocks()) {
|
||||
DecomposeSilu(sub);
|
||||
}
|
||||
|
||||
if (node->kind() == aten::silu) {
|
||||
DecomposeSilu(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph) {
|
||||
DecomposeSilu(graph->block());
|
||||
EliminateDeadCode(graph);
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
#include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
|
||||
#include <torch/csrc/jit/codegen/onednn/operator.h>
|
||||
|
||||
#include <ATen/code_template.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
bool shouldDecomposeSilu(Node* node) {
|
||||
if (node->kind() != aten::silu) {
|
||||
return false;
|
||||
}
|
||||
auto inputToSilu = node->input(0)->node();
|
||||
if (inputToSilu->kind() == aten::_convolution) {
|
||||
// TODO: remove transpose check once the bridge supported ConvTranspose
|
||||
bool transposed = Operator::Bool(inputToSilu, 6);
|
||||
return !transposed;
|
||||
}
|
||||
if (inputToSilu->kind() == aten::linear) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void DecomposeSilu(Node* node) {
|
||||
if (shouldDecomposeSilu(node)) {
|
||||
auto dtype = node->input(0)->type()->expect<TensorType>();
|
||||
|
||||
WithInsertPoint guard(node);
|
||||
auto g = node->owningGraph();
|
||||
auto sigmoid = g->insert(aten::sigmoid, {node->input(0)});
|
||||
sigmoid->setType(dtype);
|
||||
|
||||
auto mul = g->insert(aten::mul, {sigmoid, node->input(0)});
|
||||
mul->setType(dtype);
|
||||
|
||||
node->output()->replaceAllUsesWith(mul);
|
||||
}
|
||||
}
|
||||
|
||||
static void DecomposeSilu(Block* block) {
|
||||
for (auto node : block->nodes()) {
|
||||
for (auto sub : node->blocks()) {
|
||||
DecomposeSilu(sub);
|
||||
}
|
||||
|
||||
if (node->kind() == aten::silu) {
|
||||
DecomposeSilu(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph) {
|
||||
DecomposeSilu(graph->block());
|
||||
EliminateDeadCode(graph);
|
||||
}
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -1,15 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace onednn {
|
||||
|
||||
void DecomposeSiluForLLGA(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace onednn
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user