BCLoss¶
-
class
maze.train.trainers.imitation.bc_loss.BCLoss(action_spaces_dict: Dict[Union[int, str], gym.spaces.Dict], loss_discrete: torch.nn.Module = torch.nn.CrossEntropyLoss, loss_box: torch.nn.Module = torch.nn.MSELoss, loss_multi_binary: torch.nn.Module = torch.nn.functional.binary_cross_entropy_with_logits)¶ Loss function for behavioral cloning.
-
action_spaces_dict: Dict[Union[int, str], gym.spaces.Dict]¶ Action space we are training on (used to determine appropriate loss functions)
-
calculate_loss(policy: maze.core.agent.torch_policy.TorchPolicy, observation_dict: Dict[Union[int, str], Any], action_dict: Dict[Union[int, str], Any], events: maze.train.trainers.imitation.imitation_events.ImitationEvents) → torch.Tensor¶ Calculate and return the training loss for one step (= multiple sub-steps in structured scenarios).
- Parameters
policy – Structured policy to evaluate
observation_dict – Dictionary with observations identified by substep ID
action_dict – Dictionary with actions identified by substep ID
events –
- Returns
Total loss
-