import os import gradio as gr import pandas as pd from src.populate import get_leaderboard_df, get_filtered_leaderboard from src.display.filters import get_model_type_choices, get_strategy_choices, get_metric_categories from src.language import lang import logging import traceback # 配置路径 RESULTS_PATH = os.path.join(os.path.dirname(__file__), "results") # 确保结果目录存在 os.makedirs(RESULTS_PATH, exist_ok=True) # 配置日志 logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) def get_category_columns(df, category): """获取特定类别的指标列,根据类别应用不同的命名规则""" if df.empty: return [], {} category_name = category["name"] prefixes = category["prefix"] if not isinstance(prefixes, list): prefixes = [prefixes] # 确保model_name和model_type在前面 columns = [] if 'model_name' in df.columns: columns.append('model_name') if 'model_type' in df.columns and 'model_type' not in columns: columns.append('model_type') # 收集并处理指标列 metric_columns = [] for col in df.columns: for p in prefixes: if col.startswith(f"MAE_{p}") or col.startswith(f"MSE_{p}"): # 不同类别使用不同的重命名规则 if category_name == "平稳性" or category_name == "Stationarity": simplified_col = col elif category_name == "方差特性" or category_name == "Scadasticity": if "Homo-Scedasticity" in p: simplified_col = col.replace(p, "Homo_", 1) elif "Hetero-Scedasticity" in p: simplified_col = col.replace(p, "Hetero_", 1) else: simplified_col = col elif category_name == "季节数" or category_name == "Seasonality Count": if "Seasonality_Count" in p: simplified_col = col.replace(p, "Count", 1) else: simplified_col = col else: simplified_col = col.replace(p, "", 1) metric_columns.append((col, simplified_col)) break # 去重并保持顺序 seen = set() unique_metrics = [] for col, simplified in metric_columns: if col not in seen: seen.add(col) unique_metrics.append((col, simplified)) # 过滤存在的列 existing_columns = [col for col in columns + [col for col, _ in unique_metrics] if col in df.columns] # 创建重命名字典 rename_dict = {col: simplified for col, simplified in unique_metrics if col in existing_columns} return existing_columns, rename_dict def refresh_leaderboard(): """刷新排行榜数据,返回原始DataFrame和错误信息""" logger.info("开始刷新排行榜数据") try: if not os.path.exists(RESULTS_PATH): error_msg = f"❌ 结果目录不存在: {RESULTS_PATH}" logger.error(error_msg) categories = get_metric_categories() return [pd.DataFrame() for _ in categories] + [error_msg] if not os.path.isdir(RESULTS_PATH): error_msg = f"❌ {RESULTS_PATH} 不是一个目录" logger.error(error_msg) categories = get_metric_categories() return [pd.DataFrame() for _ in categories] + [error_msg] # 加载数据 df = get_leaderboard_df(RESULTS_PATH) logger.info(f"成功加载数据,共 {len(df)} 条记录") categories = get_metric_categories() data_outputs = [] for category in categories: if df.empty: data_outputs.append(pd.DataFrame()) else: category_cols, rename_dict = get_category_columns(df, category) logger.info(f"类别 {category['name']} 的列: {category_cols}") display_df = df[category_cols].copy() display_df = display_df.rename(columns=rename_dict) if 'model_name' in display_df.columns: display_df['model_name'] = display_df['model_name'].str.slice(0, 25) + \ (display_df['model_name'].str.len() > 25).map({True: '...', False: ''}) data_outputs.append(display_df) # 处理空状态信息 if df.empty: model_folders = [f for f in os.listdir(RESULTS_PATH) if os.path.isdir(os.path.join(RESULTS_PATH, f))] error_msg = "⚠️ 找到模型文件夹,但无法加载数据" if model_folders else f"⚠️ 未在 {RESULTS_PATH} 中找到模型" logger.warning(error_msg) else: error_msg = "" data_outputs.append(error_msg) logger.info("刷新排行榜完成") return data_outputs except Exception as e: error_msg = f"❌ 刷新失败: {str(e)}" logger.error(f"刷新数据失败: {str(e)}\n{traceback.format_exc()}") categories = get_metric_categories() return [pd.DataFrame() for _ in categories] + [error_msg] def apply_filters(model_type, strategies, filter_mode): """应用筛选条件,返回原始DataFrame和提示信息""" logger.info(f"开始应用筛选: 模型类型={model_type}, 策略={strategies}, 模式={filter_mode}") try: filtered_df = get_filtered_leaderboard( RESULTS_PATH, model_type=model_type, strategies=strategies, filter_mode=filter_mode ) logger.info(f"筛选完成,得到 {len(filtered_df)} 条记录") categories = get_metric_categories() data_outputs = [] for category in categories: if filtered_df.empty: data_outputs.append(pd.DataFrame()) else: category_cols, rename_dict = get_category_columns(filtered_df, category) logger.info(f"筛选后类别 {category['name']} 的列: {category_cols}") display_df = filtered_df[category_cols].copy() display_df = display_df.rename(columns=rename_dict) if 'model_name' in display_df.columns: display_df['model_name'] = display_df['model_name'].str.slice(0, 25) + \ (display_df['model_name'].str.len() > 25).map({True: '...', False: ''}) data_outputs.append(display_df) empty_msg = "⚠️ 没有找到符合筛选条件的模型。" if filtered_df.empty else "" data_outputs.append(empty_msg) logger.info("筛选应用完成") return data_outputs except Exception as e: error_msg = f"❌ 筛选失败: {str(e)}" logger.error(f"筛选失败: {str(e)}\n{traceback.format_exc()}") categories = get_metric_categories() return [pd.DataFrame() for _ in categories] + [error_msg] def create_interface(): """创建Gradio界面,实现多语言切换及固定列功能""" try: model_type_choices = get_model_type_choices() except Exception as e: model_type_choices = ["All"] print(f"获取模型类型选项时出错: {str(e)}") try: strategy_choices = get_strategy_choices(RESULTS_PATH) except Exception as e: strategy_choices = [] print(f"获取策略选项时出错: {str(e)}") categories = get_metric_categories() with gr.Blocks(title="Aries 模型评估排行榜") as demo: # 添加CSS确保model_name列固定 gr.HTML(""" """) # 保存标题组件引用 title_markdown = gr.Markdown(f"# {lang.get('title')}") # 添加语言切换按钮 with gr.Row(): lang_btn = gr.Button(f"Switch to English" if lang.current_lang == "zh" else f"切换至中文") with gr.Tabs() as main_tabs: # 模型排行榜标签页 with gr.Tab(lang.get("model_leaderboard")) as leaderboard_tab: # 保存筛选条件标题引用 filter_conditions = gr.Markdown(f"### {lang.get('filter_conditions')}") with gr.Row(): with gr.Column(scale=1): model_type = gr.Dropdown( choices=model_type_choices, label=lang.get("model_type"), value="All" ) strategies = gr.CheckboxGroup( choices=strategy_choices, label=lang.get("strategies"), value=[] ) filter_mode = gr.Radio( choices=[lang.get("intersection"), lang.get("union")], label=lang.get("filter_mode"), value=lang.get("intersection") ) with gr.Row(): refresh_btn = gr.Button(lang.get("refresh")) apply_btn = gr.Button(lang.get("apply_filters")) with gr.Column(scale=3): empty_state = gr.Markdown(visible=False) with gr.Tabs() as category_tabs: category_dataframes = [] category_tabs_list = [] for category in categories: tab = gr.Tab(category["name"]) category_tabs_list.append(tab) with tab: df_component = gr.Dataframe( interactive=False, wrap=True, label=category["description"], elem_classes="fixed-column-table" ) category_dataframes.append(df_component) # 关于标签页 - 支持多语言切换 about_tab = gr.Tab(lang.get("about_tab")) with about_tab: about_markdown = gr.Markdown( f"## {lang.get('about_title')}\n{lang.get('about_content')}" ) # 事件处理 apply_btn.click( fn=apply_filters, inputs=[model_type, strategies, filter_mode], outputs=category_dataframes + [empty_state] ) refresh_btn.click( fn=refresh_leaderboard, outputs=category_dataframes + [empty_state] ) demo.load( fn=refresh_leaderboard, outputs=category_dataframes + [empty_state] ) def toggle_language(): new_lang = lang.switch_language() btn_text = "切换至中文" if new_lang == "en" else "Switch to English" # 更新标题 new_title = gr.update(value=f"# {lang.get('title')}") # 更新标签页标题 leaderboard_tab_title = gr.update(label=lang.get("model_leaderboard")) # 更新筛选条件标题 new_filter_conditions = gr.update(value=f"### {lang.get('filter_conditions')}") # 更新各种组件的标签 new_model_type_label = gr.update(label=lang.get("model_type")) new_strategies_label = gr.update(label=lang.get("strategies")) new_filter_mode_label = gr.update(label=lang.get("filter_mode")) new_refresh_label = gr.update(value=lang.get("refresh")) new_apply_label = gr.update(value=lang.get("apply_filters")) new_filter_mode_choices = [lang.get("intersection"), lang.get("union")] new_filter_mode = gr.update( choices=new_filter_mode_choices, value=lang.get("intersection") ) # 重新获取多语言配置 new_categories = get_metric_categories() refreshed_data = refresh_leaderboard() dataframes = refreshed_data[:-1] empty_state_msg = refreshed_data[-1] # 准备更新列表 updates = [ btn_text, new_title, leaderboard_tab_title, new_filter_conditions, new_model_type_label, new_strategies_label, new_filter_mode, new_refresh_label, new_apply_label ] # 更新指标类别标签页 for tab, new_cat in zip(category_tabs_list, new_categories): updates.append(gr.update(label=new_cat["name"])) # 更新指标表格 for df_comp, new_cat, df in zip(category_dataframes, new_categories, dataframes): updates.append(gr.update( label=new_cat["description"], value=df, visible=True )) # 更新空状态提示 updates.append(gr.update(value=empty_state_msg, visible=bool(empty_state_msg))) updates.append(gr.update(label=lang.get("about_tab"))) updates.append(gr.update(value=f"## {lang.get('about_title')}\n{lang.get('about_content')}")) return updates # 语言切换按钮绑定 lang_btn.click( fn=toggle_language, inputs=[], outputs=[ lang_btn, title_markdown, leaderboard_tab, filter_conditions, model_type, strategies, filter_mode, refresh_btn, apply_btn ] + category_tabs_list + category_dataframes + [empty_state, about_tab, about_markdown] ) return demo if __name__ == "__main__": print(f"正在启动Aries模型评估排行榜,结果目录: {RESULTS_PATH}") demo = create_interface() demo.launch()