Control of a Cart-Pole Dynamic System with Ray RLlib

Reinforcement Learning (RL) to control the balancing of a pole on a moving cart

Control
Reinforcement Learning
Ray RLlib
Author

Kobus Esterhuysen

Published

November 29, 2021

1. Introduction

The cart-pole problem can be considered as the “Hello World” problem of Reinforcement Learning (RL). It was described by Barto (1983). The physics of the system is as follows:

  • All motion happens in a vertical plane
  • A hinged pole is attached to a cart
  • The cart slides horizontally on a track in an effort to balance the pole vertically
  • The system has four state variables:

\(x\): displacement of the cart

\(\theta\): vertical angle on the pole

\(\dot{x}\): velocity of the cart

\(\dot{\theta}\): angular velocity of the pole

Here is a graphical representation of the system:

Cartpole environment

2. Purpose

The purpose of our activity in this blog post is to construct and train an entity, let’s call it a controller, that can manage the horizontal motions of the cart so that the pole remains as close to vertical as possible. The controlled entity is, of course, the cart and pole system.

3. RLlib Setup

We will use the Ray RLlib framework. In addition, this notebook will be run in Google Collab.

!pip install ray[rllib]
Collecting ray[rllib]
  Downloading ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl (54.7 MB)
     |████████████████████████████████| 54.7 MB 162 kB/s 
Requirement already satisfied: grpcio>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (1.42.0)
Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (7.1.2)
Requirement already satisfied: protobuf>=3.15.3 in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (3.17.3)
Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (21.2.0)
Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (1.0.2)
Collecting redis>=3.5.0
  Downloading redis-4.0.2-py3-none-any.whl (119 kB)
     |████████████████████████████████| 119 kB 46.0 MB/s 
Requirement already satisfied: jsonschema in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (2.6.0)
Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (1.19.5)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (3.13)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (3.4.0)
Requirement already satisfied: scikit-image in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (0.18.3)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (1.1.5)
Requirement already satisfied: matplotlib!=3.4.3 in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (3.2.2)
Collecting lz4
  Downloading lz4-3.1.10-cp37-cp37m-manylinux2010_x86_64.whl (1.8 MB)
     |████████████████████████████████| 1.8 MB 43.9 MB/s 
Requirement already satisfied: dm-tree in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (0.1.6)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (1.4.1)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (2.23.0)
Requirement already satisfied: gym in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (0.17.3)
Collecting tensorboardX>=1.9
  Downloading tensorboardX-2.4.1-py2.py3-none-any.whl (124 kB)
     |████████████████████████████████| 124 kB 60.2 MB/s 
Requirement already satisfied: tabulate in /usr/local/lib/python3.7/dist-packages (from ray[rllib]) (0.8.9)
Requirement already satisfied: six>=1.5.2 in /usr/local/lib/python3.7/dist-packages (from grpcio>=1.28.1->ray[rllib]) (1.15.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib!=3.4.3->ray[rllib]) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib!=3.4.3->ray[rllib]) (1.3.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib!=3.4.3->ray[rllib]) (3.0.6)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib!=3.4.3->ray[rllib]) (2.8.2)
Collecting deprecated
  Downloading Deprecated-1.2.13-py2.py3-none-any.whl (9.6 kB)
Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.7/dist-packages (from deprecated->redis>=3.5.0->ray[rllib]) (1.13.3)
Requirement already satisfied: pyglet<=1.5.0,>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from gym->ray[rllib]) (1.5.0)
Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from gym->ray[rllib]) (1.3.0)
Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from pyglet<=1.5.0,>=1.4.0->gym->ray[rllib]) (0.16.0)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->ray[rllib]) (2018.9)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->ray[rllib]) (2021.10.8)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->ray[rllib]) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->ray[rllib]) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->ray[rllib]) (3.0.4)
Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image->ray[rllib]) (7.1.2)
Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from scikit-image->ray[rllib]) (1.2.0)
Requirement already satisfied: imageio>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image->ray[rllib]) (2.4.1)
Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image->ray[rllib]) (2.6.3)
Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.7/dist-packages (from scikit-image->ray[rllib]) (2021.11.2)
Installing collected packages: deprecated, redis, tensorboardX, ray, lz4
Successfully installed deprecated-1.2.13 lz4-3.1.10 ray-1.8.0 redis-4.0.2 tensorboardX-2.4.1
import ray
import ray.rllib.agents.ppo as ppo
import pandas as pd
import json
import os
import shutil
import sys
!pip list | grep ^ray
ray                           1.8.0
# 
# setup a folder for checkpoints
CHECKPOINT_ROOT = base_dir+"checkpoints/ppo/cart"
# start Ray
ray.init(ignore_reinit_error=True)
{'metrics_export_port': 63167,
 'node_id': '4b86f136e3f35b82039e3e2e778975223b356bbe2f3ae0bad76bbdd4',
 'node_ip_address': '172.28.0.2',
 'object_store_address': '/tmp/ray/session_2021-11-30_01-07-21_219554_60/sockets/plasma_store',
 'raylet_ip_address': '172.28.0.2',
 'raylet_socket_name': '/tmp/ray/session_2021-11-30_01-07-21_219554_60/sockets/raylet',
 'redis_address': '172.28.0.2:6379',
 'session_dir': '/tmp/ray/session_2021-11-30_01-07-21_219554_60',
 'webui_url': None}

4. Hyperparameters

Here we specify all the hyperparameters for the problem:

N_ITERATIONS = 10 #number of training runs
config = ppo.DEFAULT_CONFIG.copy()
config["log_level"] = "WARN"
config["num_workers"] = 1 #use > 1 for using more CPU cores, including over a cluster
config["num_sgd_iter"] = 10 #number of SGD (stochastic gradient descent) iterations per training minibatch
config["sgd_minibatch_size"] = 250
config["model"]["fcnet_hiddens"] = [100, 50]
config["num_cpus_per_worker"] = 0 #avoids running out of resources in the notebook environment when this cell is re-executed

5. Environment

Let’s start with the controlled entity. In Reinforcement Learning, the controlled entity is known as an environment. We make use of an environment provided by the OpenAI Gym framework, known as “CartPole-v1”.

import gym
env = gym.make("CartPole-v1")

Input to Environment

Actions to the environment come from an action space with a size of 2.

env.action_space
Discrete(2)

We will use the convention that the action on the cart is as follows:

  • 0 means LEFT
  • 1 means RIGHT

Evolution of the Environment

The arrival of an action at the input of the environment leads to the update of its state. This is how the environment evolves. To advance the state of the environment, the environment.step method takes an input action and applies it to the environment.

The next fragment of code drives the environment through 30 steps by applying random actions:

env.reset()
for i in range(30):
  observation, reward, done, info = env.step(env.action_space.sample())
  print("step", i, observation, reward, done, info)
env.close()
step 0 [-0.02002099  0.18795453  0.01057995 -0.31008693] 1.0 False {}
step 1 [-0.0162619   0.38292416  0.00437821 -0.59941456] 1.0 False {}
step 2 [-0.00860342  0.57798458 -0.00761008 -0.8907152 ] 1.0 False {}
step 3 [ 0.00295627  0.3829667  -0.02542438 -0.60043419] 1.0 False {}
step 4 [ 0.01061561  0.18820948 -0.03743307 -0.31586674] 1.0 False {}
step 5 [ 0.01437979 -0.00635983 -0.0437504  -0.03521998] 1.0 False {}
step 6 [ 0.0142526  -0.20082797 -0.0444548   0.24334459] 1.0 False {}
step 7 [ 0.01023604 -0.00510016 -0.03958791 -0.06302248] 1.0 False {}
step 8 [ 0.01013404 -0.19963282 -0.04084836  0.21691208] 1.0 False {}
step 9 [ 0.00614138 -0.39414773 -0.03651012  0.49643498] 1.0 False {}
step 10 [-0.00174158 -0.58873635 -0.02658142  0.7773918 ] 1.0 False {}
step 11 [-0.0135163  -0.39325913 -0.01103358  0.47646554] 1.0 False {}
step 12 [-0.02138148 -0.58822357 -0.00150427  0.76565058] 1.0 False {}
step 13 [-0.03314596 -0.78332477  0.01380874  1.05785981] 1.0 False {}
step 14 [-0.04881245 -0.97862694  0.03496594  1.35484476] 1.0 False {}
step 15 [-0.06838499 -0.78396084  0.06206283  1.073302  ] 1.0 False {}
step 16 [-0.08406421 -0.97984574  0.08352887  1.3847984 ] 1.0 False {}
step 17 [-0.10366112 -0.78585886  0.11122484  1.11936153] 1.0 False {}
step 18 [-0.1193783  -0.59235766  0.13361207  0.86353596] 1.0 False {}
step 19 [-0.13122545 -0.39928297  0.15088279  0.61567036] 1.0 False {}
step 20 [-0.13921111 -0.20655524  0.1631962   0.37405461] 1.0 False {}
step 21 [-0.14334222 -0.01408201  0.17067729  0.13694784] 1.0 False {}
step 22 [-0.14362386  0.17823658  0.17341625 -0.0974026 ] 1.0 False {}
step 23 [-0.14005913 -0.01889181  0.1714682   0.24458412] 1.0 False {}
step 24 [-0.14043696  0.17341923  0.17635988  0.01051282] 1.0 False {}
step 25 [-0.13696858 -0.02373578  0.17657013  0.35323964] 1.0 False {}
step 26 [-0.13744329  0.1684936   0.18363493  0.12102105] 1.0 False {}
step 27 [-0.13407342 -0.02871936  0.18605535  0.46555246] 1.0 False {}
step 28 [-0.13464781 -0.2259156   0.1953664   0.81062711] 1.0 False {}
step 29 [-0.13916612 -0.03392968  0.21157894  0.58519959] 1.0 True {}
# 
# install dependencies needed for recording videos
!apt-get install -y xvfb x11-utils
!pip install pyvirtualdisplay==0.2.*
Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following additional packages will be installed:
  libxxf86dga1
Suggested packages:
  mesa-utils
The following NEW packages will be installed:
  libxxf86dga1 x11-utils xvfb
0 upgraded, 3 newly installed, 0 to remove and 37 not upgraded.
Need to get 994 kB of archives.
After this operation, 2,981 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic/main amd64 libxxf86dga1 amd64 2:1.1.4-1 [13.7 kB]
Get:2 http://archive.ubuntu.com/ubuntu bionic/main amd64 x11-utils amd64 7.7+3build1 [196 kB]
Get:3 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 xvfb amd64 2:1.19.6-1ubuntu4.9 [784 kB]
Fetched 994 kB in 1s (839 kB/s)
Selecting previously unselected package libxxf86dga1:amd64.
(Reading database ... 155222 files and directories currently installed.)
Preparing to unpack .../libxxf86dga1_2%3a1.1.4-1_amd64.deb ...
Unpacking libxxf86dga1:amd64 (2:1.1.4-1) ...
Selecting previously unselected package x11-utils.
Preparing to unpack .../x11-utils_7.7+3build1_amd64.deb ...
Unpacking x11-utils (7.7+3build1) ...
Selecting previously unselected package xvfb.
Preparing to unpack .../xvfb_2%3a1.19.6-1ubuntu4.9_amd64.deb ...
Unpacking xvfb (2:1.19.6-1ubuntu4.9) ...
Setting up xvfb (2:1.19.6-1ubuntu4.9) ...
Setting up libxxf86dga1:amd64 (2:1.1.4-1) ...
Setting up x11-utils (7.7+3build1) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
Processing triggers for libc-bin (2.27-3ubuntu1.3) ...
/sbin/ldconfig.real: /usr/local/lib/python3.7/dist-packages/ideep4py/lib/libmkldnn.so.0 is not a symbolic link

Collecting pyvirtualdisplay==0.2.*
  Downloading PyVirtualDisplay-0.2.5-py2.py3-none-any.whl (13 kB)
Collecting EasyProcess
  Downloading EasyProcess-0.3-py2.py3-none-any.whl (7.9 kB)
Installing collected packages: EasyProcess, pyvirtualdisplay
Successfully installed EasyProcess-0.3 pyvirtualdisplay-0.2.5
from pyvirtualdisplay import Display
display = Display(visible=False, size=(1400, 900))
_ = display.start()
from gym.wrappers.monitoring.video_recorder import VideoRecorder
before_training = "before_training.mp4"
video = VideoRecorder(env, before_training)
# returns an initial observation
env.reset()
for i in range(200):
  env.render()
  video.capture_frame()
  observation, reward, done, info = env.step(env.action_space.sample())
video.close()
env.close()
/usr/local/lib/python3.7/dist-packages/gym/logger.py:30: UserWarning: WARN: You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.
  warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))
from base64 import b64encode
def render_mp4(videopath: str) -> str:
  """
  Gets a string containing a b4-encoded version of the MP4 video
  at the specified path.
  """
  mp4 = open(videopath, 'rb').read()
  base64_encoded_mp4 = b64encode(mp4).decode()
  return f'<video width=400 controls><source src="data:video/mp4;' \
         f'base64,{base64_encoded_mp4}" type="video/mp4"></video>'
from IPython.display import HTML
html = render_mp4(before_training)
HTML(html)

Output from Environment

The output from the environment returns a tuple containing:

  • the next observation of the environment
  • the reward
  • a flag indicating whether the episode is done
  • some other information

6. Agent

The controller in our problem is the algorithm used to solve the problem. In RL parlance the controller is known as an Agent. RLlib provides implementations of a variety of Agents.

For our problem we will use the PPO agent.

The fundamental problem for an Agent is how to find the next best action to submit to the environment.

7. Train the agent

ENV = "CartPole-v1" #OpenAI Gym environment for Cart Pole
agent = ppo.PPOTrainer(config, env=ENV)
results = []
episode_data = []
episode_json = []
for n in range(N_ITERATIONS):
    result = agent.train()
    results.append(result)
    episode = {'n': n, 
               'episode_reward_min': result['episode_reward_min'], 
               'episode_reward_mean': result['episode_reward_mean'], 
               'episode_reward_max': result['episode_reward_max'],  
               'episode_len_mean': result['episode_len_mean']}
    episode_data.append(episode)
    episode_json.append(json.dumps(episode))
    file_name = agent.save(CHECKPOINT_ROOT)
    print(f'{n:3d}: Min/Mean/Max reward: {result["episode_reward_min"]:8.4f}/{result["episode_reward_mean"]:8.4f}/{result["episode_reward_max"]:8.4f}. Checkpoint saved to {file_name}')
2021-11-30 01:09:32,414 INFO trainer.py:753 -- Tip: set framework=tfe or the --eager flag to enable TensorFlow eager execution
2021-11-30 01:09:32,416 INFO ppo.py:167 -- In multi-agent mode, policies will be optimized sequentially by the multi-GPU optimizer. Consider setting simple_optimizer=True if this doesn't work for you.
2021-11-30 01:09:32,418 INFO trainer.py:772 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
(pid=482) 2021-11-30 01:09:37,858   WARNING deprecation.py:39 -- DeprecationWarning: `SampleBatch['is_training']` has been deprecated. Use `SampleBatch.is_training` instead. This will raise an error in the future!
2021-11-30 01:09:39,395 WARNING deprecation.py:39 -- DeprecationWarning: `SampleBatch['is_training']` has been deprecated. Use `SampleBatch.is_training` instead. This will raise an error in the future!
2021-11-30 01:09:40,446 WARNING trainer_template.py:186 -- `execution_plan` functions should accept `trainer`, `workers`, and `config` as args!
2021-11-30 01:09:40,457 WARNING util.py:57 -- Install gputil for GPU system monitoring.
2021-11-30 01:09:46,319 WARNING deprecation.py:39 -- DeprecationWarning: `slice` has been deprecated. Use `SampleBatch[start:stop]` instead. This will raise an error in the future!
  0: Min/Mean/Max reward:   9.0000/ 21.9725/ 95.0000. Checkpoint saved to /content/gdrive/My Drive/RLlib/checkpoints/ppo/cart/checkpoint_000001/checkpoint-1
  1: Min/Mean/Max reward:   9.0000/ 29.3971/161.0000. Checkpoint saved to /content/gdrive/My Drive/RLlib/checkpoints/ppo/cart/checkpoint_000002/checkpoint-2
  2: Min/Mean/Max reward:  11.0000/ 42.4400/126.0000. Checkpoint saved to /content/gdrive/My Drive/RLlib/checkpoints/ppo/cart/checkpoint_000003/checkpoint-3
  3: Min/Mean/Max reward:  13.0000/ 53.9900/161.0000. Checkpoint saved to /content/gdrive/My Drive/RLlib/checkpoints/ppo/cart/checkpoint_000004/checkpoint-4
  4: Min/Mean/Max reward:  14.0000/ 69.1700/200.0000. Checkpoint saved to /content/gdrive/My Drive/RLlib/checkpoints/ppo/cart/checkpoint_000005/checkpoint-5
  5: Min/Mean/Max reward:  21.0000/ 90.1300/278.0000. Checkpoint saved to /content/gdrive/My Drive/RLlib/checkpoints/ppo/cart/checkpoint_000006/checkpoint-6
  6: Min/Mean/Max reward:  21.0000/112.2200/391.0000. Checkpoint saved to /content/gdrive/My Drive/RLlib/checkpoints/ppo/cart/checkpoint_000007/checkpoint-7
  7: Min/Mean/Max reward:  21.0000/134.9200/419.0000. Checkpoint saved to /content/gdrive/My Drive/RLlib/checkpoints/ppo/cart/checkpoint_000008/checkpoint-8
  8: Min/Mean/Max reward:  21.0000/159.2900/419.0000. Checkpoint saved to /content/gdrive/My Drive/RLlib/checkpoints/ppo/cart/checkpoint_000009/checkpoint-9
  9: Min/Mean/Max reward:  25.0000/187.1200/500.0000. Checkpoint saved to /content/gdrive/My Drive/RLlib/checkpoints/ppo/cart/checkpoint_000010/checkpoint-10
df = pd.DataFrame(data=episode_data)
df
n episode_reward_min episode_reward_mean episode_reward_max episode_len_mean
0 0 9.0 21.972527 95.0 21.972527
1 1 9.0 29.397059 161.0 29.397059
2 2 11.0 42.440000 126.0 42.440000
3 3 13.0 53.990000 161.0 53.990000
4 4 14.0 69.170000 200.0 69.170000
5 5 21.0 90.130000 278.0 90.130000
6 6 21.0 112.220000 391.0 112.220000
7 7 21.0 134.920000 419.0 134.920000
8 8 21.0 159.290000 419.0 159.290000
9 9 25.0 187.120000 500.0 187.120000
df.plot(x="n", y=["episode_reward_mean", "episode_reward_min", "episode_reward_max"], secondary_y=True);

import pprint
policy = agent.get_policy()
model = policy.model
pprint.pprint(model.variables())
pprint.pprint(model.value_function())
print(model.base_model.summary())
[<tf.Variable 'default_policy/fc_1/kernel:0' shape=(4, 100) dtype=float32>,
 <tf.Variable 'default_policy/fc_1/bias:0' shape=(100,) dtype=float32>,
 <tf.Variable 'default_policy/fc_value_1/kernel:0' shape=(4, 100) dtype=float32>,
 <tf.Variable 'default_policy/fc_value_1/bias:0' shape=(100,) dtype=float32>,
 <tf.Variable 'default_policy/fc_2/kernel:0' shape=(100, 50) dtype=float32>,
 <tf.Variable 'default_policy/fc_2/bias:0' shape=(50,) dtype=float32>,
 <tf.Variable 'default_policy/fc_value_2/kernel:0' shape=(100, 50) dtype=float32>,
 <tf.Variable 'default_policy/fc_value_2/bias:0' shape=(50,) dtype=float32>,
 <tf.Variable 'default_policy/fc_out/kernel:0' shape=(50, 2) dtype=float32>,
 <tf.Variable 'default_policy/fc_out/bias:0' shape=(2,) dtype=float32>,
 <tf.Variable 'default_policy/value_out/kernel:0' shape=(50, 1) dtype=float32>,
 <tf.Variable 'default_policy/value_out/bias:0' shape=(1,) dtype=float32>]
<tf.Tensor 'Reshape:0' shape=(?,) dtype=float32>
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 observations (InputLayer)      [(None, 4)]          0           []                               
                                                                                                  
 fc_1 (Dense)                   (None, 100)          500         ['observations[0][0]']           
                                                                                                  
 fc_value_1 (Dense)             (None, 100)          500         ['observations[0][0]']           
                                                                                                  
 fc_2 (Dense)                   (None, 50)           5050        ['fc_1[0][0]']                   
                                                                                                  
 fc_value_2 (Dense)             (None, 50)           5050        ['fc_value_1[0][0]']             
                                                                                                  
 fc_out (Dense)                 (None, 2)            102         ['fc_2[0][0]']                   
                                                                                                  
 value_out (Dense)              (None, 1)            51          ['fc_value_2[0][0]']             
                                                                                                  
==================================================================================================
Total params: 11,253
Trainable params: 11,253
Non-trainable params: 0
__________________________________________________________________________________________________
None

Visualization after training

after_training = "after_training.mp4"
after_video = VideoRecorder(env, after_training)
observation = env.reset()
done = False
while not done:
  env.render()
  after_video.capture_frame()
  action = agent.compute_action(observation)
  observation, reward, done, info = env.step(action)
after_video.close()
env.close()
# You should get a video similar to the one below. 
html = render_mp4(after_training)
HTML(html)
2021-11-30 01:13:24,634 WARNING deprecation.py:39 -- DeprecationWarning: `compute_action` has been deprecated. Use `compute_single_action` instead. This will raise an error in the future!