TorchDeltaStateCritic

class maze.core.agent.torch_state_critic.TorchDeltaStateCritic(networks: Mapping[Union[str, int], torch.nn.Module], num_policies: int, device: str)

First sub step gets a regular critic, subsequent sub-steps predict a delta w.r.t. to the previous critic. Can be instantiated via the DeltaStateCriticComposer.

property num_critics

implementation of TorchStateCritic

predict_value(observation: Dict[str, numpy.ndarray], critic_id: Union[int, str]) → torch.Tensor

Predictions depend on previous sub-steps, thus this method is not supported in the delta state critic.

predict_values(observations: Dict[Union[str, int], Dict[str, torch.Tensor]]) → Tuple[Dict[Union[str, int], torch.Tensor], Dict[Union[str, int], torch.Tensor]]

implementation of StateCritic