EpiNow2 Stan Functions
gaussian_process.stan
Go to the documentation of this file.
1/**
2 * These functions implement approximate Gaussian processes for Stan using
3 * Hilbert space methods. The functions are based on the following:
4 * - https://avehtari.github.io/casestudies/Motorcycle/motorcycle_gpcourse.html (Section 4)
5 * - https://doi.org/10.1007/s11222-022-10167-2
6 */
7
8/**
9 * Spectral density for Exponentiated Quadratic kernel
10 *
11 * @param alpha Scaling parameter
12 * @param rho Length scale parameter
13 * @param L Length of the interval
14 * @param M Number of basis functions
15 * @return A vector of spectral densities
16 *
17 * @ingroup estimates_smoothing
18 */
19vector diagSPD_EQ(real alpha, real rho, real L, int M) {
20 vector[M] indices = linspaced_vector(M, 1, M);
21 real factor = alpha * sqrt(sqrt(2 * pi()) * rho);
22 real exponent = -0.25 * (rho * pi() / 2 / L)^2;
23 return factor * exp(exponent * square(indices));
24}
25
26/**
27 * Squared spectral indices shared by the Matern kernels
28 *
29 * Returns `square(pi() / (2 * L) * linspaced_vector(M, 1, M))`, the term that
30 * appears in the denominator of every Matern spectral density. The
31 * `linspaced_vector()` call uses data-only bounds and is scaled afterwards so
32 * the function compiles under the data-only argument constraint in older Stan
33 * versions (e.g. the one shipped with rstan).
34 *
35 * @param M Number of basis functions
36 * @param L Length of the interval
37 * @return A vector of squared spectral indices
38 *
39 * @ingroup estimates_smoothing
40 */
41vector matern_indices(int M, real L) {
42 vector[M] indices = linspaced_vector(M, 1, M);
43 return square(pi() / (2 * L) * indices);
44}
45
46/**
47 * Spectral density for 1/2 Matern (Ornstein-Uhlenbeck) kernel
48 *
49 * @param alpha Scaling parameter
50 * @param rho Length scale parameter
51 * @param L Length of the interval
52 * @param M Number of basis functions
53 * @return A vector of spectral densities
54 *
55 * @ingroup estimates_smoothing
56 */
57vector diagSPD_Matern12(real alpha, real rho, real L, int M) {
58 vector[M] denom = 1 / rho + rho * matern_indices(M, L);
59 return alpha * sqrt(2 ./ denom);
60}
61
62/**
63 * Spectral density for 3/2 Matern kernel
64 *
65 * @param alpha Scaling parameter
66 * @param rho Length scale parameter
67 * @param L Length of the interval
68 * @param M Number of basis functions
69 * @return A vector of spectral densities
70 *
71 * @ingroup estimates_smoothing
72 */
73vector diagSPD_Matern32(real alpha, real rho, real L, int M) {
74 real factor = 2 * alpha * (sqrt(3) / rho)^1.5;
75 vector[M] denom = 3 / square(rho) + matern_indices(M, L);
76 return factor ./ denom;
77}
78
79/**
80 * Spectral density for 5/2 Matern kernel
81 *
82 * @param alpha Scaling parameter
83 * @param rho Length scale parameter
84 * @param L Length of the interval
85 * @param M Number of basis functions
86 * @return A vector of spectral densities
87 *
88 * @ingroup estimates_smoothing
89 */
90vector diagSPD_Matern52(real alpha, real rho, real L, int M) {
91 real factor = 16 * pow(sqrt(5) / rho, 5);
92 vector[M] denom = 3 * pow(5 / square(rho) + matern_indices(M, L), 3);
93 return alpha * sqrt(factor ./ denom);
94}
95
96/**
97 * Spectral density for periodic kernel
98 *
99 * @param alpha Scaling parameter
100 * @param rho Length scale parameter
101 * @param M Number of basis functions
102 * @return A vector of spectral densities
103 *
104 * @ingroup estimates_smoothing
105 */
106vector diagSPD_Periodic(real alpha, real rho, int M) {
107 real a = inv_square(rho);
108 vector[M] indices = linspaced_vector(M, 1, M);
109 vector[M] q = exp(
110 log(alpha) + 0.5 *
111 (log(2) - a + to_vector(log_modified_bessel_first_kind(indices, a)))
112 );
113 return append_row(q, q);
114}
115
116/**
117 * Basis functions for Gaussian Process
118 *
119 * @param N Number of data points
120 * @param M Number of basis functions
121 * @param L Length of the interval
122 * @param x Vector of input data
123 * @return A matrix of basis functions
124 *
125 * @ingroup estimates_smoothing
126 */
127matrix PHI(int N, int M, real L, vector x) {
128 row_vector[M] k = linspaced_row_vector(M, 1, M);
129 matrix[N, M] phi = sin((pi() / (2 * L) * (x + L)) * k) / sqrt(L);
130 return phi;
131}
132
133/**
134 * Basis functions for periodic Gaussian Process
135 *
136 * @param N Number of data points
137 * @param M Number of basis functions
138 * @param w0 Fundamental frequency
139 * @param x Vector of input data
140 * @return A matrix of basis functions
141 *
142 * @ingroup estimates_smoothing
143 */
144matrix PHI_periodic(int N, int M, real w0, vector x) {
145 row_vector[M] k = linspaced_row_vector(M, 1, M);
146 matrix[N, M] w0xk = (w0 * x) * k;
147 return append_col(cos(w0xk), sin(w0xk));
148}
149
150/**
151 * Setup Gaussian process noise dimensions
152 *
153 * @param ot_h Observation time horizon
154 * @param t Total time points
155 * @param horizon Forecast horizon
156 * @param estimate_r Indicator if estimating r
157 * @param stationary Indicator if stationary
158 * @param future_fixed Indicator if future is fixed
159 * @param fixed_from Fixed point from
160 * @return Number of noise terms
161 *
162 * @ingroup estimates_smoothing
163 */
164int setup_noise(int ot_h, int t, int horizon, int estimate_r,
165 int stationary, int future_fixed, int fixed_from) {
166 int noise_time = estimate_r > 0 ? (stationary > 0 ? ot_h : ot_h - 1) : t;
167 int noise_terms =
168 future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time;
169 return noise_terms;
170}
171
172/**
173 * Setup approximate Gaussian process
174 *
175 * @param M Number of basis functions
176 * @param L Length of the interval
177 * @param dimension Dimension of the process
178 * @param is_periodic Indicator if the process is periodic
179 * @param w0 Fundamental frequency for periodic process
180 * @return A matrix of basis functions
181 *
182 * @ingroup estimates_smoothing
183 */
184matrix setup_gp(int M, real L, int dimension, int is_periodic, real w0) {
185 vector[dimension] x = linspaced_vector(dimension, 1, dimension);
186 x = 2 * (x - mean(x)) / (max(x) - 1);
187 if (is_periodic) {
188 return PHI_periodic(dimension, M, w0, x);
189 } else {
190 return PHI(dimension, M, L, x);
191 }
192}
193
194/**
195 * Update Gaussian process using spectral densities
196 *
197 * @param PHI Basis functions matrix
198 * @param M Number of basis functions
199 * @param L Length of the interval
200 * @param alpha Scaling parameter
201 * @param rho Length scale parameter
202 * @param eta Vector of noise terms
203 * @param type Type of kernel (0: SE, 1: Periodic, 2: Matern)
204 * @param nu Smoothness parameter for Matern kernel
205 * @return A vector of updated noise terms
206 */
207vector update_gp(matrix PHI, int M, real L, real alpha,
208 real rho, vector eta, int type, real nu) {
209 vector[type == 1 ? 2 * M : M] diagSPD; // spectral density
210
211 // GP in noise - spectral densities
212 if (type == 0) {
213 diagSPD = diagSPD_EQ(alpha, rho, L, M);
214 } else if (type == 1) {
215 diagSPD = diagSPD_Periodic(alpha, rho, M);
216 } else if (type == 2) {
217 if (nu == 0.5) {
218 diagSPD = diagSPD_Matern12(alpha, rho, L, M);
219 } else if (nu == 1.5) {
220 diagSPD = diagSPD_Matern32(alpha, rho, L, M);
221 } else if (nu == 2.5) {
222 diagSPD = diagSPD_Matern52(alpha, rho, L, M);
223 } else {
224 reject("nu must be one of 0.5, 1.5, or 2.5; found nu=", nu);
225 }
226 }
227 return PHI * (diagSPD .* eta);
228}
229
230/**
231 * Priors for Gaussian process (excluding length scale)
232 *
233 * @param eta Vector of noise terms
234 *
235 * @ingroup estimates_smoothing
236 */
237void gaussian_process_lp(vector eta) {
238 eta ~ std_normal();
239}
240
vector update_gp(matrix PHI, int M, real L, real alpha, real rho, vector eta, int type, real nu)
matrix PHI_periodic(int N, int M, real w0, vector x)
vector diagSPD_EQ(real alpha, real rho, real L, int M)
vector diagSPD_Periodic(real alpha, real rho, int M)
vector diagSPD_Matern12(real alpha, real rho, real L, int M)
int setup_noise(int ot_h, int t, int horizon, int estimate_r, int stationary, int future_fixed, int fixed_from)
matrix setup_gp(int M, real L, int dimension, int is_periodic, real w0)
matrix PHI(int N, int M, real L, vector x)
vector diagSPD_Matern52(real alpha, real rho, real L, int M)
vector diagSPD_Matern32(real alpha, real rho, real L, int M)
vector matern_indices(int M, real L)
void gaussian_process_lp(vector eta)