MultiStepIMPALA¶
-
class
maze.train.trainers.impala.impala_trainer.MultiStepIMPALA(model: maze.core.agent.torch_actor_critic.TorchActorCritic, rollout_actors: maze.train.parallelization.distributed_actors.distributed_actors.BaseDistributedActors, eval_env: Union[maze.train.parallelization.distributed_env.distributed_env.BaseDistributedEnv, maze.core.env.structured_env.StructuredEnv, maze.core.env.structured_env_spaces_mixin.StructuredEnvSpacesMixin, maze.core.log_stats.log_stats_env.LogStatsEnv], options: maze.train.trainers.impala.impala_algorithm_config.ImpalaAlgorithmConfig)¶ Multi step advantage actor critic.
- Parameters
model – Structured policy to train
rollout_actors – Distributed actors for collection of training rollouts
eval_env – Env to run evaluation on
options – Algorithm options
-
evaluate(deterministic: bool, repeats: int) → None¶ Perform evaluation on eval env.
- Parameters
deterministic – deterministic or stochastic action sampling (selection)
repeats – number of evaluation episodes to average over
-
load_state_dict(state_dict: Dict) → None¶ Set the model and optimizer state. :param state_dict: The state dict.
-
train(n_epochs: int, epoch_length: int, deterministic_eval: bool, eval_repeats: int, patience: Optional[int], model_selection: Optional[maze.train.trainers.common.model_selection.best_model_selection.BestModelSelection]) → None¶ Train function that wraps normal train function in order to close all processes properly
- Parameters
n_epochs – number of epochs to train.
epoch_length – number of updates per epoch.
deterministic_eval – run evaluation in deterministic mode (argmax-policy)
eval_repeats – number of evaluation trials
patience – number of steps used for early stopping
model_selection – Optional model selection class, receives model evaluation results
-
train_async(n_epochs: int, epoch_length: int, deterministic_eval: bool, eval_repeats: int, patience: Optional[int], model_selection: Optional[maze.train.trainers.common.model_selection.best_model_selection.BestModelSelection]) → None¶ Train policy using the synchronous advantage actor critic.
- Parameters
n_epochs – number of epochs to train.
epoch_length – number of updates per epoch.
deterministic_eval – run evaluation in deterministic mode (argmax-policy)
eval_repeats – number of evaluation trials
patience – number of steps used for early stopping
model_selection – Optional model selection class, receives model evaluation results