MultiStepIMPALA

class maze.train.trainers.impala.impala_trainer.MultiStepIMPALA(model: maze.core.agent.torch_actor_critic.TorchActorCritic, rollout_actors: maze.train.parallelization.distributed_actors.distributed_actors.BaseDistributedActors, eval_env: Union[maze.train.parallelization.distributed_env.distributed_env.BaseDistributedEnv, maze.core.env.structured_env.StructuredEnv, maze.core.env.structured_env_spaces_mixin.StructuredEnvSpacesMixin, maze.core.log_stats.log_stats_env.LogStatsEnv], options: maze.train.trainers.impala.impala_algorithm_config.ImpalaAlgorithmConfig)

Multi step advantage actor critic.

Parameters
  • model – Structured policy to train

  • rollout_actors – Distributed actors for collection of training rollouts

  • eval_env – Env to run evaluation on

  • options – Algorithm options

evaluate(deterministic: bool, repeats: int)None

Perform evaluation on eval env.

Parameters
  • deterministic – deterministic or stochastic action sampling (selection)

  • repeats – number of evaluation episodes to average over

load_state(file_path: Union[str, BinaryIO])None

implementation of Trainer

load_state_dict(state_dict: Dict)None

Set the model and optimizer state. :param state_dict: The state dict.

train(n_epochs: int, epoch_length: int, deterministic_eval: bool, eval_repeats: int, patience: Optional[int], model_selection: Optional[maze.train.trainers.common.model_selection.best_model_selection.BestModelSelection])None

Train function that wraps normal train function in order to close all processes properly

Parameters
  • n_epochs – number of epochs to train.

  • epoch_length – number of updates per epoch.

  • deterministic_eval – run evaluation in deterministic mode (argmax-policy)

  • eval_repeats – number of evaluation trials

  • patience – number of steps used for early stopping

  • model_selection – Optional model selection class, receives model evaluation results

train_async(n_epochs: int, epoch_length: int, deterministic_eval: bool, eval_repeats: int, patience: Optional[int], model_selection: Optional[maze.train.trainers.common.model_selection.best_model_selection.BestModelSelection])None

Train policy using the synchronous advantage actor critic.

Parameters
  • n_epochs – number of epochs to train.

  • epoch_length – number of updates per epoch.

  • deterministic_eval – run evaluation in deterministic mode (argmax-policy)

  • eval_repeats – number of evaluation trials

  • patience – number of steps used for early stopping

  • model_selection – Optional model selection class, receives model evaluation results