BaseWorkerOutput

class maze.train.parallelization.base_worker.BaseWorkerOutput(observations: Dict[Union[str, int], Dict[str, torch.Tensor]], actions_taken: Dict[Union[str, int], Dict[str, torch.Tensor]], rewards: torch.Tensor, dones: torch.Tensor, infos: List[Any])

Base class for outputs generated by the agent.

Parameters
  • observations – Observations collected during the rollout.

  • actions_taken – Actions taken during the rollout.

  • rewards – Rewards collected during the rollout.

  • dones – Dones collected during the rollout.

  • infos – Infos collected during the rollout.

static get_dict_dict_obj_attr_names() → List[str]

Retrieve the attribute names of the actor output fields that have dict dict structure.

Returns

A list of all attributes having a dict-dict structure.

static get_list_obj_attr_names() → List[str]

Retrieve the attribute names of the actor output fields that have list structure.

Returns

A list of all attributes having a list structure.

static get_tensor_obj_attr_names() → List[str]

Retrieve the attribute names of the actor output fields that have tensor structure.

Returns

A list of all attributes having a tensor structure.

to(device: str)None

Cast all elements to the given device.

Parameters

device – The device to put the output on (cpu or cuda).