Coverage for src\time_series_analyzer\transfer_function.py: 80%

147 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-23 11:57 +0800

1""" 

2传递函数推导引擎 

3 

4基于符号计算自动推导ARIMA模型的传递函数表达式。 

5将时间序列模型转换为关于滞后算子B的多项式比值形式。 

6""" 

7 

8from typing import Dict, Any, Optional, Tuple 

9from sympy import symbols, Poly, simplify, factor, expand, Symbol, Rational 

10from sympy.polys.polyfuncs import interpolate 

11import sympy as sp 

12 

13from .models import ARIMAModel, SeasonalARIMAModel 

14 

15 

16class TransferFunction: 

17 """ 

18 传递函数表示类 

19  

20 表示形式: H(B) = numerator(B) / denominator(B) 

21 其中B是滞后算子 

22 """ 

23 

24 def __init__(self, numerator: Poly, denominator: Poly, lag_operator: Symbol = None): 

25 """ 

26 初始化传递函数 

27  

28 Args: 

29 numerator: 分子多项式 

30 denominator: 分母多项式  

31 lag_operator: 滞后算子符号 

32 """ 

33 if lag_operator is None: 

34 lag_operator = symbols('B') 

35 

36 self.lag_operator = lag_operator 

37 self.numerator = numerator 

38 self.denominator = denominator 

39 

40 # 简化传递函数 

41 self._simplify() 

42 

43 def _simplify(self): 

44 """简化传递函数,约去公因子""" 

45 try: 

46 # 计算最大公约数 

47 gcd_poly = sp.gcd(self.numerator.as_expr(), self.denominator.as_expr()) 

48 

49 if gcd_poly != 1: 

50 # 约去公因子 

51 self.numerator = Poly( 

52 simplify(self.numerator.as_expr() / gcd_poly), 

53 self.lag_operator 

54 ) 

55 self.denominator = Poly( 

56 simplify(self.denominator.as_expr() / gcd_poly), 

57 self.lag_operator 

58 ) 

59 except Exception: # 如果简化失败,保持原样 

60 pass 

61 

62 def evaluate_at_frequency(self, frequency: complex) -> complex: 

63 """ 

64 在特定频率处计算传递函数值 

65  

66 Args: 

67 frequency: 频率值 (通常是 e^{-iω}) 

68  

69 Returns: 

70 传递函数在该频率处的值 

71 """ 

72 num_expr = self.numerator.as_expr().subs(self.lag_operator, frequency) 

73 den_expr = self.denominator.as_expr().subs(self.lag_operator, frequency) 

74 

75 # 确保结果为数值 

76 num_val = complex(num_expr.evalf()) 

77 den_val = complex(den_expr.evalf()) 

78 

79 if abs(den_val) < 1e-12: 

80 raise ValueError(f"分母在频率{frequency}处为零") 

81 

82 return num_val / den_val 

83 

84 def get_poles(self) -> list: 

85 """获取传递函数的极点(分母的根)""" 

86 try: 

87 roots = sp.solve(self.denominator.as_expr(), self.lag_operator) 

88 return [complex(root.evalf()) for root in roots if root.is_finite] 

89 except Exception: 

90 return [] 

91 

92 def get_zeros(self) -> list: 

93 """获取传递函数的零点(分子的根)""" 

94 try: 

95 roots = sp.solve(self.numerator.as_expr(), self.lag_operator) 

96 return [complex(root.evalf()) for root in roots if root.is_finite] 

97 except Exception: 

98 return [] 

99 

100 def is_stable(self) -> bool: 

101 """ 

102 检查系统稳定性 

103 对于离散时间系统,所有极点的模长应小于1 

104 """ 

105 poles = self.get_poles() 

106 return all(abs(pole) < 1 for pole in poles) 

107 

108 def __str__(self) -> str: 

109 """字符串表示""" 

110 return f"H({self.lag_operator}) = ({self.numerator.as_expr()}) / ({self.denominator.as_expr()})" 

111 

112 def __repr__(self) -> str: 

113 return self.__str__() 

114 

115 

116class TransferFunctionDeriver: 

117 """ 

118 传递函数推导器 

119  

120 将ARIMA模型自动转换为传递函数形式 

121 """ 

122 

123 def __init__(self, lag_operator: Symbol = None): 

124 """ 

125 初始化推导器 

126  

127 Args: 

128 lag_operator: 滞后算子符号,默认为B 

129 """ 

130 if lag_operator is None: 

131 lag_operator = symbols('B') 

132 self.lag_operator = lag_operator 

133 

134 def derive_arima_transfer_function(self, model: ARIMAModel) -> TransferFunction: 

135 """ 

136 推导ARIMA模型的传递函数 

137  

138 ARIMA(p,d,q)模型的一般形式: 

139 φ(B)(1-B)^d X_t = θ(B)ε_t 

140  

141 传递函数: H(B) = θ(B) / [φ(B)(1-B)^d] 

142  

143 Args: 

144 model: ARIMA模型 

145  

146 Returns: 

147 传递函数对象 

148 """ 

149 # 获取各个多项式 

150 ar_poly = model.get_ar_polynomial(self.lag_operator) 

151 ma_poly = model.get_ma_polynomial(self.lag_operator) 

152 diff_poly = model.get_difference_polynomial(self.lag_operator) 

153 

154 # 分子:移动平均多项式 θ(B) 

155 numerator = ma_poly 

156 

157 # 分母:自回归多项式 × 差分多项式 φ(B)(1-B)^d 

158 denominator = ar_poly * diff_poly 

159 

160 return TransferFunction(numerator, denominator, self.lag_operator) 

161 

162 def derive_sarima_transfer_function(self, model: SeasonalARIMAModel) -> TransferFunction: 

163 """ 

164 推导季节性ARIMA模型的传递函数 

165  

166 SARIMA(p,d,q)(P,D,Q,m)模型的一般形式: 

167 φ(B)Φ(B^m)(1-B)^d(1-B^m)^D X_t = θ(B)Θ(B^m)ε_t 

168  

169 传递函数: H(B) = θ(B)Θ(B^m) / [φ(B)Φ(B^m)(1-B)^d(1-B^m)^D] 

170  

171 Args: 

172 model: 季节性ARIMA模型 

173  

174 Returns: 

175 传递函数对象 

176 """ 

177 # 获取非季节性多项式 

178 ar_poly = model.get_ar_polynomial(self.lag_operator) 

179 ma_poly = model.get_ma_polynomial(self.lag_operator) 

180 diff_poly = model.get_difference_polynomial(self.lag_operator) 

181 

182 # 获取季节性多项式 

183 seasonal_ar_poly = model.get_seasonal_ar_polynomial(self.lag_operator) 

184 seasonal_ma_poly = model.get_seasonal_ma_polynomial(self.lag_operator) 

185 seasonal_diff_poly = model.get_seasonal_difference_polynomial(self.lag_operator) 

186 

187 # 分子:θ(B)Θ(B^m) 

188 numerator = ma_poly * seasonal_ma_poly 

189 

190 # 分母:φ(B)Φ(B^m)(1-B)^d(1-B^m)^D 

191 denominator = ar_poly * seasonal_ar_poly * diff_poly * seasonal_diff_poly 

192 

193 return TransferFunction(numerator, denominator, self.lag_operator) 

194 

195 def derive_transfer_function(self, model) -> TransferFunction: 

196 """ 

197 通用传递函数推导方法 

198  

199 Args: 

200 model: ARIMA或SARIMA模型 

201  

202 Returns: 

203 传递函数对象 

204 """ 

205 if isinstance(model, SeasonalARIMAModel): 

206 return self.derive_sarima_transfer_function(model) 

207 elif isinstance(model, ARIMAModel): 

208 return self.derive_arima_transfer_function(model) 

209 else: 

210 raise ValueError(f"不支持的模型类型: {type(model)}") 

211 

212 def derive_impulse_response(self, model, max_lag: int = 20) -> Dict[int, Any]: 

213 """ 

214 推导脉冲响应函数 

215  

216 通过传递函数的幂级数展开获得脉冲响应系数 

217  

218 Args: 

219 model: 时间序列模型 

220 max_lag: 最大滞后阶数 

221  

222 Returns: 

223 脉冲响应系数字典 {lag: coefficient} 

224 """ 

225 transfer_func = self.derive_transfer_function(model) 

226 

227 # 计算幂级数展开 

228 try: 

229 # H(B) = num(B) / den(B) = Σ h_j B^j 

230 num_expr = transfer_func.numerator.as_expr() 

231 den_expr = transfer_func.denominator.as_expr() 

232 

233 # 使用sympy的series展开 

234 series = sp.series(num_expr / den_expr, self.lag_operator, 0, max_lag + 1) 

235 

236 impulse_response = {} 

237 for i in range(max_lag + 1): 

238 coeff = series.coeff(self.lag_operator, i) 

239 if coeff is not None: 

240 impulse_response[i] = coeff 

241 else: 

242 impulse_response[i] = 0 

243 

244 return impulse_response 

245 

246 except Exception as e: 

247 # 如果符号计算失败,返回空字典 

248 return {} 

249 

250 def analyze_stability(self, model) -> Dict[str, Any]: 

251 """ 

252 分析模型稳定性 

253  

254 Args: 

255 model: 时间序列模型 

256  

257 Returns: 

258 稳定性分析结果 

259 """ 

260 transfer_func = self.derive_transfer_function(model) 

261 

262 poles = transfer_func.get_poles() 

263 zeros = transfer_func.get_zeros() 

264 is_stable = transfer_func.is_stable() 

265 

266 # 计算极点的模长 

267 pole_magnitudes = [abs(pole) for pole in poles] 

268 

269 return { 

270 "is_stable": is_stable, 

271 "poles": poles, 

272 "zeros": zeros, "pole_magnitudes": pole_magnitudes, 

273 "max_pole_magnitude": max(pole_magnitudes) if pole_magnitudes else 0, 

274 "stability_margin": 1 - max(pole_magnitudes) if pole_magnitudes else 1 

275 } 

276 

277 def get_frequency_response(self, model, frequencies: list, 

278 param_values: dict = None) -> Dict[str, list]: 

279 """ 

280 计算频率响应 

281  

282 Args: 

283 model: 时间序列模型 

284 frequencies: 频率列表 (弧度) 

285 param_values: 模型参数的数值,格式为 {'phi_1': 0.5, 'theta_1': 0.3, ...} 

286 如果为None,将使用默认值 

287  

288 Returns: 

289 频率响应数据 

290 """ 

291 import math 

292 transfer_func = self.derive_transfer_function(model) 

293 

294 # 如果没有提供参数值,使用默认值 

295 if param_values is None: 

296 param_values = self._get_default_params(model) 

297 

298 magnitudes = [] 

299 phases = [] 

300 

301 for omega in frequencies: 

302 # 计算 e^{-iω} 

303 z = complex(math.cos(omega), -math.sin(omega)) 

304 

305 try: 

306 response = self._evaluate_with_params(transfer_func, z, param_values) 

307 magnitudes.append(abs(response)) 

308 phases.append(math.atan2(response.imag, response.real)) 

309 except (ValueError, TypeError): 

310 magnitudes.append(float('inf')) 

311 phases.append(0) 

312 

313 return { 

314 "frequencies": frequencies, 

315 "magnitudes": magnitudes, 

316 "phases": phases, 

317 "magnitude_db": [20 * math.log10(mag) if mag > 0 and math.isfinite(mag) else -float('inf') 

318 for mag in magnitudes] 

319 } 

320 

321 def _get_default_params(self, model) -> dict: 

322 """获取模型的默认参数值""" 

323 params = {} 

324 

325 # AR参数:使用稳定值 

326 if hasattr(model, 'ar_params') and model.ar_params: 

327 for i, param in enumerate(model.ar_params): 

328 if isinstance(param, str): 

329 # 为了稳定性,AR参数应该较小 

330 params[param] = 0.1 * (i + 1) 

331 

332 # MA参数:使用适中值 

333 if hasattr(model, 'ma_params') and model.ma_params: 

334 for i, param in enumerate(model.ma_params): 

335 if isinstance(param, str): 

336 params[param] = 0.2 * (i + 1) 

337 

338 # 季节性参数 

339 if hasattr(model, 'seasonal_ar_params') and model.seasonal_ar_params: 

340 for i, param in enumerate(model.seasonal_ar_params): 

341 if isinstance(param, str): 

342 params[param] = 0.05 * (i + 1) 

343 

344 if hasattr(model, 'seasonal_ma_params') and model.seasonal_ma_params: 

345 for i, param in enumerate(model.seasonal_ma_params): 

346 if isinstance(param, str): 

347 params[param] = 0.1 * (i + 1) 

348 

349 return params 

350 

351 def _evaluate_with_params(self, transfer_func: TransferFunction, 

352 frequency: complex, param_values: dict) -> complex: 

353 """使用给定参数值计算传递函数在特定频率的值""" 

354 # 获取分子和分母表达式 

355 num_expr = transfer_func.numerator.as_expr() 

356 den_expr = transfer_func.denominator.as_expr() 

357 

358 # 替换滞后算子 

359 num_expr = num_expr.subs(transfer_func.lag_operator, frequency) 

360 den_expr = den_expr.subs(transfer_func.lag_operator, frequency) 

361 

362 # 替换参数 

363 for param_name, param_value in param_values.items(): 

364 num_expr = num_expr.subs(symbols(param_name), param_value) 

365 den_expr = den_expr.subs(symbols(param_name), param_value) 

366 

367 # 计算数值 

368 num_val = complex(num_expr.evalf()) 

369 den_val = complex(den_expr.evalf()) 

370 

371 if abs(den_val) < 1e-12: 

372 raise ValueError(f"分母在频率{frequency}处为零") 

373 

374 return num_val / den_val