TorchPolicy

class maze.core.agent.torch_policy.TorchPolicy(networks: Mapping[Union[str, int], torch.nn.Module], distribution_mapper: maze.distributions.distribution_mapper.DistributionMapper, device: str)

Encapsulates multiple torch policies along with a distribution mapper for training and rollouts in structured environments.

Parameters
  • networks – Mapping of policy networks to encapsulate

  • distribution_mapper – Distribution mapper associated with the policy mapping.

  • device – Device the policy should be located on (cpu or cuda)

compute_action(observation: Dict[str, numpy.ndarray], maze_state: Optional[Any] = None, policy_id: Union[str, int] = None, deterministic: bool = False) → Dict[str, Union[int, numpy.ndarray]]

implementation of Policy

compute_action_distribution(observation: Any, policy_id: Union[str, int] = None) → Any

Query the policy corresponding to the given ID for the action distribution.

compute_action_logits_entropy_dist(policy_id: Union[str, int], observation: Dict[Union[str, int], torch.Tensor], deterministic: bool, temperature: float) → Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor, maze.distributions.dict.DictProbabilityDistribution]

Compute action for the given observation and policy ID and return it together with the logits.

Parameters
  • policy_id – ID of the policy to query (does not have to be provided if policies dict contain only 1 policy

  • observation – Current observation of the environment

  • deterministic – Specify if the action should be computed deterministically

  • temperature – Controls the sampling behaviour. * 1.0 corresponds to unmodified sampling * smaller than 1.0 concentrates the action distribution towards deterministic sampling

Returns

Tuple of (action, logits_dict, entropy, prob_dist)

compute_action_with_logits(observation: Any, policy_id: Union[str, int] = None, deterministic: bool = False) → Tuple[Any, Dict[str, torch.Tensor]]

Compute action for the given observation and policy ID and return it together with the logits.

Parameters
  • observation – Current observation of the environment

  • policy_id – ID of the policy to query (does not have to be provided if policies dict contain only 1 policy

  • deterministic – Specify if the action should be computed deterministically

Returns

Tuple of (action, logits_dict)

compute_logits_dict(observation: Any, policy_id: Union[str, int] = None) → Dict[str, torch.Tensor]

Get the logits for the given observation and policy ID.

Parameters
  • observation – Observation to return probability distribution for

  • policy_id – Policy ID this observation corresponds to

Returns

Logits dictionary

compute_top_action_candidates(observation: Dict[str, numpy.ndarray], num_candidates: int, maze_state: Optional[Any] = None, policy_id: Union[str, int] = None, deterministic: bool = False) → Tuple[Sequence[Dict[str, Union[int, numpy.ndarray]]], Sequence[float]]

implementation of Policy

eval()None

implementation of TorchModel

load_state_dict(state_dict: Dict)None

implementation of TorchModel

logits_dict_to_distribution(logits_dict: Dict[str, torch.Tensor], temperature: float = 1.0)

Helper function for creation of a dict probability distribution from the given logits dictionary.

Parameters
  • logits_dict – A logits dictionary [action_head: action_logits] to parameterize the distribution from.

  • temperature – Controls the sampling behaviour. * 1.0 corresponds to unmodified sampling * smaller than 1.0 concentrates the action distribution towards deterministic sampling

Returns

(DictProbabilityDistribution) the respective instance of a DictProbabilityDistribution.

needs_state()bool

This policy does not require the state() object to compute the action.

parameters() → List[torch.Tensor]

implementation of TorchModel

state_dict() → Dict

implementation of TorchModel

to(device: str)None

implementation of TorchModel

train()None

implementation of TorchModel