Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
where \(\mathbf{s}_t\) are hidden states, \(\mathbf{y}_t\) are noisy observations, \(\mathbf{\breve{B}}\), \(\mathbf{\breve{A}}\) are state transition and observation matrices, \(\mathbf{Q}\) and \(\mathbf{R}\) are state transition noise and observation noise covariance matrices.
To make things more interesting, we will use a state space model that is subject to rotation in 3 dimensions, i.e. along the x, y, and z axes. This is a common use case for flying aircraft that rotate relative to a ground-based frame of coordinates. Once we have the transition and observation matrices for rotation around all 3 axes, we will use these matrices for the final state space model. Eventually, we will perform bayesian multivariate inference of the random vector \(\mathbf{s}_t\). Note that this state space model does not capture the complete dynamics of a rotating object - only the location is included.
The provision function, provides another covariate vector. Because this is a sequential system, the provision function defines the transition between the previous state and the current state. This special case of the provision function is known as a transition function and it returns a provision/pre-state:
The response function, provides the response to the covariate vector, called the response: \[\tilde{\mathbf{r}}_{t} = f_r(\tilde{\mathbf{s}}_{t}) = f_A(\tilde{\mathbf{s}}_{t}) = \tilde{\mathbf{A}} \tilde{\mathbf{s}}_{t} \]
After combining with observation noise the observation is produced:
## Data comes from either a simulation/lab (sim|lab) OR from the field (fld)## Data are handled either in batches (batch) OR online as individual points (point)## Batch data accumulates either## along the depth/examples dimension/axis (into the screen/page), OR## typical for supervised & unsupervised learning## along the time dimension/axis (down the screen page)## typical for sequential decision learning (reinforcement learning & active inference)functionsim_batch_data(rng, T, B̃, Ã, Q̃, R̃) ## simulated batch data s̃ₜ₋₁ = _s̃₀ p̃ =Vector{Vector{Float64}}(undef, T) s̃ =Vector{Vector{Float64}}(undef, T) r̃ =Vector{Vector{Float64}}(undef, T) y =Vector{Vector{Float64}}(undef, T)for t in1:T## p̃[t] = B̃*s̃ₜ₋₁ p̃[t] =fₚ(B̃=B̃, s̃ₜ₋₁=s̃ₜ₋₁) s̃[t] =rand(rng, MvNormal(p̃[t], Q̃))## r̃[t] = Ã*s̃[t] r̃[t] =fᵣ(Ã=Ã, s̃ₜ=s̃[t]) y[t] =rand(rng, MvNormal(r̃[t], R̃)) s̃ₜ₋₁ = s̃[t]endreturn s̃, yend
Let’s visualize the synthesized dataset. Lines represent the hidden states that need to be estimated/inferred. We only have acces to noisy observations which are represented as dots.
_p =plot(title="Hidden states with noisy observations")_p =plot!(_p, getindex.(_s̃, 1), label="Hidden Signal "* L"\tilde{s}_1", color=:red)_p =scatter!(_p, getindex.(_y, 1), label=false, markersize=2, color=:red)_p =plot!(_p, getindex.(_s̃, 2), label="Hidden Signal "* L"\tilde{s}_2", color=:green)_p =scatter!(_p, getindex.(_y, 2), label=false, markersize=2, color=:green)_p =plot!(_p, getindex.(_s̃, 3), label="Hidden Signal "* L"\tilde{s}_3", color=:blue)_p =scatter!(_p, getindex.(_y, 3), label=false, markersize=2, color=:blue)plot(_p)
The Generative Model
We now use RxInfer:
@modelfunctionrotate_ssm(y, s₀, B, A, Q, R) s_prior ~MvNormalMeanCovariance(mean(s₀), cov(s₀)) sₜ₋₁ = s_priorfor t in1:length(y) s[t] ~MvNormalMeanCovariance(B*sₜ₋₁, Q) ## `s` is a sequence of hidden states##- s[t] ~ MvNormalMeanCovariance(f(B= B, sₜ₋₁= sₜ₋₁), Q) ## `s` is a sequence of hidden states y[t] ~MvNormalMeanCovariance(A*s[t], R) ## `y` is a sequence of "clamped" observations sₜ₋₁ = s[t]endend
## We assume the _B̃, _Ã, _Q̃, _R̃ are known, i.e. not hidden, even though the tildes## in their names indicate that they are hidden _result =infer( model=rotate_ssm(s₀=_s̃₀, B=_B̃, A=_Ã, Q=_Q̃, R=_R̃), data= (y = _y,), free_energy=true);
_p =plot(title="Estimated states from noisy observations")_p =plot!(_p, getindex.(_s̃, 1), label="Hidden Signal "* L"\tilde{s}_1", color=:red, linestyle=:dash)_p =plot!(_p, getindex.(_s̃, 2), label="Hidden Signal "* L"\tilde{s}_2", color=:green, linestyle=:dash)_p =plot!(_p, getindex.(_s̃, 3), label="Hidden Signal "* L"\tilde{s}_3", color=:blue, linestyle=:dash)_p =plot!(_p, getindex.(mean.(_smarginals), 1), ribbon=getindex.(var.(_smarginals), 1) .|> sqrt, fillalpha=0.5, label="Estimated Signal "* L"s_1", color=:pink)_p =plot!(_p, getindex.(mean.(_smarginals), 2), ribbon=getindex.(var.(_smarginals), 2) .|> sqrt, fillalpha=0.5, label="Estimated Signal "* L"s_2", color=:lightgreen)_p =plot!(_p, getindex.(mean.(_smarginals), 3), ribbon=getindex.(var.(_smarginals), 3) .|> sqrt, fillalpha=0.5, label="Estimated Signal "* L"s_3", color=:lightblue)plot(_p)
As we can see from our plot, estimated signal resembles closely to the real hidden states with small variance. We maybe also interested in the value for minus log evidence:
## given the analytical solution, the free energy will be equal to the negative log evidence_logevidence =-_result.free_energy; _logevidence