Source code for narla.rewards.reward_types

from __future__ import annotations

import enum
from typing import List

import narla


[docs]class RewardTypes(str, enum.Enum): ACTIVE_NEURONS = "active_neurons" ACTIVITY_TRACE = "activity_trace" LAYER_SPARSITY = "layer_sparsity" PREDICTION = "prediction" TASK_REWARD = "task_reward"
[docs] @staticmethod def biological_reward_types() -> List[RewardTypes]: """ Get the RewardTypes that are biological """ return [RewardTypes.ACTIVE_NEURONS, RewardTypes.LAYER_SPARSITY, RewardTypes.PREDICTION, RewardTypes.ACTIVITY_TRACE]
[docs] def to_reward(self) -> narla.rewards.Reward: """ Convert the RewardType to a Reward object """ if self == RewardTypes.ACTIVE_NEURONS: return narla.rewards.ActiveNeurons() elif self == RewardTypes.LAYER_SPARSITY: return narla.rewards.LayerSparsity() elif self == RewardTypes.PREDICTION: return narla.rewards.Prediction() elif self == RewardTypes.ACTIVITY_TRACE: return narla.rewards.ActivityTrace() raise NotImplementedError(f"Reward not implemented for type: {self}")