Coverage for maze_dataset/dataset/maze_dataset_config.py: 24%

118 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1"implements `MazeDatasetConfig` which is used to generate or load a dataset" 

2 

3import hashlib 

4import importlib.metadata 

5import json 

6import typing 

7import warnings 

8from typing import Callable 

9 

10import numpy as np 

11from jaxtyping import Float 

12from muutils.json_serialize import ( 

13 serializable_dataclass, 

14 serializable_field, 

15) 

16from muutils.json_serialize.util import ( 

17 safe_getsource, 

18 string_as_lines, 

19) 

20from muutils.misc import sanitize_fname, shorten_numerical_to_str 

21 

22from maze_dataset.constants import Coord, CoordTup 

23from maze_dataset.dataset.dataset import ( 

24 GPTDatasetConfig, 

25) 

26from maze_dataset.dataset.success_predict_math import cfg_success_predict_fn 

27from maze_dataset.generation.generators import _GENERATORS_PERCOLATED, GENERATORS_MAP 

28 

29SERIALIZE_MINIMAL_THRESHOLD: int | None = 100 

30"""If `n_mazes>=SERIALIZE_MINIMAL_THRESHOLD`, then the MazeDataset will use `serialize_minimal`. 

31Setting to None means that `serialize_minimal` will never be used. 

32Set to -1 to make calls to `read` use `MazeDataset._load_legacy`. Used for profiling only.""" 

33 

34MAZEDATASETCONFIG_FNAME_HASH_LENGTH: int = 5 

35"length of the has, in characters, of the hash in the fname of a `MazeDatasetConfig`" 

36 

37_PercolationSuccessArray = Float[ 

38 np.ndarray, 

39 "p/grid_n/deadends/endpoints_not_equal/generator_func=5", 

40] 

41 

42 

43class NoPercolationInConfigError(ValueError): 

44 """raised when trying to predict the success fraction of a config that doesn't have percolation""" 

45 

46 pass 

47 

48 

49class SuccessChanceTooSmallError(ValueError): 

50 """raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`""" 

51 

52 pass 

53 

54 

55def set_serialize_minimal_threshold(threshold: int | None) -> None: 

56 "get the global SERIALIZE_MINIMAL_THRESHOLD" 

57 global SERIALIZE_MINIMAL_THRESHOLD # noqa: PLW0603 

58 SERIALIZE_MINIMAL_THRESHOLD = threshold 

59 

60 

61def _load_maze_ctor(maze_ctor_serialized: str | dict) -> Callable: 

62 "get the maze constructor from `GENERATORS_MAP`" 

63 if isinstance(maze_ctor_serialized, dict): 

64 # this is both the new and old version of the serialization 

65 return GENERATORS_MAP[maze_ctor_serialized["__name__"]] 

66 elif isinstance(maze_ctor_serialized, str): 

67 # this is a version I switched to for a while but now we are switching back 

68 warnings.warn( 

69 "you are loading an old model/config in `_load_maze_ctor()`!!! this should not be happening, please report: " 

70 "https://github.com/understanding-search/maze-dataset/issues/new", 

71 ) 

72 return GENERATORS_MAP[maze_ctor_serialized] 

73 else: 

74 err_msg: str = f"maze_ctor_serialized is of type {type(maze_ctor_serialized) = }, expected str or dict\n{maze_ctor_serialized = }" 

75 raise TypeError(err_msg) 

76 

77 

78EndpointKwargsType = dict[ 

79 typing.Literal[ 

80 "allowed_start", 

81 "allowed_end", 

82 "deadend_start", 

83 "deadend_end", 

84 "endpoints_not_equal", 

85 "except_on_no_valid_endpoint", 

86 ], 

87 bool | None | list[tuple[int, int]], 

88] 

89"type hint for `MazeDatasetConfig.endpoint_kwargs`" 

90 

91 

92def _load_endpoint_kwargs(data: dict) -> EndpointKwargsType: 

93 if data.get("endpoint_kwargs") is None: 

94 return dict() 

95 

96 else: 

97 return { 

98 k: ( 

99 # bools and Nones are fine 

100 v 

101 if (isinstance(v, bool) or v is None) 

102 # assume its a CoordList 

103 else [tuple(x) for x in v] # muutils/zanj saves tuples as lists 

104 ) 

105 for k, v in data["endpoint_kwargs"].items() 

106 } 

107 

108 

109@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"]) 

110class _MazeDatasetConfig_base(GPTDatasetConfig): # noqa: N801 

111 """base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here""" 

112 

113 # NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only 

114 

115 grid_n: int = serializable_field() # type: ignore[misc] 

116 

117 # not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters 

118 n_mazes: int = serializable_field(compare=False) # type: ignore[misc] 

119 

120 maze_ctor: Callable = serializable_field( 

121 default=GENERATORS_MAP["gen_dfs"], 

122 serialization_fn=lambda gen_func: { 

123 "__name__": gen_func.__name__, 

124 "__module__": gen_func.__module__, 

125 # NOTE: this was causing hashing issues on 3.13 vs older versions because somehow, 

126 # the `__doc__` variable is different across versions??????? WHY???????? IT TREATS WHITESPACE DIFFERENTLY 

127 # so we just uh. strip it all now. 

128 # see: 

129 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080746?pr=53 

130 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080742?pr=53 

131 # https://www.diffchecker.com/tqIMSevy/ 

132 # update: we also need to filter for empty lines. B) 

133 "__doc__": [ 

134 line.strip() 

135 for line in string_as_lines(gen_func.__doc__) 

136 if line.strip() 

137 ], 

138 "source_code": safe_getsource(gen_func), 

139 }, 

140 loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]), 

141 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 

142 ) 

143 

144 maze_ctor_kwargs: dict = serializable_field( 

145 default_factory=dict, 

146 serialization_fn=lambda kwargs: kwargs, 

147 loading_fn=lambda data: ( 

148 dict() 

149 if data.get("maze_ctor_kwargs", None) 

150 is None # this should handle the backwards compatibility 

151 else data["maze_ctor_kwargs"] 

152 ), 

153 ) 

154 

155 endpoint_kwargs: EndpointKwargsType = serializable_field( 

156 default_factory=dict, 

157 serialization_fn=lambda kwargs: kwargs, 

158 loading_fn=_load_endpoint_kwargs, 

159 assert_type=False, 

160 ) 

161 

162 # NOTE: this part is very hacky. the way muutils works is that it iterates over the *keys in the serialized data*, 

163 # and so we need to save an `None` here or this wont load the `fname` field on load 

164 # this is a total mess, and very confusing, and entirely my fault 

165 _fname_loaded: str | None = serializable_field( 

166 default=None, 

167 compare=False, 

168 serialization_fn=lambda _: None, 

169 loading_fn=lambda data: data.get("fname", None), 

170 ) 

171 

172 @property 

173 def grid_shape(self) -> CoordTup: 

174 """return the shape of the grid as a tuple""" 

175 return (self.grid_n, self.grid_n) 

176 

177 @property 

178 def grid_shape_np(self) -> Coord: 

179 """return the shape of the grid as a numpy array""" 

180 return np.array(self.grid_shape) 

181 

182 @property 

183 def max_grid_n(self) -> int: 

184 """return the maximum of the grid shape""" 

185 return max(self.grid_shape) 

186 

187 def _serialize_base( 

188 self, applied_filters__skip__collect_generation_meta: bool = True 

189 ) -> dict: 

190 """serialize the base config for user in `stable_hash_cfg()` and `to_fname()` 

191 

192 - note that the _fname_loaded will always be `None` to avoid infinite recursion 

193 - note that we **do not** by default include information about metadata collection here, 

194 since otherwise loading a dataset that we minified by collecting the metadata would be impossible 

195 but for comparing things, we do store it when serializing properly by setting 

196 `applied_filters__skip__collect_generation_meta=False` 

197 """ 

198 serialized: dict = _MazeDatasetConfig_base.serialize(self) 

199 if applied_filters__skip__collect_generation_meta: 

200 serialized["applied_filters"] = [ 

201 x 

202 for x in serialized["applied_filters"] 

203 if x.get("name", None) != "collect_generation_meta" 

204 ] 

205 return serialized 

206 

207 def _stable_str_dump(self) -> str: 

208 return json.dumps( 

209 self._serialize_base(), 

210 sort_keys=True, 

211 indent=None, 

212 ) 

213 

214 def stable_hash_cfg(self) -> int: 

215 """return a stable hash of the config""" 

216 return int.from_bytes( 

217 hashlib.md5( # noqa: S324 

218 bytes(self._stable_str_dump(), "ascii") 

219 ).digest(), 

220 "big", 

221 ) 

222 

223 def to_fname(self) -> str: 

224 """return a unique identifier (valid as a filename) for this config""" 

225 n_mazes_str: str = shorten_numerical_to_str(self.n_mazes) 

226 maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_") 

227 hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH 

228 return sanitize_fname( 

229 f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}", 

230 ) 

231 

232 

233# NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only 

234@serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) 

235class MazeDatasetConfig(_MazeDatasetConfig_base): # type: ignore[misc] 

236 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset""" 

237 

238 @property 

239 def config_version(self) -> str: 

240 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 

241 return "1.0" 

242 

243 @property 

244 def versions(self) -> dict: 

245 """return the versions of the config and the maze_dataset""" 

246 return dict( 

247 config=self.config_version, 

248 maze_dataset=importlib.metadata.version("maze_dataset"), 

249 ) 

250 

251 def serialize(self) -> dict: 

252 "serialize the MazeDatasetConfig with all fields and fname" 

253 return { 

254 **self._serialize_base( 

255 applied_filters__skip__collect_generation_meta=False 

256 ), 

257 "fname": self.to_fname(), 

258 "versions": self.versions, 

259 } 

260 

261 def summary(self) -> dict: 

262 """return a summary of the config""" 

263 # do we run this to make sure it doesn't error? 

264 super_summary: dict = super().summary() 

265 assert super_summary 

266 self_ser: dict = self.serialize() 

267 return dict( 

268 name=self.name, 

269 fname=self.to_fname(), 

270 sdc_hash=self.stable_hash_cfg(), 

271 seed=self.seed, 

272 seq_len_min=self.seq_len_min, 

273 seq_len_max=self.seq_len_max, 

274 applied_filters=self.applied_filters, 

275 grid_n=self_ser["grid_n"], 

276 n_mazes=self_ser["n_mazes"], 

277 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 

278 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 

279 endpoint_kwargs=self_ser["endpoint_kwargs"], 

280 ) 

281 

282 def _to_ps_array(self) -> _PercolationSuccessArray: 

283 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. 

284 

285 used in predicting the success rate 

286 """ 

287 try: 

288 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( 

289 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" 

290 ) 

291 assert "p" in self.maze_ctor_kwargs, ( 

292 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" 

293 ) 

294 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( 

295 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" 

296 ) 

297 except AssertionError as e: 

298 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" 

299 raise NoPercolationInConfigError( 

300 err_msg, 

301 ) from e 

302 

303 endpoints_unique_flag: int = int( 

304 # we are pretty sure it will be an int or bool here 

305 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] 

306 ) 

307 

308 # adjustment for bknutson0 

309 if not ( 

310 self.endpoint_kwargs.get("deadend_start", False) 

311 and self.endpoint_kwargs.get("deadend_end", False) 

312 ): 

313 # we didnt train on this, but if either endpoint is not required to be in a dead end 

314 # then requiring the endpoints to be unique does not really affect the success rate 

315 # (except for very small percolation values, pure percolation generation) 

316 endpoints_unique_flag = 0 

317 

318 return np.array( 

319 [ 

320 float(self.maze_ctor_kwargs["p"]), 

321 float(self.grid_n), 

322 float( 

323 int( 

324 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] 

325 or self.endpoint_kwargs.get("deadend_end", False), 

326 ), 

327 ), 

328 float(endpoints_unique_flag), 

329 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), 

330 ], 

331 dtype=np.float64, 

332 ) 

333 

334 @classmethod 

335 def _from_ps_array( 

336 cls, 

337 arr: _PercolationSuccessArray, 

338 name: str = "predict", 

339 n_mazes: int = 100, 

340 **kwargs, 

341 ) -> "MazeDatasetConfig": 

342 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. 

343 

344 # Returns: 

345 - `MazeDatasetConfig` 

346 Config corresponding to `arr` 

347 """ 

348 return cls( 

349 name=name, 

350 grid_n=int(arr[1]), 

351 n_mazes=n_mazes, 

352 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], 

353 maze_ctor_kwargs={"p": float(arr[0])}, 

354 endpoint_kwargs=dict( 

355 deadend_start=bool(arr[2]), 

356 deadend_end=bool(arr[2]), 

357 endpoints_not_equal=bool(arr[3]), 

358 except_on_no_valid_endpoint=False, 

359 ), 

360 **kwargs, 

361 ) 

362 

363 def success_fraction_estimate( 

364 self, 

365 except_if_all_success_expected: bool = False, 

366 ) -> float: 

367 """Estimate the success fraction of this config. 

368 

369 only valid when the generator is a percolation generator, 

370 and endpoints are enforced to be dead ends 

371 

372 this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit` 

373 

374 # Parameters: 

375 - `except_if_all_success_expected : bool` 

376 if `True`, don't raise an error if the success fraction is below the threshold. 

377 will always return `1.0` if the config is not expected to fail 

378 

379 # Returns: 

380 - `float` 

381 estimated success fraction 

382 

383 # Raises: 

384 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 

385 """ 

386 try: 

387 return cfg_success_predict_fn(self) 

388 

389 except NoPercolationInConfigError as e: 

390 if except_if_all_success_expected: 

391 return 1.0 

392 else: 

393 raise e # noqa: TRY201 

394 

395 def success_fraction_compensate( 

396 self, 

397 safety_margin: float = 1.2, 

398 except_if_all_success_expected: bool = False, 

399 epsilon: float = 1e-2, 

400 ) -> "MazeDatasetConfig": 

401 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 

402 

403 # Parameters: 

404 - `safety_margin : float` 

405 safety margin to apply to the success fraction estimate 

406 (defaults to `1.2`, or 20% more mazes than estimated) 

407 - `except_if_all_success_expected : bool` 

408 if `True`, don't raise an error if the success fraction is below the threshold. 

409 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 

410 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 

411 since `safety_margin` is still applied. 

412 (defaults to `False`) 

413 - `epsilon : float` 

414 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 

415 (defaults to `1e-2`) 

416 

417 # Returns: 

418 - `MazeDatasetConfig` 

419 new config with adjusted `n_mazes` 

420 

421 # Raises: 

422 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 

423 """ 

424 # compute and check the success fraction 

425 success_fraction: float = self.success_fraction_estimate( 

426 except_if_all_success_expected=except_if_all_success_expected, 

427 ) 

428 if success_fraction < epsilon: 

429 err_msg: str = ( 

430 f"{success_fraction = } is below the threshold of {epsilon = }" 

431 ) 

432 raise SuccessChanceTooSmallError( 

433 err_msg, 

434 ) 

435 

436 # compute the new number of mazes 

437 n_mazes: int = self.n_mazes 

438 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 

439 

440 # put it in a new config and return 

441 cfg_dict: dict = self.serialize() 

442 cfg_dict["n_mazes"] = new_n_mazes 

443 return MazeDatasetConfig.load(cfg_dict)