Coverage for src\time_series_analyzer\models.py: 94%
157 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-23 11:57 +0800
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-23 11:57 +0800
1"""
2时间序列模型的核心类定义
4包含ARIMA模型和季节性ARIMA模型的参数化表示、验证和基础数学结构。
5"""
7from typing import Optional, Tuple, List, Dict, Any, Union
8from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict
9import numpy as np
10from sympy import symbols, Poly, Symbol
13class ARIMAModel(BaseModel):
14 """
15 ARIMA(p, d, q)模型的参数化表示
17 Attributes:
18 p: 自回归阶数 (AR order)
19 d: 差分阶数 (Integration order)
20 q: 移动平均阶数 (MA order)
21 ar_params: 自回归参数 [φ₁, φ₂, ..., φₚ]
22 ma_params: 移动平均参数 [θ₁, θ₂, ..., θₑ]
23 constant: 常数项
24 name: 模型名称
25 """
27 p: int = Field(ge=0, description="自回归阶数")
28 d: int = Field(ge=0, description="差分阶数")
29 q: int = Field(ge=0, description="移动平均阶数")
31 ar_params: Optional[List[Union[float, str]]] = Field(
32 default=None,
33 description="自回归参数,长度应等于p"
34 )
35 ma_params: Optional[List[Union[float, str]]] = Field(
36 default=None,
37 description="移动平均参数,长度应等于q"
38 )
39 constant: float = Field(default=0.0, description="常数项")
40 name: Optional[str] = Field(default=None, description="模型名称")
42 model_config = ConfigDict(
43 validate_assignment=True,
44 extra="forbid"
45 )
47 @field_validator('ar_params')
48 @classmethod
49 def validate_ar_params(cls, v, info):
50 """验证自回归参数"""
51 if v is not None and info.data:
52 p = info.data.get('p', 0)
53 if len(v) != p:
54 raise ValueError(f"自回归参数长度({len(v)})必须等于p({p})")
55 return v
57 @field_validator('ma_params')
58 @classmethod
59 def validate_ma_params(cls, v, info):
60 """验证移动平均参数"""
61 if v is not None and info.data:
62 q = info.data.get('q', 0)
63 if len(v) != q:
64 raise ValueError(f"移动平均参数长度({len(v)})必须等于q({q})")
65 return v
67 @model_validator(mode='before')
68 @classmethod
69 def validate_model(cls, values):
70 """模型整体验证"""
71 if isinstance(values, dict):
72 p, d, q = values.get('p', 0), values.get('d', 0), values.get('q', 0)
74 # 只对纯ARIMA模型检查,SARIMA模型有自己的验证逻辑
75 if cls.__name__ == 'ARIMAModel' and p == 0 and d == 0 and q == 0:
76 raise ValueError("p, d, q不能全为0")
78 # 如果没有提供参数,生成默认的符号参数
79 if values.get('ar_params') is None and p > 0:
80 values['ar_params'] = [f"phi_{i+1}" for i in range(p)]
82 if values.get('ma_params') is None and q > 0:
83 values['ma_params'] = [f"theta_{i+1}" for i in range(q)]
85 # 生成默认名称
86 if values.get('name') is None:
87 values['name'] = f"ARIMA({p},{d},{q})"
89 return values
91 def get_ar_polynomial(self, lag_operator: Symbol = None) -> Poly:
92 """
93 获取自回归多项式 φ(B) = 1 - φ₁B - φ₂B² - ... - φₚBᵖ
95 Args:
96 lag_operator: 滞后算子符号,默认为B
98 Returns:
99 自回归多项式
100 """
101 if lag_operator is None:
102 lag_operator = symbols('B')
104 if self.p == 0:
105 return Poly(1, lag_operator)
107 # 构建多项式系数,注意Poly的系数顺序是从高次项到低次项
108 # φ(B) = 1 - φ₁B - φ₂B² - ... - φₚBᵖ
109 coeffs = [-float(param) if isinstance(param, (int, float))
110 else -symbols(str(param)) for param in self.ar_params[::-1]] + [1]
112 return Poly(coeffs, lag_operator)
114 def get_ma_polynomial(self, lag_operator: Symbol = None) -> Poly:
115 """
116 获取移动平均多项式 θ(B) = 1 + θ₁B + θ₂B² + ... + θₑBᵠ
118 Args:
119 lag_operator: 滞后算子符号,默认为B
121 Returns:
122 移动平均多项式
123 """
124 if lag_operator is None:
125 lag_operator = symbols('B')
127 if self.q == 0:
128 return Poly(1, lag_operator)
130 # 构建多项式系数,注意Poly的系数顺序是从高次项到低次项
131 # θ(B) = 1 + θ₁B + θ₂B² + ... + θₑBᵠ
132 coeffs = [float(param) if isinstance(param, (int, float))
133 else symbols(str(param)) for param in self.ma_params[::-1]] + [1]
135 return Poly(coeffs, lag_operator)
137 def get_difference_polynomial(self, lag_operator: Symbol = None) -> Poly:
138 """
139 获取差分多项式 (1-B)ᵈ
141 Args:
142 lag_operator: 滞后算子符号,默认为B
144 Returns:
145 差分多项式
146 """
147 if lag_operator is None:
148 lag_operator = symbols('B')
150 if self.d == 0:
151 return Poly(1, lag_operator)
153 # (1-B)^d
154 base_poly = Poly([1, -1], lag_operator) # 1 - B
155 result = base_poly
157 for _ in range(self.d - 1):
158 result = result * base_poly
160 return result
162 def to_dict(self) -> Dict[str, Any]:
163 """转换为字典格式"""
164 return {
165 "model_type": "ARIMA",
166 "parameters": {
167 "p": self.p,
168 "d": self.d,
169 "q": self.q
170 },
171 "ar_params": self.ar_params,
172 "ma_params": self.ma_params,
173 "constant": self.constant,
174 "name": self.name
175 }
177 def __str__(self) -> str:
178 """字符串表示"""
179 return f"{self.name}: AR({self.p}), I({self.d}), MA({self.q})"
182class SeasonalARIMAModel(ARIMAModel):
183 """
184 季节性ARIMA模型 SARIMA(p,d,q)(P,D,Q,m)
186 继承自ARIMAModel,增加季节性参数
187 """
189 # 季节性参数
190 P: int = Field(ge=0, description="季节性自回归阶数")
191 D: int = Field(ge=0, description="季节性差分阶数")
192 Q: int = Field(ge=0, description="季节性移动平均阶数")
193 m: int = Field(gt=0, description="季节性周期")
195 # 季节性系数
196 seasonal_ar_params: Optional[List[Union[float, str]]] = Field(
197 default=None,
198 description="季节性自回归参数"
199 )
200 seasonal_ma_params: Optional[List[Union[float, str]]] = Field(
201 default=None,
202 description="季节性移动平均参数"
203 )
205 @field_validator('seasonal_ar_params')
206 @classmethod
207 def validate_seasonal_ar_params(cls, v, info):
208 """验证季节性自回归参数"""
209 if v is not None and info.data:
210 P = info.data.get('P', 0)
211 if len(v) != P:
212 raise ValueError(f"季节性自回归参数长度({len(v)})必须等于P({P})")
213 return v
215 @field_validator('seasonal_ma_params')
216 @classmethod
217 def validate_seasonal_ma_params(cls, v, info):
218 """验证季节性移动平均参数"""
219 if v is not None and info.data:
220 Q = info.data.get('Q', 0)
221 if len(v) != Q:
222 raise ValueError(f"季节性移动平均参数长度({len(v)})必须等于Q({Q})")
223 return v
225 @model_validator(mode='before')
226 @classmethod
227 def validate_seasonal_model(cls, values):
228 """季节性模型验证"""
229 if isinstance(values, dict):
230 p, d, q = values.get('p', 0), values.get('d', 0), values.get('q', 0)
231 P, D, Q, m = values.get('P', 0), values.get('D', 0), values.get('Q', 0), values.get('m', 1)
233 # 对于SARIMA模型,允许非季节性部分全为0,但至少要有一个季节性参数不为0
234 if p == 0 and d == 0 and q == 0 and P == 0 and D == 0 and Q == 0:
235 raise ValueError("SARIMA模型的所有参数不能全为0")
237 # 生成默认的非季节性参数
238 if values.get('ar_params') is None and p > 0:
239 values['ar_params'] = [f"phi_{i+1}" for i in range(p)]
241 if values.get('ma_params') is None and q > 0:
242 values['ma_params'] = [f"theta_{i+1}" for i in range(q)]
244 # 生成默认季节性参数
245 if values.get('seasonal_ar_params') is None and P > 0:
246 values['seasonal_ar_params'] = [f"Phi_{i+1}" for i in range(P)]
248 if values.get('seasonal_ma_params') is None and Q > 0:
249 values['seasonal_ma_params'] = [f"Theta_{i+1}" for i in range(Q)]
251 # 更新模型名称
252 values['name'] = f"SARIMA({p},{d},{q})({P},{D},{Q},{m})"
254 return values
256 def get_seasonal_ar_polynomial(self, lag_operator: Symbol = None) -> Poly:
257 """获取季节性自回归多项式"""
258 if lag_operator is None:
259 lag_operator = symbols('B')
261 if self.P == 0:
262 return Poly(1, lag_operator)
264 # Φ(B^m) = 1 - Φ₁B^m - Φ₂B^(2m) - ... - ΦₚB^(Pm)
265 coeffs = {}
266 coeffs[0] = 1 # 常数项
268 for i, param in enumerate(self.seasonal_ar_params):
269 power = (i + 1) * self.m
270 coeff = -float(param) if isinstance(param, (int, float)) else -symbols(str(param))
271 coeffs[power] = coeff
273 # 创建多项式
274 max_power = max(coeffs.keys()) if coeffs else 0
275 poly_coeffs = [coeffs.get(i, 0) for i in range(max_power + 1)]
277 return Poly(poly_coeffs, lag_operator)
279 def get_seasonal_ma_polynomial(self, lag_operator: Symbol = None) -> Poly:
280 """获取季节性移动平均多项式"""
281 if lag_operator is None:
282 lag_operator = symbols('B')
284 if self.Q == 0:
285 return Poly(1, lag_operator)
287 # Θ(B^m) = 1 + Θ₁B^m + Θ₂B^(2m) + ... + ΘₑB^(Qm)
288 coeffs = {}
289 coeffs[0] = 1 # 常数项
291 for i, param in enumerate(self.seasonal_ma_params):
292 power = (i + 1) * self.m
293 coeff = float(param) if isinstance(param, (int, float)) else symbols(str(param))
294 coeffs[power] = coeff
296 # 创建多项式
297 max_power = max(coeffs.keys()) if coeffs else 0
298 poly_coeffs = [coeffs.get(i, 0) for i in range(max_power + 1)]
300 return Poly(poly_coeffs, lag_operator)
302 def get_seasonal_difference_polynomial(self, lag_operator: Symbol = None) -> Poly:
303 """获取季节性差分多项式 (1-B^m)^D"""
304 if lag_operator is None:
305 lag_operator = symbols('B')
307 if self.D == 0:
308 return Poly(1, lag_operator)
310 # (1-B^m)^D
311 coeffs = [0] * (self.m + 1)
312 coeffs[0] = 1 # 1
313 coeffs[self.m] = -1 # -B^m
315 base_poly = Poly(coeffs, lag_operator)
316 result = base_poly
318 for _ in range(self.D - 1):
319 result = result * base_poly
321 return result
323 def to_dict(self) -> Dict[str, Any]:
324 """转换为字典格式"""
325 base_dict = super().to_dict()
326 base_dict.update({
327 "model_type": "SARIMA",
328 "seasonal_parameters": {
329 "P": self.P,
330 "D": self.D,
331 "Q": self.Q,
332 "m": self.m
333 },
334 "seasonal_ar_params": self.seasonal_ar_params,
335 "seasonal_ma_params": self.seasonal_ma_params
336 })
337 return base_dict