Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated to latest tensorflow version and added requirements file #90

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 85 additions & 142 deletions deep_q_network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
from __future__ import print_function

import os
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.layers import Lambda
import cv2
import sys
sys.path.append("game/")
Expand All @@ -10,180 +10,132 @@
import numpy as np
from collections import deque

GAME = 'bird' # the name of the game being played for log files
ACTIONS = 2 # number of valid actions
GAMMA = 0.99 # decay rate of past observations
OBSERVE = 100000. # timesteps to observe before training
EXPLORE = 2000000. # frames over which to anneal epsilon
FINAL_EPSILON = 0.0001 # final value of epsilon
INITIAL_EPSILON = 0.0001 # starting value of epsilon
REPLAY_MEMORY = 50000 # number of previous transitions to remember
BATCH = 32 # size of minibatch
GAME = 'bird' # the name of the game being played for log files
ACTIONS = 2 # number of valid actions
GAMMA = 0.99 # decay rate of past observations
OBSERVE = 100000 # timesteps to observe before training
EXPLORE = 2000000 # frames over which to anneal epsilon
FINAL_EPSILON = 0.0001 # final value of epsilon
INITIAL_EPSILON = 0.1 # starting value of epsilon
REPLAY_MEMORY = 50000 # number of previous transitions to remember
BATCH = 32 # size of minibatch
FRAME_PER_ACTION = 1

def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev = 0.01)
return tf.Variable(initial)

def bias_variable(shape):
initial = tf.constant(0.01, shape = shape)
return tf.Variable(initial)

def conv2d(x, W, stride):
return tf.nn.conv2d(x, W, strides = [1, stride, stride, 1], padding = "SAME")

def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME")

def createNetwork():
# network weights
W_conv1 = weight_variable([8, 8, 4, 32])
b_conv1 = bias_variable([32])

W_conv2 = weight_variable([4, 4, 32, 64])
b_conv2 = bias_variable([64])

W_conv3 = weight_variable([3, 3, 64, 64])
b_conv3 = bias_variable([64])

W_fc1 = weight_variable([1600, 512])
b_fc1 = bias_variable([512])

W_fc2 = weight_variable([512, ACTIONS])
b_fc2 = bias_variable([ACTIONS])

# input layer
s = tf.placeholder("float", [None, 80, 80, 4])

# hidden layers
h_conv1 = tf.nn.relu(conv2d(s, W_conv1, 4) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2, 2) + b_conv2)
#h_pool2 = max_pool_2x2(h_conv2)

h_conv3 = tf.nn.relu(conv2d(h_conv2, W_conv3, 1) + b_conv3)
#h_pool3 = max_pool_2x2(h_conv3)

#h_pool3_flat = tf.reshape(h_pool3, [-1, 256])
h_conv3_flat = tf.reshape(h_conv3, [-1, 1600])

h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, W_fc1) + b_fc1)

# readout layer
readout = tf.matmul(h_fc1, W_fc2) + b_fc2

return s, readout, h_fc1

def trainNetwork(s, readout, h_fc1, sess):
# define the cost function
a = tf.placeholder("float", [None, ACTIONS])
y = tf.placeholder("float", [None])
readout_action = tf.reduce_sum(tf.multiply(readout, a), reduction_indices=1)
cost = tf.reduce_mean(tf.square(y - readout_action))
train_step = tf.train.AdamOptimizer(1e-6).minimize(cost)

# open up a game state to communicate with emulator
def build_model(input_shape, action_space):
inputs = layers.Input(shape=input_shape)
layer = layers.Conv2D(32, (8, 8), strides=(4, 4), padding='same', activation='relu')(inputs)
layer = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(layer)
layer = layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same', activation='relu')(layer)
layer = layers.Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu')(layer)
layer = layers.Flatten()(layer)
layer = layers.Dense(512, activation='relu')(layer)

value_fc = layers.Dense(1)(layer)
advantage_fc = layers.Dense(action_space)(layer)

def dueling_dqn(value, advantage):
mean_advantage = Lambda(lambda x: tf.reduce_mean(x, axis=1, keepdims=True))(advantage)
return value + (advantage - mean_advantage)

policy = dueling_dqn(value_fc, advantage_fc)

model = models.Model(inputs=inputs, outputs=policy)
model.compile(optimizer=optimizers.Adam(learning_rate=1e-6), loss='mse')
return model

def preprocess(image):
image = cv2.cvtColor(cv2.resize(image, (80, 80)), cv2.COLOR_BGR2GRAY)
_, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
return image

def trainNetwork(model, target_model):
game_state = game.GameState()

# store the previous observations in replay memory
D = deque()

# printing
a_file = open("logs_" + GAME + "/readout.txt", 'w')
h_file = open("logs_" + GAME + "/hidden.txt", 'w')

# get the first state by doing nothing and preprocess the image to 80x80x4
do_nothing = np.zeros(ACTIONS)
do_nothing[0] = 1
x_t, r_0, terminal = game_state.frame_step(do_nothing)
x_t = cv2.cvtColor(cv2.resize(x_t, (80, 80)), cv2.COLOR_BGR2GRAY)
ret, x_t = cv2.threshold(x_t,1,255,cv2.THRESH_BINARY)
x_t = preprocess(x_t)
s_t = np.stack((x_t, x_t, x_t, x_t), axis=2)

# saving and loading networks
saver = tf.train.Saver()
sess.run(tf.initialize_all_variables())
checkpoint = tf.train.get_checkpoint_state("saved_networks")
if checkpoint and checkpoint.model_checkpoint_path:
saver.restore(sess, checkpoint.model_checkpoint_path)
print("Successfully loaded:", checkpoint.model_checkpoint_path)
s_t = np.expand_dims(s_t, axis=0)

checkpoint_path = "saved_networks/" + GAME + "-dqn"
checkpoint_dir = os.path.dirname(checkpoint_path)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)

checkpoint = tf.train.Checkpoint(model=model)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if checkpoint_manager.latest_checkpoint:
checkpoint.restore(checkpoint_manager.latest_checkpoint)
print("Successfully loaded:", checkpoint_manager.latest_checkpoint)
else:
print("Could not find old network weights")

# start training
epsilon = INITIAL_EPSILON
t = 0
while "flappy bird" != "angry bird":
# choose an action epsilon greedily
readout_t = readout.eval(feed_dict={s : [s_t]})[0]
while True:
readout_t = model.predict(s_t)[0]
a_t = np.zeros([ACTIONS])
action_index = 0
if t % FRAME_PER_ACTION == 0:
if random.random() <= epsilon:
print("----------Random Action----------")
action_index = random.randrange(ACTIONS)
a_t[random.randrange(ACTIONS)] = 1
a_t[action_index] = 1
else:
action_index = np.argmax(readout_t)
a_t[action_index] = 1
else:
a_t[0] = 1 # do nothing
a_t[0] = 1 # do nothing

# scale down epsilon
if epsilon > FINAL_EPSILON and t > OBSERVE:
epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE

# run the selected action and observe next state and reward
x_t1_colored, r_t, terminal = game_state.frame_step(a_t)
x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY)
ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY)
x_t1 = preprocess(x_t1_colored)
x_t1 = np.reshape(x_t1, (80, 80, 1))
#s_t1 = np.append(x_t1, s_t[:,:,1:], axis = 2)
s_t1 = np.append(x_t1, s_t[:, :, :3], axis=2)
s_t1 = np.append(s_t[:, :, :, 1:], np.expand_dims(x_t1, axis=0), axis=3)

# store the transition in D
D.append((s_t, a_t, r_t, s_t1, terminal))
if len(D) > REPLAY_MEMORY:
D.popleft()

# only train if done observing
if t > OBSERVE:
# sample a minibatch to train on
minibatch = random.sample(D, BATCH)

# get the batch variables
s_j_batch = [d[0] for d in minibatch]
a_batch = [d[1] for d in minibatch]
r_batch = [d[2] for d in minibatch]
s_j1_batch = [d[3] for d in minibatch]
s_j_batch = np.array([d[0] for d in minibatch])
a_batch = np.array([d[1] for d in minibatch])
r_batch = np.array([d[2] for d in minibatch])
s_j1_batch = np.array([d[3] for d in minibatch])

y_batch = []
readout_j1_batch = readout.eval(feed_dict = {s : s_j1_batch})
for i in range(0, len(minibatch)):
readout_j1_batch = model.predict(s_j1_batch)
readout_j1_target_batch = target_model.predict(s_j1_batch)
for i in range(len(minibatch)):
terminal = minibatch[i][4]
# if terminal, only equals reward
if terminal:
y_batch.append(r_batch[i])
else:
y_batch.append(r_batch[i] + GAMMA * np.max(readout_j1_batch[i]))
max_action = np.argmax(readout_j1_batch[i])
y_batch.append(r_batch[i] + GAMMA * readout_j1_target_batch[i][max_action])

y_batch = np.array(y_batch)
a_batch = np.array(a_batch)
target_f = model.predict(s_j_batch)
for i in range(len(minibatch)):
target_f[i][np.argmax(a_batch[i])] = y_batch[i]

# perform gradient step
train_step.run(feed_dict = {
y : y_batch,
a : a_batch,
s : s_j_batch}
)
model.fit(s_j_batch, target_f, epochs=1, verbose=0)

# update the old values
s_t = s_t1
t += 1

# save progress every 10000 iterations
if t % 10000 == 0:
saver.save(sess, 'saved_networks/' + GAME + '-dqn', global_step = t)
checkpoint_manager.save()
target_model.set_weights(model.get_weights())

# print info
state = ""
if t <= OBSERVE:
state = "observe"
Expand All @@ -192,24 +144,15 @@ def trainNetwork(s, readout, h_fc1, sess):
else:
state = "train"

print("TIMESTEP", t, "/ STATE", state, \
"/ EPSILON", epsilon, "/ ACTION", action_index, "/ REWARD", r_t, \
"/ Q_MAX %e" % np.max(readout_t))
# write info to files
'''
if t % 10000 <= 100:
a_file.write(",".join([str(x) for x in readout_t]) + '\n')
h_file.write(",".join([str(x) for x in h_fc1.eval(feed_dict={s:[s_t]})[0]]) + '\n')
cv2.imwrite("logs_tetris/frame" + str(t) + ".png", x_t1)
'''
print(f"TIMESTEP {t} / STATE {state} / EPSILON {epsilon} / ACTION {action_index} / REWARD {r_t} / Q_MAX {np.max(readout_t)}")

def playGame():
sess = tf.InteractiveSession()
s, readout, h_fc1 = createNetwork()
trainNetwork(s, readout, h_fc1, sess)

def main():
playGame()
input_shape = (80, 80, 4)
action_space = ACTIONS
model = build_model(input_shape, action_space)
target_model = build_model(input_shape, action_space)
target_model.set_weights(model.get_weights())
trainNetwork(model, target_model)

if __name__ == "__main__":
main()
playGame()
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tensorflow==2.16.2
opencv-python
pygame
Binary file removed saved_networks/bird-dqn-2880000
Binary file not shown.
Binary file removed saved_networks/bird-dqn-2880000.meta
Binary file not shown.
Binary file removed saved_networks/bird-dqn-2890000
Binary file not shown.
Binary file removed saved_networks/bird-dqn-2890000.meta
Binary file not shown.
Binary file removed saved_networks/bird-dqn-2900000
Binary file not shown.
Binary file removed saved_networks/bird-dqn-2900000.meta
Binary file not shown.
Binary file removed saved_networks/bird-dqn-2910000
Binary file not shown.
Binary file removed saved_networks/bird-dqn-2910000.meta
Binary file not shown.
Binary file removed saved_networks/bird-dqn-2920000
Binary file not shown.
Binary file removed saved_networks/bird-dqn-2920000.meta
Binary file not shown.
10 changes: 4 additions & 6 deletions saved_networks/checkpoint
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
model_checkpoint_path: "bird-dqn-2920000"
all_model_checkpoint_paths: "bird-dqn-2880000"
all_model_checkpoint_paths: "bird-dqn-2890000"
all_model_checkpoint_paths: "bird-dqn-2900000"
all_model_checkpoint_paths: "bird-dqn-2910000"
all_model_checkpoint_paths: "bird-dqn-2920000"
model_checkpoint_path: "ckpt-1"
all_model_checkpoint_paths: "ckpt-1"
all_model_checkpoint_timestamps: 1721205841.8255534
last_preserved_timestamp: 1721205413.3230736
Binary file added saved_networks/ckpt-1.data-00000-of-00001
Binary file not shown.
Binary file added saved_networks/ckpt-1.index
Binary file not shown.
Binary file removed saved_networks/pretrained_model/bird-dqn-policy
Binary file not shown.