Coverage for src\time_series_analyzer\parsers.py: 51%
166 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-23 11:57 +0800
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-23 11:57 +0800
1"""
2模型参数解析模块
4支持多种输入方式:命令行参数、配置文件(JSON/YAML)、交互式输入等。
5"""
7import json
8import yaml
9from typing import Dict, Any, Optional, Union, List
10from pathlib import Path
11import re
13from .models import ARIMAModel, SeasonalARIMAModel
16class ModelParser:
17 """模型参数解析器"""
19 @staticmethod
20 def parse_arima_string(arima_str: str) -> Dict[str, Any]:
21 """
22 解析ARIMA字符串格式
24 支持格式:
25 - "ARIMA(2,1,1)"
26 - "ARIMA(2,1,1,0.5,-0.3,0.2)" # 包含参数
27 - "SARIMA(2,1,1)(1,1,1,12)"
29 Args:
30 arima_str: ARIMA模型字符串
32 Returns:
33 解析后的参数字典
34 """
35 arima_str = arima_str.strip().upper()
37 # 匹配SARIMA格式
38 sarima_pattern = r'SARIMA\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)'
39 sarima_match = re.match(sarima_pattern, arima_str)
41 if sarima_match:
42 p, d, q, P, D, Q, m = map(int, sarima_match.groups())
43 return {
44 "model_type": "SARIMA",
45 "p": p, "d": d, "q": q,
46 "P": P, "D": D, "Q": Q, "m": m
47 }
49 # 匹配ARIMA格式
50 arima_pattern = r'ARIMA\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(.*))?\s*\)'
51 arima_match = re.match(arima_pattern, arima_str)
53 if arima_match:
54 p, d, q = map(int, arima_match.groups()[:3])
55 params_str = arima_match.group(4)
57 result = {
58 "model_type": "ARIMA",
59 "p": p, "d": d, "q": q
60 }
62 # 解析参数
63 if params_str:
64 try:
65 params = [float(x.strip()) for x in params_str.split(',')]
66 if len(params) >= p:
67 result["ar_params"] = params[:p]
68 if len(params) >= p + q:
69 result["ma_params"] = params[p:p+q]
70 if len(params) > p + q:
71 result["constant"] = params[p+q]
72 except ValueError:
73 pass # 忽略参数解析错误
75 return result
77 raise ValueError(f"无法解析ARIMA字符串: {arima_str}")
79 @staticmethod
80 def parse_json_file(file_path: Union[str, Path]) -> Dict[str, Any]:
81 """
82 从JSON文件解析模型参数
84 Args:
85 file_path: JSON文件路径
87 Returns:
88 解析后的参数字典
89 """
90 file_path = Path(file_path)
92 if not file_path.exists():
93 raise FileNotFoundError(f"文件不存在: {file_path}")
95 with open(file_path, 'r', encoding='utf-8') as f:
96 data = json.load(f)
98 return ModelParser._validate_config_data(data)
100 @staticmethod
101 def parse_yaml_file(file_path: Union[str, Path]) -> Dict[str, Any]:
102 """
103 从YAML文件解析模型参数
105 Args:
106 file_path: YAML文件路径
108 Returns:
109 解析后的参数字典
110 """
111 file_path = Path(file_path)
113 if not file_path.exists():
114 raise FileNotFoundError(f"文件不存在: {file_path}")
116 with open(file_path, 'r', encoding='utf-8') as f:
117 data = yaml.safe_load(f)
119 return ModelParser._validate_config_data(data)
121 @staticmethod
122 def _validate_config_data(data: Dict[str, Any]) -> Dict[str, Any]:
123 """验证配置数据格式"""
124 required_fields = ["model_type", "p", "d", "q"]
126 for field in required_fields:
127 if field not in data:
128 raise ValueError(f"配置文件缺少必需字段: {field}")
130 model_type = data["model_type"].upper()
131 if model_type not in ["ARIMA", "SARIMA"]:
132 raise ValueError(f"不支持的模型类型: {model_type}")
134 if model_type == "SARIMA":
135 seasonal_fields = ["P", "D", "Q", "m"]
136 for field in seasonal_fields:
137 if field not in data:
138 raise ValueError(f"SARIMA模型缺少必需字段: {field}")
140 return data
142 @staticmethod
143 def interactive_input() -> Dict[str, Any]:
144 """
145 交互式输入模型参数
147 Returns:
148 解析后的参数字典
149 """
150 print("=== 时序模型参数输入 ===")
152 # 选择模型类型
153 while True:
154 model_type = input("请选择模型类型 (1: ARIMA, 2: SARIMA): ").strip()
155 if model_type in ["1", "ARIMA", "arima"]:
156 model_type = "ARIMA"
157 break
158 elif model_type in ["2", "SARIMA", "sarima"]:
159 model_type = "SARIMA"
160 break
161 else:
162 print("无效输入,请重新选择")
164 # 输入基本参数
165 try:
166 p = int(input("请输入自回归阶数 p: "))
167 d = int(input("请输入差分阶数 d: "))
168 q = int(input("请输入移动平均阶数 q: "))
169 except ValueError:
170 raise ValueError("参数必须是非负整数")
172 result = {
173 "model_type": model_type,
174 "p": p, "d": d, "q": q
175 }
177 # 输入季节性参数
178 if model_type == "SARIMA":
179 try:
180 P = int(input("请输入季节性自回归阶数 P: "))
181 D = int(input("请输入季节性差分阶数 D: "))
182 Q = int(input("请输入季节性移动平均阶数 Q: "))
183 m = int(input("请输入季节性周期 m: "))
184 except ValueError:
185 raise ValueError("季节性参数必须是非负整数")
187 result.update({"P": P, "D": D, "Q": Q, "m": m})
189 # 询问是否输入具体参数值
190 if input("是否输入具体参数值? (y/n): ").lower().startswith('y'):
191 if p > 0:
192 ar_params = []
193 for i in range(p):
194 try:
195 param = float(input(f"请输入AR参数 φ_{i+1}: "))
196 ar_params.append(param)
197 except ValueError:
198 print(f"使用符号参数 phi_{i+1}")
199 ar_params.append(f"phi_{i+1}")
200 result["ar_params"] = ar_params
202 if q > 0:
203 ma_params = []
204 for i in range(q):
205 try:
206 param = float(input(f"请输入MA参数 θ_{i+1}: "))
207 ma_params.append(param)
208 except ValueError:
209 print(f"使用符号参数 theta_{i+1}")
210 ma_params.append(f"theta_{i+1}")
211 result["ma_params"] = ma_params
213 if model_type == "SARIMA":
214 if result.get("P", 0) > 0:
215 seasonal_ar_params = []
216 for i in range(result["P"]):
217 try:
218 param = float(input(f"请输入季节性AR参数 Φ_{i+1}: "))
219 seasonal_ar_params.append(param)
220 except ValueError:
221 print(f"使用符号参数 Phi_{i+1}")
222 seasonal_ar_params.append(f"Phi_{i+1}")
223 result["seasonal_ar_params"] = seasonal_ar_params
225 if result.get("Q", 0) > 0:
226 seasonal_ma_params = []
227 for i in range(result["Q"]):
228 try:
229 param = float(input(f"请输入季节性MA参数 Θ_{i+1}: "))
230 seasonal_ma_params.append(param)
231 except ValueError:
232 print(f"使用符号参数 Theta_{i+1}")
233 seasonal_ma_params.append(f"Theta_{i+1}")
234 result["seasonal_ma_params"] = seasonal_ma_params
236 # 常数项
237 try:
238 constant = float(input("请输入常数项 (默认0): ") or "0")
239 result["constant"] = constant
240 except ValueError:
241 result["constant"] = 0
243 return result
245 @staticmethod
246 def create_model_from_dict(data: Dict[str, Any]) -> Union[ARIMAModel, SeasonalARIMAModel]:
247 """
248 从字典创建模型对象
250 Args:
251 data: 参数字典
253 Returns:
254 模型对象
255 """
256 model_type = data.get("model_type", "ARIMA").upper()
258 if model_type == "SARIMA":
259 # 移除model_type字段
260 sarima_data = {k: v for k, v in data.items() if k != "model_type"}
261 return SeasonalARIMAModel(**sarima_data)
262 else:
263 # 移除SARIMA特有的字段和model_type字段
264 arima_data = {k: v for k, v in data.items()
265 if k not in ["model_type", "P", "D", "Q", "m", "seasonal_ar_params", "seasonal_ma_params"]}
266 return ARIMAModel(**arima_data)
268 @staticmethod
269 def parse_from_string(input_str: str) -> Union[ARIMAModel, SeasonalARIMAModel]:
270 """
271 从字符串解析并创建模型
273 Args:
274 input_str: 输入字符串
276 Returns:
277 模型对象
278 """
279 data = ModelParser.parse_arima_string(input_str)
280 return ModelParser.create_model_from_dict(data)
282 @staticmethod
283 def parse_from_file(file_path: Union[str, Path]) -> Union[ARIMAModel, SeasonalARIMAModel]:
284 """
285 从文件解析并创建模型
287 Args:
288 file_path: 文件路径
290 Returns:
291 模型对象
292 """
293 file_path = Path(file_path)
295 if file_path.suffix.lower() == '.json':
296 data = ModelParser.parse_json_file(file_path)
297 elif file_path.suffix.lower() in ['.yaml', '.yml']:
298 data = ModelParser.parse_yaml_file(file_path)
299 else:
300 raise ValueError(f"不支持的文件格式: {file_path.suffix}")
302 return ModelParser.create_model_from_dict(data)
304 @staticmethod
305 def parse_interactive() -> Union[ARIMAModel, SeasonalARIMAModel]:
306 """
307 交互式解析并创建模型
309 Returns:
310 模型对象
311 """
312 data = ModelParser.interactive_input()
313 return ModelParser.create_model_from_dict(data)