Temporal Difference 2 – Estimation of the Action-value Function in Reinforcement Learning

Find action values under a given policy

Reinforcement Learning
TD
OpenAI Gym
Author

Kobus Esterhuysen

Published

February 2, 2022

Back to Blog |  LearnableLoopAI.com |  Portfolio of Projects |  LinkedIn


1. Introduction

In a Markov Decision Process (Figure 1) the agent and environment interacts continuously.

Figure 1 Agent/Environment interaction in a MDP

More details are available in Reinforcement Learning: An Introduction by Sutton and Barto.

The dynamics of the MDP is given by \[ \begin{aligned} p(s',r|s,a) &= Pr\{ S_{t+1}=s',R_{t+1}=r | S_t=s,A_t=a \} \\ \end{aligned} \]

The policy of an agent is a mapping from the current state of the environment to an action that the agent needs to take in this state. Formally, a policy is given by \[ \begin{aligned} \pi(a|s) &= Pr\{A_t=a|S_t=s\} \end{aligned} \]

The discounted return is given by \[ \begin{aligned} G_t &= R_{t+1} + \gamma R_{t+2} + \gamma ^2 R_{t+3} + ... + R_T \\ &= \sum_{k=0}^\infty \gamma ^k R_{t+1+k} \end{aligned} \] where \(\gamma\) is the discount factor and \(R\) is the reward.

Most reinforcement learning algorithms involve the estimation of value functions - in our present case, the action-value function. The action-value function maps each state-action pair to a measure of “how good it is to be in that state-action” in terms of expected rewards. Formally, the action-value function, under policy \(\pi\) is given by \[ \begin{aligned} q_\pi(s,a) &= \mathbb{E}_\pi[G_t|S_t=s, A_t=a] \end{aligned} \]

The Temporal Difference (TD) algorithm discussed in this post will numerically estimate \(q_\pi(s,a)\).

2. Environment

Figure 2 shows the environment we will use in this series: The Windy Gridworld:

Figure 2 Windy Gridworld Environment

Each episode starts in the start state, S and the agent tries to get to the goal state, G in as few steps as possible. There are four movements or actions that can be applied to the environment by the agent:

  • up (0)
  • right (1)
  • down (2)
  • left (3)

There is a crosswind running upward through the middle of the grid. The strength of the crosswind is indicated in the center columns of the grid. The strength is added to the vertical displacement of a movement or action, based on the strength indicated for the departing state. For example, if you are one cell to the tight of the goal, the action left will take you to the cell just above the goal.

Tasks are episodic and undiscounted. All rewards are -1 until the goal is reached.

States are numbered from 0 to 69 in this 7 by 10 grid, in a row-wise fashion starting in the top-left corner.

The environment is implemented using the OpenAI Gym library.

3. Agent

The agent is the decision maker. It needs to provide instructions to reach the goal in as few steps as possible while compensating for the effect of the wind. After observing the state of the environment, expressed as a number from 0 to 69 that reflects the current position in the grid, the agent can take one of the following actions:

  • up (0)
  • right (1)
  • down (2)
  • left (3)

4. Temporal Difference (TD) Estimation of the Action-value Function, \(q_\pi(s,a)\)

We will now proceed to estimate the action-value function for the given policy \(\pi\). We can take \(\gamma=1\) as the sum will remain finite:

In this case the TD method will make the following update:

\[ \large Q(S_t,A_t) ← Q(S_t,A_t) + α[R_{t+1} + \gamma Q(S_{t+1},A_{t+1}) - Q(S_t,A_t)] \]

The target in this update is:

\[ \large R_{t+1} + \gamma Q(S_{t+1},A_{t+1}) \]

Algorithms making use of this kind of update rule are known as SARSA algorithms. This acronym derives from the fact that the following five elements are needed for an update:

\[ \large S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1} \]

5. Implementation

We consider two versions of the MC Prediction algorithm: A forward version, and a backward version.

Figure 2 shows the algorithm:

Figure 2 TD Prediction (Sarsa), for estimating the action-value function

Next, we present the code that implements the algorithm.

import gym
import matplotlib
import numpy as np
import sys
from collections import defaultdict
import pprint as pp
from matplotlib import pyplot as plt
%matplotlib inline
import itertools
from matplotlib import cm, colors
env = WindyGridworldEnv()

5.1 Policy

The following function captures the policy used by the agent:

def create_policy_random(n_A):
  def policy_function(observation):
    #probabilities for each action, length n_A:
    action_probs = np.ones(n_A, dtype=float)/n_A
    return np.random.choice(np.arange(len(action_probs)), p=action_probs)
  return policy_function      

5.3 Main loop

The following function implements the main loop of the algorithm. It iterates for n_episodes. It also takes a list of monitored_state_actions for which it will record the evolution of action values. This is handy for showing how action values converge during the process.

def td_prediction_sarsa(env, n_episodes, gamma=1.0, alpha=0.5, epsilon=0.1, monitored_state_actions=None, diag=False):
  Q = defaultdict(lambda: np.zeros(env.action_space.n))
  pi = create_policy_random(env.action_space.n)
  monitored_state_action_values = defaultdict(list)
  stats = myplot.EpisodeStats(
    episode_lengths=np.zeros(n_episodes),
    episode_rewards=np.zeros(n_episodes))
  for i in range(n_episodes):
    if (i + 1)%10 == 0: print("\rEpisode {}/{}".format(i + 1, n_episodes), end=""); sys.stdout.flush()
    print(f'\nepisode {i}:') if diag else None
    #---initialize St    
    St = env.reset()
    #---choose At from St, using given policy
    At = pi(St)
    #---repeat for each step of episode
    for t in itertools.count(): #while True:
      #---take action At, observe Rt+1, St+1
      Stp1, Rtp1, done, _ = env.step(At) # St+1, Rt+1 OR s',r
      #---choose At+1 from St+1, using given policy
      Atp1 = pi(Stp1)
      print(f"---t={t} St, At, Rt+1, St+1, At+1: {St, At, Rtp1, Stp1, Atp1}") if diag else None
      stats.episode_rewards[i] += Rtp1; stats.episode_lengths[i] = t
      #---update Q
      Q[St][At] = Q[St][At] + alpha*( Rtp1 + gamma*Q[Stp1][Atp1] - Q[St][At] ); print(f"Q[St][At]: {Q[St][At]}") if diag else None
      St = Stp1; At = Atp1
      if done:
          break
    #---until St is terminal
    if monitored_state_actions:
      for msa in monitored_state_actions: 
        s = msa[0]; a = msa[1]            
        # print("\rQ[{}]: {}".format(msa, Q[s][a]), end=""); sys.stdout.flush()
        monitored_state_action_values[msa].append(Q[s][a])
  return Q, stats, monitored_state_action_values

5.4 Monitored state-actions

Let’s pick a number of state-actions to monitor. Each tuple captures the state number (0 to 69) and an action (0, 1, 2, 3)

monitored_state_actions = [(0, 1), (7, 3), (57, 2), (68, 0)]
Q,stats,monitored_state_action_values = td_prediction_sarsa(
  env, 
  n_episodes=10, 
  alpha=0.5,
  monitored_state_actions=monitored_state_actions,  
  diag=False)
Episode 10/10
Q
defaultdict(<function __main__.td_prediction_sarsa.<locals>.<lambda>>,
            {0: array([-130.21765987, -143.5432508 , -119.89485328, -133.88642559]),
             1: array([-143.25586199, -144.70357216, -141.7840175 , -137.20258816]),
             2: array([-151.44191022, -170.21283604, -148.33133697, -137.08449294]),
             3: array([-164.76415063, -174.77154891, -154.4987116 , -154.79278312]),
             4: array([-181.62326961, -184.40197609, -183.45675144, -164.50839718]),
             5: array([-188.44514126, -189.42454014, -187.66972352, -181.48821112]),
             6: array([-187.85897908, -188.70534129, -188.48478915, -187.04432567]),
             7: array([-188.82039454, -182.29671548, -188.23317708, -187.92342654]),
             8: array([-182.68005963, -178.9240007 , -179.33804605, -185.96499726]),
             9: array([-168.6376302 , -183.74847608, -133.41859813, -184.85356041]),
             10: array([-128.70861097, -139.0681346 , -129.06622311, -133.00609635]),
             11: array([-147.7567565 , -150.46251265, -127.52902579, -136.7273636 ]),
             12: array([-154.96216268, -162.6326481 , -144.9227521 , -131.21919519]),
             13: array([-163.08944615, -173.89644405, -166.61877017, -160.52108247]),
             14: array([-164.61709914, -171.63120901, -158.84578362, -157.88680435]),
             15: array([-149.3642964 , -120.16430885,  -45.07792077,  -76.78454458]),
             16: array([ -78.944795  ,  -76.16542744, -129.04716579,  -83.25377903]),
             17: array([-137.36657407, -150.9882114 , -177.37220363, -149.10277867]),
             18: array([-174.91305027, -177.74871199, -178.44852154, -182.10046244]),
             19: array([-177.16533324, -173.50184903, -111.27057527, -177.47499287]),
             20: array([-136.26843505, -123.14952555, -102.40774147, -130.05535791]),
             21: array([-129.22470755, -122.7225615 , -101.26803638, -125.00831896]),
             22: array([-150.91532888, -158.106315  , -116.53872736, -118.62548079]),
             23: array([-161.47142422, -157.80460354, -154.00444374, -137.13353016]),
             24: array([-154.85927344,  -92.34942078,  -79.50755276, -140.07444859]),
             25: array([-122.8983022 ,  -51.73995068,  -13.65733695, -101.61408625]),
             26: array([  0.        , -30.74928712, -13.86143982, -50.15577582]),
             27: array([-169.10620314, -149.19676308, -127.43972177,  -78.25542231]),
             28: array([-167.08030531, -175.43875759, -159.11300754, -169.19508907]),
             29: array([-166.70837493, -165.29801628,  -67.28195202, -166.57272592]),
             30: array([-129.51881087, -119.69581108,  -98.29786383,  -99.70238096]),
             31: array([-132.31577688, -116.27862717,  -87.33164007, -109.87119308]),
             32: array([-140.37953922, -139.64172542, -104.17113408, -121.47213232]),
             33: array([-157.20389098, -125.11021188, -142.78629176, -132.27103827]),
             34: array([-138.14280894,  -35.42179794,  -64.62854511, -133.03154987]),
             35: array([ 0.  , -0.75, -1.  ,  0.  ]),
             36: array([  0.        , -24.19502755,  -0.5       ,   0.        ]),
             37: array([0., 0., 0., 0.]),
             38: array([-154.79212798, -154.1070795 , -120.93259191, -133.73185606]),
             39: array([-139.2662964 ,  -99.6882193 ,  -15.97640896, -138.75404355]),
             40: array([ -99.54789193, -102.22702182,  -87.40372696,  -88.40134742]),
             41: array([-101.98254871,  -98.38565691,  -69.37589338,  -94.84172441]),
             42: array([-123.37304589, -118.0907383 ,  -76.02251519,  -92.93773595]),
             43: array([-143.24224429, -101.64108766,  -83.57553191, -107.27090594]),
             44: array([-44.25477232,  -1.        , -16.57231825, -80.95996634]),
             45: array([-16.93394358,  -0.75      ,   0.        , -22.09487522]),
             47: array([ -6.31149394, -79.30964076,  -0.5       ,   0.        ]),
             48: array([-90.62522153, -84.38265235,  -2.4569325 ,  -0.99804688]),
             49: array([-83.84590394, -20.69362764, -16.20883841, -10.86993516]),
             50: array([-107.01682414,  -95.61187957,  -66.11468079,  -80.64601024]),
             51: array([-99.75658923, -68.41114379, -76.37980858, -65.42847011]),
             52: array([-91.97907049, -75.16249195, -71.64981997, -93.51761803]),
             53: array([-111.96337608,  -20.58037682,  -63.18864706,  -95.77321694]),
             54: array([-57.40446941,  -3.81152555,  -6.60539235, -53.90241747]),
             57: array([-17.49193515,   0.        ,   0.        ,   0.        ]),
             58: array([-63.44279336, -28.64442117, -17.38872588,  -0.875     ]),
             59: array([-25.35226593,  -7.61031499,  -3.00378394,  -5.31714515]),
             60: array([-96.13571125, -65.74693201, -68.28867688, -71.9065069 ]),
             61: array([-82.16606589, -55.63741037, -64.05846489, -72.63316174]),
             62: array([-66.96678607, -42.29627695, -68.47674789, -62.88820874]),
             63: array([-122.38559097,  -14.33604434,  -12.47281612,  -48.94850071]),
             68: array([ 0.  , -0.75, -0.5 , -0.5 ]),
             69: array([-28.60080249,  -2.66784644, -14.90611466,  -0.75      ])})
Q[50]
array([-107.01682414,  -95.61187957,  -66.11468079,  -80.64601024])
Q[7][0], Q[7][1], Q[7][2], Q[7][3]
(-188.82039454324692,
 -182.2967154789809,
 -188.23317708192786,
 -187.92342653635168)
print(monitored_state_actions[0])
print(monitored_state_action_values[monitored_state_actions[0]])
(0, 1)
[-9.675669392355946, -15.124011334242805, -26.71110515544056, -38.7291227291568, -76.69174707511156, -73.38136859322967, -73.38136859322967, -121.03548085187168, -142.42781434742943, -143.54325080210805]
# last value in monitored_state_actions should be value in Q
msa = monitored_state_actions[0]; print('msa:', msa)
s = msa[0]; print('s:', s)
a = msa[1]; print('a:', a)
monitored_state_action_values[msa][-1], Q[s][a] #monitored_stuff[msa] BUT Q[s][a]
msa: (0, 1)
s: 0
a: 1
(-143.54325080210805, -143.54325080210805)

5.5 Run 1

Q1,stats,monitored_state_action_values1 = td_prediction_sarsa(
  env, 
  n_episodes=20, 
  alpha=0.5,
  monitored_state_actions=monitored_state_actions,  
  diag=False)
Episode 20/20
# last value in monitored_state_actions should be value in Q
msa = monitored_state_actions[0]; print('msa:', msa)
s = msa[0]; print('s:', s)
a = msa[1]; print('a:', a)
monitored_state_action_values1[msa][-1], Q1[s][a] #monitored_stuff[msa] BUT Q[s][a]
msa: (0, 1)
s: 0
a: 1
(-349.70479826983944, -349.70479826983944)

5.5.1 Monitored state-actions

The following chart shows how the values of the 4 monitored state-actions converge to their values:

plt.rcParams["figure.figsize"] = (18,10)
for msa in monitored_state_actions:
  plt.plot(monitored_state_action_values1[msa])
plt.title('Estimated $q_\pi(s,a)$ for some state-actions', fontsize=18)
plt.xlabel('Episodes', fontsize=16)
plt.ylabel('Estimated $q_\pi(s,a)$', fontsize=16)
plt.legend(monitored_state_actions, fontsize=16)
plt.show()

5.5.2 Other metrics

Here are some additional metrics:

myplot.plot_episode_stats(stats);

5.5.3 State-value function

To make a plot of the state-value function we derive the state-value function from the action-value function:

# create state-value function from action-value function 
V1 = defaultdict(float)
for state, actions in Q1.items():
    action_value = np.max(actions)
    V1[state] = action_value

To make a plot of the state-value function we reshape the values to align with the Windy Gridworld pattern.

# convert V1 to V1p for plotting
states_shape = (7, 10)
nS = np.prod(states_shape)
V1p = {}
for s in range(nS):
  position = np.unravel_index(s, states_shape); #print(f"position: {position}")
  V1p[position] = V1[s]
V1p
{(0, 0): -327.6071922498886,
 (0, 1): -323.86069090792967,
 (0, 2): -328.85814327332844,
 (0, 3): -344.9640026141235,
 (0, 4): -367.7745154409497,
 (0, 5): -371.3770111178868,
 (0, 6): -375.48137244500117,
 (0, 7): -366.93211438939494,
 (0, 8): -361.72724799049274,
 (0, 9): -351.35391456286993,
 (1, 0): -320.53422131352875,
 (1, 1): -325.2242150311729,
 (1, 2): -324.239772198814,
 (1, 3): -350.04577508652443,
 (1, 4): -361.45496680515873,
 (1, 5): -286.1746103140105,
 (1, 6): -109.44140511014332,
 (1, 7): -360.94038190900784,
 (1, 8): -348.67781390674645,
 (1, 9): -290.36373655105893,
 (2, 0): -302.1787655552366,
 (2, 1): -296.58457760037595,
 (2, 2): -316.05581872302525,
 (2, 3): -344.02910502918513,
 (2, 4): -298.41429286385596,
 (2, 5): -153.72312343131915,
 (2, 6): 0.0,
 (2, 7): -295.20593059832555,
 (2, 8): -355.92922004811703,
 (2, 9): -184.34839030676795,
 (3, 0): -290.92803360750634,
 (3, 1): -282.7727544690155,
 (3, 2): -282.5124479431155,
 (3, 3): -332.95139123228034,
 (3, 4): -164.1041169313297,
 (3, 5): -47.48642145752244,
 (3, 6): 0.0,
 (3, 7): 0.0,
 (3, 8): -306.0836033845137,
 (3, 9): -92.8672806085619,
 (4, 0): -268.4850531128284,
 (4, 1): -268.4262610494959,
 (4, 2): -270.63921862328755,
 (4, 3): -296.94096824805774,
 (4, 4): -136.134255227891,
 (4, 5): -8.21472230878848,
 (4, 6): 0.0,
 (4, 7): 0.0,
 (4, 8): -0.9999990463256836,
 (4, 9): -24.033126624036612,
 (5, 0): -264.4047862316787,
 (5, 1): -255.932044558136,
 (5, 2): -261.9487198201093,
 (5, 3): -253.26877938647434,
 (5, 4): -59.39333489551983,
 (5, 5): 0.0,
 (5, 6): 0.0,
 (5, 7): 0.0,
 (5, 8): -140.81414257329675,
 (5, 9): -72.23483755767484,
 (6, 0): -252.9062288835798,
 (6, 1): -251.5652242952258,
 (6, 2): -254.9386564133718,
 (6, 3): -127.75548151219893,
 (6, 4): 0.0,
 (6, 5): 0.0,
 (6, 6): 0.0,
 (6, 7): 0.0,
 (6, 8): -0.875,
 (6, 9): -39.643911746737146}
myplot.plot_state_value_surface(V1p, title='State-Value Function for Windy Gridworld', wireframe=True, azim=-150, elev=60);
/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:136: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return array(a, dtype, copy=False, order=order, subok=True)

myplot.plot_state_value_surface(V1p, title='State-Value Function for Windy Gridworld', wireframe=False, azim=-150, elev=60);

myplot.plot_state_value_heatmap_windy_gridworld(V1p, title='State-Value Function for Windy Gridworld');

5.5.4 Policy function

# create policy function from action-value function 
P1 = defaultdict(float)
for state, actions in Q1.items():
    action = np.argmax(actions)
    P1[state] = action
# convert P1 to P1p for plotting
states_shape = (7, 10)
nS = np.prod(states_shape)
P1p = {}
for s in range(nS):
  position = np.unravel_index(s, states_shape); #print(f"position: {position}")
  P1p[position] = P1[s]
myplot.plot_policy_windy_gridworld(P1p, title='Policy Function for Windy Gridworld');

5.6 Run 2

Q2,stats,monitored_state_action_values2 = td_prediction_sarsa(
    env, 
    n_episodes=200, 
    alpha=0.5,    
    monitored_state_actions=monitored_state_actions,
    diag=False)
Episode 200/200
# last value in monitored_state_actions should be value in Q
msa = monitored_state_actions[0]; print('msa:', msa)
s = msa[0]; print('s:', s)
a = msa[1]; print('a:', a)
monitored_state_action_values2[msa][-1], Q2[s][a] #monitored_stuff[msa] BUT Q[s][a]
msa: (0, 1)
s: 0
a: 1
(-1977.1839173706244, -1977.1839173706244)

5.6.1 Monitored state-actions

The following chart shows how the values of the monitored state-actions converge to their values:

plt.rcParams["figure.figsize"] = (18,10)
for msa in monitored_state_actions:
  plt.plot(monitored_state_action_values2[msa])
plt.title('Estimated $q_\pi(s,a)$ for some state-actions', fontsize=18)
plt.xlabel('Episodes', fontsize=16)
plt.ylabel('Estimated $q_\pi(s,a)$', fontsize=16)
plt.legend(monitored_state_actions, fontsize=16)
plt.show()

5.6.2 Other metrics

Here are some additional metrics:

myplot.plot_episode_stats(stats);

5.6.3 State-value function

To make a plot of the state-value function we reshape the values to align with the Windy Gridworld pattern.

# create state-value function from action-value function 
V2 = defaultdict(float)
for state, actions in Q2.items():
    action_value = np.max(actions)
    V2[state] = action_value
# convert V2 to V2p for plotting
states_shape = (7, 10)
nS = np.prod(states_shape)
V2p = {}
for s in range(nS):
  position = np.unravel_index(s, states_shape); #print(f"position: {position}")
  V2p[position] = V2[s]
V2p
{(0, 0): -1976.6067124687627,
 (0, 1): -1977.0994015067404,
 (0, 2): -1972.8905470947368,
 (0, 3): -1975.9459020505924,
 (0, 4): -1956.4970137697956,
 (0, 5): -1961.3755964291297,
 (0, 6): -1939.7875789658233,
 (0, 7): -1907.4095018879138,
 (0, 8): -1866.1975642206223,
 (0, 9): -1861.7647547127203,
 (1, 0): -1971.2271431010415,
 (1, 1): -1970.6705744665903,
 (1, 2): -1968.5256728794718,
 (1, 3): -1977.9058270219132,
 (1, 4): -1967.3140943562007,
 (1, 5): -1920.1807571550594,
 (1, 6): -1703.3399089623265,
 (1, 7): -1901.4968212332117,
 (1, 8): -1905.9896006723643,
 (1, 9): -1873.6567681088377,
 (2, 0): -1960.6615239439993,
 (2, 1): -1953.1921260894733,
 (2, 2): -1967.237191187847,
 (2, 3): -1972.6071747637784,
 (2, 4): -1945.9910533887498,
 (2, 5): -1822.6856722595462,
 (2, 6): -1250.4065633273499,
 (2, 7): -1891.0524309422026,
 (2, 8): -1884.284928372635,
 (2, 9): -1850.3809755131779,
 (3, 0): -1949.4001251122963,
 (3, 1): -1943.796407642659,
 (3, 2): -1950.1484344415687,
 (3, 3): -1962.4487740979143,
 (3, 4): -1923.5071557185,
 (3, 5): -1464.773115720776,
 (3, 6): -917.2905366213927,
 (3, 7): 0.0,
 (3, 8): -1788.6565982098004,
 (3, 9): -1050.423941819396,
 (4, 0): -1933.41404031188,
 (4, 1): -1935.0016460831494,
 (4, 2): -1926.0721447363062,
 (4, 3): -1943.2236150470028,
 (4, 4): -1839.2678704754803,
 (4, 5): -746.3710223675707,
 (4, 6): 0.0,
 (4, 7): -0.9999997615814209,
 (4, 8): -1.0,
 (4, 9): -407.27969256197025,
 (5, 0): -1926.2015230746654,
 (5, 1): -1928.4798241719166,
 (5, 2): -1919.2685658757823,
 (5, 3): -1879.5951191937493,
 (5, 4): -1706.369764419835,
 (5, 5): 0.0,
 (5, 6): 0.0,
 (5, 7): -370.5368901541118,
 (5, 8): -1411.688140001559,
 (5, 9): -1065.7778767656464,
 (6, 0): -1925.2053053323307,
 (6, 1): -1915.3297274089157,
 (6, 2): -1924.8061736735706,
 (6, 3): -1885.2362140996274,
 (6, 4): 0.0,
 (6, 5): 0.0,
 (6, 6): 0.0,
 (6, 7): 0.0,
 (6, 8): -396.9539138131017,
 (6, 9): -1013.9091447979373}
myplot.plot_state_value_surface(V2p, title='State-Value Function for Windy Gridworld', wireframe=False, azim=-150, elev=60);

myplot.plot_state_value_heatmap_windy_gridworld(V2p, title='State-Value Function for Windy Gridworld');

5.6.4 Policy function

# create policy function from action-value function 
P2 = defaultdict(float)
for state, actions in Q2.items():
    action = np.argmax(actions)
    P2[state] = action
# convert P2 to P2p for plotting
states_shape = (7, 10)
nS = np.prod(states_shape)
P2p = {}
for s in range(nS):
  position = np.unravel_index(s, states_shape); #print(f"position: {position}")
  P2p[position] = P2[s]
# myplot.plot_state_value_heatmap_windy_gridworld(P2p, title='Policy Function for Windy Gridworld');
myplot.plot_policy_windy_gridworld(P2p, title='Policy Function for Windy Gridworld');