InMemoryImitationDataSet¶
-
class
maze.train.trainers.imitation.in_memory_data_set.InMemoryImitationDataSet(*args: Any, **kwargs: Any)¶ Trajectory data set for imitation learning.
Loads all data on initialization and then keeps it in memory.
- Parameters
trajectory_data_dir – The directory where the trajectory data are stored.
env_factory – Function for creating an environment for state and action conversion. For Maze envs, the environment configuration (i.e. space interfaces, wrappers etc.) determines the format of the actions and observations that will be derived from the recorded MazeActions and MazeStates (e.g. multi-step observations/actions etc.).
-
static
get_trajectory_files(trajectory_data_dir: str) → List[pathlib.Path]¶ List pickle files (“pkl” suffix, used for trajectory data storage by default) in the given directory.
- Parameters
trajectory_data_dir – Where to look for the trajectory records (= pickle files).
- Returns
A list of available pkl files in the given directory.
-
static
load_episode_record(env: maze.core.env.structured_env.StructuredEnv, episode_record: maze.core.trajectory_recorder.episode_record.EpisodeRecord) → Tuple[List[Dict[Union[int, str], Any]], List[Dict[Union[int, str], Any]]]¶ Convert an episode trajectory record into an array of observations and actions using the given env.
- Parameters
env – Env to use for conversion of MazeStates and MazeActions into observations and actions
episode_record – Episode record to load
- Returns
Loaded observations and actions. I.e., a tuple (observation_list, action_list). Each of the lists contains observation/action dictionaries, with keys corresponding to IDs of structured sub-steps. (I.e., the dictionary will have just one entry for non-structured scenarios.)
-
random_split(lengths: Sequence[int], generator: torch.Generator = torch.default_generator) → List[torch.utils.data.dataset.Subset]¶ Randomly split the dataset into non-overlapping new datasets of given lengths.
The split is based on episodes – samples from the same episode will end up in the same subset. Based on the available episode lengths, this might result in subsets of slightly different lengths than specified.
Optionally fix the generator for reproducible results, e.g.:
self.random_split([3, 7], generator=torch.Generator().manual_seed(42))
- Parameters
lengths – lengths of splits to be produced (best effort, the result might differ based on available episode lengths
generator – Generator used for the random permutation.
- Returns
A list of the data subsets, each with size roughly (!) corresponding to what was specified by lengths.