Source code for narla.settings.trial_settings

import dataclasses

try:
    from typing import Literal
except:
    from typing_extensions import Literal

import narla
from narla.settings.base_settings import BaseSettings


[docs]@dataclasses.dataclass class TrialSettings(BaseSettings): batch_size: int = 128 """Batch size to use during training""" device: Literal["cpu", "cuda"] = "cpu" """Device to put the network on""" gpu: int = 0 """GPU ID to run on""" maximum_episodes: int = 20_000 """Total number of episodes to run for""" random_seed: int = 0 """Random seed""" results_directory: str = "" """Path to save results""" save_every: int = 1_000 """Save results every n steps""" trial_id: int = 0 """Unique ID of the trial being run, corresponds to the path of data saving <results_directory>/<trial_id>/""" def __post_init__(self): narla.settings.reproducibility(self.random_seed)