RxInfer.jl is a Julia package designed for automated Bayesian inference using reactive message passing on factor graphs[1]. This powerful tool offers several advantages over traditional frequentist statistical methods:
Efficiency and Accuracy
RxInfer.jl often outperforms general-purpose probabilistic programming packages in terms of computational load, speed, memory usage, and accuracy[1]. It achieves this by leveraging conjugate likelihood-prior pairings in models, which have analytical posteriors known to RxInfer.jl. This approach allows for faster and more precise inference, especially in models with conjugate relationships.
Scalability
The package demonstrates superior scalability compared to other methods, particularly in complex models. This makes RxInfer.jl an excellent choice for handling large-scale Bayesian inference tasks that might be computationally prohibitive with traditional frequentist approaches[1].
Flexibility
While RxInfer.jl excels in conjugate models, it also supports non-conjugate inference and is continuously expanding to accommodate a broader class of models[1]. This flexibility allows researchers and data scientists to tackle a wide range of probabilistic modeling problems within a single framework.
Bayesian Advantages
As a Bayesian tool, RxInfer.jl inherently provides several benefits over frequentist methods:
It allows for the incorporation of prior knowledge into the analysis, which can be particularly useful when dealing with small sample sizes[3].
It provides a more intuitive interpretation of results through posterior distributions, rather than relying on p-values and confidence intervals that are often misinterpreted[3].
It offers a natural framework for updating beliefs as new data becomes available, making it ideal for iterative experimentation and decision-making processes[3].
By leveraging these advantages, RxInfer.jl empowers users to perform sophisticated Bayesian inference tasks with greater ease and efficiency than traditional frequentist tools.
This notebook is a very limited introduction to the capabilities of RxInfer. We look at a simple linear regression problem and we make use of some synthetic data. In this problem our need is to estimate the parameters of a linear regression problem making use of Bayesian regression and RxInfer. We also compare the estimated values to the true values of the parameters from the data generation process.
The s values (‘states’) are the heights of a male cohort measured in centimeters, and the y values are the associated observed weights in kilograms. We assume that the measurement/observation noise is known.
More formally:
The data set consists of \(N\) pairs of covariates (albeit scalars) and observations
\(\mathcal{D} = \{ (s_n, y_n), \forall n \in \{1:N\}\), where a covariate \(s_n \in \mathbb{R}\), and its associated observation \(y_n \in \mathbb{R}\).
The observation model is defined as \[
y = f_r(s) + v
\] where \(V \sim \mathcal{N}(v \mid 0, \sigma^2_V)\) and \(v\) is referred to as the observation errors. \(f_r\) is the response function. Covariates \(s\) are assumed to be exact and free from any system errors, ususally denoted as \(w\).
We therefor have a set of observations \(\mathcal{D} = \{y_1, y_2, ..., y_{\texttt{D}} \}\) that are realizations of the random variables \(\mathbf{Y} = [Y_1, Y_2, ..., Y_{\texttt{D}}]^T \sim p(\mathbf y)\).
We model the weight \(y_n\in\mathbb{R}\) as a normal distribution and treat \(s_n\) as a fixed hyperparameter:
The height is denoted as \(s_n \in \mathbb{R}\) and the recorded/observed weight as \(y_n \in \mathbb{R}\). Prior beliefs on \(\beta_1\) and \(\beta_0\) are derived from available medical records.
The random scalars (aka random variables) are defined as \(B_1 \sim p(\beta_1)\) and \(B_0 \sim p(\beta_0)\) so that
In combination, we have the probabilistic model \[p(\mathbf{y}, \beta_1, \beta_0) = p(\beta_1)p(\beta_0) \prod_{n=1}^N p(\mathbf{y}_n \mid \beta_1, \beta_0),\] where the goal is to infer the posterior distributions \(p(\beta_1 \mid \mathbf{y})\) and \(p(\beta_0 \mid \mathbf{y})\).
versioninfo() ## Julia version
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 12 × Intel(R) Core(TM) i7-8700B CPU @ 3.20GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
Threads: 1 default, 0 interactive, 1 GC (on 12 virtual cores)
Environment:
JULIA_NUM_THREADS =
Updating registry at `~/.julia/registries/General.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Pkg.status()
Status `~/.julia/environments/v1.10/Project.toml`
[336ed68f] CSV v0.10.14
[a93c6f00] DataFrames v1.7.0
[38e38edf] GLM v1.9.0
[b964fa9f] LaTeXStrings v1.3.1
[91a5bcdd] Plots v1.40.8
⌃ [86711068] RxInfer v3.7.0
[860ef19b] StableRNGs v1.0.2
[f3b207a7] StatsPlots v0.15.7
[37e2e46d] LinearAlgebra
Info Packages marked with ⌃ have new versions available and may be upgradable.
Solution
Next we define some functions that will allow the creation of the data.
_β̃₁ =0.44## 0.44 kg/cm_β̃₀ =-10.0##kg \beta[tab]\tilde[tab]\_0[tab]_σ̃²_V =3.5## _N = 171 ## number of samples_start =50.0_step =1.0_stop =220.0
220.0
Provision function (\(f_p\))
The provision function, provides another covariate vector. For this simple problem a covariate vector consists of a single scalar. In addition, we only create a single batch of covariates.
The tildes indicate that the parameters and variables are hidden and not observed.
## provision function, provides next batch of covariate vectors## covariate vector here is just a scalarfunctionfₚ(;start, step, stop)returnfloat.(collect(start:step:stop))## tmp = float.(collect(1:N))## return [[i] for i in tmp]endfₚ(start=50, step=1.0, stop=220)
The response function, provides the response to a covariate vector. For this simple problem a covariate vector consists of a single scalar. In addition, we only create a single batch of responses.
\[\tilde{r}_{i} = f_{r}(\tilde{s}_{i}) = \tilde{\beta_1} \tilde{s}_{i} + \tilde{\beta}_0\]\[y_{i} = f_{r}(\tilde{s}_{i}) + \tilde{v} = \tilde{\beta}_1 \tilde{s}_{i} + \tilde{\beta_0} + \tilde{v}\] where the observation noise is
The tildes indicate that the parameters and variables are hidden and not observed.
## response function, provides next batch of responses to covariate vectors## covariate vector here is just a scalarfunctionfᵣ(β̃₁, β̃₀, s̃)return β̃₁ .* s̃ .+ β̃₀## return [β̃₁*i[1] + β̃₀ for i in s̃]end
fᵣ (generic function with 1 method)
The data generation function
Data can be sourced from a simulation or the field. Here the data is simulated/synthesized. A tilde over a symbol indicates a real/true value, i.e. not estimated/random.
## sim|lab OR fld _ batch OR point|obser## batch is along 'depth/examples' dimensionfunctionsim_batch_data(β̃₁, β̃₀, σ̃²_V, start, step, stop; rng=StableRNG(1234)) p̃ₜ =fₚ(start=start, step=step, stop=stop) s̃ₜ = p̃ₜ ## no system noise r̃ₜ =fᵣ(β̃₁, β̃₀, s̃ₜ) v =randn(rng, length(s̃ₜ)) .*sqrt(σ̃²_V) yₜ = r̃ₜ + vreturn s̃ₜ, yₜend;
scatter( _s̃ₜ, _yₜ, title="Weight vs Height for the male cohort", color="orange", legend=false)## scatter([i[1] for i in _s̃ₜ], _yₜ, title="Weight vs Height for the male cohort", legend=false)xlabel!("Height [cm]")ylabel!("Weight [kg]")
The free energy plot shows good convergence performance:
plot(2:_its, ## drop first iteration coz it is influenced by 'initmessages' results.free_energy[2:end], title="Free energy", slabel="Iteration", ylabel="Free energy [nats]", legend=false)
Next we plot the priors and posteriors of both parameters: