-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
101 lines (86 loc) · 3.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import matplotlib.pyplot as plt
import torch
def draw(file_name, title, x_label, y_label, y, x=None):
r"""
Plot the figure and write it to the file
Args:
file_name: the file name to save
title: the title of the plot
x_label: label for x-axis
y_label: label for y-axis
y: the data to plot
x: the data for x coordinate
"""
plt.figure()
if x is not None:
plt.plot(x, y)
else:
plt.plot(y)
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.savefig(file_name)
def gen_epsilon_greedy_policy(n_action, epsilon, estimator=None):
"""Generate a epsilon greedy policy
Args:
n_action: the number of actions
epsilon: epsilon
estimator: the estimator to predict the Q, has to implement the predict method to receive a state as a parameter
Returns:
the policy function:
inputs: state, Q (=None if estimator is not None)
output: action
"""
def policy_function(state, Q=None):
if estimator is None and Q is None:
raise Exception('estimator and Q cannot both be none in the policy function')
if estimator is not None:
Q_state = estimator.predict(state)
else:
Q_state = Q[state]
probs = torch.ones(n_action) * epsilon / n_action
best_action = torch.argmax(Q_state).item()
probs[best_action] += 1 - epsilon
action = torch.multinomial(probs, 1).item()
return action
return policy_function
def gen_softmax_exploration_policy(tau):
"""Generate a softmax exploration policy
Args:
tau (float): to control the exploration and exploitation, -> 0, towards best action, -> 1, towards equal exploration
Returns:
the policy function:
inputs: state, Q
output: action
"""
def policy_function(state, Q):
probs = torch.exp(Q[state] / tau)
probs = probs / torch.sum(probs)
action = torch.multinomial(probs, 1).item()
return action
return policy_function
def upper_confidence_bound(Q, state, action_count, episode):
"""
Return an action with the highes upper confidence bound
Args:
Q (torch.Tensor): Q-function
state (int): the state the env is on
action_count (torch.Tensor): how many times each action has appeared, should be a FloatTensor
episode (int): the number of episode the algorithm is currently on
Returns:
the best action (int)
"""
ucb = torch.sqrt(2 * torch.log(torch.tensor(float(episode))) / action_count) + Q[state]
return torch.argmax(ucb)
def thompson_sampling(alpha, beta):
"""
Return an action based on beta distribution
Args:
alpha: alpha in beta distribution
beta: beta in beta distribution
Returns:
action (int): the best action
Note: each beta distribution should start with alpha=beta=1
"""
prior_values = torch.distributions.beta.Beta(alpha, beta).sample()
return torch.argmax(prior_values)