TorchStateCritic

class maze.core.agent.torch_state_critic.TorchStateCritic(networks: Mapping[Union[str, int], torch.nn.Module], num_policies: int, device: str)

Encapsulates multiple torch state critics for training in structured environments.

Parameters
  • networks – Mapping of value functions (critic) to encapsulate.

  • num_policies – The number of corresponding policies.

  • device – Device the policy should be located on (cpu or cuda)

bootstrap_returns(observations: Dict[Union[str, int], Dict[str, torch.Tensor]], rews: numpy.ndarray, dones: numpy.ndarray, gamma: float, gae_lambda: float) → Tuple[Dict[Union[str, int], torch.Tensor], Dict[Union[str, int], torch.Tensor], Dict[Union[str, int], torch.Tensor]]

Bootstrap returns using the value function.

Useful for example to implement PPO or A2C.

Parameters
  • observations – Sub-step observations as tensor dictionary.

  • rews – Array holding the per step rewards.

  • dones – Array indicating if a step is a done step.

  • gamma – Discounting factor

  • gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)

Returns

Tuple containing the computed returns, the predicted values and the detached predicted values.

compute_return(gamma: float, gae_lambda: float, rewards: numpy.ndarray, values: torch.Tensor, dones: numpy.ndarray, deltas: torch.Tensor = None) → torch.Tensor

Compute bootstrapped return from rewards and estimated values.

Parameters
  • gamma – Discounting factor

  • gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)

  • rewards – Step rewards with shape (n_steps, n_workers)

  • values – Predicted values with shape (n_steps, n_workers)

  • dones – Step dones with shape (n_steps, n_workers)

  • deltas – Predicted value deltas to previous sub-step with shape (n_steps, n_workers)

Returns

Per time step returns.

property device

implementation of TorchModel

eval()None

implementation of TorchModel

load_state_dict(state_dict: Dict)None

implementation of TorchModel

abstract property num_critics

Returns the number of critic networks. :return: Number of critic networks.

parameters() → List[torch.Tensor]

implementation of TorchModel

predict_value(observation: Dict[str, numpy.ndarray], critic_id: Union[int, str]) → torch.Tensor

implementation of StateCritic

state_dict() → Dict

implementation of TorchModel

to(device: str)None

implementation of TorchModel

train()None

implementation of TorchModel