Introduction to Gym toolkit

Gym Environments

The centerpiece of Gym is the environment, which defines the “game” in which your reinforcement algorithm will compete. An environment does not need to be a game; however, it describes the following game-like features:

  • action space: What actions can we take on the environment, at each step/episode, to alter the environment.

  • observation space: What is the current state of the portion of the environment that we can observe. Usually, we can see the entire environment.

Before we begin to look at Gym, it is essential to understand some of the terminology used by this library.

  • Agent - The machine learning program or model that controls the actions. Step - One round of issuing actions that affect the observation space.

  • Episode - A collection of steps that terminates when the agent fails to meet the environment’s objective, or the episode reaches the maximum number of allowed steps.

  • Render - Gym can render one frame for display after each episode.

  • Reward - A positive reinforcement that can occur at the end of each episode, after the agent acts.

  • Nondeterministic - For some environments, randomness is a factor in deciding what effects actions have on reward and changes to the observation space.

import gym

def query_environment(name):
  env = gym.make(name)
  spec = gym.spec(name)
  print(f"Action Space: {env.action_space}")
  print(f"Observation Space: {env.observation_space}")
  print(f"Max Episode Steps: {spec.max_episode_steps}")
  print(f"Nondeterministic: {spec.nondeterministic}")
  print(f"Reward Range: {env.reward_range}")
  print(f"Reward Threshold: {spec.reward_threshold}")
query_environment("CartPole-v1")
Action Space: Discrete(2)
Observation Space: Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)
Max Episode Steps: 500
Nondeterministic: False
Reward Range: (-inf, inf)
Reward Threshold: 475.0

The CartPole-v1 environment challenges the agent to move a cart while keeping a pole balanced. The environment has an observation space of 4 continuous numbers:

  • Cart Position

  • Cart Velocity

  • Pole Angle

  • Pole Velocity At Tip

To achieve this goal, the agent can take the following actions:

  • Push cart to the left

  • Push cart to the right

There is also a continuous variant of the mountain car. This version does not simply have the motor on or off. For the continuous car the action space is a single floating point number that specifies how much forward or backward force is being applied.

Simple

import random
from typing import List


class Environment:
    def __init__(self):
        self.steps_left = 10

    def get_observation(self) -> List[float]:
        return [0.0, 0.0, 0.0]

    def get_actions(self) -> List[int]:
        return [0, 1]

    def is_done(self) -> bool:
        return self.steps_left == 0

    def action(self, action: int) -> float:
        if self.is_done():
            raise Exception("Game is over")
        self.steps_left -= 1
        return random.random()


class Agent:
    def __init__(self):
        self.total_reward = 0.0

    def step(self, env: Environment):
        current_obs = env.get_observation()
        actions = env.get_actions()
        reward = env.action(random.choice(actions))
        self.total_reward += reward


if __name__ == "__main__":
    env = Environment()
    agent = Agent()

    while not env.is_done():
        agent.step(env)

    print("Total reward got: %.4f" % agent.total_reward)
Total reward got: 4.6979

Frozenlake

import gym
env = gym.make("FrozenLake-v0")
env.render()
SFFF
FHFH
FFFH
HFFG
print(env.observation_space)
Discrete(16)
print(env.action_space)
Discrete(4)

Number

Action

0

Left

1

Down

2

Right

3

Up

We can obtain the transition probability and the reward function by just typing env.P[state][action]. So, to obtain the transition probability of moving from state S to the other states by performing the action right, we can type env.P[S][right]. But we cannot just type state S and action right directly since they are encoded as numbers. We learned that state S is encoded as 0 and the action right is encoded as 2, so, to obtain the transition probability of state S by performing the action right, we type env.P[0][2]

print(env.P[0][2])
[(0.3333333333333333, 4, 0.0, False), (0.3333333333333333, 1, 0.0, False), (0.3333333333333333, 0, 0.0, False)]

Our output is in the form of [(transition probability, next state, reward, Is terminal state?)]

state = env.reset()
env.step(1)
(1, 0.0, False, {'prob': 0.3333333333333333})
(next_state, reward, done, info) = env.step(1)
  • next_state represents the next state.

  • reward represents the obtained reward.

  • done implies whether our episode has ended. That is, if the next state is a terminal state, then our episode will end, so done will be marked as True else it will be marked as False.

  • info — Apart from the transition probability, in some cases, we also obtain other information saved as info, which is used for debugging purposes.

random_action = env.action_space.sample()
next_state, reward, done, info = env.step(random_action)

Generating an episode The episode is the agent environment interaction starting from the initial state to the terminal state. The agent interacts with the environment by performing some action in each state. An episode ends if the agent reaches the terminal state. So, in the Frozen Lake environment, the episode will end if the agent reaches the terminal state, which is either the hole state (H) or goal state (G).

import gym

env = gym.make("FrozenLake-v0")
state = env.reset()
print('Time Step 0 :')
env.render()
num_timesteps = 20

for t in range(num_timesteps):
  random_action = env.action_space.sample()
  new_state, reward, done, info = env.step(random_action)
  print ('Time Step {} :'.format(t+1))
  env.render()
  if done:
    break
Time Step 0 :

SFFF
FHFH
FFFH
HFFG
Time Step 1 :
  (Right)
SFFF
FHFH
FFFH
HFFG
Time Step 2 :
  (Right)
SFFF
FHFH
FFFH
HFFG
Time Step 3 :
  (Left)
SFFF
FHFH
FFFH
HFFG
Time Step 4 :
  (Right)
SFFF
FHFH
FFFH
HFFG
Time Step 5 :
  (Up)
SFFF
FHFH
FFFH
HFFG

Instead of generating one episode, we can also generate a series of episodes by taking some random action in each state

import gym
env = gym.make("FrozenLake-v0")
num_episodes = 10
num_timesteps = 20 
for i in range(num_episodes):
    
    state = env.reset()
    print('Time Step 0 :')
    env.render()
    
    for t in range(num_timesteps):
        random_action = env.action_space.sample()
        new_state, reward, done, info = env.step(random_action)
        print ('Time Step {} :'.format(t+1))
        env.render()
        if done:
            break

Cartpole

env = gym.make("CartPole-v0")
print(env.observation_space)
Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)

Note that all of these values are continuous, that is:

  • The value of the cart position ranges from -4.8 to 4.8.

  • The value of the cart velocity ranges from -Inf to Inf ( to ).

  • The value of the pole angle ranges from -0.418 radians to 0.418 radians.

  • The value of the pole velocity at the tip ranges from -Inf to Inf.

print(env.reset())
[-0.03805974 -0.00851157 -0.00346854 -0.03263184]
print(env.observation_space.high)
[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]

It implies that:

  1. The maximum value of the cart position is 4.8.

  2. We learned that the maximum value of the cart velocity is +Inf, and we know that infinity is not really a number, so it is represented using the largest positive real value 3.4028235e+38.

  3. The maximum value of the pole angle is 0.418 radians.

  4. The maximum value of the pole velocity at the tip is +Inf, so it is represented using the largest positive real value 3.4028235e+38.

print(env.observation_space.low)
[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38]

It states that:

  1. The minimum value of the cart position is -4.8.

  2. We learned that the minimum value of the cart velocity is -Inf, and we know that infinity is not really a number, so it is represented using the largest negative real value -3.4028235e+38.

  3. The minimum value of the pole angle is -0.418 radians.

  4. The minimum value of the pole velocity at the tip is -Inf, so it is represented using the largest negative real value -3.4028235e+38.

print(env.action_space)
Discrete(2)

Number

Action

0

Push cart to the left

1

Push cart to the right

import gym


if __name__ == "__main__":
    env = gym.make("CartPole-v0")

    total_reward = 0.0
    total_steps = 0
    obs = env.reset()

    while True:
        action = env.action_space.sample()
        obs, reward, done, _ = env.step(action)
        total_reward += reward
        total_steps += 1
        if done:
            break

    print("Episode done in %d steps, total reward %.2f" % (
        total_steps, total_reward))
Episode done in 20 steps, total reward 20.00

Wrappers

Very frequently, you will want to extend the environment’s functionality in some generic way. For example, imagine an environment gives you some observations, but you want to accumulate them in some buffer and provide to the agent the N last observations. This is a common scenario for dynamic computer games, when one single frame is just not enough to get the full information about the game state. Another example is when you want to be able to crop or preprocess an image’s pixels to make it more convenient for the agent to digest, or if you want to normalize reward scores somehow. There are many such situations that have the same structure – you want to “wrap” the existing environment and add some extra logic for doing something. Gym provides a convenient framework for these situations – the Wrapper class.

Random action wrapper

import gym
from typing import TypeVar
import random

Action = TypeVar('Action')


class RandomActionWrapper(gym.ActionWrapper):
    def __init__(self, env, epsilon=0.1):
        super(RandomActionWrapper, self).__init__(env)
        self.epsilon = epsilon

    def action(self, action: Action) -> Action:
        if random.random() < self.epsilon:
            print("Random!")
            return self.env.action_space.sample()
        return action


if __name__ == "__main__":
    env = RandomActionWrapper(gym.make("CartPole-v0"))

    obs = env.reset()
    total_reward = 0.0

    while True:
        obs, reward, done, _ = env.step(0)
        total_reward += reward
        if done:
            break

    print("Reward got: %.2f" % total_reward)
Reward got: 9.00

Atari GAN

! wget http://www.atarimania.com/roms/Roms.rar
! mkdir /content/ROM/
! unrar e /content/Roms.rar /content/ROM/
! python -m atari_py.import_roms /content/ROM/

Normal

import random
import argparse
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import torchvision.utils as vutils

import gym
import gym.spaces

import numpy as np

log = gym.logger
log.set_level(gym.logger.INFO)

LATENT_VECTOR_SIZE = 100
DISCR_FILTERS = 64
GENER_FILTERS = 64
BATCH_SIZE = 16

# dimension input image will be rescaled
IMAGE_SIZE = 64

LEARNING_RATE = 0.0001
REPORT_EVERY_ITER = 100
SAVE_IMAGE_EVERY_ITER = 1000


class InputWrapper(gym.ObservationWrapper):
    """
    Preprocessing of input numpy array:
    1. resize image into predefined size
    2. move color channel axis to a first place
    """
    def __init__(self, *args):
        super(InputWrapper, self).__init__(*args)
        assert isinstance(self.observation_space, gym.spaces.Box)
        old_space = self.observation_space
        self.observation_space = gym.spaces.Box(
            self.observation(old_space.low),
            self.observation(old_space.high),
            dtype=np.float32)

    def observation(self, observation):
        # resize image
        new_obs = cv2.resize(
            observation, (IMAGE_SIZE, IMAGE_SIZE))
        # transform (210, 160, 3) -> (3, 210, 160)
        new_obs = np.moveaxis(new_obs, 2, 0)
        return new_obs.astype(np.float32)


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        # this pipe converges image into the single number
        self.conv_pipe = nn.Sequential(
            nn.Conv2d(in_channels=input_shape[0], out_channels=DISCR_FILTERS,
                      kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS, out_channels=DISCR_FILTERS*2,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS*2),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 2, out_channels=DISCR_FILTERS * 4,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS * 4),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 4, out_channels=DISCR_FILTERS * 8,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS * 8),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 8, out_channels=1,
                      kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        conv_out = self.conv_pipe(x)
        return conv_out.view(-1, 1).squeeze(dim=1)


class Generator(nn.Module):
    def __init__(self, output_shape):
        super(Generator, self).__init__()
        # pipe deconvolves input vector into (3, 64, 64) image
        self.pipe = nn.Sequential(
            nn.ConvTranspose2d(in_channels=LATENT_VECTOR_SIZE, out_channels=GENER_FILTERS * 8,
                               kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(GENER_FILTERS * 8),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 8, out_channels=GENER_FILTERS * 4,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 4, out_channels=GENER_FILTERS * 2,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 2, out_channels=GENER_FILTERS,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS, out_channels=output_shape[0],
                               kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.pipe(x)


def iterate_batches(envs, batch_size=BATCH_SIZE):
    batch = [e.reset() for e in envs]
    env_gen = iter(lambda: random.choice(envs), None)

    while True:
        e = next(env_gen)
        obs, reward, is_done, _ = e.step(e.action_space.sample())
        if np.mean(obs) > 0.01:
            batch.append(obs)
        if len(batch) == batch_size:
            # Normalising input between -1 to 1
            batch_np = np.array(batch, dtype=np.float32) * 2.0 / 255.0 - 1.0
            yield torch.tensor(batch_np)
            batch.clear()
        if is_done:
            e.reset()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cuda", default=False, action='store_true',
        help="Enable cuda computation")
    args = parser.parse_args(args={})

    device = torch.device("cuda" if args.cuda else "cpu")
    envs = [
        InputWrapper(gym.make(name))
        for name in ('Breakout-v0', 'AirRaid-v0', 'Pong-v0')
    ]
    input_shape = envs[0].observation_space.shape

    net_discr = Discriminator(input_shape=input_shape).to(device)
    net_gener = Generator(output_shape=input_shape).to(device)

    objective = nn.BCELoss()
    gen_optimizer = optim.Adam(
        params=net_gener.parameters(), lr=LEARNING_RATE,
        betas=(0.5, 0.999))
    dis_optimizer = optim.Adam(
        params=net_discr.parameters(), lr=LEARNING_RATE,
        betas=(0.5, 0.999))
    writer = SummaryWriter()

    gen_losses = []
    dis_losses = []
    iter_no = 0

    true_labels_v = torch.ones(BATCH_SIZE, device=device)
    fake_labels_v = torch.zeros(BATCH_SIZE, device=device)

    for batch_v in iterate_batches(envs):
        # fake samples, input is 4D: batch, filters, x, y
        gen_input_v = torch.FloatTensor(
            BATCH_SIZE, LATENT_VECTOR_SIZE, 1, 1)
        gen_input_v.normal_(0, 1)
        gen_input_v = gen_input_v.to(device)
        batch_v = batch_v.to(device)
        gen_output_v = net_gener(gen_input_v)

        # train discriminator
        dis_optimizer.zero_grad()
        dis_output_true_v = net_discr(batch_v)
        dis_output_fake_v = net_discr(gen_output_v.detach())
        dis_loss = objective(dis_output_true_v, true_labels_v) + \
                   objective(dis_output_fake_v, fake_labels_v)
        dis_loss.backward()
        dis_optimizer.step()
        dis_losses.append(dis_loss.item())

        # train generator
        gen_optimizer.zero_grad()
        dis_output_v = net_discr(gen_output_v)
        gen_loss_v = objective(dis_output_v, true_labels_v)
        gen_loss_v.backward()
        gen_optimizer.step()
        gen_losses.append(gen_loss_v.item())

        iter_no += 1
        if iter_no % REPORT_EVERY_ITER == 0:
            log.info("Iter %d: gen_loss=%.3e, dis_loss=%.3e",
                     iter_no, np.mean(gen_losses),
                     np.mean(dis_losses))
            writer.add_scalar(
                "gen_loss", np.mean(gen_losses), iter_no)
            writer.add_scalar(
                "dis_loss", np.mean(dis_losses), iter_no)
            gen_losses = []
            dis_losses = []
        if iter_no % SAVE_IMAGE_EVERY_ITER == 0:
            writer.add_image("fake", vutils.make_grid(
                gen_output_v.data[:64], normalize=True), iter_no)
            writer.add_image("real", vutils.make_grid(
                batch_v.data[:64], normalize=True), iter_no)
INFO: Making new env: Breakout-v0
INFO: Making new env: AirRaid-v0
INFO: Making new env: Pong-v0
INFO: Iter 100: gen_loss=5.454e+00, dis_loss=5.009e-02
INFO: Iter 200: gen_loss=7.054e+00, dis_loss=4.306e-03
INFO: Iter 300: gen_loss=7.568e+00, dis_loss=2.140e-03
INFO: Iter 400: gen_loss=7.842e+00, dis_loss=1.272e-03
INFO: Iter 500: gen_loss=8.155e+00, dis_loss=1.019e-03
INFO: Iter 600: gen_loss=8.442e+00, dis_loss=6.918e-04
INFO: Iter 700: gen_loss=8.560e+00, dis_loss=5.483e-04
INFO: Iter 800: gen_loss=9.014e+00, dis_loss=4.792e-04
INFO: Iter 900: gen_loss=7.517e+00, dis_loss=2.132e-01
INFO: Iter 1000: gen_loss=7.375e+00, dis_loss=1.050e-01
INFO: Iter 1100: gen_loss=6.722e+00, dis_loss=1.718e-02
INFO: Iter 1200: gen_loss=6.346e+00, dis_loss=6.303e-03
INFO: Iter 1300: gen_loss=6.636e+00, dis_loss=6.348e-03
INFO: Iter 1400: gen_loss=6.612e+00, dis_loss=7.664e-02
INFO: Iter 1500: gen_loss=6.028e+00, dis_loss=7.801e-03
INFO: Iter 1600: gen_loss=6.665e+00, dis_loss=3.651e-03
INFO: Iter 1700: gen_loss=7.290e+00, dis_loss=5.616e-02
INFO: Iter 1800: gen_loss=6.314e+00, dis_loss=7.723e-02
INFO: Iter 1900: gen_loss=5.940e+00, dis_loss=3.784e-01
INFO: Iter 2000: gen_loss=5.053e+00, dis_loss=2.623e-01
INFO: Iter 2100: gen_loss=5.465e+00, dis_loss=9.114e-02
INFO: Iter 2200: gen_loss=5.480e+00, dis_loss=3.963e-01
INFO: Iter 2300: gen_loss=4.549e+00, dis_loss=2.361e-01
INFO: Iter 2400: gen_loss=5.407e+00, dis_loss=1.310e-01
INFO: Iter 2500: gen_loss=5.766e+00, dis_loss=5.550e-02
INFO: Iter 2600: gen_loss=5.816e+00, dis_loss=1.418e-01
INFO: Iter 2700: gen_loss=6.737e+00, dis_loss=5.231e-02
INFO: Iter 2800: gen_loss=7.147e+00, dis_loss=1.491e-01
INFO: Iter 2900: gen_loss=6.541e+00, dis_loss=2.155e-02
INFO: Iter 3000: gen_loss=7.072e+00, dis_loss=1.127e-01
INFO: Iter 3100: gen_loss=6.137e+00, dis_loss=6.138e-02
INFO: Iter 3200: gen_loss=7.406e+00, dis_loss=3.540e-02
INFO: Iter 3300: gen_loss=7.850e+00, dis_loss=5.691e-03
INFO: Iter 3400: gen_loss=8.614e+00, dis_loss=7.228e-03
INFO: Iter 3500: gen_loss=8.885e+00, dis_loss=3.191e-03
INFO: Iter 3600: gen_loss=5.367e+00, dis_loss=5.296e-01
INFO: Iter 3700: gen_loss=4.176e+00, dis_loss=3.335e-01
INFO: Iter 3800: gen_loss=5.174e+00, dis_loss=2.732e-01
INFO: Iter 3900: gen_loss=5.492e+00, dis_loss=1.298e-01
INFO: Iter 4000: gen_loss=6.570e+00, dis_loss=1.961e-02
INFO: Iter 4100: gen_loss=7.011e+00, dis_loss=2.517e-02
INFO: Iter 4200: gen_loss=8.362e+00, dis_loss=4.330e-03
INFO: Iter 4300: gen_loss=6.908e+00, dis_loss=2.161e-01
INFO: Iter 4400: gen_loss=5.226e+00, dis_loss=2.762e-01
INFO: Iter 4500: gen_loss=4.998e+00, dis_loss=2.893e-01
INFO: Iter 4600: gen_loss=5.078e+00, dis_loss=3.962e-01
INFO: Iter 4700: gen_loss=4.886e+00, dis_loss=1.932e-01
INFO: Iter 4800: gen_loss=6.110e+00, dis_loss=7.615e-02
INFO: Iter 4900: gen_loss=5.402e+00, dis_loss=1.634e-01
INFO: Iter 5000: gen_loss=5.336e+00, dis_loss=1.919e-01
INFO: Iter 5100: gen_loss=5.749e+00, dis_loss=8.817e-02
INFO: Iter 5200: gen_loss=5.879e+00, dis_loss=1.182e-01
INFO: Iter 5300: gen_loss=5.417e+00, dis_loss=1.651e-01
INFO: Iter 5400: gen_loss=6.747e+00, dis_loss=3.846e-02
INFO: Iter 5500: gen_loss=5.133e+00, dis_loss=1.996e-01
INFO: Iter 5600: gen_loss=6.116e+00, dis_loss=2.946e-01
INFO: Iter 5700: gen_loss=5.858e+00, dis_loss=2.152e-02

Ignite

import random
import argparse
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
from ignite.engine import Engine, Events
from ignite.metrics import RunningAverage
from ignite.contrib.handlers import tensorboard_logger as tb_logger

import torchvision.utils as vutils

import gym
import gym.spaces

import numpy as np

log = gym.logger
log.set_level(gym.logger.INFO)

LATENT_VECTOR_SIZE = 100
DISCR_FILTERS = 64
GENER_FILTERS = 64
BATCH_SIZE = 16

# dimension input image will be rescaled
IMAGE_SIZE = 64

LEARNING_RATE = 0.0001
REPORT_EVERY_ITER = 100
SAVE_IMAGE_EVERY_ITER = 1000


class InputWrapper(gym.ObservationWrapper):
    """
    Preprocessing of input numpy array:
    1. resize image into predefined size
    2. move color channel axis to a first place
    """
    def __init__(self, *args):
        super(InputWrapper, self).__init__(*args)
        assert isinstance(self.observation_space, gym.spaces.Box)
        old_space = self.observation_space
        self.observation_space = gym.spaces.Box(self.observation(old_space.low), self.observation(old_space.high),
                                                dtype=np.float32)

    def observation(self, observation):
        # resize image
        new_obs = cv2.resize(observation, (IMAGE_SIZE, IMAGE_SIZE))
        # transform (210, 160, 3) -> (3, 210, 160)
        new_obs = np.moveaxis(new_obs, 2, 0)
        return new_obs.astype(np.float32)


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        # this pipe converges image into the single number
        self.conv_pipe = nn.Sequential(
            nn.Conv2d(in_channels=input_shape[0], out_channels=DISCR_FILTERS,
                      kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS, out_channels=DISCR_FILTERS*2,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS*2),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 2, out_channels=DISCR_FILTERS * 4,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS * 4),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 4, out_channels=DISCR_FILTERS * 8,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS * 8),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 8, out_channels=1,
                      kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        conv_out = self.conv_pipe(x)
        return conv_out.view(-1, 1).squeeze(dim=1)


class Generator(nn.Module):
    def __init__(self, output_shape):
        super(Generator, self).__init__()
        # pipe deconvolves input vector into (3, 64, 64) image
        self.pipe = nn.Sequential(
            nn.ConvTranspose2d(in_channels=LATENT_VECTOR_SIZE, out_channels=GENER_FILTERS * 8,
                               kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(GENER_FILTERS * 8),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 8, out_channels=GENER_FILTERS * 4,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 4, out_channels=GENER_FILTERS * 2,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 2, out_channels=GENER_FILTERS,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS, out_channels=output_shape[0],
                               kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.pipe(x)


def iterate_batches(envs, batch_size=BATCH_SIZE):
    batch = [e.reset() for e in envs]
    env_gen = iter(lambda: random.choice(envs), None)

    while True:
        e = next(env_gen)
        obs, reward, is_done, _ = e.step(e.action_space.sample())
        if np.mean(obs) > 0.01:
            batch.append(obs)
        if len(batch) == batch_size:
            # Normalising input between -1 to 1
            batch_np = np.array(batch, dtype=np.float32) * 2.0 / 255.0 - 1.0
            yield torch.tensor(batch_np)
            batch.clear()
        if is_done:
            e.reset()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda", default=False, action='store_true', help="Enable cuda computation")
    args = parser.parse_args(args={})

    device = torch.device("cuda" if args.cuda else "cpu")
    # envs = [InputWrapper(gym.make(name)) for name in ('Breakout-v0', 'AirRaid-v0', 'Pong-v0')]
    envs = [InputWrapper(gym.make(name)) for name in ['Breakout-v0']]
    input_shape = envs[0].observation_space.shape

    net_discr = Discriminator(input_shape=input_shape).to(device)
    net_gener = Generator(output_shape=input_shape).to(device)

    objective = nn.BCELoss()
    gen_optimizer = optim.Adam(params=net_gener.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    dis_optimizer = optim.Adam(params=net_discr.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

    true_labels_v = torch.ones(BATCH_SIZE, device=device)
    fake_labels_v = torch.zeros(BATCH_SIZE, device=device)

    def process_batch(trainer, batch):
        gen_input_v = torch.FloatTensor(
            BATCH_SIZE, LATENT_VECTOR_SIZE, 1, 1)
        gen_input_v.normal_(0, 1)
        gen_input_v = gen_input_v.to(device)
        batch_v = batch.to(device)
        gen_output_v = net_gener(gen_input_v)

        # train discriminator
        dis_optimizer.zero_grad()
        dis_output_true_v = net_discr(batch_v)
        dis_output_fake_v = net_discr(gen_output_v.detach())
        dis_loss = objective(dis_output_true_v, true_labels_v) + \
                   objective(dis_output_fake_v, fake_labels_v)
        dis_loss.backward()
        dis_optimizer.step()

        # train generator
        gen_optimizer.zero_grad()
        dis_output_v = net_discr(gen_output_v)
        gen_loss = objective(dis_output_v, true_labels_v)
        gen_loss.backward()
        gen_optimizer.step()

        if trainer.state.iteration % SAVE_IMAGE_EVERY_ITER == 0:
            fake_img = vutils.make_grid(
                gen_output_v.data[:64], normalize=True)
            trainer.tb.writer.add_image(
                "fake", fake_img, trainer.state.iteration)
            real_img = vutils.make_grid(
                batch_v.data[:64], normalize=True)
            trainer.tb.writer.add_image(
                "real", real_img, trainer.state.iteration)
            trainer.tb.writer.flush()
        return dis_loss.item(), gen_loss.item()

    engine = Engine(process_batch)
    tb = tb_logger.TensorboardLogger(log_dir=None)
    engine.tb = tb
    RunningAverage(output_transform=lambda out: out[1]).\
        attach(engine, "avg_loss_gen")
    RunningAverage(output_transform=lambda out: out[0]).\
        attach(engine, "avg_loss_dis")

    handler = tb_logger.OutputHandler(tag="train",
        metric_names=['avg_loss_gen', 'avg_loss_dis'])
    tb.attach(engine, log_handler=handler,
              event_name=Events.ITERATION_COMPLETED)

    @engine.on(Events.ITERATION_COMPLETED)
    def log_losses(trainer):
        if trainer.state.iteration % REPORT_EVERY_ITER == 0:
            log.info("%d: gen_loss=%f, dis_loss=%f",
                     trainer.state.iteration,
                     trainer.state.metrics['avg_loss_gen'],
                     trainer.state.metrics['avg_loss_dis'])

    engine.run(data=iterate_batches(envs))
INFO: Making new env: Breakout-v0
INFO: 100: gen_loss=5.327549, dis_loss=0.200626
INFO: 200: gen_loss=6.850880, dis_loss=0.028281
INFO: 300: gen_loss=7.435633, dis_loss=0.004672
INFO: 400: gen_loss=7.708136, dis_loss=0.001331
INFO: 500: gen_loss=8.000729, dis_loss=0.000699
INFO: 600: gen_loss=8.314868, dis_loss=0.000474
INFO: 700: gen_loss=8.620416, dis_loss=0.000328
INFO: 800: gen_loss=8.779677, dis_loss=0.000286
INFO: 900: gen_loss=8.907359, dis_loss=0.000267
INFO: 1000: gen_loss=6.822098, dis_loss=0.558477
INFO: 1100: gen_loss=6.491067, dis_loss=0.079029
INFO: 1200: gen_loss=6.794054, dis_loss=0.012632
INFO: 1300: gen_loss=7.230944, dis_loss=0.002747
INFO: 1400: gen_loss=7.698738, dis_loss=0.000962
INFO: 1500: gen_loss=8.162801, dis_loss=0.000497
INFO: 1600: gen_loss=8.546710, dis_loss=0.000326
INFO: 1700: gen_loss=8.939020, dis_loss=0.000224
INFO: 1800: gen_loss=9.027502, dis_loss=0.000216
INFO: 1900: gen_loss=9.230650, dis_loss=0.000172
INFO: 2000: gen_loss=9.495270, dis_loss=0.000129
INFO: 2100: gen_loss=9.700210, dis_loss=0.000104
INFO: 2200: gen_loss=9.862649, dis_loss=0.000086
INFO: 2300: gen_loss=10.042667, dis_loss=0.000075
INFO: 2400: gen_loss=10.333560, dis_loss=0.000052
INFO: 2500: gen_loss=10.437976, dis_loss=0.000045
INFO: 2600: gen_loss=10.592011, dis_loss=0.000040
INFO: 2700: gen_loss=10.633485, dis_loss=0.000039
INFO: 2800: gen_loss=10.627324, dis_loss=0.000036
INFO: 2900: gen_loss=10.665850, dis_loss=0.000036
INFO: 3000: gen_loss=10.712931, dis_loss=0.000036
INFO: 3100: gen_loss=10.853663, dis_loss=0.000030
INFO: 3200: gen_loss=10.868406, dis_loss=0.000030
INFO: 3300: gen_loss=10.904878, dis_loss=0.000027
INFO: 3400: gen_loss=11.031057, dis_loss=0.000022
INFO: 3500: gen_loss=11.114413, dis_loss=0.000022
INFO: 3600: gen_loss=11.334650, dis_loss=0.000018
INFO: 3700: gen_loss=11.537755, dis_loss=0.000013
INFO: 3800: gen_loss=11.573673, dis_loss=0.000015
INFO: 3900: gen_loss=11.594438, dis_loss=0.000013
INFO: 4000: gen_loss=11.650991, dis_loss=0.000012
INFO: 4100: gen_loss=11.350557, dis_loss=0.000023
INFO: 4200: gen_loss=11.715774, dis_loss=0.000012
INFO: 4300: gen_loss=11.970108, dis_loss=0.000008
INFO: 4400: gen_loss=12.142686, dis_loss=0.000007
INFO: 4500: gen_loss=12.200508, dis_loss=0.000007
INFO: 4600: gen_loss=12.209455, dis_loss=0.000006
INFO: 4700: gen_loss=12.215595, dis_loss=0.000007
INFO: 4800: gen_loss=12.352226, dis_loss=0.000006
INFO: 4900: gen_loss=12.434466, dis_loss=0.000006
INFO: 5000: gen_loss=12.517082, dis_loss=0.000005
INFO: 5100: gen_loss=12.604175, dis_loss=0.000005
INFO: 5200: gen_loss=12.744095, dis_loss=0.000004
INFO: 5300: gen_loss=12.880165, dis_loss=0.000004
INFO: 5400: gen_loss=12.999031, dis_loss=0.000003

Render environments in Colab

Alternative 1

It is possible to visualize the game your agent is playing, even on CoLab. This section provides information on how to generate a video in CoLab that shows you an episode of the game your agent is playing. This video process is based on suggestions found here.

Begin by installing pyvirtualdisplay and python-opengl.

!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1

!apt-get update > /dev/null 2>&1
!apt-get install cmake > /dev/null 2>&1
!pip install --upgrade setuptools 2>&1
!pip install ez_setup > /dev/null 2>&1
!pip install gym[atari] > /dev/null 2>&1

!wget http://www.atarimania.com/roms/Roms.rar
!mkdir /content/ROM/
!unrar e /content/Roms.rar /content/ROM/
!python -m atari_py.import_roms /content/ROM/
import gym
from gym.wrappers import Monitor
import glob
import io
import base64
from IPython.display import HTML
from pyvirtualdisplay import Display
from IPython import display as ipythondisplay

display = Display(visible=0, size=(1400, 900))
display.start()

"""
Utility functions to enable video recording of gym environment 
and displaying it.
To enable video, just do "env = wrap_env(env)""
"""

def show_video():
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")
    

def wrap_env(env):
  env = Monitor(env, './video', force=True)
  return env
#env = wrap_env(gym.make("MountainCar-v0"))
env = wrap_env(gym.make("Atlantis-v0"))

observation = env.reset()

while True:
    env.render()
    #your agent goes here
    action = env.action_space.sample() 
    observation, reward, done, info = env.step(action)
    if done: 
      break;
            
env.close()
show_video()

Alternative 2

!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!pip install colabgymrender
import gym
from colabgymrender.recorder import Recorder

env = gym.make("Breakout-v0")
directory = './video'
env = Recorder(env, directory)

observation = env.reset()
terminal = False
while not terminal:
  action = env.action_space.sample()
  observation, reward, terminal, info = env.step(action)

env.play()