Source code for narla.history.history

from __future__ import annotations

import random
from typing import Dict, List

import torch
import pandas as pd

import narla


[docs]class History: """ History object contains arbitrary data from training :param storage_size: Max storage size """ def __init__(self, storage_size: int = 10_000): self.storage_size = storage_size self._history: Dict[str, List[torch.Tensor | float]] = {}
[docs] def clear(self): """ Clear History """ self._history: Dict[str, List[torch.Tensor | float]] = {}
[docs] def get(self, key: str, stack: bool = False) -> list | torch.Tensor: """ Access an element from the History :param key: Name of element to Access :param stack: If ``True`` will stack the values into a Tensor """ tensors = self._history.get(key) if stack: tensors = torch.stack(tensors) return tensors
[docs] def record(self, **kwargs): """ Store all the keyword arguments :param kwargs: If the key word doesn't yet exist an internal list will be created and the value appended to it """ for key, value in kwargs.items(): if key not in self._history: self._history[key] = [] if key == narla.history.saved_data.NEXT_OBSERVATION: self._history[key] = [None] elif key == narla.history.saved_data.OBSERVATION: self.record(**{narla.history.saved_data.NEXT_OBSERVATION: value}) if key == narla.history.saved_data.NEXT_OBSERVATION: self._history[narla.history.saved_data.NEXT_OBSERVATION][-1] = value self._history[narla.history.saved_data.NEXT_OBSERVATION].append(None) else: self._history[key].append(value)
[docs] def sample(self, names: List[str], sample_size: int, from_most_recent: int = 10_000) -> List[List[torch.Tensor]]: sample = [] for name in names: items = self.get(name)[-from_most_recent:] sample.append(random.Random(123).sample(items, sample_size)) return sample
[docs] def stack(self, *keys, dim: int = -1) -> torch.Tensor: """ Stack a series of Tensors :param keys: Keys of the values to access from the History :param dim: Dimension of the Tensors to stack on """ lists = [self.get(key, stack=True).squeeze() for key in keys] return torch.stack(lists, dim=dim)
[docs] def to_data_frame(self, keys: List[str] = ()) -> pd.DataFrame: """ Convert a History object to a DataFrame :param keys: Keys to use as columns in the DataFrame """ if len(keys) == 0: keys = sorted(self._history.keys()) data = {key: self.get(key) for key in keys} data_frame = pd.DataFrame(data=data, columns=keys) return data_frame
def __len__(self) -> int: minimum_length = 1e10 for history_record in self._history.values(): minimum_length = min(minimum_length, len(history_record)) return minimum_length