Coverage for src\time_series_analyzer\transfer_function.py: 80%
147 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-23 11:57 +0800
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-23 11:57 +0800
1"""
2传递函数推导引擎
4基于符号计算自动推导ARIMA模型的传递函数表达式。
5将时间序列模型转换为关于滞后算子B的多项式比值形式。
6"""
8from typing import Dict, Any, Optional, Tuple
9from sympy import symbols, Poly, simplify, factor, expand, Symbol, Rational
10from sympy.polys.polyfuncs import interpolate
11import sympy as sp
13from .models import ARIMAModel, SeasonalARIMAModel
16class TransferFunction:
17 """
18 传递函数表示类
20 表示形式: H(B) = numerator(B) / denominator(B)
21 其中B是滞后算子
22 """
24 def __init__(self, numerator: Poly, denominator: Poly, lag_operator: Symbol = None):
25 """
26 初始化传递函数
28 Args:
29 numerator: 分子多项式
30 denominator: 分母多项式
31 lag_operator: 滞后算子符号
32 """
33 if lag_operator is None:
34 lag_operator = symbols('B')
36 self.lag_operator = lag_operator
37 self.numerator = numerator
38 self.denominator = denominator
40 # 简化传递函数
41 self._simplify()
43 def _simplify(self):
44 """简化传递函数,约去公因子"""
45 try:
46 # 计算最大公约数
47 gcd_poly = sp.gcd(self.numerator.as_expr(), self.denominator.as_expr())
49 if gcd_poly != 1:
50 # 约去公因子
51 self.numerator = Poly(
52 simplify(self.numerator.as_expr() / gcd_poly),
53 self.lag_operator
54 )
55 self.denominator = Poly(
56 simplify(self.denominator.as_expr() / gcd_poly),
57 self.lag_operator
58 )
59 except Exception: # 如果简化失败,保持原样
60 pass
62 def evaluate_at_frequency(self, frequency: complex) -> complex:
63 """
64 在特定频率处计算传递函数值
66 Args:
67 frequency: 频率值 (通常是 e^{-iω})
69 Returns:
70 传递函数在该频率处的值
71 """
72 num_expr = self.numerator.as_expr().subs(self.lag_operator, frequency)
73 den_expr = self.denominator.as_expr().subs(self.lag_operator, frequency)
75 # 确保结果为数值
76 num_val = complex(num_expr.evalf())
77 den_val = complex(den_expr.evalf())
79 if abs(den_val) < 1e-12:
80 raise ValueError(f"分母在频率{frequency}处为零")
82 return num_val / den_val
84 def get_poles(self) -> list:
85 """获取传递函数的极点(分母的根)"""
86 try:
87 roots = sp.solve(self.denominator.as_expr(), self.lag_operator)
88 return [complex(root.evalf()) for root in roots if root.is_finite]
89 except Exception:
90 return []
92 def get_zeros(self) -> list:
93 """获取传递函数的零点(分子的根)"""
94 try:
95 roots = sp.solve(self.numerator.as_expr(), self.lag_operator)
96 return [complex(root.evalf()) for root in roots if root.is_finite]
97 except Exception:
98 return []
100 def is_stable(self) -> bool:
101 """
102 检查系统稳定性
103 对于离散时间系统,所有极点的模长应小于1
104 """
105 poles = self.get_poles()
106 return all(abs(pole) < 1 for pole in poles)
108 def __str__(self) -> str:
109 """字符串表示"""
110 return f"H({self.lag_operator}) = ({self.numerator.as_expr()}) / ({self.denominator.as_expr()})"
112 def __repr__(self) -> str:
113 return self.__str__()
116class TransferFunctionDeriver:
117 """
118 传递函数推导器
120 将ARIMA模型自动转换为传递函数形式
121 """
123 def __init__(self, lag_operator: Symbol = None):
124 """
125 初始化推导器
127 Args:
128 lag_operator: 滞后算子符号,默认为B
129 """
130 if lag_operator is None:
131 lag_operator = symbols('B')
132 self.lag_operator = lag_operator
134 def derive_arima_transfer_function(self, model: ARIMAModel) -> TransferFunction:
135 """
136 推导ARIMA模型的传递函数
138 ARIMA(p,d,q)模型的一般形式:
139 φ(B)(1-B)^d X_t = θ(B)ε_t
141 传递函数: H(B) = θ(B) / [φ(B)(1-B)^d]
143 Args:
144 model: ARIMA模型
146 Returns:
147 传递函数对象
148 """
149 # 获取各个多项式
150 ar_poly = model.get_ar_polynomial(self.lag_operator)
151 ma_poly = model.get_ma_polynomial(self.lag_operator)
152 diff_poly = model.get_difference_polynomial(self.lag_operator)
154 # 分子:移动平均多项式 θ(B)
155 numerator = ma_poly
157 # 分母:自回归多项式 × 差分多项式 φ(B)(1-B)^d
158 denominator = ar_poly * diff_poly
160 return TransferFunction(numerator, denominator, self.lag_operator)
162 def derive_sarima_transfer_function(self, model: SeasonalARIMAModel) -> TransferFunction:
163 """
164 推导季节性ARIMA模型的传递函数
166 SARIMA(p,d,q)(P,D,Q,m)模型的一般形式:
167 φ(B)Φ(B^m)(1-B)^d(1-B^m)^D X_t = θ(B)Θ(B^m)ε_t
169 传递函数: H(B) = θ(B)Θ(B^m) / [φ(B)Φ(B^m)(1-B)^d(1-B^m)^D]
171 Args:
172 model: 季节性ARIMA模型
174 Returns:
175 传递函数对象
176 """
177 # 获取非季节性多项式
178 ar_poly = model.get_ar_polynomial(self.lag_operator)
179 ma_poly = model.get_ma_polynomial(self.lag_operator)
180 diff_poly = model.get_difference_polynomial(self.lag_operator)
182 # 获取季节性多项式
183 seasonal_ar_poly = model.get_seasonal_ar_polynomial(self.lag_operator)
184 seasonal_ma_poly = model.get_seasonal_ma_polynomial(self.lag_operator)
185 seasonal_diff_poly = model.get_seasonal_difference_polynomial(self.lag_operator)
187 # 分子:θ(B)Θ(B^m)
188 numerator = ma_poly * seasonal_ma_poly
190 # 分母:φ(B)Φ(B^m)(1-B)^d(1-B^m)^D
191 denominator = ar_poly * seasonal_ar_poly * diff_poly * seasonal_diff_poly
193 return TransferFunction(numerator, denominator, self.lag_operator)
195 def derive_transfer_function(self, model) -> TransferFunction:
196 """
197 通用传递函数推导方法
199 Args:
200 model: ARIMA或SARIMA模型
202 Returns:
203 传递函数对象
204 """
205 if isinstance(model, SeasonalARIMAModel):
206 return self.derive_sarima_transfer_function(model)
207 elif isinstance(model, ARIMAModel):
208 return self.derive_arima_transfer_function(model)
209 else:
210 raise ValueError(f"不支持的模型类型: {type(model)}")
212 def derive_impulse_response(self, model, max_lag: int = 20) -> Dict[int, Any]:
213 """
214 推导脉冲响应函数
216 通过传递函数的幂级数展开获得脉冲响应系数
218 Args:
219 model: 时间序列模型
220 max_lag: 最大滞后阶数
222 Returns:
223 脉冲响应系数字典 {lag: coefficient}
224 """
225 transfer_func = self.derive_transfer_function(model)
227 # 计算幂级数展开
228 try:
229 # H(B) = num(B) / den(B) = Σ h_j B^j
230 num_expr = transfer_func.numerator.as_expr()
231 den_expr = transfer_func.denominator.as_expr()
233 # 使用sympy的series展开
234 series = sp.series(num_expr / den_expr, self.lag_operator, 0, max_lag + 1)
236 impulse_response = {}
237 for i in range(max_lag + 1):
238 coeff = series.coeff(self.lag_operator, i)
239 if coeff is not None:
240 impulse_response[i] = coeff
241 else:
242 impulse_response[i] = 0
244 return impulse_response
246 except Exception as e:
247 # 如果符号计算失败,返回空字典
248 return {}
250 def analyze_stability(self, model) -> Dict[str, Any]:
251 """
252 分析模型稳定性
254 Args:
255 model: 时间序列模型
257 Returns:
258 稳定性分析结果
259 """
260 transfer_func = self.derive_transfer_function(model)
262 poles = transfer_func.get_poles()
263 zeros = transfer_func.get_zeros()
264 is_stable = transfer_func.is_stable()
266 # 计算极点的模长
267 pole_magnitudes = [abs(pole) for pole in poles]
269 return {
270 "is_stable": is_stable,
271 "poles": poles,
272 "zeros": zeros, "pole_magnitudes": pole_magnitudes,
273 "max_pole_magnitude": max(pole_magnitudes) if pole_magnitudes else 0,
274 "stability_margin": 1 - max(pole_magnitudes) if pole_magnitudes else 1
275 }
277 def get_frequency_response(self, model, frequencies: list,
278 param_values: dict = None) -> Dict[str, list]:
279 """
280 计算频率响应
282 Args:
283 model: 时间序列模型
284 frequencies: 频率列表 (弧度)
285 param_values: 模型参数的数值,格式为 {'phi_1': 0.5, 'theta_1': 0.3, ...}
286 如果为None,将使用默认值
288 Returns:
289 频率响应数据
290 """
291 import math
292 transfer_func = self.derive_transfer_function(model)
294 # 如果没有提供参数值,使用默认值
295 if param_values is None:
296 param_values = self._get_default_params(model)
298 magnitudes = []
299 phases = []
301 for omega in frequencies:
302 # 计算 e^{-iω}
303 z = complex(math.cos(omega), -math.sin(omega))
305 try:
306 response = self._evaluate_with_params(transfer_func, z, param_values)
307 magnitudes.append(abs(response))
308 phases.append(math.atan2(response.imag, response.real))
309 except (ValueError, TypeError):
310 magnitudes.append(float('inf'))
311 phases.append(0)
313 return {
314 "frequencies": frequencies,
315 "magnitudes": magnitudes,
316 "phases": phases,
317 "magnitude_db": [20 * math.log10(mag) if mag > 0 and math.isfinite(mag) else -float('inf')
318 for mag in magnitudes]
319 }
321 def _get_default_params(self, model) -> dict:
322 """获取模型的默认参数值"""
323 params = {}
325 # AR参数:使用稳定值
326 if hasattr(model, 'ar_params') and model.ar_params:
327 for i, param in enumerate(model.ar_params):
328 if isinstance(param, str):
329 # 为了稳定性,AR参数应该较小
330 params[param] = 0.1 * (i + 1)
332 # MA参数:使用适中值
333 if hasattr(model, 'ma_params') and model.ma_params:
334 for i, param in enumerate(model.ma_params):
335 if isinstance(param, str):
336 params[param] = 0.2 * (i + 1)
338 # 季节性参数
339 if hasattr(model, 'seasonal_ar_params') and model.seasonal_ar_params:
340 for i, param in enumerate(model.seasonal_ar_params):
341 if isinstance(param, str):
342 params[param] = 0.05 * (i + 1)
344 if hasattr(model, 'seasonal_ma_params') and model.seasonal_ma_params:
345 for i, param in enumerate(model.seasonal_ma_params):
346 if isinstance(param, str):
347 params[param] = 0.1 * (i + 1)
349 return params
351 def _evaluate_with_params(self, transfer_func: TransferFunction,
352 frequency: complex, param_values: dict) -> complex:
353 """使用给定参数值计算传递函数在特定频率的值"""
354 # 获取分子和分母表达式
355 num_expr = transfer_func.numerator.as_expr()
356 den_expr = transfer_func.denominator.as_expr()
358 # 替换滞后算子
359 num_expr = num_expr.subs(transfer_func.lag_operator, frequency)
360 den_expr = den_expr.subs(transfer_func.lag_operator, frequency)
362 # 替换参数
363 for param_name, param_value in param_values.items():
364 num_expr = num_expr.subs(symbols(param_name), param_value)
365 den_expr = den_expr.subs(symbols(param_name), param_value)
367 # 计算数值
368 num_val = complex(num_expr.evalf())
369 den_val = complex(den_expr.evalf())
371 if abs(den_val) < 1e-12:
372 raise ValueError(f"分母在频率{frequency}处为零")
374 return num_val / den_val