TorchPolicy¶
-
class
maze.core.agent.torch_policy.
TorchPolicy
(networks: Mapping[Union[str, int], torch.nn.Module], distribution_mapper: maze.distributions.distribution_mapper.DistributionMapper, device: str)¶ Encapsulates multiple torch policies along with a distribution mapper for training and rollouts in structured environments.
- Parameters
networks – Mapping of policy networks to encapsulate
distribution_mapper – Distribution mapper associated with the policy mapping.
device – Device the policy should be located on (cpu or cuda)
-
compute_action
(observation: Dict[str, numpy.ndarray], maze_state: Optional[Any] = None, policy_id: Union[str, int] = None, deterministic: bool = False) → Dict[str, Union[int, numpy.ndarray]]¶ implementation of
Policy
-
compute_action_distribution
(observation: Any, policy_id: Union[str, int] = None) → Any¶ Query the policy corresponding to the given ID for the action distribution.
-
compute_action_logits_entropy_dist
(policy_id: Union[str, int], observation: Dict[Union[str, int], torch.Tensor], deterministic: bool, temperature: float) → Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor, maze.distributions.dict.DictProbabilityDistribution]¶ Compute action for the given observation and policy ID and return it together with the logits.
- Parameters
policy_id – ID of the policy to query (does not have to be provided if policies dict contain only 1 policy
observation – Current observation of the environment
deterministic – Specify if the action should be computed deterministically
temperature – Controls the sampling behaviour. * 1.0 corresponds to unmodified sampling * smaller than 1.0 concentrates the action distribution towards deterministic sampling
- Returns
Tuple of (action, logits_dict, entropy, prob_dist)
-
compute_action_with_logits
(observation: Any, policy_id: Union[str, int] = None, deterministic: bool = False) → Tuple[Any, Dict[str, torch.Tensor]]¶ Compute action for the given observation and policy ID and return it together with the logits.
- Parameters
observation – Current observation of the environment
policy_id – ID of the policy to query (does not have to be provided if policies dict contain only 1 policy
deterministic – Specify if the action should be computed deterministically
- Returns
Tuple of (action, logits_dict)
-
compute_logits_dict
(observation: Any, policy_id: Union[str, int] = None) → Dict[str, torch.Tensor]¶ Get the logits for the given observation and policy ID.
- Parameters
observation – Observation to return probability distribution for
policy_id – Policy ID this observation corresponds to
- Returns
Logits dictionary
-
compute_top_action_candidates
(observation: Dict[str, numpy.ndarray], num_candidates: int, maze_state: Optional[Any] = None, policy_id: Union[str, int] = None, deterministic: bool = False) → Tuple[Sequence[Dict[str, Union[int, numpy.ndarray]]], Sequence[float]]¶ implementation of
Policy
-
eval
() → None¶ implementation of
TorchModel
-
load_state_dict
(state_dict: Dict) → None¶ implementation of
TorchModel
-
logits_dict_to_distribution
(logits_dict: Dict[str, torch.Tensor], temperature: float = 1.0)¶ Helper function for creation of a dict probability distribution from the given logits dictionary.
- Parameters
logits_dict – A logits dictionary [action_head: action_logits] to parameterize the distribution from.
temperature – Controls the sampling behaviour. * 1.0 corresponds to unmodified sampling * smaller than 1.0 concentrates the action distribution towards deterministic sampling
- Returns
(DictProbabilityDistribution) the respective instance of a DictProbabilityDistribution.
-
parameters
() → List[torch.Tensor]¶ implementation of
TorchModel
-
state_dict
() → Dict¶ implementation of
TorchModel
-
to
(device: str) → None¶ implementation of
TorchModel
-
train
() → None¶ implementation of
TorchModel