MultiCategoricalProbabilityDistribution¶
-
class
maze.distributions.multi_categorical.MultiCategoricalProbabilityDistribution(logits: torch.Tensor, action_space: gym.spaces.MultiDiscrete, temperature: float)¶ Multi-categorical probability distribution.
The respective functions either return aggregated properties across the sub-distributions using a reduce_fun such as mean or sum.
- Parameters
logits – The concatenated action selection logits for all sub spaces.
-
deterministic_sample() → Dict[str, torch.Tensor]¶ implementation of
TorchProbabilityDistributioninterface
-
entropy(reduce_fun: callable = torch.mean) → torch.Tensor¶ implementation of
TorchProbabilityDistributioninterface
-
kl(other: maze.distributions.multi_categorical.MultiCategoricalProbabilityDistribution, reduce_fun: callable = torch.mean) → torch.Tensor¶ implementation of
TorchProbabilityDistributioninterface
-
log_prob(actions: List[torch.Tensor]) → torch.Tensor¶ implementation of
TorchProbabilityDistributioninterface
-
neg_log_prob(actions: List[torch.Tensor]) → torch.Tensor¶ implementation of
TorchProbabilityDistributioninterface
-
classmethod
required_logits_shape(action_space: gym.spaces.MultiDiscrete) → Sequence[int]¶ implementation of
TorchProbabilityDistributioninterface
-
sample() → List[torch.Tensor]¶ implementation of
TorchProbabilityDistributioninterface