import os
import sys
import pathlib
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import copy
from pprint import pprint ##.
= pathlib.Path(os.getcwd())
path = str(path.parent) + '/'
module_path
sys.path.append(module_path)
from pymdp.agent import Agent
from pymdp import utils
from pymdp.maths import softmax
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator
We use PyMDP
and modify an existing demo:
https://github.com/infer-actively/pymdp/blob/master/examples/agent_demo.ipynb
The modifications are:
- Do some restructuring of the contents
- Use some of my preferred symbols
- Use the
_car
prefix (for cardinality of factors and modalities) - Add a
_lab
dict for all labels - Use implied label indices from _lab to set values of matrices
- Add visualization
- Prefix globals with
_
- Lines where changes were made usually contains
##.
Active Inference Demo: Constructing a basic generative model from the “ground up”
This demo notebook provides a full walk-through of how to build a POMDP agent’s generative model and perform active inference routine (inversion of the generative model) using the Agent()
class of pymdp
. We build a generative model from ‘ground up’, directly encoding our own A, B, and C matrices.
Imports
First, import pymdp
and the modules we’ll need.
The world (as represented by the agent’s generative model)
Observations
The observation modalities themselves are divided into 3 modalities. You can think of these as 3 independent sources of information that the agent has access to. You could think of this in direct perceptual terms - e.g. 3 different sensory organs like eyes, ears, & nose, that give you qualitatively-different kinds of information. Or you can think of it more abstractly - like getting your news from 3 different media sources (online news articles, Twitter feed, and Instagram).
1. Observations of the game state - GAME_STATE_OBS
The first observation modality is the GAME_STATE_OBS
modality, and corresponds to observations that give the agent information about the GAME_STATE
. There are three possible outcomes within this modality:
HIGH_REW_EVIDENCE
GAME_STATE_OBS = 0
LOW_REW_EVIDENCE
GAME_STATE_OBS = 1
NO_EVIDENCE
GAME_STATE_OBS = 2
So the first outcome can be described as lending evidence to the idea that the GAME_STATE
is HIGH_REW
; the second outcome can be described as lending evidence to the idea that the GAME_STATE
is LOW_REW
; and the third outcome within this modality doesn’t tell the agent one way or another whether the GAME_STATE
is HIGH_REW
or LOW_REW
.
2. Reward observations - GAME_OUTCOME
The second observation modality is the GAME_OUTCOME
modality, and corresponds to arbitrary observations that are functions of the GAME_STATE
. We call the first outcome level of this modality
REWARD
GAME_OUTCOME = 0
, which gives you a hint about how we’ll set up the C matrix (the agent’s “utility function” over outcomes). We call the second outcome level of this modality
PUN
GAME_OUTCOME = 1
, and the third outcome level
NEUTRAL
GAME_OUTCOME = 2
By design, we will set up the A
matrix such that the REWARD
outcome is (expected to be) more likely when the GAME_STATE
is HIGH_REW
(0
) and when the agent is in the PLAYING
state, and that the PUN
outcome is (expected to be) more likely when the GAME_STATE
is LOW_REW
(1
) and the agent is in the PLAYING
state. The NEUTRAL
outcome is not expected to occur when the agent is playing the game, but will be expected to occur when the agent is in the SAMPLING
state. This NEUTRAL
outcome within the GAME_OUTCOME
modality is thus a meaningless or ‘null’ observation that the agent gets when it’s not actually playing the game (because an observation has to be sampled nonetheless from all modalities).
3. “Proprioceptive” or self-state observations - ACTION_SELF_OBS
The third observation modality is the ACTION_SELF_OBS
modality, and corresponds to the agent observing what level of the PLAYING_VS_SAMPLING
state it is currently in. These observations are direct, ‘unambiguous’ mappings to the true PLAYING_VS_SAMPLING
state, and simply allow the agent to “know” whether it’s playing the game, sampling information to learn about the game state, or where it’s sitting at the START
state. The levels of this outcome are simply thus
START_O
,PLAY_O
, andSAMPLE_O
,
where the _O
suffix simply distinguishes them from their corresponding hidden states, for which they provide direct evidence.
Note about the arbitrariness of ‘labelling’ observations, before defining the A
and C
matrices.
There is a bit of a circularity here, in that that we’re “pre-empting” what the A matrix (likelihood mapping) should look like, by giving these observations labels that imply particular roles or meanings. An observation per se doesn’t mean anything, it’s just some discrete index that distinguishes it from another observation. It’s only through its probabilistic relationship to hidden states (encoded in the A
matrix, as we’ll see below) that we endow an observation with meaning. For example: by already labelling GAME_STATE_OBS=0
as HIGH_REW_EVIDENCE
, that’s a hint about how we’re going to structure the A
matrix for the GAME_STATE_OBS
modality.
## ?utils.get_model_dimensions_from_labels
##.rewrite method to allow for shorter names so that the _lab dict can be used more
## easily when setting up the matrices
def get_model_dimensions_from_labels(model_labels):
## modalities = model_labels['observations']
= model_labels['obs'] ##.
modalities = len(modalities.keys())
num_modalities = [len(modalities[modality]) for modality in modalities.keys()]
num_obs
## factors = model_labels['states']
= model_labels['sta'] ##.
factors = len(factors.keys())
num_factors = [len(factors[factor]) for factor in factors.keys()]
num_states
## if 'actions' in model_labels.keys():
if 'ctr' in model_labels.keys(): ##.
## controls = model_labels['actions']
= model_labels['ctr'] ##.
controls = len(controls.keys())
num_control_fac = [len(controls[cfac]) for cfac in controls.keys()]
num_controls return num_obs, num_modalities, num_states, num_factors, num_controls, num_control_fac
else:
return num_obs, num_modalities, num_states, num_factors
= { ##.
_lab "ctr": {
"NULL": [
"NULL_ACTION",
],"PLAYING_VS_SAMPLING_CONTROL": [
"START_ACTION",
"PLAY_ACTION",
"SAMPLE_ACTION"
],
},"sta": {
"GAME_STATE": [
"HIGH_REW",
"LOW_REW"
],"PLAYING_VS_SAMPLING": [
"START",
"PLAYING",
"SAMPLING"
],
},"obs": {
"GAME_STATE_OBS": [
"HIGH_REW_EVIDENCE",
"LOW_REW_EVIDENCE",
"NO_EVIDENCE"
],"GAME_OUTCOME": [
"REWARD",
"PUN",
"NEUTRAL"
],"ACTION_SELF_OBS": [ ##.direct obser of hidden state PLAYING_VS_SAMPLING
"START_O",
"PLAY_O",
"SAMPLE_O"
]
},
}\
_car_obser,_num_obser, \
_car_state,_num_state, = get_model_dimensions_from_labels(_lab) ##.
_car_contr,_num_contr _car_obser,_num_obser,_car_state,_num_state,_car_contr,_num_contr
([3, 3, 3], 3, [2, 3], 2, [1, 3], 2)
print(f'{_car_obser=}') ##.cardinality of observation modalities
print(f'{_num_obser=}') ##.number of observation modalities
print(f'{_car_state=}') ##.cardinality of state factors
print(f'{_num_state=}') ##.number of state factors
print(f'{_car_contr=}') ##.cardinality of control factors
print(f'{_num_contr=}') ##.number of control factors
_car_obser=[3, 3, 3]
_num_obser=3
_car_state=[2, 3]
_num_state=2
_car_contr=[1, 3]
_num_contr=2
##.
= list(_lab['ctr'].keys()); print(f'{_ctr_fac_names=}') ##.control factor names
_ctr_fac_names = list(_lab['sta'].keys()); print(f'{_sta_fac_names=}') ##.state factor names
_sta_fac_names = list(_lab['obs'].keys()); print(f'{_obs_mod_names=}') ##.observation modality names _obs_mod_names
_ctr_fac_names=['NULL', 'PLAYING_VS_SAMPLING_CONTROL']
_sta_fac_names=['GAME_STATE', 'PLAYING_VS_SAMPLING']
_obs_mod_names=['GAME_STATE_OBS', 'GAME_OUTCOME', 'ACTION_SELF_OBS']
Setting up observation likelihood matrix - first main component of generative model
## A = utils.obj_array_zeros([[o] + _car_state for _, o in enumerate(_car_obser)]) ##.
= utils.obj_array_zeros([[car_o] + _car_state for car_o in _car_obser]) ##.
_A _A
array([array([[[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.]]]), array([[[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.]]]),
array([[[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.]]])], dtype=object)
Set up the first modality’s likelihood mapping, correspond to how "GAME_STATE_OBS"
i.e. modality_names[0]
are related to hidden states.
## they always get the 'no evidence' observation in the START STATE
0][ ##.
_A['obs']['GAME_STATE_OBS'].index('NO_EVIDENCE'),
_lab[
:, 'sta']['PLAYING_VS_SAMPLING'].index('START')
_lab[= 1.0
] ## they always get the 'no evidence' observation in the PLAYING STATE
0][ ##.
_A['obs']['GAME_STATE_OBS'].index('NO_EVIDENCE'),
_lab[
:, 'sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
_lab[= 1.0
] ## the agent expects to see the HIGH_REW_EVIDENCE observation with 80% probability,
## if the GAME_STATE is HIGH_REW, and the agent is in the SAMPLING state
0][ ##.
_A['obs']['GAME_STATE_OBS'].index('HIGH_REW_EVIDENCE'),
_lab['sta']['GAME_STATE'].index('HIGH_REW'),
_lab['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
_lab[= 0.8
] ## the agent expects to see the LOW_REW_EVIDENCE observation with 20% probability,
## if the GAME_STATE is HIGH_REW, and the agent is in the SAMPLING state
0][ ##.
_A['obs']['GAME_STATE_OBS'].index('LOW_REW_EVIDENCE'),
_lab['sta']['GAME_STATE'].index('HIGH_REW'),
_lab['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
_lab[= 0.2
]
## the agent expects to see the LOW_REW_EVIDENCE observation with 80% probability,
## if the GAME_STATE is LOW_REW, and the agent is in the SAMPLING state
0][ ##.
_A['obs']['GAME_STATE_OBS'].index('LOW_REW_EVIDENCE'),
_lab['sta']['GAME_STATE'].index('LOW_REW'),
_lab['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
_lab[= 0.8
] ## the agent expects to see the HIGH_REW_EVIDENCE observation with 20% probability,
## if the GAME_STATE is LOW_REW, and the agent is in the SAMPLING state
0][ ##.
_A['obs']['GAME_STATE_OBS'].index('HIGH_REW_EVIDENCE'),
_lab['sta']['GAME_STATE'].index('LOW_REW'),
_lab['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
_lab[= 0.2
]
## quick way to do this
## A[0][:, :, 0] = 1.0
## A[0][:, :, 1] = 1.0
## A[0][:, :, 2] = np.array([[0.8, 0.2], [0.2, 0.8], [0.0, 0.0]])
0] _A[
array([[[0. , 0. , 0.8],
[0. , 0. , 0.2]],
[[0. , 0. , 0.2],
[0. , 0. , 0.8]],
[[1. , 1. , 0. ],
[1. , 1. , 0. ]]])
Set up the second modality’s likelihood mapping, correspond to how "GAME_OUTCOME"
i.e. modality_names[1]
are related to hidden states.
## regardless of the game state, if you're at the START, you see the 'neutral' outcome
1][ ##.
_A['obs']['GAME_OUTCOME'].index('NEUTRAL'),
_lab[
:, 'sta']['PLAYING_VS_SAMPLING'].index('START')
_lab[= 1.0
]
## regardless of the game state, if you're in the SAMPLING state, you see the 'neutral' outcome
1][ ##.
_A['obs']['GAME_OUTCOME'].index('NEUTRAL'),
_lab[
:, 'sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
_lab[= 1.0
]
## this is the distribution that maps from the "GAME_STATE" to the "GAME_OUTCOME"
## observation , in the case that "GAME_STATE" is `HIGH_REW`
= softmax(np.array([1.0, 0]))
_HIGH_REW_MAPPING
## this is the distribution that maps from the "GAME_STATE" to the "GAME_OUTCOME"
## observation , in the case that "GAME_STATE" is `LOW_REW`
= softmax(np.array([0.0, 1.0]))
_LOW_REW_MAPPING
## fill out the A matrix using the reward probabilities
1][ ##.
_A['obs']['GAME_OUTCOME'].index('REWARD'),
_lab['sta']['GAME_STATE'].index('HIGH_REW'),
_lab['sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
_lab[= _HIGH_REW_MAPPING[0]
] 1][ ##.
_A['obs']['GAME_OUTCOME'].index('PUN'),
_lab['sta']['GAME_STATE'].index('HIGH_REW'),
_lab['sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
_lab[= _HIGH_REW_MAPPING[1]
] 1][ ##.
_A['obs']['GAME_OUTCOME'].index('REWARD'),
_lab['sta']['GAME_STATE'].index('LOW_REW'),
_lab['sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
_lab[= _LOW_REW_MAPPING[0]
] 1][ ##.
_A['obs']['GAME_OUTCOME'].index('PUN'),
_lab['sta']['GAME_STATE'].index('LOW_REW'),
_lab['sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
_lab[= _LOW_REW_MAPPING[1]
]
## quick way to do this
## A[1][2, :, 0] = np.ones(num_states[0])
## A[1][0:2, :, 1] = softmax(np.eye(num_obs[1] - 1)) # relationship of game state to reward observations (mapping between reward-state (first hidden state factor) and rewards (Good vs Bad))
## A[1][2, :, 2] = np.ones(num_states[0])
1] _A[
array([[[0. , 0.73105858, 0. ],
[0. , 0.26894142, 0. ]],
[[0. , 0.26894142, 0. ],
[0. , 0.73105858, 0. ]],
[[1. , 0. , 1. ],
[1. , 0. , 1. ]]])
Set up the third modality’s likelihood mapping, correspond to how "ACTION_SELF_OBS"
i.e. modality_names[2]
are related to hidden states.
2][ ##.
_A['obs']['ACTION_SELF_OBS'].index('START_O'),
_lab[
:, 'sta']['PLAYING_VS_SAMPLING'].index('START')
_lab[= 1.0
] 2][ ##.
_A['obs']['ACTION_SELF_OBS'].index('PLAY_O'),
_lab[
:, 'sta']['PLAYING_VS_SAMPLING'].index('PLAYING')
_lab[= 1.0
] 2][ ##.
_A['obs']['ACTION_SELF_OBS'].index('SAMPLE_O'),
_lab[
:, 'sta']['PLAYING_VS_SAMPLING'].index('SAMPLING')
_lab[= 1.0
]
## quick way to do this
## modality_idx, factor_idx = 2, 2
## for sampling_state_i in num_states[factor_idx]:
## A[modality_idx][sampling_state_i,:,sampling_state_i] = 1.0
2] _A[
array([[[1., 0., 0.],
[1., 0., 0.]],
[[0., 1., 0.],
[0., 1., 0.]],
[[0., 0., 1.],
[0., 0., 1.]]])
print(f'=== _car_state:\n{_car_state}')
print(f'=== _car_obser:\n{_car_obser}')
_A
=== _car_state:
[2, 3]
=== _car_obser:
[3, 3, 3]
array([array([[[0. , 0. , 0.8],
[0. , 0. , 0.2]],
[[0. , 0. , 0.2],
[0. , 0. , 0.8]],
[[1. , 1. , 0. ],
[1. , 1. , 0. ]]]),
array([[[0. , 0.73105858, 0. ],
[0. , 0.26894142, 0. ]],
[[0. , 0.26894142, 0. ],
[0. , 0.73105858, 0. ]],
[[1. , 0. , 1. ],
[1. , 0. , 1. ]]]),
array([[[1., 0., 0.],
[1., 0., 0.]],
[[0., 1., 0.],
[0., 1., 0.]],
[[0., 0., 1.],
[0., 0., 1.]]])], dtype=object)
Control states
The ‘control state’ factors are the agent’s representation of the control states (or actions) that it believes can influence the dynamics of the hidden states - i.e. hidden state factors that are under the influence of control states are are ‘controllable’. In practice, we often encode every hidden state factor as being under the influence of control states, but the ‘uncontrollable’ hidden state factors are driven by a trivially-1-dimensional control state or action-affordance. This trivial action simply ‘maintains the default environmental dynamics as they are’ i.e. does nothing. This will become more clear when we set up the transition model (the B
matrices) below.
1. NULL
This reflects the agent’s lack of ability to influence the GAME_STATE
using policies or actions. The dimensionality of this control factor is 1, and there is only one action along this control factor:
NULL_ACTION
or “don’t do anything to do the environment”.
This just means that the transition dynamics along the GAME_STATE
hidden state factor have their own, uncontrollable dynamics that are not conditioned on this NULL
control state - or rather, always conditioned on an unchanging, 1-dimensional NULL_ACTION
.
2. PLAYING_VS_SAMPLING_CONTROL
This is a control factor that reflects the agent’s ability to move itself between the START
, PLAYING
and SAMPLING
states of the PLAYING_VS_SAMPLING
hidden state factor. The levels/values of this control factor are
START_ACTION
,PLAY_ACTION
, andSAMPLE_ACTION
.
When we describe the B
matrices below, we will set up the transition dynamics of the PLAYING_VS_SAMPLING
hidden state factor, such that they are totally determined by the value of the PLAYING_VS_SAMPLING_CONTROL
factor.
(Controllable-) Transition Dynamics
Importantly, some hidden state factors are controllable by the agent, meaning that the probability of being in state \(i\) at \(t+1\) isn’t merely a function of the state at \(t\), but also of actions (or from the generative model’s perspective, control states ). So each transition likelihood or B
matrix encodes conditional probability distributions over states at \(t+1\), where the conditioning variables are both the states at \(t-1\) and the actions at \(t-1\). This extra conditioning on control states is encoded by a third, lagging dimension on each factor-specific B
matrix. So they are technically B
“tensors” or an array of action-conditioned B
matrices.
For example, in our case the 2nd hidden state factor (PLAYING_VS_SAMPLING
) is under the control of the agent, which means the corresponding transition likelihoods B[1]
are index-able by both previous state and action.
## this is the (non-trivial) controllable factor, where there will be a >1-dimensional
## control state along this factor
= [1] ##.used in Agent constructor
_control_fac_idx = utils.obj_array(_num_state) ##.
_B _B
array([None, None], dtype=object)
0] = np.zeros((_car_state[0], _car_state[0], _car_contr[0])) ##.
_B[0] _B[
array([[[0.],
[0.]],
[[0.],
[0.]]])
= 0.0
_p_stoch
## we cannot influence factor zero, set up the 'default' stationary dynamics -
## one state just maps to itself at the next timestep with very high probability,
## by default. So this means the reward state can change from one to another with
## some low probability (p_stoch)
0][ ##.
_B['sta']['GAME_STATE'].index('HIGH_REW'),
_lab['sta']['GAME_STATE'].index('HIGH_REW'),
_lab['ctr']['NULL'].index('NULL_ACTION')
_lab[= 1.0 - _p_stoch
] 0][ ##.
_B['sta']['GAME_STATE'].index('LOW_REW'),
_lab['sta']['GAME_STATE'].index('HIGH_REW'),
_lab['ctr']['NULL'].index('NULL_ACTION')
_lab[= _p_stoch
]
0][ ##.
_B['sta']['GAME_STATE'].index('LOW_REW'),
_lab['sta']['GAME_STATE'].index('LOW_REW'),
_lab['ctr']['NULL'].index('NULL_ACTION')
_lab[= 1.0 - _p_stoch
] 0][ ##.
_B['sta']['GAME_STATE'].index('HIGH_REW'),
_lab['sta']['GAME_STATE'].index('LOW_REW'),
_lab['ctr']['NULL'].index('NULL_ACTION')
_lab[= _p_stoch
]
0] _B[
array([[[1.],
[0.]],
[[0.],
[1.]]])
## setup our controllable factor
1] = np.zeros((_car_state[1], _car_state[1], _car_contr[1])) ##.
_B[1][ ##.
_B['sta']['PLAYING_VS_SAMPLING'].index('START'),
_lab[
:, 'ctr']['PLAYING_VS_SAMPLING_CONTROL'].index('START_ACTION')
_lab[= 1.0
] 1][ ##.
_B['sta']['PLAYING_VS_SAMPLING'].index('PLAYING'),
_lab[
:, 'ctr']['PLAYING_VS_SAMPLING_CONTROL'].index('PLAY_ACTION')
_lab[= 1.0
] 1][ ##.
_B['sta']['PLAYING_VS_SAMPLING'].index('SAMPLING'),
_lab[
:, 'ctr']['PLAYING_VS_SAMPLING_CONTROL'].index('SAMPLE_ACTION')
_lab[= 1.0
]
1] _B[
array([[[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.]],
[[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.]],
[[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.]]])
print(f'=== _car_contr:\n{_car_contr}')
print(f'=== _car_state:\n{_car_state}')
_B
=== _car_contr:
[1, 3]
=== _car_state:
[2, 3]
array([array([[[1.],
[0.]],
[[0.],
[1.]]]), array([[[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.]],
[[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.]],
[[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.]]])], dtype=object)
Prior preferences
Now we parameterise the C vector, or the prior beliefs about observations. This will be used in the expression of the prior over actions, which is technically a softmax function of the negative expected free energy of each action. It is the equivalent of the exponentiated reward function in reinforcement learning treatments.
= utils.obj_array_zeros([car_o for car_o in _car_obser]) ##.
_C _C
array([array([0., 0., 0.]), array([0., 0., 0.]), array([0., 0., 0.])],
dtype=object)
## make the observation we've a priori named `REWARD` actually desirable, by building
## a high prior expectation of encountering it
1][ ##.
_C['obs']['GAME_OUTCOME'].index('REWARD'),
_lab[= 1.0
] ## make the observation we've a prior named `PUN` actually aversive, by building a
## low prior expectation of encountering it
1][ ##.
_C['obs']['GAME_OUTCOME'].index('PUN'),
_lab[= -1.0
]
## the above code implies the following for the `neutral' observation:
## we don't need to write this - but it's basically just saying that observing `NEUTRAL`
## is in between reward and punishment
1][ ##.
_C['obs']['GAME_OUTCOME'].index('NEUTRAL'),
_lab[= 0.0
]
1] _C[
array([ 1., -1., 0.])
Initialise an instance of the Agent()
class:
All you have to do is call Agent(generative_model_params...)
where generative_model_params
are your A, B, C’s… and whatever parameters of the generative model you want to specify
= Agent(A=_A, B=_B, C=_C, control_fac_idx=_control_fac_idx)
_agent _agent
<pymdp.agent.Agent at 0x7fcd504e4ee0>
Generative process:
Important note how the generative process doesn’t have to be described by A and B matrices - can just be the arbitrary ‘rules of the game’ that you ‘write in’ as a modeller. But here we just use the same transition/likelihood matrices to make the sampling process straightforward.
## transition/observation matrices characterising the generative process
## _A_gp = copy.deepcopy(_A)
= copy.deepcopy(_A) ##.
_Ă ## _B_gp = copy.deepcopy(_B)
= copy.deepcopy(_B) ##. _B̆
Initialise the simulation
## initial state
= 20 ## number of timesteps in the simulation
_T # _T = 100 ## number of timesteps in the simulation
= [ ## initial observation
_obser 'obs']['GAME_STATE_OBS'].index('NO_EVIDENCE'),
_lab['obs']['GAME_OUTCOME'].index('NEUTRAL'),
_lab['obs']['ACTION_SELF_OBS'].index('START_O')
_lab[; print(f'{_obser=}')
]
= [ ## initial (true) state
_state 'sta']['GAME_STATE'].index('HIGH_REW'),
_lab['sta']['PLAYING_VS_SAMPLING'].index('START')
_lab[; print(f'{_state=}') ]
_obser=[2, 2, 0]
_state=[0, 0]
Create some string names for the state, observation, and action indices to help with print statements
= [_lab['sta'][sfn] for sfn in _sta_fac_names]; print(f'{_state_idx_names=}') ##.
_state_idx_names = [_lab['obs'][omn] for omn in _obs_mod_names]; print(f'{_obser_idx_names=}')
_obser_idx_names = [_lab['ctr'][cfn] for cfn in _ctr_fac_names]; print(f'{_action_idx_names=}') ##. _action_idx_names
_state_idx_names=[['HIGH_REW', 'LOW_REW'], ['START', 'PLAYING', 'SAMPLING']]
_obser_idx_names=[['HIGH_REW_EVIDENCE', 'LOW_REW_EVIDENCE', 'NO_EVIDENCE'], ['REWARD', 'PUN', 'NEUTRAL'], ['START_O', 'PLAY_O', 'SAMPLE_O']]
_action_idx_names=[['NULL_ACTION'], ['START_ACTION', 'PLAY_ACTION', 'SAMPLE_ACTION']]
Run simulation
= {'GAME_STATE': [], 'PLAYING_VS_SAMPLING': []}
_sta_facs = {'NULL': [], 'PLAYING_VS_SAMPLING_CONTROL': []}
_ctr_facs = {'GAME_STATE': [], 'PLAYING_VS_SAMPLING': []}
_bel_facs = {'GAME_STATE_OBS': [], 'GAME_OUTCOME': [], 'ACTION_SELF_OBS': []}
_obs_mods ## min_F = []
= []
_qIpiIs = []
_GNegs for t in range(_T):
print(f"\nTime {t}:")
print(f"State: {[(_sta_fac_names[sfi], _state_idx_names[sfi][_state[sfi]]) for sfi in range(len(_sta_fac_names))]}") ##.
for sfi, sfn in enumerate(_sta_fac_names):
_sta_facs[sfn].append(_state_idx_names[sfi][_state[sfi]])print(f"Obser: {[(_obs_mod_names[omi], _obser_idx_names[omi][_obser[omi]]) for omi in range(len(_obs_mod_names))]}") ##.
for omi, omn in enumerate(_obs_mod_names):
_obs_mods[omn].append(_obser_idx_names[omi][_obser[omi]])
## update agent
= _agent.infer_states(_obser) ##.
belief_state ## _agent.infer_policies()
= _agent.infer_policies() ##.posterior over policies and negative EFE
qIpiI, GNeg print(f'{qIpiI=}')
print(f'{GNeg=}')
_qIpiIs.append(qIpiI)
_GNegs.append(GNeg)= _agent.sample_action()
action ## min_F.append(np.min(_agent.F)) ##.does not have .F
## update environment
for sfi, sf in enumerate(_state):
= utils.sample(_B̆[sfi][:, sf, int(action[sfi])]) ##.
_state[sfi] for omi, _ in enumerate(_obser): ##.
= utils.sample(_Ă[omi][:, _state[0], _state[1]]) ##.
_obser[omi]
print(f"Beliefs: {[(_sta_fac_names[sfi], belief_state[sfi].round(3).T) for sfi in range(len(_sta_fac_names))]}") ##.
for sfi, sfn in enumerate(_sta_fac_names):
'sta'][sfn][int(np.argmax(belief_state[sfi].round(3).T))] )
_bel_facs[sfn].append( _lab[print(f"Action: {[(_ctr_fac_names[a], _action_idx_names[a][int(action[a])]) for a in range(len(_sta_fac_names))]}") ##.
for cfi, cfn in enumerate(_ctr_fac_names):
int(action[cfi])]) _ctr_facs[cfn].append(_action_idx_names[cfi][
Time 0:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'START')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'START_O')]
qIpiI=array([0.03478899, 0.20528675, 0.75992426])
GNeg=array([-3.60483043, -3.49388625, -3.41208556])
Beliefs: [('GAME_STATE', array([0.5, 0.5])), ('PLAYING_VS_SAMPLING', array([1., 0., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]
Time 1:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'LOW_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.1162517 , 0.00435522, 0.87939308])
GNeg=array([-3.60483043, -3.81010428, -3.47836328])
Beliefs: [('GAME_STATE', array([0.2, 0.8])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]
Time 2:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'HIGH_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.03478899, 0.20528675, 0.75992426])
GNeg=array([-3.60483043, -3.49388625, -3.41208556])
Beliefs: [('GAME_STATE', array([0.5, 0.5])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]
Time 3:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'LOW_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.1162517 , 0.00435522, 0.87939308])
GNeg=array([-3.60483043, -3.81010428, -3.47836328])
Beliefs: [('GAME_STATE', array([0.2, 0.8])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]
Time 4:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'LOW_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.3264811 , 0.00071881, 0.6728001 ])
GNeg=array([-3.60483043, -3.98723886, -3.55963817])
Beliefs: [('GAME_STATE', array([0.059, 0.941])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]
Time 5:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'HIGH_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.1162517 , 0.00435522, 0.87939308])
GNeg=array([-3.60483043, -3.81010428, -3.47836328])
Beliefs: [('GAME_STATE', array([0.2, 0.8])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]
Time 6:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'HIGH_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.03478899, 0.20528675, 0.75992426])
GNeg=array([-3.60483043, -3.49388625, -3.41208556])
Beliefs: [('GAME_STATE', array([0.5, 0.5])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]
Time 7:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'LOW_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.1162517 , 0.00435522, 0.87939308])
GNeg=array([-3.60483043, -3.81010428, -3.47836328])
Beliefs: [('GAME_STATE', array([0.2, 0.8])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]
Time 8:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'HIGH_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.03478899, 0.20528675, 0.75992426])
GNeg=array([-3.60483043, -3.49388625, -3.41208556])
Beliefs: [('GAME_STATE', array([0.5, 0.5])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'SAMPLE_ACTION')]
Time 9:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'SAMPLING')]
Obser: [('GAME_STATE_OBS', 'HIGH_REW_EVIDENCE'), ('GAME_OUTCOME', 'NEUTRAL'), ('ACTION_SELF_OBS', 'SAMPLE_O')]
qIpiI=array([0.00362533, 0.96895062, 0.02742404])
GNeg=array([-3.60483043, -3.25556369, -3.47836328])
Beliefs: [('GAME_STATE', array([0.8, 0.2])), ('PLAYING_VS_SAMPLING', array([0., 0., 1.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
Time 10:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([0.00121169, 0.9954936 , 0.00329471])
GNeg=array([-3.60483043, -3.18537893, -3.54231138])
Beliefs: [('GAME_STATE', array([0.916, 0.084])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
Time 11:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'PUN'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([0.00362533, 0.96895062, 0.02742404])
GNeg=array([-3.60483043, -3.25556369, -3.47836328])
Beliefs: [('GAME_STATE', array([0.8, 0.2])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
Time 12:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([0.00121169, 0.9954936 , 0.00329471])
GNeg=array([-3.60483043, -3.18537893, -3.54231138])
Beliefs: [('GAME_STATE', array([0.916, 0.084])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
Time 13:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'PUN'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([0.00362533, 0.96895062, 0.02742404])
GNeg=array([-3.60483043, -3.25556369, -3.47836328])
Beliefs: [('GAME_STATE', array([0.8, 0.2])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
Time 14:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([0.00121169, 0.9954936 , 0.00329471])
GNeg=array([-3.60483043, -3.18537893, -3.54231138])
Beliefs: [('GAME_STATE', array([0.916, 0.084])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
Time 15:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([7.89044066e-04, 9.98014004e-01, 1.19695225e-03])
GNeg=array([-3.60483043, -3.15841165, -3.57878595])
Beliefs: [('GAME_STATE', array([0.967, 0.033])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
Time 16:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([6.73423691e-04, 9.98535582e-01, 7.90994555e-04])
GNeg=array([-3.60483043, -3.14847603, -3.59477315])
Beliefs: [('GAME_STATE', array([0.988, 0.012])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
Time 17:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([6.35368381e-04, 9.98689779e-01, 6.74852661e-04])
GNeg=array([-3.60483043, -3.14483077, -3.60106234])
Beliefs: [('GAME_STATE', array([0.995, 0.005])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
Time 18:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'PUN'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([6.73423691e-04, 9.98535582e-01, 7.90994555e-04])
GNeg=array([-3.60483043, -3.14847603, -3.59477315])
Beliefs: [('GAME_STATE', array([0.988, 0.012])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
Time 19:
State: [('GAME_STATE', 'HIGH_REW'), ('PLAYING_VS_SAMPLING', 'PLAYING')]
Obser: [('GAME_STATE_OBS', 'NO_EVIDENCE'), ('GAME_OUTCOME', 'REWARD'), ('ACTION_SELF_OBS', 'PLAY_O')]
qIpiI=array([6.35368381e-04, 9.98689779e-01, 6.74852661e-04])
GNeg=array([-3.60483043, -3.14483077, -3.60106234])
Beliefs: [('GAME_STATE', array([0.995, 0.005])), ('PLAYING_VS_SAMPLING', array([0., 1., 0.]))]
Action: [('NULL', 'NULL_ACTION'), ('PLAYING_VS_SAMPLING_CONTROL', 'PLAY_ACTION')]
= [
colors 'NULL_ACTION':'black'},
{'START_ACTION':'red', 'PLAY_ACTION':'green', 'SAMPLE_ACTION': 'blue'},
{
'HIGH_REW':'orange', 'LOW_REW':'purple'},
{'START':'red', 'PLAYING':'green', 'SAMPLING': 'blue'},
{
'HIGH_REW':'orange', 'LOW_REW':'purple'},
{'START':'red', 'PLAYING':'green', 'SAMPLING': 'blue'},
{
'HIGH_REW_EVIDENCE':'orange', 'LOW_REW_EVIDENCE':'purple', 'NO_EVIDENCE':'pink'},
{'REWARD':'red', 'PUN':'green', 'NEUTRAL': 'blue'},
{'START_O':'red', 'PLAY_O':'green', 'SAMPLE_O': 'blue'}
{
]
= 12
ylabel_size = 7 ##markersize for Line2D, diameter in points
msi = (msi/2)**2 * np.pi ##size for scatter, area of marker in points squared
siz
= plt.figure(figsize=(9, 6))
fig ## gs = GridSpec(6, 1, figure=fig, height_ratios=[1, 3, 1, 3, 3, 1])
= GridSpec(9, 1, figure=fig, height_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1])
gs = [fig.add_subplot(gs[i]) for i in range(9)]
ax
= 0
i f'Agent Demo', fontweight='bold',fontsize=14)
ax[i].set_title(= 0
y_pos for t, s in zip(range(_T), _ctr_facs['NULL']):
=colors[i][s], s=siz)
ax[i].scatter(t, y_pos, color'$a^{NULL}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_ylabel(
ax[i].set_yticks([])
ax[i].set_xticklabels([])= [
leg_items 0],[0],marker='o',color='w',markerfacecolor='black',markersize=msi,label='NULL_ACTION')]
Line2D([=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].legend(handles'top'].set_visible(False); ax[i].spines['right'].set_visible(False)
ax[i].spines[
= 1
i = 0
y_pos for t, s in zip(range(_T), _ctr_facs['PLAYING_VS_SAMPLING_CONTROL']):
=colors[i][s], s=siz, label=s)
ax[i].scatter(t, y_pos, color'$a^{PLAYING\_VS\_SAMPLING\_CONTROL}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_ylabel(
ax[i].set_yticks([])= [
leg_items 0],[0],marker='o',color='w',markerfacecolor=colors[i]['START_ACTION'],markersize=msi,label='START_ACTION'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['PLAY_ACTION'],markersize=msi,label='PLAY_ACTION'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['SAMPLE_ACTION'],markersize=msi,label='SAMPLE_ACTION')]
Line2D([=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].legend(handles'top'].set_visible(False); ax[i].spines['right'].set_visible(False)
ax[i].spines[
= 2
i = 0
y_pos for t, s in zip(range(_T), _sta_facs['GAME_STATE']):
=colors[i][s], s=siz)
ax[i].scatter(t, y_pos, color'$s^{GAME\_STATE}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_ylabel(
ax[i].set_yticks([])
ax[i].set_xticklabels([])= [
leg_items 0],[0],marker='o',color='w',markerfacecolor=colors[i]['HIGH_REW'],markersize=msi,label='HIGH_REW'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['LOW_REW'],markersize=msi,label='LOW_REW')]
Line2D([=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].legend(handles'top'].set_visible(False); ax[i].spines['right'].set_visible(False)
ax[i].spines[
= 3
i = 0
y_pos for t, s in zip(range(_T), _sta_facs['PLAYING_VS_SAMPLING']):
=colors[i][s], s=siz, label=s)
ax[i].scatter(t, y_pos, color'$s^{PLAYING\_VS\_SAMPLING}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_ylabel(
ax[i].set_yticks([])= [
leg_items 0],[0],marker='o',color='w',markerfacecolor=colors[i]['START'],markersize=msi,label='START'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['PLAYING'],markersize=msi,label='PLAYING'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['SAMPLING'],markersize=msi,label='SAMPLING')]
Line2D([=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].legend(handles'top'].set_visible(False); ax[i].spines['right'].set_visible(False)
ax[i].spines[
= 4
i = 0
y_pos for t, s in zip(range(_T), _sta_facs['GAME_STATE']):
=colors[i][s], s=siz)
ax[i].scatter(t, y_pos, color'$q(s)^{GAME\_STATE}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_ylabel(
ax[i].set_yticks([])
ax[i].set_xticklabels([])= [
leg_items 0],[0],marker='o',color='w',markerfacecolor=colors[i]['HIGH_REW'],markersize=msi,label='HIGH_REW'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['LOW_REW'],markersize=msi,label='LOW_REW')]
Line2D([=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].legend(handles'top'].set_visible(False); ax[i].spines['right'].set_visible(False)
ax[i].spines[
= 5
i = 0
y_pos for t, s in zip(range(_T), _sta_facs['PLAYING_VS_SAMPLING']):
=colors[i][s], s=siz, label=s)
ax[i].scatter(t, y_pos, color'$q(s)^{PLAYING\_VS\_SAMPLING}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_ylabel(
ax[i].set_yticks([])= [
leg_items 0],[0],marker='o',color='w',markerfacecolor=colors[i]['START'],markersize=msi,label='START'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['PLAYING'],markersize=msi,label='PLAYING'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['SAMPLING'],markersize=msi,label='SAMPLING')]
Line2D([=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].legend(handles'top'].set_visible(False); ax[i].spines['right'].set_visible(False)
ax[i].spines[
= 6
i = 0
y_pos for t, s in zip(range(_T), _obs_mods['GAME_STATE_OBS']):
=colors[i][s], s=siz, label=s)
ax[i].scatter(t, y_pos, color'$o^{GAME\_STATE\_OBS}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_ylabel(
ax[i].set_yticks([])= [
leg_items 0],[0],marker='o',color='w',markerfacecolor=colors[i]['HIGH_REW_EVIDENCE'],markersize=msi,label='HIGH_REW_EVIDENCE'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['LOW_REW_EVIDENCE'],markersize=msi,label='LOW_REW_EVIDENCE'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['NO_EVIDENCE'],markersize=msi,label='NO_EVIDENCE')]
Line2D([=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].legend(handles'top'].set_visible(False); ax[i].spines['right'].set_visible(False)
ax[i].spines[
= 7
i = 0
y_pos for t, s in zip(range(_T), _obs_mods['GAME_OUTCOME']):
=colors[i][s], s=siz, label=s)
ax[i].scatter(t, y_pos, color'$o^{GAME\_OUTCOME}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_ylabel(
ax[i].set_yticks([])= [
leg_items 0],[0],marker='o',color='w',markerfacecolor=colors[i]['REWARD'],markersize=msi,label='REWARD'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['PUN'],markersize=msi,label='PUN'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['NEUTRAL'],markersize=msi,label='NEUTRAL')]
Line2D([=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].legend(handles'top'].set_visible(False); ax[i].spines['right'].set_visible(False)
ax[i].spines[
= 8
i = 0
y_pos for t, s in zip(range(_T), _obs_mods['ACTION_SELF_OBS']):
=colors[i][s], s=siz, label=s)
ax[i].scatter(t, y_pos, color'$o^{ACTION\_SELF\_OBS}_t$', rotation=0, fontweight='bold', fontsize=ylabel_size)
ax[i].set_ylabel(
ax[i].set_yticks([])= [
leg_items 0],[0],marker='o',color='w',markerfacecolor=colors[i]['START_O'],markersize=msi,label='START_O'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['PLAY_O'],markersize=msi,label='PLAY_O'),
Line2D([0],[0],marker='o',color='w',markerfacecolor=colors[i]['SAMPLE_O'],markersize=msi,label='SAMPLE_O')]
Line2D([=leg_items, bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, labelspacing=0.1)
ax[i].legend(handles'top'].set_visible(False); ax[i].spines['right'].set_visible(False)
ax[i].spines[
=True))
ax[i].xaxis.set_major_locator(MaxNLocator(integer## ax[i].xaxis.set_major_locator(MaxNLocator(nbins=10, integer=True))
'$\mathrm{time,}\ t$', fontweight='bold', fontsize=12)
ax[i].set_xlabel(
plt.tight_layout()=0.1) ## Adjust this value as needed
plt.subplots_adjust(hspace plt.show()