Coverage for maze_dataset/dataset/configs.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1"`MAZE_DATASET_CONFIGS` contains some default configs for tests and demos" 

2 

3import copy 

4from typing import Iterator, Mapping 

5 

6from maze_dataset.dataset.maze_dataset import MazeDatasetConfig 

7from maze_dataset.generation.generators import LatticeMazeGenerators 

8 

9_MAZE_DATASET_CONFIGS_SRC: dict[str, MazeDatasetConfig] = { 

10 cfg.to_fname(): cfg 

11 for cfg in [ 

12 MazeDatasetConfig( 

13 name="test", 

14 grid_n=3, 

15 n_mazes=5, 

16 maze_ctor=LatticeMazeGenerators.gen_dfs, 

17 ), 

18 MazeDatasetConfig( 

19 name="test-perc", 

20 grid_n=3, 

21 n_mazes=5, 

22 maze_ctor=LatticeMazeGenerators.gen_dfs_percolation, 

23 maze_ctor_kwargs={"p": 0.7}, 

24 ), 

25 MazeDatasetConfig( 

26 name="demo_small", 

27 grid_n=3, 

28 n_mazes=100, 

29 maze_ctor=LatticeMazeGenerators.gen_dfs, 

30 ), 

31 MazeDatasetConfig( 

32 name="demo", 

33 grid_n=6, 

34 n_mazes=10000, 

35 maze_ctor=LatticeMazeGenerators.gen_dfs, 

36 ), 

37 ] 

38} 

39 

40 

41class _MazeDatsetConfigsWrapper(Mapping[str, MazeDatasetConfig]): 

42 "wrap the default configs in a read-only dict-like object" 

43 

44 def __init__(self, configs: dict[str, MazeDatasetConfig]) -> None: 

45 "initialize with a dict of configs" 

46 self._configs = configs 

47 

48 def __getitem__(self, item: str) -> MazeDatasetConfig: 

49 return self._configs[item] 

50 

51 def __len__(self) -> int: 

52 return len(self._configs) 

53 

54 def __iter__(self) -> Iterator: 

55 "iterate over the keys" 

56 return iter(self._configs) 

57 

58 # TYPING: error: Return type "list[str]" of "keys" incompatible with return type "KeysView[str]" in supertype "Mapping" [override] 

59 def keys(self) -> list[str]: # type: ignore[override] 

60 "return the keys" 

61 return list(self._configs.keys()) 

62 

63 # TYPING: error: Return type "list[tuple[str, MazeDatasetConfig]]" of "items" incompatible with return type "ItemsView[str, MazeDatasetConfig]" in supertype "Mapping" [override] 

64 def items(self) -> list[tuple[str, MazeDatasetConfig]]: # type: ignore[override] 

65 "return the items" 

66 return [(k, copy.deepcopy(v)) for k, v in self._configs.items()] 

67 

68 # TYPING: error: Return type "list[MazeDatasetConfig]" of "values" incompatible with return type "ValuesView[MazeDatasetConfig]" in supertype "Mapping" [override] 

69 def values(self) -> list[MazeDatasetConfig]: # type: ignore[override] 

70 return [copy.deepcopy(v) for v in self._configs.values()] 

71 

72 

73MAZE_DATASET_CONFIGS: _MazeDatsetConfigsWrapper = _MazeDatsetConfigsWrapper( 

74 _MAZE_DATASET_CONFIGS_SRC, 

75)