Offline next-frame prediction world model using Gymnasium and PyTorch. This repo provides a full runnable scaffold with dataset generation, training, and evaluation (including videos and plots). The current environment setup uses Procgen, but core modeling logic is environment-agnostic.
./scripts/setup_venv.sh
source .venv/bin/activatescripts/setup_venv.sh installs core + dev dependencies and Procgen when the
Python version is compatible. Procgen currently requires Python <3.11.
Later, you can activate with:
./scripts/activate_venv.shGenerate dataset (random policy):
python -m src.dataset.generate_dataset --game coinrun --steps 300000 --out_dir data/coinrunNote: CoinRun/Procgen commands require Procgen support (Python <3.11).
Train (uses config.yaml):
python -m src.trainOverride config values with dot notation:
python -m src.train optimizer.lr=1e-4 train.max_steps=5Evaluate:
python -m src.eval --checkpoint runs/.../ckpt.pt --game coinrunEnable W&B logging for training and evaluation:
wandb login
python -m src.train
python -m src.eval --checkpoint runs/.../ckpt.pt --game coinrunTo disable logging:
python -m src.train experiment.wandb.mode=disabled
python -m src.eval --checkpoint runs/.../ckpt.pt --game coinrun --wandb_mode disabled- Dataset shards:
data/coinrun/shard_*_{frames,actions,done}.npy - Manifest:
data/coinrun/manifest.json - Training run:
runs/<timestamp>_<game>/metrics.csv,runs/<timestamp>_<game>/val_metrics.csv,runs/<timestamp>_<game>/images/,runs/<timestamp>_<game>/videos/,runs/<timestamp>_<game>/checkpoints/,runs/<timestamp>_<game>/resolved_config.yaml - Eval artifacts:
runs/<timestamp>_<game>_eval/videos/(MP4 + GIF),runs/<timestamp>_<game>_eval/plots/mse_vs_horizon.png
- Inputs are
(B, 4, 84, 84)float32 in[0,1], actions(B,)int64. - Targets are
(B, 1, 84, 84)float32 in[0,1]. - Offline dataset shards are memory-mapped
.npyfiles; random access pulls one sample at a time. - DataLoader settings are explicit in
config.yaml(num_workers,prefetch_factor,persistent_workers,pin_memory); setprefetch_factor: nullwhennum_workers: 0. - Data config is validated at startup and will raise if values are inconsistent.
- Train/val split is controlled by
data.val_ratioand validation runs everytrain.val_every_steps. - At
train.log_every, an image strip of the 4 input frames plus predicted next frame is saved (and logged to W&B if enabled). - Each validation run can optionally trigger an open-loop rollout video (left=GT, right=prediction) controlled by
train.val_rollout_*. - Default loss: Huber. Use
train.loss=msefor MSE. - CPU-only execution is supported via
train.cpu=true.
src/models/world_model.py: action conditioning and model variants.src/eval.py: rollout logic, horizon aggregation, and metrics.src/utils/metrics.py: add additional metrics or perceptual losses.