Temporal Difference 1 – Estimation of the State-value Function in Reinforcement Learning

Find state values under a given policy

Reinforcement Learning
TD
OpenAI Gym
Author

Kobus Esterhuysen

Published

February 1, 2022

Back to Blog |  LearnableLoopAI.com |  Portfolio of Projects |  LinkedIn


1. Introduction

In a Markov Decision Process (Figure 1) the agent and environment interacts continuously.

Figure 1 Agent/Environment interaction in a MDP

More details are available in Reinforcement Learning: An Introduction by Sutton and Barto.

The dynamics of the MDP is given by \[ \begin{aligned} p(s',r|s,a) &= Pr\{ S_{t+1}=s',R_{t+1}=r | S_t=s,A_t=a \} \\ \end{aligned} \]

The policy of an agent is a mapping from the current state of the environment to an action that the agent needs to take in this state. Formally, a policy is given by \[ \begin{aligned} \pi(a|s) &= Pr\{A_t=a|S_t=s\} \end{aligned} \]

The discounted return is given by \[ \begin{aligned} G_t &= R_{t+1} + \gamma R_{t+2} + \gamma ^2 R_{t+3} + ... + R_T \\ &= \sum_{k=0}^\infty \gamma ^k R_{t+1+k} \end{aligned} \] where \(\gamma\) is the discount factor and \(R\) is the reward.

Most reinforcement learning algorithms involve the estimation of value functions - in our present case, the state-value function. The state-value function maps each state to a measure of “how good it is to be in that state” in terms of expected rewards. Formally, the state-value function, under policy \(\pi\) is given by \[ \begin{aligned} v_\pi(s) &= \mathbb{E}_\pi[G_t|S_t=s] \end{aligned} \]

The Temporal Difference (TD) algorithm discussed in this post will numerically estimate \(v_\pi(s)\).

2. Environment

Figure 2 shows the environment we will use in this series: The Windy Gridworld:

Figure 2 Windy Gridworld Environment

Each episode starts in the start state, S and the agent tries to get to the goal state, G in as few steps as possible. There are four movements or actions that can be applied to the environment by the agent:

  • up (0)
  • right (1)
  • down (2)
  • left (3)

There is a crosswind running upward through the middle of the grid. The strength of the crosswind is indicated in the center columns of the grid. The strength is added to the vertical displacement of a movement or action, based on the strength indicated for the departing state. For example, if you are one cell to the tight of the goal, the action left will take you to the cell just above the goal.

Tasks are episodic and undiscounted. All rewards are -1 until the goal is reached.

States are numbered from 0 to 69 in this 7 by 10 grid, in a row-wise fashion starting in the top-left corner.

The environment is implemented using the OpenAI Gym library.

3. Agent

The agent is the decision maker. It needs to provide instructions to reach the goal in as few steps as possible while compensating for the effect of the wind. After observing the state of the environment, expressed as a number from 0 to 69 that reflects the current position in the grid, the agent can take one of the following actions:

  • up (0)
  • right (1)
  • down (2)
  • left (3)

4. Temporal Difference (TD) Estimation of the State-value Function, \(v_\pi(s)\)

Temporal Difference (TD) methods both bootstrap and sample. To bootstrap means that the update involves an estimate of the value function. To sample means that the update does not involve an expected value.

We will now proceed to estimate the state-value function for the given policy \(\pi\). We can take \(\gamma=1\) as the sum will remain finite:

\[ \large \begin{aligned} v_\pi(s) &= \mathbb{E}_\pi[G_t | S_t=s] \\ &= \mathbb{E}_\pi[R_{t+1} + \gamma G_{t+1} | S_t=s] \\ &= \mathbb{E}_\pi[R_{t+1} + \gamma v_{\pi}(S_{t+1}) | S_t=s] \end{aligned} \]

The TD method uses the current estimate \(V\) instead of the true \(v_{\pi}\). It makes the following update:

\[ \large V(S_t) ← V(S_t) + α[R_{t+1} + \gamma V(S_{t+1} - V(S_t)] \]

The target in this update is:

\[ \large R_{t+1} + \gamma V(S_{t+1}) \]

The target is an estimate of the return.

Like Monte Carlo (MC) methods, TD methods do not require a model of the environment. They only require experience. Unlike MC methods, they do not wait until the end of an episode before making updates. This means TD methods can be incremental so that updates can be made after each step. This requires less memory and required less peak computation.

5. Implementation

Figure 2 shows the algorithm:

Figure 2 TD(0) Prediction, for estimating the state-value function

Next, we present the code that implements the algorithm.

import gym
import matplotlib
import numpy as np
import sys
from collections import defaultdict
import pprint as pp
from matplotlib import pyplot as plt
%matplotlib inline
import itertools
from matplotlib import cm, colors
env = WindyGridworldEnv()

5.1 Policy

The following function captures the policy used by the agent:

def create_policy_random(n_A):
  def policy_function(observation):
    #probabilities for each action, length n_A:
    action_probs = np.ones(n_A, dtype=float)/n_A
    return np.random.choice(np.arange(len(action_probs)), p=action_probs)
  return policy_function      

5.3 Main loop

The following function implements the main loop of the algorithm. It iterates for n_episodes. It also takes a list of monitored_states for which it will record the evolution of state values. This is handy for showing how state values converge during the process.

def td_0_prediction(env, n_episodes, gamma=1.0, alpha=0.5, epsilon=0.1, monitored_states=None, diag=False):
  V = defaultdict(float)
  pi = create_policy_random(env.action_space.n)
  monitored_state_values = defaultdict(list)
  stats = myplot.EpisodeStats(
    episode_lengths=np.zeros(n_episodes),
    episode_rewards=np.zeros(n_episodes))
  for i in range(n_episodes):
    if (i + 1)%10 == 0: print("\rEpisode {}/{}".format(i + 1, n_episodes), end=""); sys.stdout.flush()
    print(f'\nepisode {i}:') if diag else None
    #---initialize St    
    St = env.reset()
    #---repeat for each step of episode
    for t in itertools.count(): #while True:
      #---choose At from St
      At = pi(St)
      #---take action At, observe Rt+1, St+1
      Stp1, Rtp1, done, _ = env.step(At) # St+1, Rt+1 OR s',r
      print(f"---t={t} St, At, Rt+1, St+1: {St, At, Rtp1, Stp1}") if diag else None
      stats.episode_rewards[i] += Rtp1; stats.episode_lengths[i] = t
      #---update V
      V[St] = V[St] + alpha*( Rtp1 + gamma*V[Stp1] - V[St] ); print(f"V[St]: {V[St]}") if diag else None
      St = Stp1
      if done:
        break
    #---until St is terminal            
    if monitored_states:
      for ms in monitored_states: 
        # print("\rQ[{}]: {}".format(msa, Q[s][a]), end=""); sys.stdout.flush()
        monitored_state_values[ms].append(V[ms])
  return V, stats, monitored_state_values

5.4 Monitored states

Let’s pick a number of states to monitor. Each number identifies one of the 70 states.

monitored_states = [0, 7, 57, 68]
V,stats,monitored_state_values = td_0_prediction(
  env, 
  n_episodes=10, 
  alpha=0.1,
  monitored_states=monitored_states,  
  diag=False)
Episode 10/10
V
defaultdict(float,
            {0: -137.80033464618046,
             1: -142.5790573637218,
             2: -153.38627333959087,
             3: -164.31568002773105,
             4: -176.89251225878337,
             5: -183.914514998972,
             6: -191.20756834044525,
             7: -190.92399180138216,
             8: -188.57031090761552,
             9: -183.12953773836352,
             10: -129.34496871071303,
             11: -137.70509013845967,
             12: -148.2564075292973,
             13: -163.57355671774624,
             14: -168.52572248108007,
             15: -136.59594857510058,
             16: -47.71947159507175,
             17: -174.04428610385764,
             18: -184.93312959582892,
             19: -173.05808744316175,
             20: -120.6790406130493,
             21: -127.40417214467884,
             22: -134.03582218025838,
             23: -157.06904394086223,
             24: -140.3131307571433,
             25: -58.20714953578962,
             26: -54.10882697419318,
             27: -137.92427634130968,
             28: -173.99921218868272,
             29: -141.98995706524073,
             30: -101.70243257084873,
             31: -101.7270991868837,
             32: -120.64630789535269,
             33: -140.5876115761742,
             34: -88.27864427163075,
             35: -27.654676797073208,
             36: -0.1,
             37: 0.0,
             38: -144.11639148686803,
             39: -105.9540172086496,
             40: -73.97040001702122,
             41: -87.16453555102125,
             42: -94.14254268702543,
             43: -111.51260351392207,
             44: -44.34677679073506,
             45: -7.199784687087242,
             47: -26.46341299963032,
             48: -82.57726247834209,
             49: -57.66605324164337,
             50: -61.611925018068305,
             51: -68.47009398963326,
             52: -59.006804000112034,
             53: -51.41626761054323,
             54: -35.843329651729235,
             57: -10.303183460376614,
             58: -30.378465513899624,
             59: -25.517464691704166,
             60: -50.42665757427782,
             61: -51.07727460978276,
             62: -47.70493182941376,
             63: -44.756861753242674,
             68: -12.775760805886526,
             69: -4.497100056613684})
V[50]
-61.611925018068305
print(monitored_states[0])
print(monitored_state_values[monitored_states[0]])
0
[-21.474272908702716, -33.72302605938006, -42.47003230048048, -62.50296660467135, -67.12778730329649, -77.20250676785686, -76.99022419718405, -101.71774339987682, -106.62374028592815, -137.80033464618046]
# last value in monitored_states should be value in V
ms = monitored_states[0]; print('ms:', ms)
monitored_state_values[ms][-1], V[ms]
ms: 0
(-137.80033464618046, -137.80033464618046)

5.5 Run 1

V1,stats,monitored_state_values1 = td_0_prediction(
  env, 
  n_episodes=20, 
  alpha=0.5,
  monitored_states=monitored_states,
  diag=False)
Episode 20/20
# last value in monitored_states should be value in V
ms = monitored_states[0]; print('ms:', ms)
monitored_state_values1[ms][-1], V1[ms]
ms: 0
(-735.1459031454285, -735.1459031454285)

The following chart shows how the values of the monitored states converge to their values:

plt.rcParams["figure.figsize"] = (18,10)
for ms in monitored_states:
  plt.plot(monitored_state_values1[ms])
plt.title('Estimated $v_\pi(s)$ for some states', fontsize=18)
plt.xlabel('Episodes', fontsize=16)
plt.ylabel('Estimated $v_\pi(s)$', fontsize=16)
plt.legend(monitored_states, fontsize=16)
plt.show()

Here are some additional metrics:

myplot.plot_episode_stats(stats);

To make a plot of the state-value function we reshape the values to align with the Windy Gridworld pattern.

# convert V1 to V1p for plotting
states_shape = (7, 10)
nS = np.prod(states_shape)
V1p = {}
for s in range(nS):
  position = np.unravel_index(s, states_shape); #print(f"position: {position}")
  V1p[position] = V1[s]
V1p
{(0, 0): -735.1459031454285,
 (0, 1): -730.9210723890925,
 (0, 2): -748.20305259413,
 (0, 3): -754.9261025764922,
 (0, 4): -752.8787838171859,
 (0, 5): -750.0883665323886,
 (0, 6): -743.9117787359457,
 (0, 7): -686.3473071680427,
 (0, 8): -662.4342333214856,
 (0, 9): -657.5562656363784,
 (1, 0): -718.1124016107731,
 (1, 1): -731.8165301553454,
 (1, 2): -743.5612619011463,
 (1, 3): -749.0929971772642,
 (1, 4): -747.278666145852,
 (1, 5): -721.623343444135,
 (1, 6): -686.7081261934021,
 (1, 7): -724.3152993242904,
 (1, 8): -690.1043378911771,
 (1, 9): -629.1439695837937,
 (2, 0): -721.0189278360183,
 (2, 1): -726.3611748241806,
 (2, 2): -736.108040280805,
 (2, 3): -744.1853570649814,
 (2, 4): -727.6459474757083,
 (2, 5): -645.5264132122365,
 (2, 6): -658.855984783116,
 (2, 7): -698.0281822057283,
 (2, 8): -664.3786577812602,
 (2, 9): -524.8556599796428,
 (3, 0): -701.4558214932435,
 (3, 1): -701.3759709309548,
 (3, 2): -719.855350075502,
 (3, 3): -738.8169361136412,
 (3, 4): -724.440157261181,
 (3, 5): -576.7938789729867,
 (3, 6): -466.6315820284753,
 (3, 7): 0.0,
 (3, 8): -661.6487411402459,
 (3, 9): -470.79898477239567,
 (4, 0): -686.8139009088889,
 (4, 1): -681.9207135990542,
 (4, 2): -696.8290154104961,
 (4, 3): -726.154719019099,
 (4, 4): -686.6401005770437,
 (4, 5): -364.0468750239923,
 (4, 6): 0.0,
 (4, 7): -298.06036632267455,
 (4, 8): -284.00554599611735,
 (4, 9): -351.7826173921178,
 (5, 0): -671.362142111826,
 (5, 1): -659.5054173992531,
 (5, 2): -679.8915299579523,
 (5, 3): -691.7872475376071,
 (5, 4): -585.9983407933428,
 (5, 5): 0.0,
 (5, 6): 0.0,
 (5, 7): -246.14569695121958,
 (5, 8): -537.3196032010271,
 (5, 9): -369.68583680235787,
 (6, 0): -663.6650123721063,
 (6, 1): -663.3279431198509,
 (6, 2): -663.524783367199,
 (6, 3): -668.8126886203731,
 (6, 4): 0.0,
 (6, 5): 0.0,
 (6, 6): 0.0,
 (6, 7): 0.0,
 (6, 8): -55.79920769151681,
 (6, 9): -229.7809440998002}
myplot.plot_state_value_surface(V1p, title='State-Value Function for Windy Gridworld', wireframe=True, azim=-150, elev=60);
/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:136: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return array(a, dtype, copy=False, order=order, subok=True)

myplot.plot_state_value_surface(V1p, title='State-Value Function for Windy Gridworld', wireframe=False, azim=-150, elev=60);

# plot_state_value_heatmap(V1p, title='State-Value Function for Windy Gridworld');
myplot.plot_state_value_heatmap_windy_gridworld(V1p, title='State-Value Function for Windy Gridworld');

5.6 Run 2

V2,stats,monitored_state_values2 = td_0_prediction(
  env, 
  n_episodes=200,
  alpha=0.5,
  monitored_states=monitored_states,
  diag=False)
Episode 200/200
# last value in monitored_states should be value in V
ms = monitored_states[0]; print('ms:', ms)
monitored_state_values2[ms][-1], V2[ms]
ms: 0
(-4193.810401393425, -4193.810401393425)
plt.rcParams["figure.figsize"] = (18,10)
for ms in monitored_states:
  plt.plot(monitored_state_values2[ms])
plt.title('Estimated $v_\pi(s)$ for some states', fontsize=18)
plt.xlabel('Episodes', fontsize=16)
plt.ylabel('Estimated $v_\pi(s)$', fontsize=16)
plt.legend(monitored_states, fontsize=16)
plt.show()

Here are some additional metrics:

myplot.plot_episode_stats(stats);

# convert V2 to V2p for plotting
states_shape = (7, 10)
nS = np.prod(states_shape)
V2p = {}
for s in range(nS):
  position = np.unravel_index(s, states_shape); #print(f"position: {position}")
  V2p[position] = V2[s]
V2p
{(0, 0): -4193.810401393425,
 (0, 1): -4193.814757777407,
 (0, 2): -4191.890300247405,
 (0, 3): -4185.298602492931,
 (0, 4): -4174.72733085042,
 (0, 5): -4147.26308505477,
 (0, 6): -4118.407005550107,
 (0, 7): -4085.405156242255,
 (0, 8): -4071.2953094442705,
 (0, 9): -4046.459316882357,
 (1, 0): -4191.019217298803,
 (1, 1): -4191.6556502288895,
 (1, 2): -4190.4299296322715,
 (1, 3): -4190.965173415778,
 (1, 4): -4180.658164880097,
 (1, 5): -4126.107376489743,
 (1, 6): -3987.0554972487753,
 (1, 7): -4071.2527793806516,
 (1, 8): -4057.6179390117704,
 (1, 9): -4021.9726305212557,
 (2, 0): -4186.732672802048,
 (2, 1): -4186.617107436847,
 (2, 2): -4187.14920440839,
 (2, 3): -4187.742698915825,
 (2, 4): -4173.760780122282,
 (2, 5): -4151.302795294076,
 (2, 6): -4043.2356769698185,
 (2, 7): -4065.931040123559,
 (2, 8): -4060.995750405103,
 (2, 9): -3998.1373273491854,
 (3, 0): -4181.435865523515,
 (3, 1): -4183.7105431386135,
 (3, 2): -4185.254196149683,
 (3, 3): -4167.789277506801,
 (3, 4): -4137.248052519253,
 (3, 5): -4129.991220801143,
 (3, 6): -4041.255522160846,
 (3, 7): 0.0,
 (3, 8): -3992.2745542944663,
 (3, 9): -3555.425259222465,
 (4, 0): -4178.48614972832,
 (4, 1): -4177.297804829595,
 (4, 2): -4170.504863775563,
 (4, 3): -4173.944209869453,
 (4, 4): -4106.245000924509,
 (4, 5): -4125.774821555984,
 (4, 6): 0.0,
 (4, 7): -3836.9702498224224,
 (4, 8): -654.5525306357154,
 (4, 9): -2404.980895769819,
 (5, 0): -4176.458161376984,
 (5, 1): -4172.186748590077,
 (5, 2): -4174.260331451008,
 (5, 3): -4151.870491929965,
 (5, 4): -4129.774742716047,
 (5, 5): 0.0,
 (5, 6): 0.0,
 (5, 7): -3881.510590980134,
 (5, 8): -3668.8570806490907,
 (5, 9): -3423.782182477868,
 (6, 0): -4174.446152289609,
 (6, 1): -4164.759351585928,
 (6, 2): -4165.122295462905,
 (6, 3): -4158.603379427284,
 (6, 4): 0.0,
 (6, 5): 0.0,
 (6, 6): 0.0,
 (6, 7): 0.0,
 (6, 8): -3550.3751842154347,
 (6, 9): -3407.4561412815788}
myplot.plot_state_value_surface(V2p, title='State-Value Function for Windy Gridworld', wireframe=False, azim=-150, elev=60);

myplot.plot_state_value_heatmap_windy_gridworld(V2p, title='State-Value Function for Windy Gridworld');