Source code for narla.environments.gym_environment

from __future__ import annotations

from typing import Tuple

import numpy as np
import torch
import gymnasium as gym

import narla
from narla.environments import Environment


[docs]class GymEnvironment(Environment): """ Wrapper on `Gymnasium Environments <https://gymnasium.farama.org/>`_ :param name: Name of the environment :param render: If ``True`` will visualize the environment """ def __init__(self, name: narla.environments.AvailableEnvironments, render: bool = False): super().__init__(name=name, render=render) self._gym_environment = self._build_gym_environment(name=name, render=render) self._action_space = narla.environments.ActionSpace(number_of_actions=self._gym_environment.action_space.n) @staticmethod def _build_gym_environment(name: narla.environments.AvailableEnvironments, render: bool) -> gym.Env: render_mode = None if render: render_mode = "human" gym_environment = gym.make(id=name.value, render_mode=render_mode) return gym_environment
[docs] def has_been_solved(self, episode_rewards: list) -> bool: if self._name == narla.environments.GymEnvironments.CART_POLE: return np.mean(episode_rewards[-100:]) > 300 return False
@property def observation_size(self) -> int: observation = self.reset() return observation.shape[-1]
[docs] def reset(self) -> torch.Tensor: self._episode_reward = 0 observation, info = self._gym_environment.reset() observation = self._cast_observation(observation) return observation
[docs] def step(self, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]: action = int(action.item()) observation, reward, terminated, truncated, info = self._gym_environment.step(action) self._episode_reward += reward observation = self._cast_observation(observation) reward = self._cast_reward(reward) return observation, reward, terminated or truncated