Source code for narla.neurons.neuron_settings

from __future__ import annotations

import dataclasses

import narla
from narla.neurons.neuron_types import NeuronTypes
from narla.settings.base_settings import BaseSettings


[docs]@dataclasses.dataclass class NeuronSettings(BaseSettings): learning_rate: float = 1e-4 """Learning rate for Neurons""" neuron_type: NeuronTypes = NeuronTypes.POLICY_GRADIENT """Type of Neuron that will be used"""
[docs] def create_neuron(self, observation_size: int, number_of_actions: int) -> narla.neurons.Neuron: NeuronType = self.neuron_type.to_neuron_type() return NeuronType( observation_size=observation_size, number_of_actions=number_of_actions, learning_rate=self.learning_rate, )