import gym
import matplotlib
import numpy as np
import sys
from collections import defaultdict
import pprint as pp
from matplotlib import pyplot as plt
%matplotlib inline
import itertools
from matplotlib import cm, colors
1. Introduction
In a Markov Decision Process (Figure 1) the agent and environment interacts continuously.
More details are available in Reinforcement Learning: An Introduction by Sutton and Barto.
The dynamics of the MDP is given by \[ \begin{aligned} p(s',r|s,a) &= Pr\{ S_{t+1}=s',R_{t+1}=r | S_t=s,A_t=a \} \\ \end{aligned} \]
The policy of an agent is a mapping from the current state of the environment to an action that the agent needs to take in this state. Formally, a policy is given by \[ \begin{aligned} \pi(a|s) &= Pr\{A_t=a|S_t=s\} \end{aligned} \]
The discounted return is given by \[ \begin{aligned} G_t &= R_{t+1} + \gamma R_{t+2} + \gamma ^2 R_{t+3} + ... + R_T \\ &= \sum_{k=0}^\infty \gamma ^k R_{t+1+k} \end{aligned} \] where \(\gamma\) is the discount factor and \(R\) is the reward.
Most reinforcement learning algorithms involve the estimation of value functions - in our present case, the state-value function. The state-value function maps each state to a measure of “how good it is to be in that state” in terms of expected rewards. Formally, the state-value function, under policy \(\pi\) is given by \[ \begin{aligned} v_\pi(s) &= \mathbb{E}_\pi[G_t|S_t=s] \end{aligned} \]
The Temporal Difference (TD) algorithm discussed in this post will numerically estimate \(v_\pi(s)\).
2. Environment
Figure 2 shows the environment we will use in this series: The Windy Gridworld:
Each episode starts in the start state, S and the agent tries to get to the goal state, G in as few steps as possible. There are four movements or actions that can be applied to the environment by the agent:
- up (0)
- right (1)
- down (2)
- left (3)
There is a crosswind running upward through the middle of the grid. The strength of the crosswind is indicated in the center columns of the grid. The strength is added to the vertical displacement of a movement or action, based on the strength indicated for the departing state. For example, if you are one cell to the tight of the goal, the action left will take you to the cell just above the goal.
Tasks are episodic and undiscounted. All rewards are -1 until the goal is reached.
States are numbered from 0 to 69 in this 7 by 10 grid, in a row-wise fashion starting in the top-left corner.
The environment is implemented using the OpenAI Gym library.
3. Agent
The agent is the decision maker. It needs to provide instructions to reach the goal in as few steps as possible while compensating for the effect of the wind. After observing the state of the environment, expressed as a number from 0 to 69 that reflects the current position in the grid, the agent can take one of the following actions:
- up (0)
- right (1)
- down (2)
- left (3)
4. Temporal Difference (TD) Estimation of the State-value Function, \(v_\pi(s)\)
Temporal Difference (TD) methods both bootstrap and sample. To bootstrap means that the update involves an estimate of the value function. To sample means that the update does not involve an expected value.
We will now proceed to estimate the state-value function for the given policy \(\pi\). We can take \(\gamma=1\) as the sum will remain finite:
\[ \large \begin{aligned} v_\pi(s) &= \mathbb{E}_\pi[G_t | S_t=s] \\ &= \mathbb{E}_\pi[R_{t+1} + \gamma G_{t+1} | S_t=s] \\ &= \mathbb{E}_\pi[R_{t+1} + \gamma v_{\pi}(S_{t+1}) | S_t=s] \end{aligned} \]
The TD method uses the current estimate \(V\) instead of the true \(v_{\pi}\). It makes the following update:
\[ \large V(S_t) ← V(S_t) + α[R_{t+1} + \gamma V(S_{t+1} - V(S_t)] \]
The target in this update is:
\[ \large R_{t+1} + \gamma V(S_{t+1}) \]
The target is an estimate of the return.
Like Monte Carlo (MC) methods, TD methods do not require a model of the environment. They only require experience. Unlike MC methods, they do not wait until the end of an episode before making updates. This means TD methods can be incremental so that updates can be made after each step. This requires less memory and required less peak computation.
5. Implementation
Figure 2 shows the algorithm:
Next, we present the code that implements the algorithm.
= WindyGridworldEnv() env
5.1 Policy
The following function captures the policy used by the agent:
def create_policy_random(n_A):
def policy_function(observation):
#probabilities for each action, length n_A:
= np.ones(n_A, dtype=float)/n_A
action_probs return np.random.choice(np.arange(len(action_probs)), p=action_probs)
return policy_function
5.3 Main loop
The following function implements the main loop of the algorithm. It iterates for n_episodes
. It also takes a list of monitored_states
for which it will record the evolution of state values. This is handy for showing how state values converge during the process.
def td_0_prediction(env, n_episodes, gamma=1.0, alpha=0.5, epsilon=0.1, monitored_states=None, diag=False):
= defaultdict(float)
V = create_policy_random(env.action_space.n)
pi = defaultdict(list)
monitored_state_values = myplot.EpisodeStats(
stats =np.zeros(n_episodes),
episode_lengths=np.zeros(n_episodes))
episode_rewardsfor i in range(n_episodes):
if (i + 1)%10 == 0: print("\rEpisode {}/{}".format(i + 1, n_episodes), end=""); sys.stdout.flush()
print(f'\nepisode {i}:') if diag else None
#---initialize St
= env.reset()
St #---repeat for each step of episode
for t in itertools.count(): #while True:
#---choose At from St
= pi(St)
At #---take action At, observe Rt+1, St+1
= env.step(At) # St+1, Rt+1 OR s',r
Stp1, Rtp1, done, _ print(f"---t={t} St, At, Rt+1, St+1: {St, At, Rtp1, Stp1}") if diag else None
+= Rtp1; stats.episode_lengths[i] = t
stats.episode_rewards[i] #---update V
= V[St] + alpha*( Rtp1 + gamma*V[Stp1] - V[St] ); print(f"V[St]: {V[St]}") if diag else None
V[St] = Stp1
St if done:
break
#---until St is terminal
if monitored_states:
for ms in monitored_states:
# print("\rQ[{}]: {}".format(msa, Q[s][a]), end=""); sys.stdout.flush()
monitored_state_values[ms].append(V[ms])return V, stats, monitored_state_values
5.4 Monitored states
Let’s pick a number of states to monitor. Each number identifies one of the 70 states.
= [0, 7, 57, 68] monitored_states
= td_0_prediction(
V,stats,monitored_state_values
env, =10,
n_episodes=0.1,
alpha=monitored_states,
monitored_states=False) diag
Episode 10/10
V
defaultdict(float,
{0: -137.80033464618046,
1: -142.5790573637218,
2: -153.38627333959087,
3: -164.31568002773105,
4: -176.89251225878337,
5: -183.914514998972,
6: -191.20756834044525,
7: -190.92399180138216,
8: -188.57031090761552,
9: -183.12953773836352,
10: -129.34496871071303,
11: -137.70509013845967,
12: -148.2564075292973,
13: -163.57355671774624,
14: -168.52572248108007,
15: -136.59594857510058,
16: -47.71947159507175,
17: -174.04428610385764,
18: -184.93312959582892,
19: -173.05808744316175,
20: -120.6790406130493,
21: -127.40417214467884,
22: -134.03582218025838,
23: -157.06904394086223,
24: -140.3131307571433,
25: -58.20714953578962,
26: -54.10882697419318,
27: -137.92427634130968,
28: -173.99921218868272,
29: -141.98995706524073,
30: -101.70243257084873,
31: -101.7270991868837,
32: -120.64630789535269,
33: -140.5876115761742,
34: -88.27864427163075,
35: -27.654676797073208,
36: -0.1,
37: 0.0,
38: -144.11639148686803,
39: -105.9540172086496,
40: -73.97040001702122,
41: -87.16453555102125,
42: -94.14254268702543,
43: -111.51260351392207,
44: -44.34677679073506,
45: -7.199784687087242,
47: -26.46341299963032,
48: -82.57726247834209,
49: -57.66605324164337,
50: -61.611925018068305,
51: -68.47009398963326,
52: -59.006804000112034,
53: -51.41626761054323,
54: -35.843329651729235,
57: -10.303183460376614,
58: -30.378465513899624,
59: -25.517464691704166,
60: -50.42665757427782,
61: -51.07727460978276,
62: -47.70493182941376,
63: -44.756861753242674,
68: -12.775760805886526,
69: -4.497100056613684})
50] V[
-61.611925018068305
print(monitored_states[0])
print(monitored_state_values[monitored_states[0]])
0
[-21.474272908702716, -33.72302605938006, -42.47003230048048, -62.50296660467135, -67.12778730329649, -77.20250676785686, -76.99022419718405, -101.71774339987682, -106.62374028592815, -137.80033464618046]
# last value in monitored_states should be value in V
= monitored_states[0]; print('ms:', ms)
ms -1], V[ms] monitored_state_values[ms][
ms: 0
(-137.80033464618046, -137.80033464618046)
5.5 Run 1
= td_0_prediction(
V1,stats,monitored_state_values1
env, =20,
n_episodes=0.5,
alpha=monitored_states,
monitored_states=False) diag
Episode 20/20
# last value in monitored_states should be value in V
= monitored_states[0]; print('ms:', ms)
ms -1], V1[ms] monitored_state_values1[ms][
ms: 0
(-735.1459031454285, -735.1459031454285)
The following chart shows how the values of the monitored states converge to their values:
"figure.figsize"] = (18,10)
plt.rcParams[for ms in monitored_states:
plt.plot(monitored_state_values1[ms])'Estimated $v_\pi(s)$ for some states', fontsize=18)
plt.title('Episodes', fontsize=16)
plt.xlabel('Estimated $v_\pi(s)$', fontsize=16)
plt.ylabel(=16)
plt.legend(monitored_states, fontsize plt.show()
Here are some additional metrics:
; myplot.plot_episode_stats(stats)
To make a plot of the state-value function we reshape the values to align with the Windy Gridworld pattern.
# convert V1 to V1p for plotting
= (7, 10)
states_shape = np.prod(states_shape)
nS = {}
V1p for s in range(nS):
= np.unravel_index(s, states_shape); #print(f"position: {position}")
position = V1[s] V1p[position]
V1p
{(0, 0): -735.1459031454285,
(0, 1): -730.9210723890925,
(0, 2): -748.20305259413,
(0, 3): -754.9261025764922,
(0, 4): -752.8787838171859,
(0, 5): -750.0883665323886,
(0, 6): -743.9117787359457,
(0, 7): -686.3473071680427,
(0, 8): -662.4342333214856,
(0, 9): -657.5562656363784,
(1, 0): -718.1124016107731,
(1, 1): -731.8165301553454,
(1, 2): -743.5612619011463,
(1, 3): -749.0929971772642,
(1, 4): -747.278666145852,
(1, 5): -721.623343444135,
(1, 6): -686.7081261934021,
(1, 7): -724.3152993242904,
(1, 8): -690.1043378911771,
(1, 9): -629.1439695837937,
(2, 0): -721.0189278360183,
(2, 1): -726.3611748241806,
(2, 2): -736.108040280805,
(2, 3): -744.1853570649814,
(2, 4): -727.6459474757083,
(2, 5): -645.5264132122365,
(2, 6): -658.855984783116,
(2, 7): -698.0281822057283,
(2, 8): -664.3786577812602,
(2, 9): -524.8556599796428,
(3, 0): -701.4558214932435,
(3, 1): -701.3759709309548,
(3, 2): -719.855350075502,
(3, 3): -738.8169361136412,
(3, 4): -724.440157261181,
(3, 5): -576.7938789729867,
(3, 6): -466.6315820284753,
(3, 7): 0.0,
(3, 8): -661.6487411402459,
(3, 9): -470.79898477239567,
(4, 0): -686.8139009088889,
(4, 1): -681.9207135990542,
(4, 2): -696.8290154104961,
(4, 3): -726.154719019099,
(4, 4): -686.6401005770437,
(4, 5): -364.0468750239923,
(4, 6): 0.0,
(4, 7): -298.06036632267455,
(4, 8): -284.00554599611735,
(4, 9): -351.7826173921178,
(5, 0): -671.362142111826,
(5, 1): -659.5054173992531,
(5, 2): -679.8915299579523,
(5, 3): -691.7872475376071,
(5, 4): -585.9983407933428,
(5, 5): 0.0,
(5, 6): 0.0,
(5, 7): -246.14569695121958,
(5, 8): -537.3196032010271,
(5, 9): -369.68583680235787,
(6, 0): -663.6650123721063,
(6, 1): -663.3279431198509,
(6, 2): -663.524783367199,
(6, 3): -668.8126886203731,
(6, 4): 0.0,
(6, 5): 0.0,
(6, 6): 0.0,
(6, 7): 0.0,
(6, 8): -55.79920769151681,
(6, 9): -229.7809440998002}
='State-Value Function for Windy Gridworld', wireframe=True, azim=-150, elev=60); myplot.plot_state_value_surface(V1p, title
/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:136: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
return array(a, dtype, copy=False, order=order, subok=True)
='State-Value Function for Windy Gridworld', wireframe=False, azim=-150, elev=60); myplot.plot_state_value_surface(V1p, title
# plot_state_value_heatmap(V1p, title='State-Value Function for Windy Gridworld');
='State-Value Function for Windy Gridworld'); myplot.plot_state_value_heatmap_windy_gridworld(V1p, title
5.6 Run 2
= td_0_prediction(
V2,stats,monitored_state_values2
env, =200,
n_episodes=0.5,
alpha=monitored_states,
monitored_states=False) diag
Episode 200/200
# last value in monitored_states should be value in V
= monitored_states[0]; print('ms:', ms)
ms -1], V2[ms] monitored_state_values2[ms][
ms: 0
(-4193.810401393425, -4193.810401393425)
"figure.figsize"] = (18,10)
plt.rcParams[for ms in monitored_states:
plt.plot(monitored_state_values2[ms])'Estimated $v_\pi(s)$ for some states', fontsize=18)
plt.title('Episodes', fontsize=16)
plt.xlabel('Estimated $v_\pi(s)$', fontsize=16)
plt.ylabel(=16)
plt.legend(monitored_states, fontsize plt.show()
Here are some additional metrics:
; myplot.plot_episode_stats(stats)
# convert V2 to V2p for plotting
= (7, 10)
states_shape = np.prod(states_shape)
nS = {}
V2p for s in range(nS):
= np.unravel_index(s, states_shape); #print(f"position: {position}")
position = V2[s] V2p[position]
V2p
{(0, 0): -4193.810401393425,
(0, 1): -4193.814757777407,
(0, 2): -4191.890300247405,
(0, 3): -4185.298602492931,
(0, 4): -4174.72733085042,
(0, 5): -4147.26308505477,
(0, 6): -4118.407005550107,
(0, 7): -4085.405156242255,
(0, 8): -4071.2953094442705,
(0, 9): -4046.459316882357,
(1, 0): -4191.019217298803,
(1, 1): -4191.6556502288895,
(1, 2): -4190.4299296322715,
(1, 3): -4190.965173415778,
(1, 4): -4180.658164880097,
(1, 5): -4126.107376489743,
(1, 6): -3987.0554972487753,
(1, 7): -4071.2527793806516,
(1, 8): -4057.6179390117704,
(1, 9): -4021.9726305212557,
(2, 0): -4186.732672802048,
(2, 1): -4186.617107436847,
(2, 2): -4187.14920440839,
(2, 3): -4187.742698915825,
(2, 4): -4173.760780122282,
(2, 5): -4151.302795294076,
(2, 6): -4043.2356769698185,
(2, 7): -4065.931040123559,
(2, 8): -4060.995750405103,
(2, 9): -3998.1373273491854,
(3, 0): -4181.435865523515,
(3, 1): -4183.7105431386135,
(3, 2): -4185.254196149683,
(3, 3): -4167.789277506801,
(3, 4): -4137.248052519253,
(3, 5): -4129.991220801143,
(3, 6): -4041.255522160846,
(3, 7): 0.0,
(3, 8): -3992.2745542944663,
(3, 9): -3555.425259222465,
(4, 0): -4178.48614972832,
(4, 1): -4177.297804829595,
(4, 2): -4170.504863775563,
(4, 3): -4173.944209869453,
(4, 4): -4106.245000924509,
(4, 5): -4125.774821555984,
(4, 6): 0.0,
(4, 7): -3836.9702498224224,
(4, 8): -654.5525306357154,
(4, 9): -2404.980895769819,
(5, 0): -4176.458161376984,
(5, 1): -4172.186748590077,
(5, 2): -4174.260331451008,
(5, 3): -4151.870491929965,
(5, 4): -4129.774742716047,
(5, 5): 0.0,
(5, 6): 0.0,
(5, 7): -3881.510590980134,
(5, 8): -3668.8570806490907,
(5, 9): -3423.782182477868,
(6, 0): -4174.446152289609,
(6, 1): -4164.759351585928,
(6, 2): -4165.122295462905,
(6, 3): -4158.603379427284,
(6, 4): 0.0,
(6, 5): 0.0,
(6, 6): 0.0,
(6, 7): 0.0,
(6, 8): -3550.3751842154347,
(6, 9): -3407.4561412815788}
='State-Value Function for Windy Gridworld', wireframe=False, azim=-150, elev=60); myplot.plot_state_value_surface(V2p, title
='State-Value Function for Windy Gridworld'); myplot.plot_state_value_heatmap_windy_gridworld(V2p, title