Coverage for src\time_series_analyzer\models.py: 94%

157 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-23 11:57 +0800

1""" 

2时间序列模型的核心类定义 

3 

4包含ARIMA模型和季节性ARIMA模型的参数化表示、验证和基础数学结构。 

5""" 

6 

7from typing import Optional, Tuple, List, Dict, Any, Union 

8from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict 

9import numpy as np 

10from sympy import symbols, Poly, Symbol 

11 

12 

13class ARIMAModel(BaseModel): 

14 """ 

15 ARIMA(p, d, q)模型的参数化表示 

16  

17 Attributes: 

18 p: 自回归阶数 (AR order) 

19 d: 差分阶数 (Integration order)  

20 q: 移动平均阶数 (MA order) 

21 ar_params: 自回归参数 [φ₁, φ₂, ..., φₚ] 

22 ma_params: 移动平均参数 [θ₁, θ₂, ..., θₑ] 

23 constant: 常数项 

24 name: 模型名称 

25 """ 

26 

27 p: int = Field(ge=0, description="自回归阶数") 

28 d: int = Field(ge=0, description="差分阶数") 

29 q: int = Field(ge=0, description="移动平均阶数") 

30 

31 ar_params: Optional[List[Union[float, str]]] = Field( 

32 default=None, 

33 description="自回归参数,长度应等于p" 

34 ) 

35 ma_params: Optional[List[Union[float, str]]] = Field( 

36 default=None, 

37 description="移动平均参数,长度应等于q" 

38 ) 

39 constant: float = Field(default=0.0, description="常数项") 

40 name: Optional[str] = Field(default=None, description="模型名称") 

41 

42 model_config = ConfigDict( 

43 validate_assignment=True, 

44 extra="forbid" 

45 ) 

46 

47 @field_validator('ar_params') 

48 @classmethod 

49 def validate_ar_params(cls, v, info): 

50 """验证自回归参数""" 

51 if v is not None and info.data: 

52 p = info.data.get('p', 0) 

53 if len(v) != p: 

54 raise ValueError(f"自回归参数长度({len(v)})必须等于p({p})") 

55 return v 

56 

57 @field_validator('ma_params') 

58 @classmethod 

59 def validate_ma_params(cls, v, info): 

60 """验证移动平均参数""" 

61 if v is not None and info.data: 

62 q = info.data.get('q', 0) 

63 if len(v) != q: 

64 raise ValueError(f"移动平均参数长度({len(v)})必须等于q({q})") 

65 return v 

66 

67 @model_validator(mode='before') 

68 @classmethod 

69 def validate_model(cls, values): 

70 """模型整体验证""" 

71 if isinstance(values, dict): 

72 p, d, q = values.get('p', 0), values.get('d', 0), values.get('q', 0) 

73 

74 # 只对纯ARIMA模型检查,SARIMA模型有自己的验证逻辑 

75 if cls.__name__ == 'ARIMAModel' and p == 0 and d == 0 and q == 0: 

76 raise ValueError("p, d, q不能全为0") 

77 

78 # 如果没有提供参数,生成默认的符号参数 

79 if values.get('ar_params') is None and p > 0: 

80 values['ar_params'] = [f"phi_{i+1}" for i in range(p)] 

81 

82 if values.get('ma_params') is None and q > 0: 

83 values['ma_params'] = [f"theta_{i+1}" for i in range(q)] 

84 

85 # 生成默认名称 

86 if values.get('name') is None: 

87 values['name'] = f"ARIMA({p},{d},{q})" 

88 

89 return values 

90 

91 def get_ar_polynomial(self, lag_operator: Symbol = None) -> Poly: 

92 """ 

93 获取自回归多项式 φ(B) = 1 - φ₁B - φ₂B² - ... - φₚBᵖ 

94  

95 Args: 

96 lag_operator: 滞后算子符号,默认为B 

97  

98 Returns: 

99 自回归多项式 

100 """ 

101 if lag_operator is None: 

102 lag_operator = symbols('B') 

103 

104 if self.p == 0: 

105 return Poly(1, lag_operator) 

106 

107 # 构建多项式系数,注意Poly的系数顺序是从高次项到低次项 

108 # φ(B) = 1 - φ₁B - φ₂B² - ... - φₚBᵖ 

109 coeffs = [-float(param) if isinstance(param, (int, float)) 

110 else -symbols(str(param)) for param in self.ar_params[::-1]] + [1] 

111 

112 return Poly(coeffs, lag_operator) 

113 

114 def get_ma_polynomial(self, lag_operator: Symbol = None) -> Poly: 

115 """ 

116 获取移动平均多项式 θ(B) = 1 + θ₁B + θ₂B² + ... + θₑBᵠ 

117  

118 Args: 

119 lag_operator: 滞后算子符号,默认为B 

120  

121 Returns: 

122 移动平均多项式 

123 """ 

124 if lag_operator is None: 

125 lag_operator = symbols('B') 

126 

127 if self.q == 0: 

128 return Poly(1, lag_operator) 

129 

130 # 构建多项式系数,注意Poly的系数顺序是从高次项到低次项 

131 # θ(B) = 1 + θ₁B + θ₂B² + ... + θₑBᵠ 

132 coeffs = [float(param) if isinstance(param, (int, float)) 

133 else symbols(str(param)) for param in self.ma_params[::-1]] + [1] 

134 

135 return Poly(coeffs, lag_operator) 

136 

137 def get_difference_polynomial(self, lag_operator: Symbol = None) -> Poly: 

138 """ 

139 获取差分多项式 (1-B)ᵈ 

140  

141 Args: 

142 lag_operator: 滞后算子符号,默认为B 

143  

144 Returns: 

145 差分多项式 

146 """ 

147 if lag_operator is None: 

148 lag_operator = symbols('B') 

149 

150 if self.d == 0: 

151 return Poly(1, lag_operator) 

152 

153 # (1-B)^d 

154 base_poly = Poly([1, -1], lag_operator) # 1 - B 

155 result = base_poly 

156 

157 for _ in range(self.d - 1): 

158 result = result * base_poly 

159 

160 return result 

161 

162 def to_dict(self) -> Dict[str, Any]: 

163 """转换为字典格式""" 

164 return { 

165 "model_type": "ARIMA", 

166 "parameters": { 

167 "p": self.p, 

168 "d": self.d, 

169 "q": self.q 

170 }, 

171 "ar_params": self.ar_params, 

172 "ma_params": self.ma_params, 

173 "constant": self.constant, 

174 "name": self.name 

175 } 

176 

177 def __str__(self) -> str: 

178 """字符串表示""" 

179 return f"{self.name}: AR({self.p}), I({self.d}), MA({self.q})" 

180 

181 

182class SeasonalARIMAModel(ARIMAModel): 

183 """ 

184 季节性ARIMA模型 SARIMA(p,d,q)(P,D,Q,m) 

185  

186 继承自ARIMAModel,增加季节性参数 

187 """ 

188 

189 # 季节性参数 

190 P: int = Field(ge=0, description="季节性自回归阶数") 

191 D: int = Field(ge=0, description="季节性差分阶数") 

192 Q: int = Field(ge=0, description="季节性移动平均阶数") 

193 m: int = Field(gt=0, description="季节性周期") 

194 

195 # 季节性系数 

196 seasonal_ar_params: Optional[List[Union[float, str]]] = Field( 

197 default=None, 

198 description="季节性自回归参数" 

199 ) 

200 seasonal_ma_params: Optional[List[Union[float, str]]] = Field( 

201 default=None, 

202 description="季节性移动平均参数" 

203 ) 

204 

205 @field_validator('seasonal_ar_params') 

206 @classmethod 

207 def validate_seasonal_ar_params(cls, v, info): 

208 """验证季节性自回归参数""" 

209 if v is not None and info.data: 

210 P = info.data.get('P', 0) 

211 if len(v) != P: 

212 raise ValueError(f"季节性自回归参数长度({len(v)})必须等于P({P})") 

213 return v 

214 

215 @field_validator('seasonal_ma_params') 

216 @classmethod 

217 def validate_seasonal_ma_params(cls, v, info): 

218 """验证季节性移动平均参数""" 

219 if v is not None and info.data: 

220 Q = info.data.get('Q', 0) 

221 if len(v) != Q: 

222 raise ValueError(f"季节性移动平均参数长度({len(v)})必须等于Q({Q})") 

223 return v 

224 

225 @model_validator(mode='before') 

226 @classmethod 

227 def validate_seasonal_model(cls, values): 

228 """季节性模型验证""" 

229 if isinstance(values, dict): 

230 p, d, q = values.get('p', 0), values.get('d', 0), values.get('q', 0) 

231 P, D, Q, m = values.get('P', 0), values.get('D', 0), values.get('Q', 0), values.get('m', 1) 

232 

233 # 对于SARIMA模型,允许非季节性部分全为0,但至少要有一个季节性参数不为0 

234 if p == 0 and d == 0 and q == 0 and P == 0 and D == 0 and Q == 0: 

235 raise ValueError("SARIMA模型的所有参数不能全为0") 

236 

237 # 生成默认的非季节性参数 

238 if values.get('ar_params') is None and p > 0: 

239 values['ar_params'] = [f"phi_{i+1}" for i in range(p)] 

240 

241 if values.get('ma_params') is None and q > 0: 

242 values['ma_params'] = [f"theta_{i+1}" for i in range(q)] 

243 

244 # 生成默认季节性参数 

245 if values.get('seasonal_ar_params') is None and P > 0: 

246 values['seasonal_ar_params'] = [f"Phi_{i+1}" for i in range(P)] 

247 

248 if values.get('seasonal_ma_params') is None and Q > 0: 

249 values['seasonal_ma_params'] = [f"Theta_{i+1}" for i in range(Q)] 

250 

251 # 更新模型名称 

252 values['name'] = f"SARIMA({p},{d},{q})({P},{D},{Q},{m})" 

253 

254 return values 

255 

256 def get_seasonal_ar_polynomial(self, lag_operator: Symbol = None) -> Poly: 

257 """获取季节性自回归多项式""" 

258 if lag_operator is None: 

259 lag_operator = symbols('B') 

260 

261 if self.P == 0: 

262 return Poly(1, lag_operator) 

263 

264 # Φ(B^m) = 1 - Φ₁B^m - Φ₂B^(2m) - ... - ΦₚB^(Pm) 

265 coeffs = {} 

266 coeffs[0] = 1 # 常数项 

267 

268 for i, param in enumerate(self.seasonal_ar_params): 

269 power = (i + 1) * self.m 

270 coeff = -float(param) if isinstance(param, (int, float)) else -symbols(str(param)) 

271 coeffs[power] = coeff 

272 

273 # 创建多项式 

274 max_power = max(coeffs.keys()) if coeffs else 0 

275 poly_coeffs = [coeffs.get(i, 0) for i in range(max_power + 1)] 

276 

277 return Poly(poly_coeffs, lag_operator) 

278 

279 def get_seasonal_ma_polynomial(self, lag_operator: Symbol = None) -> Poly: 

280 """获取季节性移动平均多项式""" 

281 if lag_operator is None: 

282 lag_operator = symbols('B') 

283 

284 if self.Q == 0: 

285 return Poly(1, lag_operator) 

286 

287 # Θ(B^m) = 1 + Θ₁B^m + Θ₂B^(2m) + ... + ΘₑB^(Qm) 

288 coeffs = {} 

289 coeffs[0] = 1 # 常数项 

290 

291 for i, param in enumerate(self.seasonal_ma_params): 

292 power = (i + 1) * self.m 

293 coeff = float(param) if isinstance(param, (int, float)) else symbols(str(param)) 

294 coeffs[power] = coeff 

295 

296 # 创建多项式 

297 max_power = max(coeffs.keys()) if coeffs else 0 

298 poly_coeffs = [coeffs.get(i, 0) for i in range(max_power + 1)] 

299 

300 return Poly(poly_coeffs, lag_operator) 

301 

302 def get_seasonal_difference_polynomial(self, lag_operator: Symbol = None) -> Poly: 

303 """获取季节性差分多项式 (1-B^m)^D""" 

304 if lag_operator is None: 

305 lag_operator = symbols('B') 

306 

307 if self.D == 0: 

308 return Poly(1, lag_operator) 

309 

310 # (1-B^m)^D 

311 coeffs = [0] * (self.m + 1) 

312 coeffs[0] = 1 # 1 

313 coeffs[self.m] = -1 # -B^m 

314 

315 base_poly = Poly(coeffs, lag_operator) 

316 result = base_poly 

317 

318 for _ in range(self.D - 1): 

319 result = result * base_poly 

320 

321 return result 

322 

323 def to_dict(self) -> Dict[str, Any]: 

324 """转换为字典格式""" 

325 base_dict = super().to_dict() 

326 base_dict.update({ 

327 "model_type": "SARIMA", 

328 "seasonal_parameters": { 

329 "P": self.P, 

330 "D": self.D, 

331 "Q": self.Q, 

332 "m": self.m 

333 }, 

334 "seasonal_ar_params": self.seasonal_ar_params, 

335 "seasonal_ma_params": self.seasonal_ma_params 

336 }) 

337 return base_dict