Coverage for src\time_series_analyzer\parsers.py: 51%

166 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-23 11:57 +0800

1""" 

2模型参数解析模块 

3 

4支持多种输入方式:命令行参数、配置文件(JSON/YAML)、交互式输入等。 

5""" 

6 

7import json 

8import yaml 

9from typing import Dict, Any, Optional, Union, List 

10from pathlib import Path 

11import re 

12 

13from .models import ARIMAModel, SeasonalARIMAModel 

14 

15 

16class ModelParser: 

17 """模型参数解析器""" 

18 

19 @staticmethod 

20 def parse_arima_string(arima_str: str) -> Dict[str, Any]: 

21 """ 

22 解析ARIMA字符串格式 

23  

24 支持格式: 

25 - "ARIMA(2,1,1)" 

26 - "ARIMA(2,1,1,0.5,-0.3,0.2)" # 包含参数 

27 - "SARIMA(2,1,1)(1,1,1,12)" 

28  

29 Args: 

30 arima_str: ARIMA模型字符串 

31  

32 Returns: 

33 解析后的参数字典 

34 """ 

35 arima_str = arima_str.strip().upper() 

36 

37 # 匹配SARIMA格式 

38 sarima_pattern = r'SARIMA\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)' 

39 sarima_match = re.match(sarima_pattern, arima_str) 

40 

41 if sarima_match: 

42 p, d, q, P, D, Q, m = map(int, sarima_match.groups()) 

43 return { 

44 "model_type": "SARIMA", 

45 "p": p, "d": d, "q": q, 

46 "P": P, "D": D, "Q": Q, "m": m 

47 } 

48 

49 # 匹配ARIMA格式 

50 arima_pattern = r'ARIMA\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(.*))?\s*\)' 

51 arima_match = re.match(arima_pattern, arima_str) 

52 

53 if arima_match: 

54 p, d, q = map(int, arima_match.groups()[:3]) 

55 params_str = arima_match.group(4) 

56 

57 result = { 

58 "model_type": "ARIMA", 

59 "p": p, "d": d, "q": q 

60 } 

61 

62 # 解析参数 

63 if params_str: 

64 try: 

65 params = [float(x.strip()) for x in params_str.split(',')] 

66 if len(params) >= p: 

67 result["ar_params"] = params[:p] 

68 if len(params) >= p + q: 

69 result["ma_params"] = params[p:p+q] 

70 if len(params) > p + q: 

71 result["constant"] = params[p+q] 

72 except ValueError: 

73 pass # 忽略参数解析错误 

74 

75 return result 

76 

77 raise ValueError(f"无法解析ARIMA字符串: {arima_str}") 

78 

79 @staticmethod 

80 def parse_json_file(file_path: Union[str, Path]) -> Dict[str, Any]: 

81 """ 

82 从JSON文件解析模型参数 

83  

84 Args: 

85 file_path: JSON文件路径 

86  

87 Returns: 

88 解析后的参数字典 

89 """ 

90 file_path = Path(file_path) 

91 

92 if not file_path.exists(): 

93 raise FileNotFoundError(f"文件不存在: {file_path}") 

94 

95 with open(file_path, 'r', encoding='utf-8') as f: 

96 data = json.load(f) 

97 

98 return ModelParser._validate_config_data(data) 

99 

100 @staticmethod 

101 def parse_yaml_file(file_path: Union[str, Path]) -> Dict[str, Any]: 

102 """ 

103 从YAML文件解析模型参数 

104  

105 Args: 

106 file_path: YAML文件路径 

107  

108 Returns: 

109 解析后的参数字典 

110 """ 

111 file_path = Path(file_path) 

112 

113 if not file_path.exists(): 

114 raise FileNotFoundError(f"文件不存在: {file_path}") 

115 

116 with open(file_path, 'r', encoding='utf-8') as f: 

117 data = yaml.safe_load(f) 

118 

119 return ModelParser._validate_config_data(data) 

120 

121 @staticmethod 

122 def _validate_config_data(data: Dict[str, Any]) -> Dict[str, Any]: 

123 """验证配置数据格式""" 

124 required_fields = ["model_type", "p", "d", "q"] 

125 

126 for field in required_fields: 

127 if field not in data: 

128 raise ValueError(f"配置文件缺少必需字段: {field}") 

129 

130 model_type = data["model_type"].upper() 

131 if model_type not in ["ARIMA", "SARIMA"]: 

132 raise ValueError(f"不支持的模型类型: {model_type}") 

133 

134 if model_type == "SARIMA": 

135 seasonal_fields = ["P", "D", "Q", "m"] 

136 for field in seasonal_fields: 

137 if field not in data: 

138 raise ValueError(f"SARIMA模型缺少必需字段: {field}") 

139 

140 return data 

141 

142 @staticmethod 

143 def interactive_input() -> Dict[str, Any]: 

144 """ 

145 交互式输入模型参数 

146  

147 Returns: 

148 解析后的参数字典 

149 """ 

150 print("=== 时序模型参数输入 ===") 

151 

152 # 选择模型类型 

153 while True: 

154 model_type = input("请选择模型类型 (1: ARIMA, 2: SARIMA): ").strip() 

155 if model_type in ["1", "ARIMA", "arima"]: 

156 model_type = "ARIMA" 

157 break 

158 elif model_type in ["2", "SARIMA", "sarima"]: 

159 model_type = "SARIMA" 

160 break 

161 else: 

162 print("无效输入,请重新选择") 

163 

164 # 输入基本参数 

165 try: 

166 p = int(input("请输入自回归阶数 p: ")) 

167 d = int(input("请输入差分阶数 d: ")) 

168 q = int(input("请输入移动平均阶数 q: ")) 

169 except ValueError: 

170 raise ValueError("参数必须是非负整数") 

171 

172 result = { 

173 "model_type": model_type, 

174 "p": p, "d": d, "q": q 

175 } 

176 

177 # 输入季节性参数 

178 if model_type == "SARIMA": 

179 try: 

180 P = int(input("请输入季节性自回归阶数 P: ")) 

181 D = int(input("请输入季节性差分阶数 D: ")) 

182 Q = int(input("请输入季节性移动平均阶数 Q: ")) 

183 m = int(input("请输入季节性周期 m: ")) 

184 except ValueError: 

185 raise ValueError("季节性参数必须是非负整数") 

186 

187 result.update({"P": P, "D": D, "Q": Q, "m": m}) 

188 

189 # 询问是否输入具体参数值 

190 if input("是否输入具体参数值? (y/n): ").lower().startswith('y'): 

191 if p > 0: 

192 ar_params = [] 

193 for i in range(p): 

194 try: 

195 param = float(input(f"请输入AR参数 φ_{i+1}: ")) 

196 ar_params.append(param) 

197 except ValueError: 

198 print(f"使用符号参数 phi_{i+1}") 

199 ar_params.append(f"phi_{i+1}") 

200 result["ar_params"] = ar_params 

201 

202 if q > 0: 

203 ma_params = [] 

204 for i in range(q): 

205 try: 

206 param = float(input(f"请输入MA参数 θ_{i+1}: ")) 

207 ma_params.append(param) 

208 except ValueError: 

209 print(f"使用符号参数 theta_{i+1}") 

210 ma_params.append(f"theta_{i+1}") 

211 result["ma_params"] = ma_params 

212 

213 if model_type == "SARIMA": 

214 if result.get("P", 0) > 0: 

215 seasonal_ar_params = [] 

216 for i in range(result["P"]): 

217 try: 

218 param = float(input(f"请输入季节性AR参数 Φ_{i+1}: ")) 

219 seasonal_ar_params.append(param) 

220 except ValueError: 

221 print(f"使用符号参数 Phi_{i+1}") 

222 seasonal_ar_params.append(f"Phi_{i+1}") 

223 result["seasonal_ar_params"] = seasonal_ar_params 

224 

225 if result.get("Q", 0) > 0: 

226 seasonal_ma_params = [] 

227 for i in range(result["Q"]): 

228 try: 

229 param = float(input(f"请输入季节性MA参数 Θ_{i+1}: ")) 

230 seasonal_ma_params.append(param) 

231 except ValueError: 

232 print(f"使用符号参数 Theta_{i+1}") 

233 seasonal_ma_params.append(f"Theta_{i+1}") 

234 result["seasonal_ma_params"] = seasonal_ma_params 

235 

236 # 常数项 

237 try: 

238 constant = float(input("请输入常数项 (默认0): ") or "0") 

239 result["constant"] = constant 

240 except ValueError: 

241 result["constant"] = 0 

242 

243 return result 

244 

245 @staticmethod 

246 def create_model_from_dict(data: Dict[str, Any]) -> Union[ARIMAModel, SeasonalARIMAModel]: 

247 """ 

248 从字典创建模型对象 

249  

250 Args: 

251 data: 参数字典 

252  

253 Returns: 

254 模型对象 

255 """ 

256 model_type = data.get("model_type", "ARIMA").upper() 

257 

258 if model_type == "SARIMA": 

259 # 移除model_type字段 

260 sarima_data = {k: v for k, v in data.items() if k != "model_type"} 

261 return SeasonalARIMAModel(**sarima_data) 

262 else: 

263 # 移除SARIMA特有的字段和model_type字段 

264 arima_data = {k: v for k, v in data.items() 

265 if k not in ["model_type", "P", "D", "Q", "m", "seasonal_ar_params", "seasonal_ma_params"]} 

266 return ARIMAModel(**arima_data) 

267 

268 @staticmethod 

269 def parse_from_string(input_str: str) -> Union[ARIMAModel, SeasonalARIMAModel]: 

270 """ 

271 从字符串解析并创建模型 

272  

273 Args: 

274 input_str: 输入字符串 

275  

276 Returns: 

277 模型对象 

278 """ 

279 data = ModelParser.parse_arima_string(input_str) 

280 return ModelParser.create_model_from_dict(data) 

281 

282 @staticmethod 

283 def parse_from_file(file_path: Union[str, Path]) -> Union[ARIMAModel, SeasonalARIMAModel]: 

284 """ 

285 从文件解析并创建模型 

286  

287 Args: 

288 file_path: 文件路径 

289  

290 Returns: 

291 模型对象 

292 """ 

293 file_path = Path(file_path) 

294 

295 if file_path.suffix.lower() == '.json': 

296 data = ModelParser.parse_json_file(file_path) 

297 elif file_path.suffix.lower() in ['.yaml', '.yml']: 

298 data = ModelParser.parse_yaml_file(file_path) 

299 else: 

300 raise ValueError(f"不支持的文件格式: {file_path.suffix}") 

301 

302 return ModelParser.create_model_from_dict(data) 

303 

304 @staticmethod 

305 def parse_interactive() -> Union[ARIMAModel, SeasonalARIMAModel]: 

306 """ 

307 交互式解析并创建模型 

308  

309 Returns: 

310 模型对象 

311 """ 

312 data = ModelParser.interactive_input() 

313 return ModelParser.create_model_from_dict(data)