Batch-Constrained Deep Q-Learning

Current off-policy deep reinforcement learning algorithms fail to address extrapolation error by selecting actions with respect to a learned value estimate, without consideration of the accuracy of the estimate. As a result, certain outof-distribution actions can be erroneously extrapolated to higher values. However, the value of an off-policy agent can be accurately evaluated in regions where data is available.

Batch-Constrained deep Q-learning (BCQ), uses a state-conditioned generative model to produce only previously seen actions. This generative model is combined with a Q-network, to select the highest valued action which is similar to the data in the batch. Unlike any previous continuous control deep reinforcement learning algorithms, BCQ is able to learn successfully without interacting with the environment by considering extrapolation error.

BCQ is based on a simple idea: to avoid extrapolation error a policy should induce a similar state-action visitation to the batch. We denote policies which satisfy this notion as batch-constrained. To optimize off-policy learning for a given batch, batch-constrained policies are trained to select actions with respect to three objectives:

  1. Minimize the distance of selected actions to the data in the batch.

  2. Lead to states where familiar data can be observed.

  3. Maximize the value function.

Setup

!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1

!apt-get update > /dev/null 2>&1
!apt-get install cmake > /dev/null 2>&1
!pip install --upgrade setuptools 2>&1
!pip install ez_setup > /dev/null 2>&1
!pip install gym[atari] > /dev/null 2>&1

!wget http://www.atarimania.com/roms/Roms.rar
!mkdir /content/ROM/
!unrar e /content/Roms.rar /content/ROM/
!python -m atari_py.import_roms /content/ROM/

Restart the runtime. Required.

import gym
from gym.wrappers import Monitor
import glob
import io
import base64
from IPython.display import HTML
from pyvirtualdisplay import Display
from IPython import display as ipythondisplay

display = Display(visible=0, size=(1400, 900))
display.start()

"""
Utility functions to enable video recording of gym environment 
and displaying it.
To enable video, just do "env = wrap_env(env)""
"""

def show_video():
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")
    

def wrap_env(env):
  env = Monitor(env, './video', force=True)
  return env
import cv2
import gym
import numpy as np
import torch
import importlib
import json
import os

import copy
import torch
import torch.nn as nn
import torch.nn.functional as F

Params

class Args:

    # env = "PongNoFrameskip-v0" # OpenAI gym environment name
    env = "CartPole-v0" # OpenAI gym environment name
    seed = 0 # Sets Gym, PyTorch and Numpy seeds
    buffer_name = "Default" # Prepends name to filename
    max_timesteps = 1e4 # Max time steps to run environment or train for
    BCQ_threshold = 0.3 # Threshold hyper-parameter for BCQ
    low_noise_p = 0.2 # Probability of a low noise episode when generating buffer
    rand_action_p = 0.2 # Probability of taking a random action when generating buffer, during non-low noise episode

    # Atari Specific
    atari_preprocessing = {
        "frame_skip": 4,
        "frame_size": 84,
        "state_history": 4,
        "done_on_life_loss": False,
        "reward_clipping": True,
        "max_episode_timesteps": 27e3
    }
    
    atari_parameters = {
		# Exploration
		"start_timesteps": 2e4,
		"initial_eps": 1,
		"end_eps": 1e-2,
		"eps_decay_period": 25e4,
		# Evaluation
		"eval_freq": 5e4,
		"eval_eps": 1e-3,
		# Learning
		"discount": 0.99,
		"buffer_size": 1e6,
		"batch_size": 32,
		"optimizer": "Adam",
		"optimizer_parameters": {
			"lr": 0.0000625,
			"eps": 0.00015
		},
		"train_freq": 4,
		"polyak_target_update": False,
		"target_update_freq": 8e3,
		"tau": 1
	}
    
    regular_parameters = {
		# Exploration
		"start_timesteps": 1e3,
		"initial_eps": 0.1,
		"end_eps": 0.1,
		"eps_decay_period": 1,
		# Evaluation
		"eval_freq": 5e3,
		"eval_eps": 0,
		# Learning
		"discount": 0.99,
		"buffer_size": 1e6,
		"batch_size": 64,
		"optimizer": "Adam",
		"optimizer_parameters": {
			"lr": 3e-4
		},
		"train_freq": 1,
		"polyak_target_update": True,
		"target_update_freq": 1,
		"tau": 0.005
	}


args = Args()
if not os.path.exists("./results"):
    os.makedirs("./results")

if not os.path.exists("./models"):
    os.makedirs("./models")

if not os.path.exists("./buffers"):
    os.makedirs("./buffers")
# Set seeds
torch.manual_seed(args.seed)
np.random.seed(args.seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Replay buffer

def ReplayBuffer(state_dim, is_atari, atari_preprocessing, batch_size, buffer_size, device):
	if is_atari: 
		return AtariBuffer(state_dim, atari_preprocessing, batch_size, buffer_size, device)
	else: 
		return StandardBuffer(state_dim, batch_size, buffer_size, device)
  

class AtariBuffer(object):
	def __init__(self, state_dim, atari_preprocessing, batch_size, buffer_size, device):
		self.batch_size = batch_size
		self.max_size = int(buffer_size)
		self.device = device

		self.state_history = atari_preprocessing["state_history"]

		self.ptr = 0
		self.crt_size = 0

		self.state = np.zeros((
			self.max_size + 1,
			atari_preprocessing["frame_size"],
			atari_preprocessing["frame_size"]
		), dtype=np.uint8)

		self.action = np.zeros((self.max_size, 1), dtype=np.int64)
		self.reward = np.zeros((self.max_size, 1))
		
		# not_done only consider "done" if episode terminates due to failure condition
		# if episode terminates due to timelimit, the transition is not added to the buffer
		self.not_done = np.zeros((self.max_size, 1))
		self.first_timestep = np.zeros(self.max_size, dtype=np.uint8)


	def add(self, state, action, next_state, reward, done, env_done, first_timestep):
		# If dones don't match, env has reset due to timelimit
		# and we don't add the transition to the buffer
		if done != env_done:
			return

		self.state[self.ptr] = state[0]
		self.action[self.ptr] = action
		self.reward[self.ptr] = reward
		self.not_done[self.ptr] = 1. - done
		self.first_timestep[self.ptr] = first_timestep

		self.ptr = (self.ptr + 1) % self.max_size
		self.crt_size = min(self.crt_size + 1, self.max_size)


	def sample(self):
		ind = np.random.randint(0, self.crt_size, size=self.batch_size)

		# Note + is concatenate here
		state = np.zeros(((self.batch_size, self.state_history) + self.state.shape[1:]), dtype=np.uint8)
		next_state = np.array(state)

		state_not_done = 1.
		next_not_done = 1.
		for i in range(self.state_history):

			# Wrap around if the buffer is filled
			if self.crt_size == self.max_size:
				j = (ind - i) % self.max_size
				k = (ind - i + 1) % self.max_size
			else:
				j = ind - i
				k = (ind - i + 1).clip(min=0)
				# If j == -1, then we set state_not_done to 0.
				state_not_done *= (j + 1).clip(min=0, max=1).reshape(-1, 1, 1) #np.where(j < 0, state_not_done * 0, state_not_done)
				j = j.clip(min=0)

			# State should be all 0s if the episode terminated previously
			state[:, i] = self.state[j] * state_not_done
			next_state[:, i] = self.state[k] * next_not_done

			# If this was the first timestep, make everything previous = 0
			next_not_done *= state_not_done
			state_not_done *= (1. - self.first_timestep[j]).reshape(-1, 1, 1)

		return (
			torch.ByteTensor(state).to(self.device).float(),
			torch.LongTensor(self.action[ind]).to(self.device),
			torch.ByteTensor(next_state).to(self.device).float(),
			torch.FloatTensor(self.reward[ind]).to(self.device),
			torch.FloatTensor(self.not_done[ind]).to(self.device)
		)


	def save(self, save_folder, chunk=int(1e5)):
		np.save(f"{save_folder}_action.npy", self.action[:self.crt_size])
		np.save(f"{save_folder}_reward.npy", self.reward[:self.crt_size])
		np.save(f"{save_folder}_not_done.npy", self.not_done[:self.crt_size])
		np.save(f"{save_folder}_first_timestep.npy", self.first_timestep[:self.crt_size])
		np.save(f"{save_folder}_replay_info.npy", [self.ptr, chunk])

		crt = 0
		end = min(chunk, self.crt_size + 1)
		while crt < self.crt_size + 1:
			np.save(f"{save_folder}_state_{end}.npy", self.state[crt:end])
			crt = end
			end = min(end + chunk, self.crt_size + 1)


	def load(self, save_folder, size=-1):
		reward_buffer = np.load(f"{save_folder}_reward.npy")
		size = min(int(size), self.max_size) if size > 0 else self.max_size
		self.crt_size = min(reward_buffer.shape[0], size)
		
		# Adjust crt_size if we're using a custom size
		size = min(int(size), self.max_size) if size > 0 else self.max_size
		self.crt_size = min(reward_buffer.shape[0], size)

		self.action[:self.crt_size] = np.load(f"{save_folder}_action.npy")[:self.crt_size]
		self.reward[:self.crt_size] = reward_buffer[:self.crt_size]
		self.not_done[:self.crt_size] = np.load(f"{save_folder}_not_done.npy")[:self.crt_size]
		self.first_timestep[:self.crt_size] = np.load(f"{save_folder}_first_timestep.npy")[:self.crt_size]

		self.ptr, chunk = np.load(f"{save_folder}_replay_info.npy")

		crt = 0
		end = min(chunk, self.crt_size + 1)
		while crt < self.crt_size + 1:
			self.state[crt:end] = np.load(f"{save_folder}_state_{end}.npy")
			crt = end
			end = min(end + chunk, self.crt_size + 1)


# Generic replay buffer for standard gym tasks
class StandardBuffer(object):
	def __init__(self, state_dim, batch_size, buffer_size, device):
		self.batch_size = batch_size
		self.max_size = int(buffer_size)
		self.device = device

		self.ptr = 0
		self.crt_size = 0

		self.state = np.zeros((self.max_size, state_dim))
		self.action = np.zeros((self.max_size, 1))
		self.next_state = np.array(self.state)
		self.reward = np.zeros((self.max_size, 1))
		self.not_done = np.zeros((self.max_size, 1))


	def add(self, state, action, next_state, reward, done, episode_done, episode_start):
		self.state[self.ptr] = state
		self.action[self.ptr] = action
		self.next_state[self.ptr] = next_state
		self.reward[self.ptr] = reward
		self.not_done[self.ptr] = 1. - done

		self.ptr = (self.ptr + 1) % self.max_size
		self.crt_size = min(self.crt_size + 1, self.max_size)


	def sample(self):
		ind = np.random.randint(0, self.crt_size, size=self.batch_size)
		return (
			torch.FloatTensor(self.state[ind]).to(self.device),
			torch.LongTensor(self.action[ind]).to(self.device),
			torch.FloatTensor(self.next_state[ind]).to(self.device),
			torch.FloatTensor(self.reward[ind]).to(self.device),
			torch.FloatTensor(self.not_done[ind]).to(self.device)
		)


	def save(self, save_folder):
		np.save(f"{save_folder}_state.npy", self.state[:self.crt_size])
		np.save(f"{save_folder}_action.npy", self.action[:self.crt_size])
		np.save(f"{save_folder}_next_state.npy", self.next_state[:self.crt_size])
		np.save(f"{save_folder}_reward.npy", self.reward[:self.crt_size])
		np.save(f"{save_folder}_not_done.npy", self.not_done[:self.crt_size])
		np.save(f"{save_folder}_ptr.npy", self.ptr)


	def load(self, save_folder, size=-1):
		reward_buffer = np.load(f"{save_folder}_reward.npy")
		
		# Adjust crt_size if we're using a custom size
		size = min(int(size), self.max_size) if size > 0 else self.max_size
		self.crt_size = min(reward_buffer.shape[0], size)

		self.state[:self.crt_size] = np.load(f"{save_folder}_state.npy")[:self.crt_size]
		self.action[:self.crt_size] = np.load(f"{save_folder}_action.npy")[:self.crt_size]
		self.next_state[:self.crt_size] = np.load(f"{save_folder}_next_state.npy")[:self.crt_size]
		self.reward[:self.crt_size] = reward_buffer[:self.crt_size]
		self.not_done[:self.crt_size] = np.load(f"{save_folder}_not_done.npy")[:self.crt_size]

		print(f"Replay Buffer loaded with {self.crt_size} elements.")

Atari preprocessing

# Atari Preprocessing
# Code is based on https://github.com/openai/gym/blob/master/gym/wrappers/atari_preprocessing.py
class AtariPreprocessing(object):
	def __init__(
		self,
		env,
		frame_skip=4,
		frame_size=84,
		state_history=4,
		done_on_life_loss=False,
		reward_clipping=True, # Clips to a range of -1,1
		max_episode_timesteps=27000
	):
		self.env = env.env
		self.done_on_life_loss = done_on_life_loss
		self.frame_skip = frame_skip
		self.frame_size = frame_size
		self.reward_clipping = reward_clipping
		self._max_episode_steps = max_episode_timesteps
		self.observation_space = np.zeros((frame_size, frame_size))
		self.action_space = self.env.action_space

		self.lives = 0
		self.episode_length = 0

		# Tracks previous 2 frames
		self.frame_buffer = np.zeros(
			(2,
			self.env.observation_space.shape[0],
			self.env.observation_space.shape[1]),
			dtype=np.uint8
		)
		# Tracks previous 4 states
		self.state_buffer = np.zeros((state_history, frame_size, frame_size), dtype=np.uint8)


	def reset(self):
		self.env.reset()
		self.lives = self.env.ale.lives()
		self.episode_length = 0
		self.env.ale.getScreenGrayscale(self.frame_buffer[0])
		self.frame_buffer[1] = 0

		self.state_buffer[0] = self.adjust_frame()
		self.state_buffer[1:] = 0
		return self.state_buffer


	# Takes single action is repeated for frame_skip frames (usually 4)
	# Reward is accumulated over those frames
	def step(self, action):
		total_reward = 0.
		self.episode_length += 1

		for frame in range(self.frame_skip):
			_, reward, done, _ = self.env.step(action)
			total_reward += reward

			if self.done_on_life_loss:
				crt_lives = self.env.ale.lives()
				done = True if crt_lives < self.lives else done
				self.lives = crt_lives

			if done: 
				break

			# Second last and last frame
			f = frame + 2 - self.frame_skip 
			if f >= 0:
				self.env.ale.getScreenGrayscale(self.frame_buffer[f])

		self.state_buffer[1:] = self.state_buffer[:-1]
		self.state_buffer[0] = self.adjust_frame()

		done_float = float(done)
		if self.episode_length >= self._max_episode_steps:
			done = True

		return self.state_buffer, total_reward, done, [np.clip(total_reward, -1, 1), done_float]


	def adjust_frame(self):
		# Take maximum over last two frames
		np.maximum(
			self.frame_buffer[0],
			self.frame_buffer[1],
			out=self.frame_buffer[0]
		)

		# Resize
		image = cv2.resize(
			self.frame_buffer[0],
			(self.frame_size, self.frame_size),
			interpolation=cv2.INTER_AREA
		)
		return np.array(image, dtype=np.uint8)


	def seed(self, seed):
		self.env.seed(seed)

Create Environment

# Create environment, add wrapper if necessary and create env_properties
def make_env(env_name, atari_preprocessing):
	env = wrap_env(gym.make(env_name))
	
	is_atari = gym.envs.registry.spec(env_name).entry_point == 'gym.envs.atari:AtariEnv'
	env = AtariPreprocessing(env, **atari_preprocessing) if is_atari else env

	state_dim = (
		atari_preprocessing["state_history"], 
		atari_preprocessing["frame_size"], 
		atari_preprocessing["frame_size"]
	) if is_atari else env.observation_space.shape[0]

	return (
		env,
		is_atari,
		state_dim,
		env.action_space.n
	)

DQN

# Make env and determine properties
env, is_atari, state_dim, num_actions = make_env(args.env, args.atari_preprocessing)
parameters = args.atari_parameters if is_atari else args.regular_parameters


# Set seeds
env.seed(args.seed)
env.action_space.seed(args.seed)


# Initialize buffer
replay_buffer = ReplayBuffer(state_dim, is_atari, args.atari_preprocessing, parameters["batch_size"], parameters["buffer_size"], device)
# Used for Atari
class Conv_Q(nn.Module):
	def __init__(self, frames, num_actions):
		super(Conv_Q, self).__init__()
		self.c1 = nn.Conv2d(frames, 32, kernel_size=8, stride=4)
		self.c2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
		self.c3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
		self.l1 = nn.Linear(3136, 512)
		self.l2 = nn.Linear(512, num_actions)


	def forward(self, state):
		q = F.relu(self.c1(state))
		q = F.relu(self.c2(q))
		q = F.relu(self.c3(q))
		q = F.relu(self.l1(q.reshape(-1, 3136)))
		return self.l2(q)
# Used for Box2D / Toy problems
class FC_Q(nn.Module):
	def __init__(self, state_dim, num_actions):
		super(FC_Q, self).__init__()
		self.l1 = nn.Linear(state_dim, 256)
		self.l2 = nn.Linear(256, 256)
		self.l3 = nn.Linear(256, num_actions)


	def forward(self, state):
		q = F.relu(self.l1(state))
		q = F.relu(self.l2(q))
		return self.l3(q)
class DQN(object):
	def __init__(
		self, 
		is_atari,
		num_actions,
		state_dim,
		device,
		discount=0.99,
		optimizer="Adam",
		optimizer_parameters={},
		polyak_target_update=False,
		target_update_frequency=8e3,
		tau=0.005,
		initial_eps = 1,
		end_eps = 0.001,
		eps_decay_period = 25e4,
		eval_eps=0.001,
	):
	
		self.device = device

		# Determine network type
		self.Q = Conv_Q(state_dim[0], num_actions).to(self.device) if is_atari else FC_Q(state_dim, num_actions).to(self.device)
		self.Q_target = copy.deepcopy(self.Q)
		self.Q_optimizer = getattr(torch.optim, optimizer)(self.Q.parameters(), **optimizer_parameters)

		self.discount = discount

		# Target update rule
		self.maybe_update_target = self.polyak_target_update if polyak_target_update else self.copy_target_update
		self.target_update_frequency = target_update_frequency
		self.tau = tau

		# Decay for eps
		self.initial_eps = initial_eps
		self.end_eps = end_eps
		self.slope = (self.end_eps - self.initial_eps) / eps_decay_period

		# Evaluation hyper-parameters
		self.state_shape = (-1,) + state_dim if is_atari else (-1, state_dim)
		self.eval_eps = eval_eps
		self.num_actions = num_actions

		# Number of training iterations
		self.iterations = 0


	def select_action(self, state, eval=False):
		eps = self.eval_eps if eval \
			else max(self.slope * self.iterations + self.initial_eps, self.end_eps)

		# Select action according to policy with probability (1-eps)
		# otherwise, select random action
		if np.random.uniform(0,1) > eps:
			with torch.no_grad():
				state = torch.FloatTensor(state).reshape(self.state_shape).to(self.device)
				return int(self.Q(state).argmax(1))
		else:
			return np.random.randint(self.num_actions)


	def train(self, replay_buffer):
		# Sample replay buffer
		state, action, next_state, reward, done = replay_buffer.sample()

		# Compute the target Q value
		with torch.no_grad():
			target_Q = reward + done * self.discount * self.Q_target(next_state).max(1, keepdim=True)[0]

		# Get current Q estimate
		current_Q = self.Q(state).gather(1, action)

		# Compute Q loss
		Q_loss = F.smooth_l1_loss(current_Q, target_Q)

		# Optimize the Q
		self.Q_optimizer.zero_grad()
		Q_loss.backward()
		self.Q_optimizer.step()

		# Update target network by polyak or full copy every X iterations.
		self.iterations += 1
		self.maybe_update_target()


	def polyak_target_update(self):
		for param, target_param in zip(self.Q.parameters(), self.Q_target.parameters()):
		   target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


	def copy_target_update(self):
		if self.iterations % self.target_update_frequency == 0:
			 self.Q_target.load_state_dict(self.Q.state_dict())


	def save(self, filename):
		torch.save(self.Q.state_dict(), filename + "_Q")
		torch.save(self.Q_optimizer.state_dict(), filename + "_optimizer")


	def load(self, filename):
		self.Q.load_state_dict(torch.load(filename + "_Q"))
		self.Q_target = copy.deepcopy(self.Q)
		self.Q_optimizer.load_state_dict(torch.load(filename + "_optimizer"))
# Runs policy for X episodes and returns average reward
# A fixed seed is used for the eval environment
def eval_policy(policy, env_name, seed, eval_episodes=10):
	eval_env, _, _, _ = make_env(env_name, args.atari_preprocessing)
	eval_env.seed(seed + 100)

	avg_reward = 0.
	for _ in range(eval_episodes):
		state, done = eval_env.reset(), False
		while not done:
			action = policy.select_action(np.array(state), eval=True)
			state, reward, done, _ = eval_env.step(action)
			avg_reward += reward

	avg_reward /= eval_episodes

	print("---------------------------------------")
	print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
	print("---------------------------------------")
	return avg_reward
# For saving files
setting = f"{args.env}_{args.seed}"
buffer_name = f"{args.buffer_name}_{setting}"

# Initialize and load policy
policy = DQN(
    is_atari,
    num_actions,
    state_dim,
    device,
    parameters["discount"],
    parameters["optimizer"],
    parameters["optimizer_parameters"],
    parameters["polyak_target_update"],
    parameters["target_update_freq"],
    parameters["tau"],
    parameters["initial_eps"],
    parameters["end_eps"],
    parameters["eps_decay_period"],
    parameters["eval_eps"],
)

evaluations = []

state, done = env.reset(), False
episode_start = True
episode_reward = 0
episode_timesteps = 0
episode_num = 0
low_noise_ep = np.random.uniform(0,1) < args.low_noise_p
max_episode_steps = gym.make(args.env)._max_episode_steps

# Interact with the environment for max_timesteps
for t in range(int(args.max_timesteps)):

    episode_timesteps += 1

    if t < parameters["start_timesteps"]:
        action = env.action_space.sample()
    else:
        action = policy.select_action(np.array(state))

    # Perform action and log results
    next_state, reward, done, info = env.step(action)
    episode_reward += reward

    # Only consider "done" if episode terminates due to failure condition
    done_float = float(done) if episode_timesteps < max_episode_steps else 0

    # For atari, info[0] = clipped reward, info[1] = done_float
    if is_atari:
        reward = info[0]
        done_float = info[1]
        
    # Store data in replay buffer
    replay_buffer.add(state, action, next_state, reward, done_float, done, episode_start)
    state = copy.copy(next_state)
    episode_start = False

    # Train agent after collecting sufficient data
    if t >= parameters["start_timesteps"] and (t+1) % parameters["train_freq"] == 0:
        policy.train(replay_buffer)

    if done:
        # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True
        print(f"Total T: {t+1} Episode Num: {episode_num+1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}")
        # Reset environment
        state, done = env.reset(), False
        episode_start = True
        episode_reward = 0
        episode_timesteps = 0
        episode_num += 1
        low_noise_ep = np.random.uniform(0,1) < args.low_noise_p

    # Evaluate episode
    if (t + 1) % parameters["eval_freq"] == 0:
        evaluations.append(eval_policy(policy, args.env, args.seed))
        np.save(f"./results/behavioral_{setting}", evaluations)
        policy.save(f"./models/behavioral_{setting}")

# Save final policy
policy.save(f"./models/behavioral_{setting}")
Total T: 52 Episode Num: 1 Episode T: 52 Reward: 52.000
Total T: 107 Episode Num: 2 Episode T: 55 Reward: 55.000
Total T: 136 Episode Num: 3 Episode T: 29 Reward: 29.000
Total T: 156 Episode Num: 4 Episode T: 20 Reward: 20.000
Total T: 168 Episode Num: 5 Episode T: 12 Reward: 12.000
Total T: 187 Episode Num: 6 Episode T: 19 Reward: 19.000
Total T: 207 Episode Num: 7 Episode T: 20 Reward: 20.000
Total T: 220 Episode Num: 8 Episode T: 13 Reward: 13.000
Total T: 249 Episode Num: 9 Episode T: 29 Reward: 29.000
Total T: 268 Episode Num: 10 Episode T: 19 Reward: 19.000
Total T: 286 Episode Num: 11 Episode T: 18 Reward: 18.000
Total T: 306 Episode Num: 12 Episode T: 20 Reward: 20.000
Total T: 321 Episode Num: 13 Episode T: 15 Reward: 15.000
Total T: 368 Episode Num: 14 Episode T: 47 Reward: 47.000
Total T: 405 Episode Num: 15 Episode T: 37 Reward: 37.000
Total T: 429 Episode Num: 16 Episode T: 24 Reward: 24.000
Total T: 461 Episode Num: 17 Episode T: 32 Reward: 32.000
Total T: 475 Episode Num: 18 Episode T: 14 Reward: 14.000
Total T: 500 Episode Num: 19 Episode T: 25 Reward: 25.000
Total T: 519 Episode Num: 20 Episode T: 19 Reward: 19.000
Total T: 528 Episode Num: 21 Episode T: 9 Reward: 9.000
Total T: 546 Episode Num: 22 Episode T: 18 Reward: 18.000
Total T: 566 Episode Num: 23 Episode T: 20 Reward: 20.000
Total T: 596 Episode Num: 24 Episode T: 30 Reward: 30.000
Total T: 613 Episode Num: 25 Episode T: 17 Reward: 17.000
Total T: 627 Episode Num: 26 Episode T: 14 Reward: 14.000
Total T: 642 Episode Num: 27 Episode T: 15 Reward: 15.000
Total T: 670 Episode Num: 28 Episode T: 28 Reward: 28.000
Total T: 682 Episode Num: 29 Episode T: 12 Reward: 12.000
Total T: 693 Episode Num: 30 Episode T: 11 Reward: 11.000
Total T: 707 Episode Num: 31 Episode T: 14 Reward: 14.000
Total T: 727 Episode Num: 32 Episode T: 20 Reward: 20.000
Total T: 754 Episode Num: 33 Episode T: 27 Reward: 27.000
Total T: 784 Episode Num: 34 Episode T: 30 Reward: 30.000
Total T: 819 Episode Num: 35 Episode T: 35 Reward: 35.000
Total T: 849 Episode Num: 36 Episode T: 30 Reward: 30.000
Total T: 870 Episode Num: 37 Episode T: 21 Reward: 21.000
Total T: 885 Episode Num: 38 Episode T: 15 Reward: 15.000
Total T: 910 Episode Num: 39 Episode T: 25 Reward: 25.000
Total T: 925 Episode Num: 40 Episode T: 15 Reward: 15.000
Total T: 936 Episode Num: 41 Episode T: 11 Reward: 11.000
Total T: 949 Episode Num: 42 Episode T: 13 Reward: 13.000
Total T: 966 Episode Num: 43 Episode T: 17 Reward: 17.000
Total T: 987 Episode Num: 44 Episode T: 21 Reward: 21.000
Total T: 1001 Episode Num: 45 Episode T: 14 Reward: 14.000
Total T: 1011 Episode Num: 46 Episode T: 10 Reward: 10.000
Total T: 1021 Episode Num: 47 Episode T: 10 Reward: 10.000
Total T: 1029 Episode Num: 48 Episode T: 8 Reward: 8.000
Total T: 1056 Episode Num: 49 Episode T: 27 Reward: 27.000
Total T: 1065 Episode Num: 50 Episode T: 9 Reward: 9.000
Total T: 1075 Episode Num: 51 Episode T: 10 Reward: 10.000
Total T: 1085 Episode Num: 52 Episode T: 10 Reward: 10.000
Total T: 1096 Episode Num: 53 Episode T: 11 Reward: 11.000
Total T: 1104 Episode Num: 54 Episode T: 8 Reward: 8.000
Total T: 1112 Episode Num: 55 Episode T: 8 Reward: 8.000
Total T: 1121 Episode Num: 56 Episode T: 9 Reward: 9.000
Total T: 1131 Episode Num: 57 Episode T: 10 Reward: 10.000
Total T: 1142 Episode Num: 58 Episode T: 11 Reward: 11.000
Total T: 1154 Episode Num: 59 Episode T: 12 Reward: 12.000
Total T: 1166 Episode Num: 60 Episode T: 12 Reward: 12.000
Total T: 1176 Episode Num: 61 Episode T: 10 Reward: 10.000
Total T: 1185 Episode Num: 62 Episode T: 9 Reward: 9.000
Total T: 1195 Episode Num: 63 Episode T: 10 Reward: 10.000
Total T: 1205 Episode Num: 64 Episode T: 10 Reward: 10.000
Total T: 1216 Episode Num: 65 Episode T: 11 Reward: 11.000
Total T: 1225 Episode Num: 66 Episode T: 9 Reward: 9.000
Total T: 1234 Episode Num: 67 Episode T: 9 Reward: 9.000
Total T: 1244 Episode Num: 68 Episode T: 10 Reward: 10.000
Total T: 1252 Episode Num: 69 Episode T: 8 Reward: 8.000
Total T: 1262 Episode Num: 70 Episode T: 10 Reward: 10.000
Total T: 1272 Episode Num: 71 Episode T: 10 Reward: 10.000
Total T: 1282 Episode Num: 72 Episode T: 10 Reward: 10.000
Total T: 1291 Episode Num: 73 Episode T: 9 Reward: 9.000
Total T: 1300 Episode Num: 74 Episode T: 9 Reward: 9.000
Total T: 1310 Episode Num: 75 Episode T: 10 Reward: 10.000
Total T: 1320 Episode Num: 76 Episode T: 10 Reward: 10.000
Total T: 1330 Episode Num: 77 Episode T: 10 Reward: 10.000
Total T: 1343 Episode Num: 78 Episode T: 13 Reward: 13.000
Total T: 1352 Episode Num: 79 Episode T: 9 Reward: 9.000
Total T: 1361 Episode Num: 80 Episode T: 9 Reward: 9.000
Total T: 1371 Episode Num: 81 Episode T: 10 Reward: 10.000
Total T: 1380 Episode Num: 82 Episode T: 9 Reward: 9.000
Total T: 1390 Episode Num: 83 Episode T: 10 Reward: 10.000
Total T: 1403 Episode Num: 84 Episode T: 13 Reward: 13.000
Total T: 1416 Episode Num: 85 Episode T: 13 Reward: 13.000
Total T: 1426 Episode Num: 86 Episode T: 10 Reward: 10.000
Total T: 1438 Episode Num: 87 Episode T: 12 Reward: 12.000
Total T: 1449 Episode Num: 88 Episode T: 11 Reward: 11.000
Total T: 1457 Episode Num: 89 Episode T: 8 Reward: 8.000
Total T: 1468 Episode Num: 90 Episode T: 11 Reward: 11.000
Total T: 1478 Episode Num: 91 Episode T: 10 Reward: 10.000
Total T: 1498 Episode Num: 92 Episode T: 20 Reward: 20.000
Total T: 1508 Episode Num: 93 Episode T: 10 Reward: 10.000
Total T: 1519 Episode Num: 94 Episode T: 11 Reward: 11.000
Total T: 1528 Episode Num: 95 Episode T: 9 Reward: 9.000
Total T: 1549 Episode Num: 96 Episode T: 21 Reward: 21.000
Total T: 1560 Episode Num: 97 Episode T: 11 Reward: 11.000
Total T: 1569 Episode Num: 98 Episode T: 9 Reward: 9.000
Total T: 1590 Episode Num: 99 Episode T: 21 Reward: 21.000
Total T: 1642 Episode Num: 100 Episode T: 52 Reward: 52.000
Total T: 1692 Episode Num: 101 Episode T: 50 Reward: 50.000
Total T: 1828 Episode Num: 102 Episode T: 136 Reward: 136.000
Total T: 1977 Episode Num: 103 Episode T: 149 Reward: 149.000
Total T: 2112 Episode Num: 104 Episode T: 135 Reward: 135.000
Total T: 2259 Episode Num: 105 Episode T: 147 Reward: 147.000
Total T: 2400 Episode Num: 106 Episode T: 141 Reward: 141.000
Total T: 2588 Episode Num: 107 Episode T: 188 Reward: 188.000
Total T: 2761 Episode Num: 108 Episode T: 173 Reward: 173.000
Total T: 2961 Episode Num: 109 Episode T: 200 Reward: 200.000
Total T: 3112 Episode Num: 110 Episode T: 151 Reward: 151.000
Total T: 3305 Episode Num: 111 Episode T: 193 Reward: 193.000
Total T: 3505 Episode Num: 112 Episode T: 200 Reward: 200.000
Total T: 3674 Episode Num: 113 Episode T: 169 Reward: 169.000
Total T: 3837 Episode Num: 114 Episode T: 163 Reward: 163.000
Total T: 4005 Episode Num: 115 Episode T: 168 Reward: 168.000
Total T: 4161 Episode Num: 116 Episode T: 156 Reward: 156.000
Total T: 4361 Episode Num: 117 Episode T: 200 Reward: 200.000
Total T: 4561 Episode Num: 118 Episode T: 200 Reward: 200.000
Total T: 4735 Episode Num: 119 Episode T: 174 Reward: 174.000
Total T: 4919 Episode Num: 120 Episode T: 184 Reward: 184.000
---------------------------------------
Evaluation over 10 episodes: 195.400
---------------------------------------
Total T: 5119 Episode Num: 121 Episode T: 200 Reward: 200.000
Total T: 5292 Episode Num: 122 Episode T: 173 Reward: 173.000
Total T: 5454 Episode Num: 123 Episode T: 162 Reward: 162.000
Total T: 5606 Episode Num: 124 Episode T: 152 Reward: 152.000
Total T: 5806 Episode Num: 125 Episode T: 200 Reward: 200.000
Total T: 5980 Episode Num: 126 Episode T: 174 Reward: 174.000
Total T: 6155 Episode Num: 127 Episode T: 175 Reward: 175.000
Total T: 6351 Episode Num: 128 Episode T: 196 Reward: 196.000
Total T: 6537 Episode Num: 129 Episode T: 186 Reward: 186.000
Total T: 6707 Episode Num: 130 Episode T: 170 Reward: 170.000
Total T: 6849 Episode Num: 131 Episode T: 142 Reward: 142.000
Total T: 7014 Episode Num: 132 Episode T: 165 Reward: 165.000
Total T: 7189 Episode Num: 133 Episode T: 175 Reward: 175.000
Total T: 7382 Episode Num: 134 Episode T: 193 Reward: 193.000
Total T: 7554 Episode Num: 135 Episode T: 172 Reward: 172.000
Total T: 7754 Episode Num: 136 Episode T: 200 Reward: 200.000
Total T: 7922 Episode Num: 137 Episode T: 168 Reward: 168.000
Total T: 8092 Episode Num: 138 Episode T: 170 Reward: 170.000
Total T: 8292 Episode Num: 139 Episode T: 200 Reward: 200.000
Total T: 8459 Episode Num: 140 Episode T: 167 Reward: 167.000
Total T: 8650 Episode Num: 141 Episode T: 191 Reward: 191.000
Total T: 8724 Episode Num: 142 Episode T: 74 Reward: 74.000
Total T: 8924 Episode Num: 143 Episode T: 200 Reward: 200.000
Total T: 9074 Episode Num: 144 Episode T: 150 Reward: 150.000
Total T: 9264 Episode Num: 145 Episode T: 190 Reward: 190.000
Total T: 9430 Episode Num: 146 Episode T: 166 Reward: 166.000
Total T: 9630 Episode Num: 147 Episode T: 200 Reward: 200.000
Total T: 9820 Episode Num: 148 Episode T: 190 Reward: 190.000
---------------------------------------
Evaluation over 10 episodes: 196.700
---------------------------------------

Generate Buffer

# Make env and determine properties
env, is_atari, state_dim, num_actions = make_env(args.env, args.atari_preprocessing)
parameters = args.atari_parameters if is_atari else args.regular_parameters


# Set seeds
env.seed(args.seed)
env.action_space.seed(args.seed)


# Initialize buffer
replay_buffer = ReplayBuffer(state_dim, is_atari, args.atari_preprocessing, parameters["batch_size"], parameters["buffer_size"], device)
setting = f"{args.env}_{args.seed}"
buffer_name = f"{args.buffer_name}_{setting}"

# Initialize and load policy
policy = DQN(
    is_atari,
    num_actions,
    state_dim,
    device,
    parameters["discount"],
    parameters["optimizer"],
    parameters["optimizer_parameters"],
    parameters["polyak_target_update"],
    parameters["target_update_freq"],
    parameters["tau"],
    parameters["initial_eps"],
    parameters["end_eps"],
    parameters["eps_decay_period"],
    parameters["eval_eps"],
)

policy.load(f"./models/behavioral_{setting}")

evaluations = []

state, done = env.reset(), False
episode_start = True
episode_reward = 0
episode_timesteps = 0
episode_num = 0
low_noise_ep = np.random.uniform(0,1) < args.low_noise_p
max_episode_steps = gym.make(args.env)._max_episode_steps

# Interact with the environment for max_timesteps
for t in range(int(args.max_timesteps)):

    episode_timesteps += 1

    # If generating the buffer, episode is low noise with p=low_noise_p.
    # If policy is low noise, we take random actions with p=eval_eps.
    # If the policy is high noise, we take random actions with p=rand_action_p.
    if not low_noise_ep and np.random.uniform(0,1) < args.rand_action_p - parameters["eval_eps"]:
        action = env.action_space.sample()
    else:
        action = policy.select_action(np.array(state), eval=True)

    # Perform action and log results
    next_state, reward, done, info = env.step(action)
    episode_reward += reward

    # Only consider "done" if episode terminates due to failure condition
    done_float = float(done) if episode_timesteps < max_episode_steps else 0

    # For atari, info[0] = clipped reward, info[1] = done_float
    if is_atari:
        reward = info[0]
        done_float = info[1]
        
    # Store data in replay buffer
    replay_buffer.add(state, action, next_state, reward, done_float, done, episode_start)
    state = copy.copy(next_state)
    episode_start = False

    if done:
        # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True
        print(f"Total T: {t+1} Episode Num: {episode_num+1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}")
        # Reset environment
        state, done = env.reset(), False
        episode_start = True
        episode_reward = 0
        episode_timesteps = 0
        episode_num += 1
        low_noise_ep = np.random.uniform(0,1) < args.low_noise_p

# Save final buffer and performance
evaluations.append(eval_policy(policy, args.env, args.seed))
np.save(f"./results/buffer_performance_{setting}", evaluations)
replay_buffer.save(f"./buffers/{buffer_name}")
Total T: 200 Episode Num: 1 Episode T: 200 Reward: 200.000
Total T: 400 Episode Num: 2 Episode T: 200 Reward: 200.000
Total T: 600 Episode Num: 3 Episode T: 200 Reward: 200.000
Total T: 721 Episode Num: 4 Episode T: 121 Reward: 121.000
Total T: 878 Episode Num: 5 Episode T: 157 Reward: 157.000
Total T: 1078 Episode Num: 6 Episode T: 200 Reward: 200.000
Total T: 1254 Episode Num: 7 Episode T: 176 Reward: 176.000
Total T: 1423 Episode Num: 8 Episode T: 169 Reward: 169.000
Total T: 1623 Episode Num: 9 Episode T: 200 Reward: 200.000
Total T: 1817 Episode Num: 10 Episode T: 194 Reward: 194.000
Total T: 2017 Episode Num: 11 Episode T: 200 Reward: 200.000
Total T: 2212 Episode Num: 12 Episode T: 195 Reward: 195.000
Total T: 2412 Episode Num: 13 Episode T: 200 Reward: 200.000
Total T: 2445 Episode Num: 14 Episode T: 33 Reward: 33.000
Total T: 2616 Episode Num: 15 Episode T: 171 Reward: 171.000
Total T: 2655 Episode Num: 16 Episode T: 39 Reward: 39.000
Total T: 2855 Episode Num: 17 Episode T: 200 Reward: 200.000
Total T: 2934 Episode Num: 18 Episode T: 79 Reward: 79.000
Total T: 3030 Episode Num: 19 Episode T: 96 Reward: 96.000
Total T: 3228 Episode Num: 20 Episode T: 198 Reward: 198.000
Total T: 3428 Episode Num: 21 Episode T: 200 Reward: 200.000
Total T: 3450 Episode Num: 22 Episode T: 22 Reward: 22.000
Total T: 3642 Episode Num: 23 Episode T: 192 Reward: 192.000
Total T: 3842 Episode Num: 24 Episode T: 200 Reward: 200.000
Total T: 3855 Episode Num: 25 Episode T: 13 Reward: 13.000
Total T: 4055 Episode Num: 26 Episode T: 200 Reward: 200.000
Total T: 4251 Episode Num: 27 Episode T: 196 Reward: 196.000
Total T: 4420 Episode Num: 28 Episode T: 169 Reward: 169.000
Total T: 4584 Episode Num: 29 Episode T: 164 Reward: 164.000
Total T: 4784 Episode Num: 30 Episode T: 200 Reward: 200.000
Total T: 4975 Episode Num: 31 Episode T: 191 Reward: 191.000
Total T: 5136 Episode Num: 32 Episode T: 161 Reward: 161.000
Total T: 5147 Episode Num: 33 Episode T: 11 Reward: 11.000
Total T: 5347 Episode Num: 34 Episode T: 200 Reward: 200.000
Total T: 5541 Episode Num: 35 Episode T: 194 Reward: 194.000
Total T: 5741 Episode Num: 36 Episode T: 200 Reward: 200.000
Total T: 5926 Episode Num: 37 Episode T: 185 Reward: 185.000
Total T: 6113 Episode Num: 38 Episode T: 187 Reward: 187.000
Total T: 6311 Episode Num: 39 Episode T: 198 Reward: 198.000
Total T: 6504 Episode Num: 40 Episode T: 193 Reward: 193.000
Total T: 6704 Episode Num: 41 Episode T: 200 Reward: 200.000
Total T: 6904 Episode Num: 42 Episode T: 200 Reward: 200.000
Total T: 7104 Episode Num: 43 Episode T: 200 Reward: 200.000
Total T: 7304 Episode Num: 44 Episode T: 200 Reward: 200.000
Total T: 7504 Episode Num: 45 Episode T: 200 Reward: 200.000
Total T: 7703 Episode Num: 46 Episode T: 199 Reward: 199.000
Total T: 7903 Episode Num: 47 Episode T: 200 Reward: 200.000
Total T: 8090 Episode Num: 48 Episode T: 187 Reward: 187.000
Total T: 8290 Episode Num: 49 Episode T: 200 Reward: 200.000
Total T: 8461 Episode Num: 50 Episode T: 171 Reward: 171.000
Total T: 8661 Episode Num: 51 Episode T: 200 Reward: 200.000
Total T: 8861 Episode Num: 52 Episode T: 200 Reward: 200.000
Total T: 9061 Episode Num: 53 Episode T: 200 Reward: 200.000
Total T: 9261 Episode Num: 54 Episode T: 200 Reward: 200.000
Total T: 9410 Episode Num: 55 Episode T: 149 Reward: 149.000
Total T: 9420 Episode Num: 56 Episode T: 10 Reward: 10.000
Total T: 9613 Episode Num: 57 Episode T: 193 Reward: 193.000
Total T: 9813 Episode Num: 58 Episode T: 200 Reward: 200.000
Total T: 9981 Episode Num: 59 Episode T: 168 Reward: 168.000
---------------------------------------
Evaluation over 10 episodes: 196.700
---------------------------------------

Discrete BCQ

# Used for Atari
class Conv_Q(nn.Module):
	def __init__(self, frames, num_actions):
		super(Conv_Q, self).__init__()
		self.c1 = nn.Conv2d(frames, 32, kernel_size=8, stride=4)
		self.c2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
		self.c3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

		self.q1 = nn.Linear(3136, 512)
		self.q2 = nn.Linear(512, num_actions)

		self.i1 = nn.Linear(3136, 512)
		self.i2 = nn.Linear(512, num_actions)


	def forward(self, state):
		c = F.relu(self.c1(state))
		c = F.relu(self.c2(c))
		c = F.relu(self.c3(c))

		q = F.relu(self.q1(c.reshape(-1, 3136)))
		i = F.relu(self.i1(c.reshape(-1, 3136)))
		i = self.i2(i)
		return self.q2(q), F.log_softmax(i, dim=1), i
# Used for Box2D / Toy problems
class FC_Q(nn.Module):
	def __init__(self, state_dim, num_actions):
		super(FC_Q, self).__init__()
		self.q1 = nn.Linear(state_dim, 256)
		self.q2 = nn.Linear(256, 256)
		self.q3 = nn.Linear(256, num_actions)

		self.i1 = nn.Linear(state_dim, 256)
		self.i2 = nn.Linear(256, 256)
		self.i3 = nn.Linear(256, num_actions)		


	def forward(self, state):
		q = F.relu(self.q1(state))
		q = F.relu(self.q2(q))

		i = F.relu(self.i1(state))
		i = F.relu(self.i2(i))
		i = self.i3(i)
		return self.q3(q), F.log_softmax(i, dim=1), i
class discrete_BCQ(object):
	def __init__(
		self, 
		is_atari,
		num_actions,
		state_dim,
		device,
		BCQ_threshold=0.3,
		discount=0.99,
		optimizer="Adam",
		optimizer_parameters={},
		polyak_target_update=False,
		target_update_frequency=8e3,
		tau=0.005,
		initial_eps = 1,
		end_eps = 0.001,
		eps_decay_period = 25e4,
		eval_eps=0.001,
	):
	
		self.device = device

		# Determine network type
		self.Q = Conv_Q(state_dim[0], num_actions).to(self.device) if is_atari else FC_Q(state_dim, num_actions).to(self.device)
		self.Q_target = copy.deepcopy(self.Q)
		self.Q_optimizer = getattr(torch.optim, optimizer)(self.Q.parameters(), **optimizer_parameters)

		self.discount = discount

		# Target update rule
		self.maybe_update_target = self.polyak_target_update if polyak_target_update else self.copy_target_update
		self.target_update_frequency = target_update_frequency
		self.tau = tau

		# Decay for eps
		self.initial_eps = initial_eps
		self.end_eps = end_eps
		self.slope = (self.end_eps - self.initial_eps) / eps_decay_period

		# Evaluation hyper-parameters
		self.state_shape = (-1,) + state_dim if is_atari else (-1, state_dim)
		self.eval_eps = eval_eps
		self.num_actions = num_actions

		# Threshold for "unlikely" actions
		self.threshold = BCQ_threshold

		# Number of training iterations
		self.iterations = 0


	def select_action(self, state, eval=False):
		# Select action according to policy with probability (1-eps)
		# otherwise, select random action
		if np.random.uniform(0,1) > self.eval_eps:
			with torch.no_grad():
				state = torch.FloatTensor(state).reshape(self.state_shape).to(self.device)
				q, imt, i = self.Q(state)
				imt = imt.exp()
				imt = (imt/imt.max(1, keepdim=True)[0] > self.threshold).float()
				# Use large negative number to mask actions from argmax
				return int((imt * q + (1. - imt) * -1e8).argmax(1))
		else:
			return np.random.randint(self.num_actions)


	def train(self, replay_buffer):
		# Sample replay buffer
		state, action, next_state, reward, done = replay_buffer.sample()

		# Compute the target Q value
		with torch.no_grad():
			q, imt, i = self.Q(next_state)
			imt = imt.exp()
			imt = (imt/imt.max(1, keepdim=True)[0] > self.threshold).float()

			# Use large negative number to mask actions from argmax
			next_action = (imt * q + (1 - imt) * -1e8).argmax(1, keepdim=True)

			q, imt, i = self.Q_target(next_state)
			target_Q = reward + done * self.discount * q.gather(1, next_action).reshape(-1, 1)

		# Get current Q estimate
		current_Q, imt, i = self.Q(state)
		current_Q = current_Q.gather(1, action)

		# Compute Q loss
		q_loss = F.smooth_l1_loss(current_Q, target_Q)
		i_loss = F.nll_loss(imt, action.reshape(-1))

		Q_loss = q_loss + i_loss + 1e-2 * i.pow(2).mean()

		# Optimize the Q
		self.Q_optimizer.zero_grad()
		Q_loss.backward()
		self.Q_optimizer.step()

		# Update target network by polyak or full copy every X iterations.
		self.iterations += 1
		self.maybe_update_target()


	def polyak_target_update(self):
		for param, target_param in zip(self.Q.parameters(), self.Q_target.parameters()):
		   target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


	def copy_target_update(self):
		if self.iterations % self.target_update_frequency == 0:
			 self.Q_target.load_state_dict(self.Q.state_dict())
# Make env and determine properties
env, is_atari, state_dim, num_actions = make_env(args.env, args.atari_preprocessing)
parameters = args.atari_parameters if is_atari else args.regular_parameters


# Set seeds
env.seed(args.seed)
env.action_space.seed(args.seed)


# Initialize buffer
replay_buffer = ReplayBuffer(state_dim, is_atari, args.atari_preprocessing, parameters["batch_size"], parameters["buffer_size"], device)
# For saving files
setting = f"{args.env}_{args.seed}"
buffer_name = f"{args.buffer_name}_{setting}"

# Initialize and load policy
policy = discrete_BCQ(
    is_atari,
    num_actions,
    state_dim,
    device,
    args.BCQ_threshold,
    parameters["discount"],
    parameters["optimizer"],
    parameters["optimizer_parameters"],
    parameters["polyak_target_update"],
    parameters["target_update_freq"],
    parameters["tau"],
    parameters["initial_eps"],
    parameters["end_eps"],
    parameters["eps_decay_period"],
    parameters["eval_eps"]
)

# Load replay buffer	
replay_buffer.load(f"./buffers/{buffer_name}")

evaluations = []
episode_num = 0
done = True 
training_iters = 0

while training_iters < args.max_timesteps: 
    
    for _ in range(int(parameters["eval_freq"])):
        policy.train(replay_buffer)

    evaluations.append(eval_policy(policy, args.env, args.seed))
    np.save(f"./results/BCQ_{setting}", evaluations)

    training_iters += int(parameters["eval_freq"])
    print(f"Training iterations: {training_iters}")
Replay Buffer loaded with 10000 elements.
---------------------------------------
Evaluation over 10 episodes: 175.300
---------------------------------------
Training iterations: 5000
---------------------------------------
Evaluation over 10 episodes: 198.500
---------------------------------------
Training iterations: 10000
!apt-get -qq install tree
!tree --du -h -C .
.
├── [864K]  buffers
│   ├── [ 78K]  Default_CartPole-v0_0_action.npy
│   ├── [313K]  Default_CartPole-v0_0_next_state.npy
│   ├── [ 78K]  Default_CartPole-v0_0_not_done.npy
│   ├── [ 136]  Default_CartPole-v0_0_ptr.npy
│   ├── [ 78K]  Default_CartPole-v0_0_reward.npy
│   └── [313K]  Default_CartPole-v0_0_state.npy
├── [801K]  models
│   ├── [531K]  behavioral_CartPole-v0_0_optimizer
│   └── [266K]  behavioral_CartPole-v0_0_Q
├── [4.4K]  results
│   ├── [ 144]  BCQ_CartPole-v0_0.npy
│   ├── [ 144]  behavioral_CartPole-v0_0.npy
│   └── [ 136]  buffer_performance_CartPole-v0_0.npy
├── [ 19M]  ROM
│   ├── [ 11M]  HC ROMS.zip
│   └── [7.8M]  ROMS.zip
├── [ 11M]  Roms.rar
├── [ 54M]  sample_data
│   ├── [1.7K]  anscombe.json
│   ├── [294K]  california_housing_test.csv
│   ├── [1.6M]  california_housing_train.csv
│   ├── [ 17M]  mnist_test.csv
│   ├── [ 35M]  mnist_train_small.csv
│   └── [ 930]  README.md
└── [ 65K]  video
    ├── [ 491]  openaigym.episode_batch.8.2302.stats.json
    ├── [ 406]  openaigym.manifest.8.2302.manifest.json
    ├── [2.0K]  openaigym.video.8.2302.video000000.meta.json
    ├── [ 18K]  openaigym.video.8.2302.video000000.mp4
    ├── [2.0K]  openaigym.video.8.2302.video000001.meta.json
    ├── [ 18K]  openaigym.video.8.2302.video000001.mp4
    ├── [2.0K]  openaigym.video.8.2302.video000008.meta.json
    └── [ 19K]  openaigym.video.8.2302.video000008.mp4

  86M used in 6 directories, 28 files
show_video()