TorchStateCritic¶
-
class
maze.core.agent.torch_state_critic.
TorchStateCritic
(networks: Mapping[Union[str, int], torch.nn.Module], num_policies: int, device: str)¶ Encapsulates multiple torch state critics for training in structured environments.
- Parameters
networks – Mapping of value functions (critic) to encapsulate.
num_policies – The number of corresponding policies.
device – Device the policy should be located on (cpu or cuda)
-
bootstrap_returns
(observations: Dict[Union[str, int], Dict[str, torch.Tensor]], rews: numpy.ndarray, dones: numpy.ndarray, gamma: float, gae_lambda: float) → Tuple[Dict[Union[str, int], torch.Tensor], Dict[Union[str, int], torch.Tensor], Dict[Union[str, int], torch.Tensor]]¶ Bootstrap returns using the value function.
Useful for example to implement PPO or A2C.
- Parameters
observations – Sub-step observations as tensor dictionary.
rews – Array holding the per step rewards.
dones – Array indicating if a step is a done step.
gamma – Discounting factor
gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)
- Returns
Tuple containing the computed returns, the predicted values and the detached predicted values.
-
compute_return
(gamma: float, gae_lambda: float, rewards: numpy.ndarray, values: torch.Tensor, dones: numpy.ndarray, deltas: torch.Tensor = None) → torch.Tensor¶ Compute bootstrapped return from rewards and estimated values.
- Parameters
gamma – Discounting factor
gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)
rewards – Step rewards with shape (n_steps, n_workers)
values – Predicted values with shape (n_steps, n_workers)
dones – Step dones with shape (n_steps, n_workers)
deltas – Predicted value deltas to previous sub-step with shape (n_steps, n_workers)
- Returns
Per time step returns.
-
property
device
¶ implementation of
TorchModel
-
eval
() → None¶ implementation of
TorchModel
-
load_state_dict
(state_dict: Dict) → None¶ implementation of
TorchModel
-
abstract property
num_critics
¶ Returns the number of critic networks. :return: Number of critic networks.
-
parameters
() → List[torch.Tensor]¶ implementation of
TorchModel
-
predict_value
(observation: Dict[str, numpy.ndarray], critic_id: Union[int, str]) → torch.Tensor¶ implementation of
StateCritic
-
state_dict
() → Dict¶ implementation of
TorchModel
-
to
(device: str) → None¶ implementation of
TorchModel
-
train
() → None¶ implementation of
TorchModel