Using RxInfer to perform Bayesian regression

Learning parameters for a linear regression model

Medical Industry
Bayesian Inference
Bayesian Regression
RxInfer
Julia
Author

Kobus Esterhuysen

Published

October 1, 2024

Modified

October 1, 2024

RxInfer for Bayesian Linear Regression

RxInfer.jl is a Julia package designed for automated Bayesian inference using reactive message passing on factor graphs[1]. This powerful tool offers several advantages over traditional frequentist statistical methods:

Efficiency and Accuracy

RxInfer.jl often outperforms general-purpose probabilistic programming packages in terms of computational load, speed, memory usage, and accuracy[1]. It achieves this by leveraging conjugate likelihood-prior pairings in models, which have analytical posteriors known to RxInfer.jl. This approach allows for faster and more precise inference, especially in models with conjugate relationships.

Scalability

The package demonstrates superior scalability compared to other methods, particularly in complex models. This makes RxInfer.jl an excellent choice for handling large-scale Bayesian inference tasks that might be computationally prohibitive with traditional frequentist approaches[1].

Flexibility

While RxInfer.jl excels in conjugate models, it also supports non-conjugate inference and is continuously expanding to accommodate a broader class of models[1]. This flexibility allows researchers and data scientists to tackle a wide range of probabilistic modeling problems within a single framework.

Bayesian Advantages

As a Bayesian tool, RxInfer.jl inherently provides several benefits over frequentist methods:

  1. It allows for the incorporation of prior knowledge into the analysis, which can be particularly useful when dealing with small sample sizes[3].
  2. It provides a more intuitive interpretation of results through posterior distributions, rather than relying on p-values and confidence intervals that are often misinterpreted[3].
  3. It offers a natural framework for updating beliefs as new data becomes available, making it ideal for iterative experimentation and decision-making processes[3].

By leveraging these advantages, RxInfer.jl empowers users to perform sophisticated Bayesian inference tasks with greater ease and efficiency than traditional frequentist tools.

Citations:

[1] https://github.com/ReactiveBayes/RxInfer.jl

[2] https://discourse.datamethods.org/t/are-there-situations-in-which-a-frequentist-approach-is-inherently-better-than-bayesian/7172

[3] https://www.statsig.com/perspectives/bayesian-or-frequentist-choosing-your-statistical-approach

[4] https://www.reddit.com/r/math/comments/kx3sno/frequentist_statistics_vs_bayesian_and_machine/

[5] https://towardsdatascience.com/statistics-are-you-bayesian-or-frequentist-4943f953f21b?gi=dfa0db9146cb

[6] https://amplitude.com/blog/frequentist-vs-bayesian-statistics-methods

This notebook is a very limited introduction to the capabilities of RxInfer. We look at a simple linear regression problem and we make use of some synthetic data. In this problem our need is to estimate the parameters of a linear regression problem making use of Bayesian regression and RxInfer. We also compare the estimated values to the true values of the parameters from the data generation process.

The x values are the heights of a male cohort measured in centimeters, and the y values are the associated weights in kilograms. We assume that the measurement/observation noise is known.

More formally:

The data set consists of \(N\) pairs of covariates and observations

\(\mathcal{D} = \{ (x_n, y_n), \forall n \in \{1:N\}\), where a covariate \(x_n \in \mathbb{R}\), and its associated observation \(y_n \in \mathbb{R}\).

The observation model is defined as \[ y = g(x) + v \] where \(V \sim \mathcal{N}(v \mid 0, \sigma^2_V)\) and \(v\) is referred to as the observation errors. Covariates \(x\) are assumed to be exact and free from any observation errors.

We therefor have a set of observations \(\mathcal{D} = \{y_1, y_2, ..., y_{\texttt{D}} \}\) that are realizations of the random variables \(\mathbf{Y} = [Y_1, Y_2, ..., Y_{\texttt{D}}]^T \sim p(\mathbf y)\).

We model the weight \(y_n\in\mathbb{R}\) as a normal distribution and treat \(x_n\) as a fixed hyperparameter:

\[\begin{aligned} p(y_n \mid \beta_1, \beta_0) = \mathcal{N}(y_n \mid \beta_1 x_n + \beta_0, \sigma^2_V) \end{aligned}\]

where \(\sigma^2_V = 1\).

The observed height is denoted as \(x_n \in \mathbb{R}\) and the recorded weight as \(y_n \in \mathbb{R}\). Prior beliefs on \(\beta_1\) and \(\beta_0\) are derived from available medical records.

The random scalars (aka random variables) are defined as \(B_1 \sim p(\beta_1)\) and \(B_0 \sim p(\beta_0)\) so that

\[\begin{aligned} p(\beta_1) &= \mathcal{N}(\beta_1 \mid m_{B_1}, v_{B_1}) \\ p(\beta_0) &= \mathcal{N}(\beta_0 \mid m_{B_0}, v_{B_0}) \end{aligned}\]

In combination, we have the probabilistic model \[p(\mathbf{y}, \beta_1, \beta_0) = p(\beta_1)p(\beta_0) \prod_{n=1}^N p(\mathbf{y}_n \mid \beta_1, \beta_0),\] where the goal is to infer the posterior distributions \(p(\beta_1 \mid \mathbf{y})\) and \(p(\beta_0 \mid \mathbf{y})\).

versioninfo() ## Julia version
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 12 × Intel(R) Core(TM) i7-8700B CPU @ 3.20GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
Threads: 1 default, 0 interactive, 1 GC (on 12 virtual cores)
Environment:
  JULIA_NUM_THREADS = 
import Pkg
Pkg.add(Pkg.PackageSpec(;name="RxInfer"))
Pkg.add("Plots")
Pkg.add("StableRNGs")
Pkg.add("LinearAlgebra")
Pkg.add("StatsPlots")
Pkg.add("LaTeXStrings")
Pkg.add("DataFrames")
Pkg.add("CSV")
Pkg.add("GLM")

using RxInfer, Random, Plots, StableRNGs, LinearAlgebra, StatsPlots, LaTeXStrings, DataFrames, CSV, GLM
   Resolving package versions...
  No Changes to `~/.julia/environments/v1.10/Project.toml`
  No Changes to `~/.julia/environments/v1.10/Manifest.toml`
   Resolving package versions...
  No Changes to `~/.julia/environments/v1.10/Project.toml`
  No Changes to `~/.julia/environments/v1.10/Manifest.toml`
   Resolving package versions...
  No Changes to `~/.julia/environments/v1.10/Project.toml`
  No Changes to `~/.julia/environments/v1.10/Manifest.toml`
   Resolving package versions...
  No Changes to `~/.julia/environments/v1.10/Project.toml`
  No Changes to `~/.julia/environments/v1.10/Manifest.toml`
   Resolving package versions...
  No Changes to `~/.julia/environments/v1.10/Project.toml`
  No Changes to `~/.julia/environments/v1.10/Manifest.toml`
   Resolving package versions...
  No Changes to `~/.julia/environments/v1.10/Project.toml`
  No Changes to `~/.julia/environments/v1.10/Manifest.toml`
   Resolving package versions...
  No Changes to `~/.julia/environments/v1.10/Project.toml`
  No Changes to `~/.julia/environments/v1.10/Manifest.toml`
   Resolving package versions...
  No Changes to `~/.julia/environments/v1.10/Project.toml`
  No Changes to `~/.julia/environments/v1.10/Manifest.toml`
   Resolving package versions...
  No Changes to `~/.julia/environments/v1.10/Project.toml`
  No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Pkg.status()
Status `~/.julia/environments/v1.10/Project.toml`
  [336ed68f] CSV v0.10.14
  [a93c6f00] DataFrames v1.7.0
  [38e38edf] GLM v1.9.0
  [b964fa9f] LaTeXStrings v1.3.1
  [91a5bcdd] Plots v1.40.8
  [86711068] RxInfer v3.7.0
  [860ef19b] StableRNGs v1.0.2
  [f3b207a7] StatsPlots v0.15.7
  [37e2e46d] LinearAlgebra

Solution

Next we define some functions that will allow the creation of the data.

The forward function \(f()\)

The forward function \(f()\) provides the next batch of covariate vectors. For this simple problem a covariate vector consists of a single scalar. In addition, we only create a single batch of covariates.

## forward function, provides next batch of covariate vectors
## covariate vector here is just a scalar
function f(;start, step, stop)
    return float.(collect(start:step:stop))
    ## tmp = float.(collect(1:N))
    ## return [[i] for i in tmp]
end
f(start=50, step=1.0, stop=220)
171-element Vector{Float64}:
  50.0
  51.0
  52.0
  53.0
  54.0
  55.0
  56.0
  57.0
  58.0
  59.0
   ⋮
 212.0
 213.0
 214.0
 215.0
 216.0
 217.0
 218.0
 219.0
 220.0

The response function \(g()\)

The response function \(g()\) provides the next batch of responses to covariate vectors. For this simple problem a covariate vector consists of a single scalar. In addition, we only create a single batch of responses.

## response function, provides next batch of responses to covariate vectors
## covariate vector here is just a scalar
function g(a, b, x)
    return a .* x  .+  b
    ## return [a*i[1] + b for i in x]
end
g (generic function with 1 method)

The observation noise function \(v()\)

The observation noise function \(v()\) provides a batch of realizations for the random scalar \(V\). We assume the variance of \(V\) is known and given by \(\tilde{\sigma}^2_V\).

## observation noise function
function v(σ̃²_V, N, rng)
    return randn(rng, N) .* sqrt(σ̃²_V)
end
v (generic function with 1 method)

The data generation function

Data can be sourced from a simulation or the field. Here the data is simulated/synthesized. A tilde over a symbol indicates a real/true value, i.e. not estimated/random.

## sim|lab OR fld _ batch OR point|obser
function sim_batch_data(β̃₁, β̃₀, σ̃²_V, start, step, stop; rng=StableRNG(1234))
    x = f(start=start, step=step, stop=stop)
    y = g(β̃₁, β̃₀, x) .+ v(σ̃²_V, length(x), rng)
    return x, y
end;

Data generation

_β̃₁ = 0.44 ## 0.44 kg/cm
_β̃₀ = -10.0 ##kg \beta[tab]\tilde[tab]\_0[tab]
_σ̃²_V = 3.5
## _N = 171 ## number of samples
_start = 50.0
_step = 1.0
_stop = 220.0
_x_data, _y_data = sim_batch_data(_β̃₁, _β̃₀, _σ̃²_V, _start, _step, _stop)
([50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0  …  211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, 219.0, 220.0], [12.971885307398942, 14.133361372536243, 9.745125025388765, 10.8977755219049, 12.479669440354387, 12.916329854098995, 14.812558469286401, 14.91230442338781, 12.80336314053027, 12.391194810090973  …  84.31595229941021, 85.20236588769724, 84.262139477017, 82.67344945270979, 85.2708118770866, 82.48844353066583, 84.99480845725486, 83.92030204652069, 87.34820763276043, 89.43639490390385])
_x_data
171-element Vector{Float64}:
  50.0
  51.0
  52.0
  53.0
  54.0
  55.0
  56.0
  57.0
  58.0
  59.0
   ⋮
 212.0
 213.0
 214.0
 215.0
 216.0
 217.0
 218.0
 219.0
 220.0
_y_data
171-element Vector{Float64}:
 12.971885307398942
 14.133361372536243
  9.745125025388765
 10.8977755219049
 12.479669440354387
 12.916329854098995
 14.812558469286401
 14.91230442338781
 12.80336314053027
 12.391194810090973
  ⋮
 85.20236588769724
 84.262139477017
 82.67344945270979
 85.2708118770866
 82.48844353066583
 84.99480845725486
 83.92030204652069
 87.34820763276043
 89.43639490390385
scatter(
    _x_data, _y_data, 
    title="Weight vs Height for the male cohort", color="orange", legend=false)
## scatter([i[1] for i in _x_data], _y_data, title="Dataset (City road)", legend=false)
xlabel!("Height [cm]")
ylabel!("Weight [kg]")

The RxInfer model

@model function linear_regression(x, y, m_B₁, v_B₁, m_B₀, v_B₀, σ̃²_V)
    β₁ ~ Normal(mean= m_B₁, variance= v_B₁); ##println("a: $a")
    β₀ ~ Normal(mean= m_B₀, variance= v_B₀); ##println("b: $b")
    y .~ Normal(mean= β₁ .* x  .+  β₀, variance= σ̃²_V); ##println("y: $y")
    ##- y .~ Normal(mean= g(β₁, β₀, x), variance= σ̃²_V); println("y: $y")
end

Inference with the RxInfer model

We will evaluate the convergence performance of the algorithm with the free_energy = true option:

_m_B₁ = 0.0; _v_B₁ = 1.0
_m_B₀ = 0.0; _v_B₀ = 10.0
_its=  100
results = infer(
    model         = linear_regression(m_B₁=_m_B₁, v_B₁=_v_B₁, m_B₀=_m_B₀, v_B₀=_v_B₀, σ̃²_V=_σ̃²_V),
    data          = (y= _y_data, x= _x_data), 
    initialization= @initialization(μ(β₀) = NormalMeanVariance(_m_B₀, _v_B₀)), 
    returnvars    = (β₁= KeepLast(), β₀= KeepLast()),
    iterations    = _its,
    free_energy   = true
)
Inference results:
  Posteriors       | available for (β₀, β₁)
  Free Energy:     | Real[369.749, 607.758, 552.401, 509.591, 476.51, 450.945, 431.187, 415.916, 404.111, 394.986  …  363.761, 363.761, 363.761, 363.761, 363.761, 363.761, 363.761, 363.761, 363.761, 363.761]

Results

The free energy plot shows good convergence performance:

plot(
    2:_its, ## drop first iteration coz it is influenced by 'initmessages'
    results.free_energy[2:end], 
    title="Free energy", 
    xlabel="Iteration", 
    ylabel="Free energy [nats]", 
    legend=false)

Next we plot the priors and posteriors of both parameters:

priorβ₁ = plot(
    range(-3, 3, length=1000), 
    (x) -> pdf(NormalMeanVariance(_m_B₁, _v_B₁), x), 
    title=L"Prior for $\beta_1$ parameter", 
    fillalpha=0.3, fillrange=0, label=L"$p(\beta_1)$", color="lightgreen", legend=:topright)
priorβ₁ = vline!(priorβ₁, [ _β̃₁ ], label=L"$\tilde{\beta}_1$ (True)", color="blue", legend=:topright)
postβ₁ = plot(
    range(0.43, 0.46, length=1000), 
    (x) -> pdf(results.posteriors[:β₁], x), 
    title=L"Posterior for $\beta_1$ parameter", 
    fillalpha=0.3, fillrange=0, label=L"$p(\beta_1 \mid \mathbf{y})$", color="green")
postβ₁ = vline!(postβ₁, [ _β̃₁ ], label=L"$\tilde{\beta}_1$ (True)", color="blue")
plot(priorβ₁, postβ₁, size = (1000, 200), xlabel=L"$\beta_1$", ylabel=L"$p(\beta_1)$", ylims=[0,Inf])
priorβ₀ = plot(
    range(-15, 15, length=1000), 
    (x) -> pdf(NormalMeanVariance(_m_B₀, _v_B₀), x), 
    title=L"Prior for $\beta_0$ parameter", 
    fillalpha=0.3, fillrange=0, label=L"$p(\beta_0)$", color="lightgreen", legend=:topright)
priorβ₀ = vline!(priorβ₀, [ _β̃₀ ], label=L"$\tilde{\beta}_0$ (True)", color="blue")
postβ₀ = plot(
    range(-15, -5, length=1000), 
    (x) -> pdf(results.posteriors[:β₀], x), 
    title=L"Posterior for $\beta_0$ parameter", 
    fillalpha=0.3, fillrange=0, label=L"p(\beta_0 \mid \mathbf{y})", color="green", legend=:topright)
postβ₀ = vline!(postβ₀, [ _β̃₀ ], label=L"$\tilde{\beta}_0$ (True)", color="blue")
plot(priorβ₀, postβ₀, size = (1000, 200), xlabel=L"$\beta_0$", ylabel=L"$p(\beta_0)$", ylims=[0, Inf])

Finally, we print and compare the true values of the parameters with the estimated parameters:

_β₁ = results.posteriors[:β₁]
_β₀ = results.posteriors[:β₀]
println("Real _β̃₁: ", _β̃₁, " | Estimated β₁: ", mean_var(_β₁), " | Error: ", abs(mean(_β₁) - _β̃₁))
println("Real _β̃₀: ", _β̃₀, " | Estimated β₀: ", mean_var(_β₀), " | Error: ", abs(mean(_β₀) - _β̃₀))
Real _β̃₁: 0.44 | Estimated β₁: (0.44293838447893064, 9.964672943972579e-7) | Error: 0.0029383844789306335
Real _β̃₀: -10.0 | Estimated β₀: (-10.376694954290658, 0.02054664094834393) | Error: 0.3766949542906577