TrainingRunner

class maze.train.trainers.common.training_runner.TrainingRunner(state_dict_dump_file: str, spaces_config_dump_file: str, normalization_samples: int)

Base class for training runner implementations.

normalization_samples: int

Number of samples (=steps) to collect normalization statistics at the beginning of the training.

run(cfg: omegaconf.DictConfig)None

While this method is designed to be overriden by individual subclasses, it provides some functionality that is useful in general:

  • Building the env factory for env + wrappers

  • Estimating normalization statistics from the env

  • If successfully estimated, wrapping the env factory so that envs are already built with the statistics

  • Building the model composer from model config and env spaces config

  • Serializing the env spaces configuration (so that the model composer can be re-loaded for future rollout)

  • Initializing logging setup

Parameters

cfg – Full Hydra run job config

spaces_config_dump_file: str

Where to save the env spaces configuration (output directory handled by hydra)

state_dict_dump_file: str

Where to save the best model (output directory handled by hydra)