Spaces:
Running
Running
| import os | |
| import json | |
| import pandas as pd | |
| from src.display.utils import ModelType | |
| class ModelDetails: | |
| """模型详情类,与ModelType枚举兼容""" | |
| def __init__(self, name, symbol): | |
| self.name = name | |
| self.symbol = symbol | |
| def read_model_config(config_path): | |
| """读取模型配置文件""" | |
| try: | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| config = json.load(f) | |
| # 解析模型类型 | |
| model_type_str = config.get('model_type', '') | |
| model_type = ModelType.from_str(model_type_str) | |
| return { | |
| 'model': config.get('model', os.path.basename(os.path.dirname(config_path))), | |
| 'model_type': model_type.to_str(), | |
| 'strategies': config.get('strategies', {}) | |
| } | |
| except Exception as e: | |
| print(f"读取配置文件 {config_path} 出错: {str(e)}") | |
| return None | |
| def read_model_performance(csv_path): | |
| """读取模型性能CSV文件,包含所有指标""" | |
| try: | |
| # 读取CSV文件 | |
| df = pd.read_csv(csv_path) | |
| # 确保Metric列存在 | |
| if 'Metric' not in df.columns: | |
| print(f"CSV文件 {csv_path} 缺少'Metric'列") | |
| return None | |
| # 提取所有指标列(排除Model和Metric) | |
| metric_columns = [col for col in df.columns if col not in ['Model', 'Metric']] | |
| performance_data = {} | |
| # 提取MAE和MSE数据 | |
| for metric in ['MAE', 'MSE']: | |
| if metric in df['Metric'].values: | |
| metric_row = df[df['Metric'] == metric].iloc[0] | |
| for col in metric_columns: | |
| performance_data[f"{metric}_{col}"] = metric_row[col] | |
| return performance_data | |
| except Exception as e: | |
| print(f"读取性能文件 {csv_path} 出错: {str(e)}") | |
| return None | |
| def get_all_strategies(results_path): | |
| """获取所有可用的策略""" | |
| strategies = set() | |
| # 遍历所有模型文件夹 | |
| for model_folder in os.listdir(results_path): | |
| folder_path = os.path.join(results_path, model_folder) | |
| if not os.path.isdir(folder_path): | |
| continue | |
| # 读取配置文件中的策略 | |
| config_path = os.path.join(folder_path, 'config.json') | |
| if os.path.exists(config_path): | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| try: | |
| config = json.load(f) | |
| strategies.update(config.get('strategies', {}).keys()) | |
| except Exception as e: | |
| print(f"读取 {model_folder} 的配置文件时出错: {str(e)}") | |
| return sorted(strategies) | |
| def load_model_data(model_folder, results_path): | |
| """加载单个模型的所有数据""" | |
| folder_path = os.path.join(results_path, model_folder) | |
| # 检查是否为目录 | |
| if not os.path.isdir(folder_path): | |
| print(f"{folder_path} 不是一个目录,跳过") | |
| return None | |
| # 查找CSV文件 | |
| csv_files = [f for f in os.listdir(folder_path) if f.endswith('.csv')] | |
| if not csv_files: | |
| print(f"模型文件夹 {model_folder} 中未找到CSV文件") | |
| return None | |
| # 读取性能数据(使用第一个CSV文件) | |
| csv_path = os.path.join(folder_path, csv_files[0]) | |
| performance_data = read_model_performance(csv_path) | |
| if not performance_data: | |
| return None | |
| # 读取配置数据 | |
| config_path = os.path.join(folder_path, 'config.json') | |
| if not os.path.exists(config_path): | |
| print(f"模型文件夹 {model_folder} 中未找到config.json") | |
| return None | |
| config_data = read_model_config(config_path) | |
| if not config_data: | |
| return None | |
| # 合并数据 | |
| model_data = { | |
| 'model_name': config_data['model'], | |
| 'model_type': config_data['model_type'] | |
| } | |
| # 添加性能数据 | |
| model_data.update(performance_data) | |
| # 添加策略数据(前缀标识) | |
| for strategy, value in config_data['strategies'].items(): | |
| model_data[f'strategy_{strategy}'] = value | |
| return model_data | |