ESTrainer¶
-
class
maze.train.trainers.es.es_trainer.ESTrainer(algorithm_config: maze.train.trainers.es.es_algorithm_config.ESAlgorithmConfig, policy: maze.core.agent.torch_policy.TorchPolicy, shared_noise: maze.train.trainers.es.es_shared_noise_table.SharedNoiseTable, normalization_stats: Optional[Dict[str, Tuple[numpy.ndarray, numpy.ndarray]]])¶ Trainer class for OpenAI Evolution Strategies.
- Parameters
algorithm_config – Algorithm parameters.
policy – Multi-step policy encapsulating the policy networks
shared_noise – The noise table, with the same content for every worker and the master.
normalization_stats – Normalization statistics as calculated by the NormalizeObservationWrapper.
-
load_state_dict(state_dict: Dict) → None¶ Set the model and optimizer state. :param state_dict: The state dict.
-
train(distributed_rollouts: maze.train.trainers.es.distributed.es_distributed_rollouts.ESDistributedRollouts, model_selection: Optional[maze.train.trainers.common.model_selection.model_selection_base.ModelSelectionBase]) → None¶ Run the ES training loop.
- Parameters
distributed_rollouts – The distribution interface for experience collection.
model_selection – Optional model selection class, receives model evaluation results.