In [1]:
from pathlib import Path
import numpy as np
# import pysr before torch to avoid
# UserWarning: torch was imported before juliacall. This may cause a segfault. To avoid this, import juliacall before importing torch. For updates, see https://github.com/pytorch/pytorch/issues/78829.
import pysr # noqa: F401
from zanj import ZANJ
from maze_dataset import LatticeMazeGenerators, MazeDataset, MazeDatasetConfig
from maze_dataset.benchmark.config_sweep import (
SweepResult,
dataset_success_fraction,
full_percolation_analysis,
plot_grouped,
)
from maze_dataset.benchmark.sweep_fit import sweep_fit
Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython
run a basic analysis¶
In [2]:
# Run the analysis
results: SweepResult = SweepResult.analyze(
configs=[
MazeDatasetConfig(
name=f"g{grid_n}-perc",
grid_n=grid_n,
n_mazes=32,
maze_ctor=LatticeMazeGenerators.gen_percolation,
maze_ctor_kwargs=dict(),
endpoint_kwargs=dict(
deadend_start=False,
deadend_end=False,
endpoints_not_equal=False,
except_on_no_valid_endpoint=False,
),
)
for grid_n in [2, 4, 6]
],
param_values=np.linspace(0.0, 1.0, 16).tolist(),
param_key="maze_ctor_kwargs.p",
analyze_func=dataset_success_fraction,
parallel=False,
)
# Plot results
results.plot(save_path=None, cfg_keys=["n_mazes", "endpoint_kwargs"])
tqdm_allowed_kwargs = {'write_bytes', 'file', 'desc', 'unit_divisor', 'leave', 'position', 'colour', 'nrows', 'mininterval', 'dynamic_ncols', 'bar_format', 'disable', 'smoothing', 'postfix', 'ascii', 'self', 'lock_args', 'initial', 'unit', 'delay', 'gui', 'total', 'maxinterval', 'miniters', 'iterable', 'unit_scale', 'ncols'} mapped_kwargs = {'total': 3, 'desc': 'Processing 3 items'}
Processing 3 items: 100%|██████████| 3/3 [00:00<00:00, 3.09it/s]
Out[2]:
<Axes: title={'center': "maze_ctor_kwargs.p vs dataset_success_fraction\nMazeDatasetConfig(n_mazes=32, endpoint_kwargs={'deadend_start': False, 'deadend_end': False, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False})"}, xlabel='maze_ctor_kwargs.p', ylabel='dataset_success_fraction'>
check saving/loading¶
In [3]:
path = Path("../tests/_temp/dataset_frac_sweep/results_small.zanj")
results.save(path)
ZANJ().read(path).plot(cfg_keys=["n_mazes", "endpoint_kwargs"])
Out[3]:
<Axes: title={'center': "maze_ctor_kwargs.p vs dataset_success_fraction\nMazeDatasetConfig(n_mazes=32, endpoint_kwargs={'deadend_start': False, 'deadend_end': False, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False})"}, xlabel='maze_ctor_kwargs.p', ylabel='dataset_success_fraction'>
sweep acrossall endpoint kwargs and generator funcs¶
In [4]:
results_sweep: SweepResult = full_percolation_analysis(
n_mazes=16,
p_val_count=11,
grid_sizes=[2, 4, 6],
parallel=False,
save_dir=Path("tests/_temp/dataset_frac_sweep"),
)
tqdm_allowed_kwargs = {'write_bytes', 'file', 'desc', 'unit_divisor', 'leave', 'position', 'colour', 'nrows', 'mininterval', 'dynamic_ncols', 'bar_format', 'disable', 'smoothing', 'postfix', 'ascii', 'self', 'lock_args', 'initial', 'unit', 'delay', 'gui', 'total', 'maxinterval', 'miniters', 'iterable', 'unit_scale', 'ncols'} mapped_kwargs = {'total': 18, 'desc': 'Processing 18 items'}
Processing 18 items: 100%|██████████| 18/18 [00:05<00:00, 3.54it/s]
Saving results to tests/_temp/dataset_frac_sweep/result-n16-c18-p11.zanj
In [5]:
results_medium: SweepResult = SweepResult.read(
"../docs/benchmarks/percolation_fractions/medium/result-n128-c42-p50.zanj",
# "../docs/benchmarks/percolation_fractions/large/result-n256-c54-p100.zanj"
)
In [6]:
plot_grouped(
results_medium,
predict_fn=lambda x: x.success_fraction_estimate(),
prediction_density=100,
)
perform a pysr regression on a dataset we load¶
In [7]:
DATA_PATH_DIR: Path = Path("../docs/benchmarks/percolation_fractions/")
# DATA_PATH: str = DATA_PATH_DIR / "large/result-n256-c54-p100.zanj"
# DATA_PATH: str = DATA_PATH_DIR / "medium/result-n128-c42-p50.zanj"
DATA_PATH: str = DATA_PATH_DIR / "small/result-n64-c30-p25.zanj"
# DATA_PATH: str = DATA_PATH_DIR / "test/result-n16-c12-p16.zanj"
sweep_fit(
DATA_PATH,
Path("tests/_temp/fit_plots/"),
niterations=3,
)
loaded data: data.summary() = {'len(configs)': 30, 'len(param_values)': 25, 'len(result_values)': 30, 'param_key': 'maze_ctor_kwargs.p', 'analyze_func': 'dataset_success_fraction'} training data extracted: x.shape = (750, 5), y.shape = (750,) Compiling Julia backend...
/home/miv/projects/mazes/maze-dataset/.venv/lib/python3.12/site-packages/pysr/sr.py:2780: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off. warnings.warn( /home/miv/projects/mazes/maze-dataset/.venv/lib/python3.12/site-packages/pysr/sr.py:84: UserWarning: You are using the `^` operator, but have not set up `constraints` for it. This may lead to overly complex expressions. One typical constraint is to use `constraints={..., '^': (-1, 1)}`, which will allow arbitrary-complexity base (-1) but only powers such as a constant or variable (1). For more tips, please see https://ai.damtp.cam.ac.uk/pysr/tuning/ warnings.warn( [ Info: Started! [ Info: Final population: [ Info: Results saved to:
─────────────────────────────────────────────────────────────────────────────────────────────────── Complexity Loss Score Equation 1 1.163e-01 1.594e+01 y = 0.64637 3 1.025e-01 6.322e-02 y = 0.61868 ^ x₂ 4 9.422e-02 8.405e-02 y = 1.3004 - sigmoid(x₂) 5 9.202e-02 2.358e-02 y = -0.13331 + (0.66953 ^ x₂) 6 9.202e-02 -0.000e+00 y = (0.66953 ^ square(x₂)) + -0.13331 7 8.721e-02 5.365e-02 y = square(sigmoid(1.5228 - x₀) ^ x₂) 8 8.338e-02 4.497e-02 y = sigmoid((x₁ * (0.016938 ^ x₂)) ^ x₀) 11 8.066e-02 1.107e-02 y = sigmoid(exp(cube((x₀ / (-0.49991 + square(x₂))) / -1.0... 781))) 12 6.949e-02 1.489e-01 y = (exp(cube(x₄ * (x₀ * -1.6204))) * sigmoid(x₄)) ^ x₂ 14 6.808e-02 1.032e-02 y = (exp(x₀ * (x₄ * cube(x₀ * -1.6204))) * sigmoid(x₄)) ^ ... x₂ 19 6.313e-02 1.509e-02 y = (sigmoid(2.5487 - (exp(x₄ + (sigmoid(x₁ - -0.082637) ^... -0.97417)) ^ square(x₀))) ^ x₂) + -0.11761 ─────────────────────────────────────────────────────────────────────────────────────────────────── Equations saved to: equations_file = PosixPath('tests/_temp/fit_plots/equations.txt') Best PySR Equation: model.get_best()['equation'] = '(exp(cube(x4 * (x0 * -1.6203635))) * sigmoid(x4)) ^ x2' predict_fn =PySRFunction(X=>(exp(-4.25439055041077*x0**3*x4**3)/(1 + exp(-x4)))**x2) Saving plot to tests/_temp/fit_plots/ep_any.svg Saving plot to tests/_temp/fit_plots/ep_deadends.svg Saving plot to tests/_temp/fit_plots/ep_deadends_unique.svg - outputs/20250311_015932_WZSLJY/hall_of_fame.csv
interactive plots for figuring out maze_dataset.math.soft_step()
¶
In [ ]:
# Run the interactive visualization if in a Jupyter notebook
if "__vsc_ipynb_file__" in globals():
from maze_dataset.benchmark.sweep_fit import create_interactive_plot
create_interactive_plot(True)
VBox(children=(VBox(children=(Label(value='Adjust parameters:'), HBox(children=(FloatSlider(value=0.5, descrip…
In [9]:
cfg = MazeDatasetConfig(
name="test",
seed=3,
grid_n=5,
n_mazes=10,
maze_ctor=LatticeMazeGenerators.gen_dfs_percolation,
maze_ctor_kwargs=dict(p=0.7),
endpoint_kwargs=dict(
deadend_start=True,
# deadend_end=True,
endpoints_not_equal=True,
except_on_no_valid_endpoint=False,
),
)
print(f"{cfg.success_fraction_estimate() = }")
cfg_new = cfg.success_fraction_compensate()
print(f"{cfg_new.n_mazes = }")
/home/miv/projects/mazes/maze-dataset/maze_dataset/dataset/dataset.py:95: UserWarning: in GPTDatasetConfig self.name='test', self.seed=3 is trying to override muutils.mlutils.GLOBAL_SEED=42 which has already been changed elsewhere from muutils.mlutils.DEFAULT_SEED=42 warnings.warn(
cfg.success_fraction_estimate() = np.float64(0.45037493871086454) cfg_new.n_mazes = 27
In [10]:
len(MazeDataset.from_config(cfg_new))
Out[10]:
12