mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 02:53:59 +08:00
170 lines
5.6 KiB
Python
170 lines
5.6 KiB
Python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
from datetime import date
|
|
from pathlib import Path
|
|
|
|
from tabulate import tabulate
|
|
|
|
|
|
MAX_LEN_MESSAGE = 2900 # Slack endpoint has a limit of 3001 characters
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--slack_channel_name", default="trl-push-ci")
|
|
|
|
# Set up logging
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
|
|
def process_log_file(log):
|
|
failed_tests = []
|
|
passed_tests = []
|
|
section_num_failed = 0
|
|
|
|
try:
|
|
with open(log) as f:
|
|
for line in f:
|
|
try:
|
|
data = json.loads(line)
|
|
test_name = data.get("nodeid", "")
|
|
duration = f"{data['duration']:.4f}" if "duration" in data else "N/A"
|
|
outcome = data.get("outcome", "")
|
|
|
|
if test_name:
|
|
if outcome == "failed":
|
|
section_num_failed += 1
|
|
failed_tests.append([test_name, duration, log.stem.split("_")[0]])
|
|
else:
|
|
passed_tests.append([test_name, duration, log.stem.split("_")[0]])
|
|
except json.JSONDecodeError as e:
|
|
logging.warning(f"Could not decode line in {log}: {e}")
|
|
|
|
except FileNotFoundError as e:
|
|
logging.error(f"Log file {log} not found: {e}")
|
|
except Exception as e:
|
|
logging.error(f"Error processing log file {log}: {e}")
|
|
|
|
return failed_tests, passed_tests, section_num_failed
|
|
|
|
|
|
def main(slack_channel_name):
|
|
group_info = []
|
|
total_num_failed = 0
|
|
total_empty_files = []
|
|
|
|
log_files = list(Path().glob("*.log"))
|
|
if not log_files:
|
|
logging.info("No log files found.")
|
|
return
|
|
|
|
for log in log_files:
|
|
failed, passed, section_num_failed = process_log_file(log)
|
|
empty_file = not failed and not passed
|
|
|
|
total_num_failed += section_num_failed
|
|
total_empty_files.append(empty_file)
|
|
group_info.append([str(log), section_num_failed, failed])
|
|
|
|
# Clean up log file
|
|
try:
|
|
os.remove(log)
|
|
except OSError as e:
|
|
logging.warning(f"Could not remove log file {log}: {e}")
|
|
|
|
# Prepare Slack message payload
|
|
payload = [
|
|
{
|
|
"type": "header",
|
|
"text": {"type": "plain_text", "text": f"🤗 Results of the {os.environ.get('TEST_TYPE', '')} TRL tests."},
|
|
},
|
|
]
|
|
|
|
if total_num_failed > 0:
|
|
message = ""
|
|
for name, num_failed, failed_tests in group_info:
|
|
if num_failed > 0:
|
|
message += f"*{name}: {num_failed} failed test(s)*\n"
|
|
failed_table = [
|
|
test[0].split("::")[:2] + [test[0].split("::")[-1][:30] + ".."] for test in failed_tests
|
|
]
|
|
message += (
|
|
"\n```\n"
|
|
+ tabulate(failed_table, headers=["Test Location", "Test Name"], tablefmt="grid")
|
|
+ "\n```\n"
|
|
)
|
|
|
|
if any(total_empty_files):
|
|
message += f"\n*{name}: Warning! Empty file - check GitHub action job*\n"
|
|
|
|
# Logging
|
|
logging.info(f"Total failed tests: {total_num_failed}")
|
|
print(f"### {message}")
|
|
|
|
if len(message) > MAX_LEN_MESSAGE:
|
|
message = (
|
|
f"❌ There are {total_num_failed} failed tests in total! Please check the action results directly."
|
|
)
|
|
|
|
payload.append({"type": "section", "text": {"type": "mrkdwn", "text": message}})
|
|
payload.append(
|
|
{
|
|
"type": "section",
|
|
"text": {"type": "mrkdwn", "text": "*For more details:*"},
|
|
"accessory": {
|
|
"type": "button",
|
|
"text": {"type": "plain_text", "text": "Check Action results"},
|
|
"url": f"https://github.com/huggingface/trl/actions/runs/{os.environ['GITHUB_RUN_ID']}",
|
|
},
|
|
}
|
|
)
|
|
payload.append(
|
|
{
|
|
"type": "context",
|
|
"elements": [
|
|
{
|
|
"type": "plain_text",
|
|
"text": f"On Push main {os.environ.get('TEST_TYPE')} results for {date.today()}",
|
|
}
|
|
],
|
|
}
|
|
)
|
|
|
|
# Send to Slack
|
|
from slack_sdk import WebClient
|
|
|
|
slack_client = WebClient(token=os.environ.get("SLACK_API_TOKEN"))
|
|
slack_client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload)
|
|
|
|
else:
|
|
payload.append(
|
|
{
|
|
"type": "section",
|
|
"text": {
|
|
"type": "plain_text",
|
|
"text": "✅ No failures! All tests passed successfully.",
|
|
"emoji": True,
|
|
},
|
|
}
|
|
)
|
|
logging.info("All tests passed. No errors detected.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
main(args.slack_channel_name)
|