Agent Demo Using PyMDP

Adapting a PyMDP demo to show the basic opteration of Active Inference

Bayesian Inference
Active Inference
PyMDP
Python
Author

Kobus Esterhuysen

Published

December 18, 2024

Modified

December 19, 2024

Back to Blog |  LearnableLoopAI.com |  Portfolio of Projects |  LinkedIn


We use PyMDP and modify an existing demo:

https://github.com/infer-actively/pymdp/blob/master/examples/agent_demo.ipynb

The modifications are:

Active Inference Demo: Constructing a basic generative model from the “ground up”

This demo notebook provides a full walk-through of how to build a POMDP agent’s generative model and perform active inference routine (inversion of the generative model) using the Agent() class of pymdp. We build a generative model from ‘ground up’, directly encoding our own A, B, and C matrices.

Imports

First, import pymdp and the modules we’ll need.

import os
import sys
import pathlib
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import copy
from pprint import pprint ##.

path = pathlib.Path(os.getcwd())
module_path = str(path.parent) + '/'
sys.path.append(module_path)

from pymdp.agent import Agent
from pymdp import utils
from pymdp.maths import softmax

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator

The world (as represented by the agent’s generative model)

Hidden states

We assume the agent’s “represents” (this should make you think: generative model , not process ) its environment using two latent variables that are statistically independent of one another - we can thus represent them using two hidden state factors.

We refer to these two hidden state factors as

  • GAME_STATE and
  • PLAYING_VS_SAMPLING.

1. GAME_STATE

The first factor is a binary variable representing some ‘reward structure’ that characterises the world. It has two possible values or levels:

  • one level that will lead to rewards with high probability
    • GAME_STATE = 0, a state/level we will call HIGH_REW, and
  • another level that will lead to “punishments” (e.g. losing money) with high probability
    • GAME_STATE = 1, a state/level we will call LOW_REW

You can think of this hidden state factor as describing the ‘pay-off’ structure of e.g. a two-armed bandit or slot-machine with two different settings - one where you’re more likely to win (HIGH_REW), and one where you’re more likely to lose (LOW_REW). Crucially, the agent doesn’t know what the GAME_STATE actually is. They will have to infer it by actively furnishing themselves with observations

2. PLAYING_VS_SAMPLING

The second factor is a ternary (3-valued) variable representing the decision-state or ‘sampling state’ of the agent itself.

  • The first state/level of this hidden state factor is just the
    • starting or initial state of the agent
      • PLAYING_VS_SAMPLING = 0, a state that we can call START
  • the second state/level is the
    • state the agent occupies when “playing” the multi-armed bandit or slot machine
      • PLAYING_VS_SAMPLING = 1, a state that we can call PLAYING
  • the third state/level of this factor is a
    • “sampling state”
      • PLAYING_VS_SAMPLING = 2, a state that we can call SAMPLING
      • This is a decision-state that the agent occupies when it is “sampling” data in order to find out the level of the first hidden state factor - the GAME_STATE.

Observations

The observation modalities themselves are divided into 3 modalities. You can think of these as 3 independent sources of information that the agent has access to. You could think of this in direct perceptual terms - e.g. 3 different sensory organs like eyes, ears, & nose, that give you qualitatively-different kinds of information. Or you can think of it more abstractly - like getting your news from 3 different media sources (online news articles, Twitter feed, and Instagram).

1. Observations of the game state - GAME_STATE_OBS

The first observation modality is the GAME_STATE_OBS modality, and corresponds to observations that give the agent information about the GAME_STATE. There are three possible outcomes within this modality:

  • HIGH_REW_EVIDENCE
    • GAME_STATE_OBS = 0
  • LOW_REW_EVIDENCE
    • GAME_STATE_OBS = 1
  • NO_EVIDENCE
    • GAME_STATE_OBS = 2

So the first outcome can be described as lending evidence to the idea that the GAME_STATE is HIGH_REW; the second outcome can be described as lending evidence to the idea that the GAME_STATE is LOW_REW; and the third outcome within this modality doesn’t tell the agent one way or another whether the GAME_STATE is HIGH_REW or LOW_REW.

2. Reward observations - GAME_OUTCOME

The second observation modality is the GAME_OUTCOME modality, and corresponds to arbitrary observations that are functions of the GAME_STATE. We call the first outcome level of this modality

  • REWARD
    • GAME_OUTCOME = 0, which gives you a hint about how we’ll set up the C matrix (the agent’s “utility function” over outcomes). We call the second outcome level of this modality
  • PUN
    • GAME_OUTCOME = 1, and the third outcome level
  • NEUTRAL
    • GAME_OUTCOME = 2

By design, we will set up the A matrix such that the REWARD outcome is (expected to be) more likely when the GAME_STATE is HIGH_REW (0) and when the agent is in the PLAYING state, and that the PUN outcome is (expected to be) more likely when the GAME_STATE is LOW_REW (1) and the agent is in the PLAYING state. The NEUTRAL outcome is not expected to occur when the agent is playing the game, but will be expected to occur when the agent is in the SAMPLING state. This NEUTRAL outcome within the GAME_OUTCOME modality is thus a meaningless or ‘null’ observation that the agent gets when it’s not actually playing the game (because an observation has to be sampled nonetheless from all modalities).

3. “Proprioceptive” or self-state observations - ACTION_SELF_OBS

The third observation modality is the ACTION_SELF_OBS modality, and corresponds to the agent observing what level of the PLAYING_VS_SAMPLING state it is currently in. These observations are direct, ‘unambiguous’ mappings to the true PLAYING_VS_SAMPLING state, and simply allow the agent to “know” whether it’s playing the game, sampling information to learn about the game state, or where it’s sitting at the START state. The levels of this outcome are simply thus

  • START_O,
  • PLAY_O, and
  • SAMPLE_O,

where the _O suffix simply distinguishes them from their corresponding hidden states, for which they provide direct evidence.

Note about the arbitrariness of ‘labelling’ observations, before defining the A and C matrices.

There is a bit of a circularity here, in that that we’re “pre-empting” what the A matrix (likelihood mapping) should look like, by giving these observations labels that imply particular roles or meanings. An observation per se doesn’t mean anything, it’s just some discrete index that distinguishes it from another observation. It’s only through its probabilistic relationship to hidden states (encoded in the A matrix, as we’ll see below) that we endow an observation with meaning. For example: by already labelling GAME_STATE_OBS=0 as HIGH_REW_EVIDENCE, that’s a hint about how we’re going to structure the A matrix for the GAME_STATE_OBS modality.

## ?utils.get_model_dimensions_from_labels
##.rewrite method to allow for shorter names so that the _lab dict can be used more
##   easily when setting up the matrices
def get_model_dimensions_from_labels(model_labels):
    ## modalities = model_labels['observations']
    modalities = model_labels['obs'] ##.
    num_modalities = len(modalities.keys())
    num_obs = [len(modalities[modality]) for modality in modalities.keys()]

    ## factors = model_labels['states']
    factors = model_labels['sta'] ##.
    num_factors = len(factors.keys())
    num_states = [len(factors[factor]) for factor in factors.keys()]

    ## if 'actions' in model_labels.keys():
    if 'ctr' in model_labels.keys(): ##.
        ## controls = model_labels['actions']
        controls = model_labels['ctr'] ##.
        num_control_fac = len(controls.keys())
        num_controls = [len(controls[cfac]) for cfac in controls.keys()]
        return num_obs, num_modalities, num_states, num_factors, num_controls, num_control_fac
    else:
        return num_obs, num_modalities, num_states, num_factors
_lab = { ##.
    "ctr": {
        "NULL": [
            "NULL_ACTION", 
        ],
        "PLAYING_VS_SAMPLING_CONTROL": [
            "START_ACTION", 
            "PLAY_ACTION", 
            "SAMPLE_ACTION"
        ],
    },
    "sta": {
        "GAME_STATE": [
            "HIGH_REW", 
            "LOW_REW"
        ],
        "PLAYING_VS_SAMPLING": [
            "START", 
            "PLAYING", 
            "SAMPLING"
        ],
    },
    "obs": {
        "GAME_STATE_OBS": [
            "HIGH_REW_EVIDENCE",
            "LOW_REW_EVIDENCE",
            "NO_EVIDENCE"            
        ],
        "GAME_OUTCOME": [
            "REWARD",
            "PUN",
            "NEUTRAL"
        ],
        "ACTION_SELF_OBS": [ ##.direct obser of hidden state PLAYING_VS_SAMPLING
            "START_O",
            "PLAY_O",
            "SAMPLE_O"
        ]
    },
}
_car_obser,_num_obser, \
_car_state,_num_state, \
_car_contr,_num_contr = get_model_dimensions_from_labels(_lab) ##.
_car_obser,_num_obser,_car_state,_num_state,_car_contr,_num_contr
([3, 3, 3], 3, [2, 3], 2, [1, 3], 2)
print(f'{_car_obser=}') ##.cardinality of observation modalities
print(f'{_num_obser=}') ##.number of observation modalities
print(f'{_car_state=}') ##.cardinality of state factors
print(f'{_num_state=}') ##.number of state factors
print(f'{_car_contr=}') ##.cardinality of control factors
print(f'{_num_contr=}') ##.number of control factors
_car_obser=[3, 3, 3]
_num_obser=3
_car_state=[2, 3]
_num_state=2
_car_contr=[1, 3]
_num_contr=2
##.
_ctr_fac_names = list(_lab['ctr'].keys()); print(f'{_ctr_fac_names=}') ##.control factor names
_sta_fac_names = list(_lab['sta'].keys()); print(f'{_sta_fac_names=}') ##.state factor names
_obs_mod_names = list(_lab['obs'].keys()); print(f'{_obs_mod_names=}') ##.observation modality names
_ctr_fac_names=['NULL', 'PLAYING_VS_SAMPLING_CONTROL']
_sta_fac_names=['GAME_STATE', 'PLAYING_VS_SAMPLING']
_obs_mod_names=['GAME_STATE_OBS', 'GAME_OUTCOME', 'ACTION_SELF_OBS']

Setting up observation likelihood matrix - first main component of generative model

## A = utils.obj_array_zeros([[o] + _car_state for _, o in enumerate(_car_obser)]) ##.
_A = utils.obj_array_zeros([[car_o] + _car_state for car_o in _car_obser]) ##.
_A
array([array([[[0., 0., 0.],
               [0., 0., 0.]],

              [[0., 0., 0.],
               [0., 0., 0.]],

              [[0., 0., 0.],
               [0., 0., 0.]]]), array([[[0., 0., 0.],
                                        [0., 0., 0.]],

                                       [[0., 0., 0.],
                                        [0., 0., 0.]],

                                       [[0., 0., 0.],
                                        [0., 0., 0.]]]),
       array([[[0., 0., 0.],
               [0., 0., 0.]],

              [[0., 0., 0.],
               [0., 0., 0.]],

              [[0., 0., 0.],
               [0., 0., 0.]]])], dtype=object)

Set up the first modality’s likelihood mapping, correspond to how "GAME_STATE_OBS" i.e. modality_names[0] are related to hidden states.

## they always get the 'no evidence' observation in the START STATE
_A[0][ ##.
    _lab['obs']['GAME_STATE_OBS'].index('NO_EVIDENCE'), 
    :, 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('START')
] = 1.0
## they always get the 'no evidence' observation in the PLAYING STATE
_A[0][ ##.
    _lab['obs']['GAME_STATE_OBS'].index('NO_EVIDENCE'), 
    :, 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
] = 1.0
## the agent expects to see the HIGH_REW_EVIDENCE observation with 80% probability, 
##   if the GAME_STATE is HIGH_REW, and the agent is in the SAMPLING state
_A[0][ ##.
    _lab['obs']['GAME_STATE_OBS'].index('HIGH_REW_EVIDENCE'),
    _lab['sta']['GAME_STATE'].index('HIGH_REW'), 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
] = 0.8
## the agent expects to see the LOW_REW_EVIDENCE observation with 20% probability, 
##   if the GAME_STATE is HIGH_REW, and the agent is in the SAMPLING state
_A[0][ ##.
    _lab['obs']['GAME_STATE_OBS'].index('LOW_REW_EVIDENCE'),
    _lab['sta']['GAME_STATE'].index('HIGH_REW'), 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
] = 0.2

## the agent expects to see the LOW_REW_EVIDENCE observation with 80% probability, 
##   if the GAME_STATE is LOW_REW, and the agent is in the SAMPLING state
_A[0][ ##.
    _lab['obs']['GAME_STATE_OBS'].index('LOW_REW_EVIDENCE'),
    _lab['sta']['GAME_STATE'].index('LOW_REW'), 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
] = 0.8
## the agent expects to see the HIGH_REW_EVIDENCE observation with 20% probability, 
##   if the GAME_STATE is LOW_REW, and the agent is in the SAMPLING state
_A[0][ ##.
    _lab['obs']['GAME_STATE_OBS'].index('HIGH_REW_EVIDENCE'),
    _lab['sta']['GAME_STATE'].index('LOW_REW'), 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
] = 0.2

## quick way to do this
## A[0][:, :, 0] = 1.0
## A[0][:, :, 1] = 1.0
## A[0][:, :, 2] = np.array([[0.8, 0.2], [0.2, 0.8], [0.0, 0.0]])

_A[0]
array([[[0. , 0. , 0.8],
        [0. , 0. , 0.2]],

       [[0. , 0. , 0.2],
        [0. , 0. , 0.8]],

       [[1. , 1. , 0. ],
        [1. , 1. , 0. ]]])

Set up the second modality’s likelihood mapping, correspond to how "GAME_OUTCOME" i.e. modality_names[1] are related to hidden states.

## regardless of the game state, if you're at the START, you see the 'neutral' outcome
_A[1][ ##.
    _lab['obs']['GAME_OUTCOME'].index('NEUTRAL'), 
    :, 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('START')
] = 1.0

## regardless of the game state, if you're in the SAMPLING state, you see the 'neutral' outcome
_A[1][ ##.
    _lab['obs']['GAME_OUTCOME'].index('NEUTRAL'), 
    :, 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
] = 1.0

## this is the distribution that maps from the "GAME_STATE" to the "GAME_OUTCOME" 
##   observation , in the case that "GAME_STATE" is `HIGH_REW`
_HIGH_REW_MAPPING = softmax(np.array([1.0, 0])) 

## this is the distribution that maps from the "GAME_STATE" to the "GAME_OUTCOME" 
##   observation , in the case that "GAME_STATE" is `LOW_REW`
_LOW_REW_MAPPING = softmax(np.array([0.0, 1.0]))

## fill out the A matrix using the reward probabilities
_A[1][ ##.
    _lab['obs']['GAME_OUTCOME'].index('REWARD'),
    _lab['sta']['GAME_STATE'].index('HIGH_REW'), 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
] = _HIGH_REW_MAPPING[0]
_A[1][ ##.
    _lab['obs']['GAME_OUTCOME'].index('PUN'),
    _lab['sta']['GAME_STATE'].index('HIGH_REW'), 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
] = _HIGH_REW_MAPPING[1]
_A[1][ ##.
    _lab['obs']['GAME_OUTCOME'].index('REWARD'),
    _lab['sta']['GAME_STATE'].index('LOW_REW'), 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
] = _LOW_REW_MAPPING[0]
_A[1][ ##.
    _lab['obs']['GAME_OUTCOME'].index('PUN'),
    _lab['sta']['GAME_STATE'].index('LOW_REW'), 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
] = _LOW_REW_MAPPING[1]

## quick way to do this
## A[1][2, :, 0] = np.ones(num_states[0])
## A[1][0:2, :, 1] = softmax(np.eye(num_obs[1] - 1)) # relationship of game state to reward observations (mapping between reward-state (first hidden state factor) and rewards (Good vs Bad))
## A[1][2, :, 2] = np.ones(num_states[0])

_A[1]
array([[[0.        , 0.73105858, 0.        ],
        [0.        , 0.26894142, 0.        ]],

       [[0.        , 0.26894142, 0.        ],
        [0.        , 0.73105858, 0.        ]],

       [[1.        , 0.        , 1.        ],
        [1.        , 0.        , 1.        ]]])

Set up the third modality’s likelihood mapping, correspond to how "ACTION_SELF_OBS" i.e. modality_names[2] are related to hidden states.

_A[2][ ##.
    _lab['obs']['ACTION_SELF_OBS'].index('START_O'), 
    :, 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('START')
] = 1.0
_A[2][ ##.
    _lab['obs']['ACTION_SELF_OBS'].index('PLAY_O'), 
    :, 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
] = 1.0
_A[2][ ##.
    _lab['obs']['ACTION_SELF_OBS'].index('SAMPLE_O'), 
    :, 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
] = 1.0

## quick way to do this
## modality_idx, factor_idx = 2, 2
## for sampling_state_i in num_states[factor_idx]:
##     A[modality_idx][sampling_state_i,:,sampling_state_i] = 1.0

_A[2]
array([[[1., 0., 0.],
        [1., 0., 0.]],

       [[0., 1., 0.],
        [0., 1., 0.]],

       [[0., 0., 1.],
        [0., 0., 1.]]])
print(f'=== _car_state:\n{_car_state}')
print(f'=== _car_obser:\n{_car_obser}')
_A
=== _car_state:
[2, 3]
=== _car_obser:
[3, 3, 3]
array([array([[[0. , 0. , 0.8],
               [0. , 0. , 0.2]],

              [[0. , 0. , 0.2],
               [0. , 0. , 0.8]],

              [[1. , 1. , 0. ],
               [1. , 1. , 0. ]]]),
       array([[[0.        , 0.73105858, 0.        ],
               [0.        , 0.26894142, 0.        ]],

              [[0.        , 0.26894142, 0.        ],
               [0.        , 0.73105858, 0.        ]],

              [[1.        , 0.        , 1.        ],
               [1.        , 0.        , 1.        ]]]),
       array([[[1., 0., 0.],
               [1., 0., 0.]],

              [[0., 1., 0.],
               [0., 1., 0.]],

              [[0., 0., 1.],
               [0., 0., 1.]]])], dtype=object)

Control states

The ‘control state’ factors are the agent’s representation of the control states (or actions) that it believes can influence the dynamics of the hidden states - i.e. hidden state factors that are under the influence of control states are are ‘controllable’. In practice, we often encode every hidden state factor as being under the influence of control states, but the ‘uncontrollable’ hidden state factors are driven by a trivially-1-dimensional control state or action-affordance. This trivial action simply ‘maintains the default environmental dynamics as they are’ i.e. does nothing. This will become more clear when we set up the transition model (the B matrices) below.

1. NULL

This reflects the agent’s lack of ability to influence the GAME_STATE using policies or actions. The dimensionality of this control factor is 1, and there is only one action along this control factor:

  • NULL_ACTION or “don’t do anything to do the environment”.

This just means that the transition dynamics along the GAME_STATE hidden state factor have their own, uncontrollable dynamics that are not conditioned on this NULL control state - or rather, always conditioned on an unchanging, 1-dimensional NULL_ACTION.

2. PLAYING_VS_SAMPLING_CONTROL

This is a control factor that reflects the agent’s ability to move itself between the START, PLAYING and SAMPLING states of the PLAYING_VS_SAMPLING hidden state factor. The levels/values of this control factor are

  • START_ACTION,
  • PLAY_ACTION, and
  • SAMPLE_ACTION.

When we describe the B matrices below, we will set up the transition dynamics of the PLAYING_VS_SAMPLING hidden state factor, such that they are totally determined by the value of the PLAYING_VS_SAMPLING_CONTROL factor.

(Controllable-) Transition Dynamics

Importantly, some hidden state factors are controllable by the agent, meaning that the probability of being in state \(i\) at \(t+1\) isn’t merely a function of the state at \(t\), but also of actions (or from the generative model’s perspective, control states ). So each transition likelihood or B matrix encodes conditional probability distributions over states at \(t+1\), where the conditioning variables are both the states at \(t-1\) and the actions at \(t-1\). This extra conditioning on control states is encoded by a third, lagging dimension on each factor-specific B matrix. So they are technically B “tensors” or an array of action-conditioned B matrices.

For example, in our case the 2nd hidden state factor (PLAYING_VS_SAMPLING) is under the control of the agent, which means the corresponding transition likelihoods B[1] are index-able by both previous state and action.

## this is the (non-trivial) controllable factor, where there will be a >1-dimensional 
##   control state along this factor
_control_fac_idx = [1] ##.used in Agent constructor
_B = utils.obj_array(_num_state) ##.
_B
array([None, None], dtype=object)
_B[0] = np.zeros((_car_state[0], _car_state[0], _car_contr[0])) ##. 
_B[0]
array([[[0.],
        [0.]],

       [[0.],
        [0.]]])
_p_stoch = 0.0

## we cannot influence factor zero, set up the 'default' stationary dynamics - 
## one state just maps to itself at the next timestep with very high probability, 
## by default. So this means the reward state can change from one to another with 
## some low probability (p_stoch)

_B[0][ ##.
    _lab['sta']['GAME_STATE'].index('HIGH_REW'),
    _lab['sta']['GAME_STATE'].index('HIGH_REW'), 
    _lab['ctr']['NULL'].index('NULL_ACTION')
] = 1.0 - _p_stoch
_B[0][ ##.
    _lab['sta']['GAME_STATE'].index('LOW_REW'),
    _lab['sta']['GAME_STATE'].index('HIGH_REW'), 
    _lab['ctr']['NULL'].index('NULL_ACTION')
] = _p_stoch

_B[0][ ##.
    _lab['sta']['GAME_STATE'].index('LOW_REW'),
    _lab['sta']['GAME_STATE'].index('LOW_REW'), 
    _lab['ctr']['NULL'].index('NULL_ACTION')
] = 1.0 - _p_stoch
_B[0][ ##.
    _lab['sta']['GAME_STATE'].index('HIGH_REW'),
    _lab['sta']['GAME_STATE'].index('LOW_REW'), 
    _lab['ctr']['NULL'].index('NULL_ACTION')
] = _p_stoch

_B[0]
array([[[1.],
        [0.]],

       [[0.],
        [1.]]])
## setup our controllable factor
_B[1] = np.zeros((_car_state[1], _car_state[1], _car_contr[1])) ##. 
_B[1][ ##.
    _lab['sta']['PLAYING_VS_SAMPLING'].index('START'), 
    :, 
    _lab['ctr']['PLAYING_VS_SAMPLING_CONTROL'].index('START_ACTION')
] = 1.0
_B[1][ ##.
    _lab['sta']['PLAYING_VS_SAMPLING'].index('PLAYING'), 
    :, 
    _lab['ctr']['PLAYING_VS_SAMPLING_CONTROL'].index('PLAY_ACTION')
] = 1.0
_B[1][ ##.
    _lab['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING'), 
    :, 
    _lab['ctr']['PLAYING_VS_SAMPLING_CONTROL'].index('SAMPLE_ACTION')
] = 1.0

_B[1]
array([[[1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.]],

       [[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]],

       [[0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.]]])
print(f'=== _car_contr:\n{_car_contr}')
print(f'=== _car_state:\n{_car_state}')
_B
=== _car_contr:
[1, 3]
=== _car_state:
[2, 3]
array([array([[[1.],
               [0.]],

              [[0.],
               [1.]]]), array([[[1., 0., 0.],
                                [1., 0., 0.],
                                [1., 0., 0.]],

                               [[0., 1., 0.],
                                [0., 1., 0.],
                                [0., 1., 0.]],

                               [[0., 0., 1.],
                                [0., 0., 1.],
                                [0., 0., 1.]]])], dtype=object)

Prior preferences

Now we parameterise the C vector, or the prior beliefs about observations. This will be used in the expression of the prior over actions, which is technically a softmax function of the negative expected free energy of each action. It is the equivalent of the exponentiated reward function in reinforcement learning treatments.

_C = utils.obj_array_zeros([car_o for car_o in _car_obser]) ##.
_C
array([array([0., 0., 0.]), array([0., 0., 0.]), array([0., 0., 0.])],
      dtype=object)
## make the observation we've a priori named `REWARD` actually desirable, by building 
##   a high prior expectation of encountering it 
_C[1][ ##.
    _lab['obs']['GAME_OUTCOME'].index('REWARD'),
] = 1.0
## make the observation we've a prior named `PUN` actually aversive, by building a 
##   low prior expectation of encountering it
_C[1][ ##.
    _lab['obs']['GAME_OUTCOME'].index('PUN'),
] = -1.0

## the above code implies the following for the `neutral' observation:
## we don't need to write this - but it's basically just saying that observing `NEUTRAL` 
##   is in between reward and punishment
_C[1][ ##.
    _lab['obs']['GAME_OUTCOME'].index('NEUTRAL'),
] = 0.0

_C[1]
array([ 1., -1.,  0.])

Initialise an instance of the Agent() class:

All you have to do is call Agent(generative_model_params...) where generative_model_params are your A, B, C’s… and whatever parameters of the generative model you want to specify

_agent = Agent(A=_A, B=_B, C=_C, control_fac_idx=_control_fac_idx)
_agent
<pymdp.agent.Agent at 0x7fcd504e4ee0>

Generative process:

Important note how the generative process doesn’t have to be described by A and B matrices - can just be the arbitrary ‘rules of the game’ that you ‘write in’ as a modeller. But here we just use the same transition/likelihood matrices to make the sampling process straightforward.

## transition/observation matrices characterising the generative process
## _A_gp = copy.deepcopy(_A)
_Ă = copy.deepcopy(_A) ##.
## _B_gp = copy.deepcopy(_B)
_B̆ = copy.deepcopy(_B) ##.

Initialise the simulation

## initial state
_T = 20 ## number of timesteps in the simulation
# _T = 100 ## number of timesteps in the simulation
_obser = [ ## initial observation
    _lab['obs']['GAME_STATE_OBS'].index('NO_EVIDENCE'), 
    _lab['obs']['GAME_OUTCOME'].index('NEUTRAL'),
    _lab['obs']['ACTION_SELF_OBS'].index('START_O')
]; print(f'{_obser=}')

_state = [ ## initial (true) state
    _lab['sta']['GAME_STATE'].index('HIGH_REW'), 
    _lab['sta']['PLAYING_VS_SAMPLING'].index('START')
]; print(f'{_state=}')
_obser=[2, 2, 0]
_state=[0, 0]

Create some string names for the state, observation, and action indices to help with print statements

_state_idx_names = [_lab['sta'][sfn] for sfn in _sta_fac_names]; print(f'{_state_idx_names=}') ##.
_obser_idx_names = [_lab['obs'][omn] for omn in _obs_mod_names]; print(f'{_obser_idx_names=}')
_action_idx_names = [_lab['ctr'][cfn] for cfn in _ctr_fac_names]; print(f'{_action_idx_names=}') ##.
_state_idx_names=[['HIGH_REW', 'LOW_REW'], ['START', 'PLAYING', 'SAMPLING']]
_obser_idx_names=[['HIGH_REW_EVIDENCE', 'LOW_REW_EVIDENCE', 'NO_EVIDENCE'], ['REWARD', 'PUN', 'NEUTRAL'], ['START_O', 'PLAY_O', 'SAMPLE_O']]
_action_idx_names=[['NULL_ACTION'], ['START_ACTION', 'PLAY_ACTION', 'SAMPLE_ACTION']]

Run simulation

_sta_facs = {'GAME_STATE': [], 'PLAYING_VS_SAMPLING': []}
_ctr_facs = {'NULL': [], 'PLAYING_VS_SAMPLING_CONTROL': []}
_bel_facs = {'GAME_STATE': [], 'PLAYING_VS_SAMPLING': []}
_obs_mods = {'GAME_STATE_OBS': [], 'GAME_OUTCOME': [], 'ACTION_SELF_OBS': []}
## min_F = []
_qIpiIs = []
_GNegs = []
for t in range(_T):
    print(f"\nTime {t}:")
    print(f"State: {[(_sta_fac_names[sfi], _state_idx_names[sfi][_state[sfi]]) for sfi in range(len(_sta_fac_names))]}") ##.
    for sfi, sfn in enumerate(_sta_fac_names):
        _sta_facs[sfn].append(_state_idx_names[sfi][_state[sfi]])
    print(f"Obser: {[(_obs_mod_names[omi], _obser_idx_names[omi][_obser[omi]]) for omi in range(len(_obs_mod_names))]}") ##.    
    for omi, omn in enumerate(_obs_mod_names):
        _obs_mods[omn].append(_obser_idx_names[omi][_obser[omi]])

    ## update agent
    belief_state = _agent.infer_states(_obser) ##.
    ## _agent.infer_policies()
    qIpiI, GNeg = _agent.infer_policies() ##.posterior over policies and negative EFE
    print(f'{qIpiI=}')
    print(f'{GNeg=}')
    _qIpiIs.append(qIpiI)
    _GNegs.append(GNeg)
    action = _agent.sample_action()
    ## min_F.append(np.min(_agent.F)) ##.does not have .F
    
    ## update environment
    for sfi, sf in enumerate(_state):
        _state[sfi] = utils.sample(_B̆[sfi][:, sf, int(action[sfi])]) ##.
    for omi, _ in enumerate(_obser): ##.
        _obser[omi] = utils.sample(_Ă[omi][:, _state[0], _state[1]]) ##.
        
    print(f"Beliefs: {[(_sta_fac_names[sfi], belief_state[sfi].round(3).T) for sfi in range(len(_sta_fac_names))]}") ##.
    for sfi, sfn in enumerate(_sta_fac_names):
        _bel_facs[sfn].append( _lab['sta'][sfn][int(np.argmax(belief_state[sfi].round(3).T))] )
    print(f"Action: {[(_ctr_fac_names[a], _action_idx_names[a][int(action[a])]) for a in range(len(_sta_fac_names))]}") ##.
    for cfi, cfn in enumerate(_ctr_fac_names):
        _ctr_facs[cfn].append(_action_idx_names[cfi][int(action[cfi])])

Time 0:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'START')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'START_O')]
qIpiI=array([0.03478899, 0.20528675, 0.75992426])
GNeg=array([-3.60483043, -3.49388625, -3.41208556])
Beliefs: [('GAME_STATE', array([0.5, 0.5])), ('PLAYING_VS_SAMPLING', array([1., 0., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]

Time 1:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'LOW_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.1162517 , 0.00435522, 0.87939308])
GNeg=array([-3.60483043, -3.81010428, -3.47836328])
Beliefs: [('GAME_STATE', array([0.2, 0.8])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]

Time 2:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'HIGH_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.03478899, 0.20528675, 0.75992426])
GNeg=array([-3.60483043, -3.49388625, -3.41208556])
Beliefs: [('GAME_STATE', array([0.5, 0.5])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]

Time 3:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'LOW_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.1162517 , 0.00435522, 0.87939308])
GNeg=array([-3.60483043, -3.81010428, -3.47836328])
Beliefs: [('GAME_STATE', array([0.2, 0.8])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]

Time 4:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'LOW_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.3264811 , 0.00071881, 0.6728001 ])
GNeg=array([-3.60483043, -3.98723886, -3.55963817])
Beliefs: [('GAME_STATE', array([0.059, 0.941])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]

Time 5:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'HIGH_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.1162517 , 0.00435522, 0.87939308])
GNeg=array([-3.60483043, -3.81010428, -3.47836328])
Beliefs: [('GAME_STATE', array([0.2, 0.8])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]

Time 6:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'HIGH_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.03478899, 0.20528675, 0.75992426])
GNeg=array([-3.60483043, -3.49388625, -3.41208556])
Beliefs: [('GAME_STATE', array([0.5, 0.5])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]

Time 7:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'LOW_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.1162517 , 0.00435522, 0.87939308])
GNeg=array([-3.60483043, -3.81010428, -3.47836328])
Beliefs: [('GAME_STATE', array([0.2, 0.8])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]

Time 8:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'HIGH_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.03478899, 0.20528675, 0.75992426])
GNeg=array([-3.60483043, -3.49388625, -3.41208556])
Beliefs: [('GAME_STATE', array([0.5, 0.5])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]

Time 9:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'HIGH_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.00362533, 0.96895062, 0.02742404])
GNeg=array([-3.60483043, -3.25556369, -3.47836328])
Beliefs: [('GAME_STATE', array([0.8, 0.2])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]

Time 10:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([0.00121169, 0.9954936 , 0.00329471])
GNeg=array([-3.60483043, -3.18537893, -3.54231138])
Beliefs: [('GAME_STATE', array([0.916, 0.084])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]

Time 11:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'PUN'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([0.00362533, 0.96895062, 0.02742404])
GNeg=array([-3.60483043, -3.25556369, -3.47836328])
Beliefs: [('GAME_STATE', array([0.8, 0.2])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]

Time 12:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([0.00121169, 0.9954936 , 0.00329471])
GNeg=array([-3.60483043, -3.18537893, -3.54231138])
Beliefs: [('GAME_STATE', array([0.916, 0.084])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]

Time 13:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'PUN'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([0.00362533, 0.96895062, 0.02742404])
GNeg=array([-3.60483043, -3.25556369, -3.47836328])
Beliefs: [('GAME_STATE', array([0.8, 0.2])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]

Time 14:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([0.00121169, 0.9954936 , 0.00329471])
GNeg=array([-3.60483043, -3.18537893, -3.54231138])
Beliefs: [('GAME_STATE', array([0.916, 0.084])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]

Time 15:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([7.89044066e-04, 9.98014004e-01, 1.19695225e-03])
GNeg=array([-3.60483043, -3.15841165, -3.57878595])
Beliefs: [('GAME_STATE', array([0.967, 0.033])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]

Time 16:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([6.73423691e-04, 9.98535582e-01, 7.90994555e-04])
GNeg=array([-3.60483043, -3.14847603, -3.59477315])
Beliefs: [('GAME_STATE', array([0.988, 0.012])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]

Time 17:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([6.35368381e-04, 9.98689779e-01, 6.74852661e-04])
GNeg=array([-3.60483043, -3.14483077, -3.60106234])
Beliefs: [('GAME_STATE', array([0.995, 0.005])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]

Time 18:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'PUN'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([6.73423691e-04, 9.98535582e-01, 7.90994555e-04])
GNeg=array([-3.60483043, -3.14847603, -3.59477315])
Beliefs: [('GAME_STATE', array([0.988, 0.012])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]

Time 19:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([6.35368381e-04, 9.98689779e-01, 6.74852661e-04])
GNeg=array([-3.60483043, -3.14483077, -3.60106234])
Beliefs: [('GAME_STATE', array([0.995, 0.005])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
colors = [
{'NULL_ACTION':'black'},
{'START_ACTION':'red', 'PLAY_ACTION':'green', 'SAMPLE_ACTION': 'blue'},

{'HIGH_REW':'orange', 'LOW_REW':'purple'},
{'START':'red', 'PLAYING':'green', 'SAMPLING': 'blue'},

{'HIGH_REW':'orange', 'LOW_REW':'purple'},
{'START':'red', 'PLAYING':'green', 'SAMPLING': 'blue'},

{'HIGH_REW_EVIDENCE':'orange', 'LOW_REW_EVIDENCE':'purple', 'NO_EVIDENCE':'pink'},
{'REWARD':'red', 'PUN':'green', 'NEUTRAL': 'blue'},
{'START_O':'red', 'PLAY_O':'green', 'SAMPLE_O': 'blue'}
]

ylabel_size = 12
msi = 7 ##markersize for Line2D, diameter in points
siz = (msi/2)**2 * np.pi ##size for scatter, area of marker in points squared

fig = plt.figure(figsize=(9, 6))
## gs = GridSpec(6, 1, figure=fig, height_ratios=[1, 3, 1, 3, 3, 1])
gs = GridSpec(9, 1, figure=fig, height_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1])
ax = [fig.add_subplot(gs[i]) for i in range(9)]

i = 0
ax[i].set_title(f'Agent Demo', fontweight='bold',fontsize=14)
y_pos = 0
for t, s in zip(range(_T), _ctr_facs['NULL']):
    ax[i].scatter(t, y_pos, color=colors[i][s], s=siz)
ax[i].set_ylabel('$a^{NULL}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_yticks([])
ax[i].set_xticklabels([])
leg_items = [
    Line2D([0],[0],marker='o',color='w',markerfacecolor='black',markersize=msi,label='NULL_ACTION')]
ax[i].legend(handles=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].spines['top'].set_visible(False); ax[i].spines['right'].set_visible(False)

i = 1
y_pos = 0
for t, s in zip(range(_T), _ctr_facs['PLAYING_VS_SAMPLING_CONTROL']):
    ax[i].scatter(t, y_pos, color=colors[i][s], s=siz, label=s)
ax[i].set_ylabel('$a^{PLAYING\_VS\_SAMPLING\_CONTROL}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_yticks([])
leg_items = [
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['START_ACTION'],markersize=msi,label='START_ACTION'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['PLAY_ACTION'],markersize=msi,label='PLAY_ACTION'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['SAMPLE_ACTION'],markersize=msi,label='SAMPLE_ACTION')]
ax[i].legend(handles=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].spines['top'].set_visible(False); ax[i].spines['right'].set_visible(False)

i = 2
y_pos = 0
for t, s in zip(range(_T), _sta_facs['GAME_STATE']):
    ax[i].scatter(t, y_pos, color=colors[i][s], s=siz)
ax[i].set_ylabel('$s^{GAME\_STATE}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_yticks([])
ax[i].set_xticklabels([])
leg_items = [
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['HIGH_REW'],markersize=msi,label='HIGH_REW'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['LOW_REW'],markersize=msi,label='LOW_REW')]
ax[i].legend(handles=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].spines['top'].set_visible(False); ax[i].spines['right'].set_visible(False)

i = 3
y_pos = 0
for t, s in zip(range(_T), _sta_facs['PLAYING_VS_SAMPLING']):
    ax[i].scatter(t, y_pos, color=colors[i][s], s=siz, label=s)
ax[i].set_ylabel('$s^{PLAYING\_VS\_SAMPLING}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_yticks([])
leg_items = [
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['START'],markersize=msi,label='START'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['PLAYING'],markersize=msi,label='PLAYING'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['SAMPLING'],markersize=msi,label='SAMPLING')]
ax[i].legend(handles=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].spines['top'].set_visible(False); ax[i].spines['right'].set_visible(False)

i = 4
y_pos = 0
for t, s in zip(range(_T), _sta_facs['GAME_STATE']):
    ax[i].scatter(t, y_pos, color=colors[i][s], s=siz)
ax[i].set_ylabel('$q(s)^{GAME\_STATE}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_yticks([])
ax[i].set_xticklabels([])
leg_items = [
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['HIGH_REW'],markersize=msi,label='HIGH_REW'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['LOW_REW'],markersize=msi,label='LOW_REW')]
ax[i].legend(handles=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].spines['top'].set_visible(False); ax[i].spines['right'].set_visible(False)

i = 5
y_pos = 0
for t, s in zip(range(_T), _sta_facs['PLAYING_VS_SAMPLING']):
    ax[i].scatter(t, y_pos, color=colors[i][s], s=siz, label=s)
ax[i].set_ylabel('$q(s)^{PLAYING\_VS\_SAMPLING}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_yticks([])
leg_items = [
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['START'],markersize=msi,label='START'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['PLAYING'],markersize=msi,label='PLAYING'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['SAMPLING'],markersize=msi,label='SAMPLING')]
ax[i].legend(handles=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].spines['top'].set_visible(False); ax[i].spines['right'].set_visible(False)

i = 6
y_pos = 0
for t, s in zip(range(_T), _obs_mods['GAME_STATE_OBS']):
    ax[i].scatter(t, y_pos, color=colors[i][s], s=siz, label=s)
ax[i].set_ylabel('$o^{GAME\_STATE\_OBS}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_yticks([])
leg_items = [
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['HIGH_REW_EVIDENCE'],markersize=msi,label='HIGH_REW_EVIDENCE'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['LOW_REW_EVIDENCE'],markersize=msi,label='LOW_REW_EVIDENCE'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['NO_EVIDENCE'],markersize=msi,label='NO_EVIDENCE')]
ax[i].legend(handles=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].spines['top'].set_visible(False); ax[i].spines['right'].set_visible(False)

i = 7
y_pos = 0
for t, s in zip(range(_T), _obs_mods['GAME_OUTCOME']):
    ax[i].scatter(t, y_pos, color=colors[i][s], s=siz, label=s)
ax[i].set_ylabel('$o^{GAME\_OUTCOME}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_yticks([])
leg_items = [
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['REWARD'],markersize=msi,label='REWARD'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['PUN'],markersize=msi,label='PUN'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['NEUTRAL'],markersize=msi,label='NEUTRAL')]
ax[i].legend(handles=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].spines['top'].set_visible(False); ax[i].spines['right'].set_visible(False)

i = 8
y_pos = 0
for t, s in zip(range(_T), _obs_mods['ACTION_SELF_OBS']):
    ax[i].scatter(t, y_pos, color=colors[i][s], s=siz, label=s)
ax[i].set_ylabel('$o^{ACTION\_SELF\_OBS}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_yticks([])
leg_items = [
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['START_O'],markersize=msi,label='START_O'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['PLAY_O'],markersize=msi,label='PLAY_O'),
    Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['SAMPLE_O'],markersize=msi,label='SAMPLE_O')]
ax[i].legend(handles=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].spines['top'].set_visible(False); ax[i].spines['right'].set_visible(False)

ax[i].xaxis.set_major_locator(MaxNLocator(integer=True))
## ax[i].xaxis.set_major_locator(MaxNLocator(nbins=10, integer=True))
ax[i].set_xlabel('$\mathrm{time,}\ t$', fontweight='bold', fontsize=12)

plt.tight_layout()
plt.subplots_adjust(hspace=0.1) ## Adjust this value as needed
plt.show()