|
| 1 | +/- |
| 2 | +Copyright (c) 2026 Rémy Degenne. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Rémy Degenne, Etienne Marion |
| 5 | +-/ |
| 6 | +module |
| 7 | + |
| 8 | +public import Mathlib.Analysis.CStarAlgebra.Matrix |
| 9 | +public import Mathlib.MeasureTheory.Measure.CharacteristicFunction.Basic |
| 10 | +public import Mathlib.Probability.Distributions.Gaussian.Basic |
| 11 | +public import Mathlib.Probability.Moments.CovarianceBilin |
| 12 | + |
| 13 | +import Mathlib.Probability.Distributions.Gaussian.CharFun |
| 14 | +import Mathlib.Probability.Distributions.Gaussian.Fernique |
| 15 | + |
| 16 | +/-! |
| 17 | +# Multivariate Gaussian distributions |
| 18 | +
|
| 19 | +In this file we define the standard Gaussian distribution over a Euclidean space and multivariate |
| 20 | +Gaussian distributions over `EuclideanSpace ℝ ι`. |
| 21 | +
|
| 22 | +## Main definitions |
| 23 | +
|
| 24 | +* `stdGaussian E`: Standard Gaussian distribution on a finite-dimensional real inner product space |
| 25 | + `E`. This is the random vector whose coordinates in an orthonormal basis are independent standard |
| 26 | + Gaussian. |
| 27 | +* `multivariateGaussian μ S`: The multivariate Gaussian distribution on `EuclideanSpace ℝ ι` |
| 28 | + with mean `μ` and covariance matrix `S`, when `S` is a positive semidefinite matrix. |
| 29 | +
|
| 30 | +## TODO |
| 31 | +
|
| 32 | +- Generalize `multivariateGaussian μ S` when `S` is a symmetric trace class operator over a |
| 33 | + Hilbert space. |
| 34 | +
|
| 35 | +## Tags |
| 36 | +
|
| 37 | +multivariate Gaussian distribution |
| 38 | +
|
| 39 | +-/ |
| 40 | + |
| 41 | +@[expose] public section |
| 42 | + |
| 43 | + |
| 44 | +open MeasureTheory Matrix WithLp Module Complex |
| 45 | +open scoped RealInnerProductSpace MatrixOrder |
| 46 | + |
| 47 | +namespace ProbabilityTheory |
| 48 | + |
| 49 | +variable {ι : Type*} [Fintype ι] |
| 50 | + |
| 51 | +section stdGaussian |
| 52 | + |
| 53 | +/-! ### Standard Gaussian measure over a Euclidean space -/ |
| 54 | + |
| 55 | +variable {E : Type*} [NormedAddCommGroup E] [InnerProductSpace ℝ E] [FiniteDimensional ℝ E] |
| 56 | + [MeasurableSpace E] |
| 57 | + |
| 58 | +variable (E) in |
| 59 | +/-- Standard Gaussian distribution on a finite-dimensional real inner product space `E`. |
| 60 | +This is the random vector whose coordinates in an orthonormal basis are independent standard |
| 61 | +Gaussian. |
| 62 | +
|
| 63 | +The definition uses `stdOrthonormalBasis ℝ E` but does not actually depend on the |
| 64 | +basis, see `stdGaussian_eq_map_pi_orthonormalBasis`. -/ |
| 65 | +noncomputable |
| 66 | +def stdGaussian : Measure E := |
| 67 | + (Measure.pi (fun _ : Fin (Module.finrank ℝ E) ↦ gaussianReal 0 1)).map |
| 68 | + (fun x ↦ ∑ i, x i • stdOrthonormalBasis ℝ E i) |
| 69 | + |
| 70 | +variable [BorelSpace E] |
| 71 | + |
| 72 | +instance isProbabilityMeasure_stdGaussian : IsProbabilityMeasure (stdGaussian E) := |
| 73 | + Measure.isProbabilityMeasure_map (Measurable.aemeasurable (by fun_prop)) |
| 74 | + |
| 75 | +@[simp] |
| 76 | +lemma integral_id_stdGaussian : ∫ x, x ∂(stdGaussian E) = 0 := by |
| 77 | + rw [stdGaussian, integral_map _ (by fun_prop), integral_finset_sum] |
| 78 | + · simp [integral_smul_const, integral_eval] |
| 79 | + · exact fun i _ ↦ Integrable.smul_const (integrable_eval IsGaussian.integrable_id) _ |
| 80 | + · exact (Finset.measurable_sum _ (by fun_prop)).aemeasurable |
| 81 | + |
| 82 | +lemma variance_dual_stdGaussian (L : StrongDual ℝ E) : |
| 83 | + Var[L; stdGaussian E] = ‖L‖ ^ 2 := by |
| 84 | + rw [stdGaussian, variance_map L.continuous.aemeasurable (Measurable.aemeasurable (by fun_prop))] |
| 85 | + have : L ∘ (fun x : Fin (Module.finrank ℝ E) → ℝ ↦ ∑ i, x i • stdOrthonormalBasis ℝ E i) = |
| 86 | + ∑ i, (fun x : Fin (Module.finrank ℝ E) → ℝ ↦ L (stdOrthonormalBasis ℝ E i) * x i) := by |
| 87 | + ext x; simp [mul_comm] |
| 88 | + rw [this, variance_sum_pi] |
| 89 | + · change ∑ i, Var[fun x ↦ _ * (id x); gaussianReal 0 1] = _ |
| 90 | + simp_rw [variance_const_mul, variance_id_gaussianReal, (stdOrthonormalBasis ℝ E).norm_dual] |
| 91 | + simp |
| 92 | + · exact fun i ↦ IsGaussian.memLp_two_id.const_mul _ |
| 93 | + |
| 94 | +set_option backward.isDefEq.respectTransparency false in |
| 95 | +lemma charFun_stdGaussian (t : E) : |
| 96 | + charFun (stdGaussian E) t = exp (- ‖t‖ ^ 2 / 2) := by |
| 97 | + rw [charFun_apply, stdGaussian, integral_map (Measurable.aemeasurable (by fun_prop)) |
| 98 | + (Measurable.aestronglyMeasurable (by fun_prop))] |
| 99 | + simp_rw [sum_inner, ofReal_sum, Finset.sum_mul, exp_sum, |
| 100 | + integral_fintype_prod_eq_prod (f := fun i x ↦ exp (⟪x • stdOrthonormalBasis ℝ E i, t⟫ * I)), |
| 101 | + real_inner_smul_left, mul_comm _ (⟪_, _⟫), ofReal_mul, ← charFun_apply_real, |
| 102 | + charFun_gaussianReal] |
| 103 | + simp only [ofReal_zero, mul_zero, zero_mul, NNReal.coe_one, ofReal_one, one_mul, |
| 104 | + zero_sub] |
| 105 | + simp_rw [← exp_sum, Finset.sum_neg_distrib, ← Finset.sum_div, ← ofReal_pow, |
| 106 | + ← ofReal_sum, (stdOrthonormalBasis ℝ E).sum_sq_inner_right, neg_div] |
| 107 | + |
| 108 | +set_option backward.isDefEq.respectTransparency false in |
| 109 | +instance isGaussian_stdGaussian : IsGaussian (stdGaussian E) := by |
| 110 | + refine isGaussian_iff_gaussian_charFun.2 ⟨0, innerSL ℝ, |
| 111 | + LinearMap.BilinForm.isPosSemidef_iff.2 isPosSemidef_inner, ?_⟩ |
| 112 | + simp [charFun_stdGaussian, neg_div, innerSL_apply_apply ℝ] |
| 113 | + |
| 114 | +@[simp] |
| 115 | +lemma integral_strongDual_stdGaussian (L : StrongDual ℝ E) : (stdGaussian E)[L] = 0 := by |
| 116 | + rw [L.integral_comp_id_comm IsGaussian.integrable_id, integral_id_stdGaussian, map_zero] |
| 117 | + |
| 118 | +set_option backward.isDefEq.respectTransparency false in |
| 119 | +lemma charFunDual_stdGaussian (L : StrongDual ℝ E) : |
| 120 | + charFunDual (stdGaussian E) L = exp (- ‖L‖ ^ 2 / 2) := by |
| 121 | + simp [IsGaussian.charFunDual_eq, integral_complex_ofReal, variance_dual_stdGaussian, neg_div] |
| 122 | + |
| 123 | +set_option backward.isDefEq.respectTransparency false in |
| 124 | +lemma covarianceBilin_stdGaussian : |
| 125 | + covarianceBilin (stdGaussian E) = innerSL ℝ := by |
| 126 | + refine gaussian_charFun_congr 0 _ ?_ ?_ |>.2.symm |
| 127 | + · exact LinearMap.BilinForm.isPosSemidef_iff.2 isPosSemidef_inner |
| 128 | + · simp [charFun_stdGaussian, neg_div, innerSL_apply_apply ℝ] |
| 129 | + |
| 130 | +lemma stdGaussian_map {F : Type*} [NormedAddCommGroup F] [InnerProductSpace ℝ F] [MeasurableSpace F] |
| 131 | + [BorelSpace F] (f : E ≃ₗᵢ[ℝ] F) : |
| 132 | + haveI := f.finiteDimensional; (stdGaussian E).map f = stdGaussian F := by |
| 133 | + have := f.finiteDimensional |
| 134 | + apply Measure.ext_of_charFunDual |
| 135 | + ext L |
| 136 | + simp_rw [show ⇑f = f.toLinearIsometry.toContinuousLinearMap from rfl, charFunDual_map, |
| 137 | + charFunDual_stdGaussian, L.opNorm_comp_linearIsometryEquiv] |
| 138 | + |
| 139 | +lemma map_pi_eq_stdGaussian : |
| 140 | + (Measure.pi (fun _ ↦ gaussianReal 0 1)).map (toLp 2) = stdGaussian (EuclideanSpace ℝ ι) := by |
| 141 | + apply Measure.ext_of_charFun (E := EuclideanSpace ℝ ι) |
| 142 | + ext t |
| 143 | + simp_rw [charFun_stdGaussian, charFun_pi, charFun_gaussianReal, ← exp_sum, ← ofReal_pow, |
| 144 | + EuclideanSpace.real_norm_sq_eq] |
| 145 | + simp [Finset.sum_div, neg_div] |
| 146 | + |
| 147 | +/-- The definition of `stdGaussian` does not depend on the basis. -/ |
| 148 | +lemma stdGaussian_eq_map_pi_orthonormalBasis (b : OrthonormalBasis ι ℝ E) : |
| 149 | + stdGaussian E = (Measure.pi fun _ : ι ↦ gaussianReal 0 1).map (fun x ↦ ∑ i, x i • b i) := by |
| 150 | + have : (fun (x : ι → ℝ) ↦ ∑ i, x i • b i) = |
| 151 | + ⇑((EuclideanSpace.basisFun ι ℝ).equiv b (Equiv.refl ι)) ∘ (toLp 2) := by |
| 152 | + simp_rw [← b.equiv_apply_euclideanSpace] |
| 153 | + rfl |
| 154 | + rw [this, ← Measure.map_map, map_pi_eq_stdGaussian, stdGaussian_map] |
| 155 | + all_goals fun_prop |
| 156 | + |
| 157 | +end stdGaussian |
| 158 | + |
| 159 | +section multivariateGaussian |
| 160 | + |
| 161 | +/-! ### Multivariate Gaussian measures over `ℝⁿ` -/ |
| 162 | + |
| 163 | +variable [DecidableEq ι] |
| 164 | + |
| 165 | +set_option backward.isDefEq.respectTransparency false in |
| 166 | +/-- Multivariate Gaussian measure on `EuclideanSpace ℝ ι` with mean `μ` and covariance |
| 167 | +matrix `S`. This only makes sense when `S` is positive semidefinite, |
| 168 | +as then `CFC.sqrt S * CFC.sqrt S = S`. Otherwise `CFC.sqrt S = 0`, and |
| 169 | +`multivariateGaussian μ S = Measure.dirac μ` (see `multivariateGaussian_of_not_posSemidef`). -/ |
| 170 | +noncomputable |
| 171 | +def multivariateGaussian (μ : EuclideanSpace ℝ ι) (S : Matrix ι ι ℝ) : |
| 172 | + Measure (EuclideanSpace ℝ ι) := |
| 173 | + (stdGaussian (EuclideanSpace ℝ ι)).map (fun x ↦ μ + toEuclideanCLM (𝕜 := ℝ) (CFC.sqrt S) x) |
| 174 | + |
| 175 | +set_option backward.isDefEq.respectTransparency false in |
| 176 | +lemma multivariateGaussian_of_not_posSemidef (μ : EuclideanSpace ℝ ι) {S : Matrix ι ι ℝ} |
| 177 | + (hS : ¬ S.PosSemidef) : multivariateGaussian μ S = .dirac μ := by |
| 178 | + rw [multivariateGaussian, CFC.sqrt, cfcₙ_apply_of_not_predicate] |
| 179 | + · simp |
| 180 | + change ¬ (S - 0).PosSemidef |
| 181 | + simpa |
| 182 | + |
| 183 | +set_option backward.isDefEq.respectTransparency false in |
| 184 | +@[simp] |
| 185 | +lemma multivariateGaussian_zero_one : |
| 186 | + multivariateGaussian 0 (1 : Matrix ι ι ℝ) = stdGaussian (EuclideanSpace ℝ ι) := by |
| 187 | + simp [multivariateGaussian] |
| 188 | + |
| 189 | +variable {μ : EuclideanSpace ℝ ι} {S : Matrix ι ι ℝ} |
| 190 | + |
| 191 | +set_option backward.isDefEq.respectTransparency false in |
| 192 | +instance isGaussian_multivariateGaussian : IsGaussian (multivariateGaussian μ S) := by |
| 193 | + have h : (fun x ↦ μ + (toEuclideanCLM (𝕜 := ℝ) (CFC.sqrt S)) x) = |
| 194 | + (fun x ↦ μ + x) ∘ ((toEuclideanCLM (𝕜 := ℝ) (CFC.sqrt S))) := rfl |
| 195 | + simp only [multivariateGaussian] |
| 196 | + rw [h, ← Measure.map_map (measurable_const_add μ) (by fun_prop)] |
| 197 | + infer_instance |
| 198 | + |
| 199 | +@[simp] |
| 200 | +lemma integral_id_multivariateGaussian : ∫ x, x ∂(multivariateGaussian μ S) = μ := by |
| 201 | + rw [multivariateGaussian, integral_map (by fun_prop) (by fun_prop), |
| 202 | + integral_add (integrable_const _), integral_const] |
| 203 | + · simp [ContinuousLinearMap.integral_comp_comm _ IsGaussian.integrable_fun_id] |
| 204 | + · exact IsGaussian.integrable_id.comp_measurable (by fun_prop) |
| 205 | + |
| 206 | +lemma integral_id_multivariateGaussian' : (multivariateGaussian μ S)[id] = μ := by simp |
| 207 | + |
| 208 | +set_option backward.isDefEq.respectTransparency false in |
| 209 | +lemma covarianceBilin_multivariateGaussian (hS : S.PosSemidef) (x y : EuclideanSpace ℝ ι) : |
| 210 | + covarianceBilin (multivariateGaussian μ S) x y = x ⬝ᵥ S *ᵥ y := by |
| 211 | + have h : (fun x ↦ μ + x) ∘ ((toEuclideanCLM (𝕜 := ℝ) (CFC.sqrt S))) = |
| 212 | + (fun x ↦ μ + (toEuclideanCLM (𝕜 := ℝ) (CFC.sqrt S)) x) := rfl |
| 213 | + simp only [multivariateGaussian] |
| 214 | + rw [← h, ← Measure.map_map (measurable_const_add μ) (by fun_prop), covarianceBilin_map_const_add, |
| 215 | + covarianceBilin_map, covarianceBilin_stdGaussian, innerSL_apply_apply, |
| 216 | + ContinuousLinearMap.adjoint_inner_left, IsSelfAdjoint.adjoint_eq, |
| 217 | + ← ContinuousLinearMap.comp_apply, ← ContinuousLinearMap.mul_def, ← map_mul, |
| 218 | + CFC.sqrt_mul_sqrt_self _ hS.nonneg, inner_toEuclideanCLM] |
| 219 | + · exact (CFC.sqrt_nonneg S).isSelfAdjoint.map _ |
| 220 | + · exact IsGaussian.memLp_two_id |
| 221 | + |
| 222 | +set_option backward.isDefEq.respectTransparency false in |
| 223 | +lemma covariance_eval_multivariateGaussian (hS : S.PosSemidef) (i j : ι) : |
| 224 | + cov[fun x ↦ x i, fun x ↦ x j; multivariateGaussian μ S] = S i j := by |
| 225 | + have (i : ι) : (fun x : EuclideanSpace ℝ ι ↦ x i) = |
| 226 | + fun x ↦ ⟪EuclideanSpace.basisFun ι ℝ i, x⟫ := by ext; simp [PiLp.inner_apply] |
| 227 | + rw [this, this, ← covarianceBilin_apply_eq_cov, covarianceBilin_multivariateGaussian hS] |
| 228 | + · simp |
| 229 | + · exact IsGaussian.memLp_two_id |
| 230 | + |
| 231 | +lemma variance_eval_multivariateGaussian (hS : S.PosSemidef) (i : ι) : |
| 232 | + Var[fun x ↦ x i; multivariateGaussian μ S] = S i i := by |
| 233 | + rw [← covariance_self, covariance_eval_multivariateGaussian hS] |
| 234 | + exact Measurable.aemeasurable <| by fun_prop |
| 235 | + |
| 236 | +lemma measurePreserving_eval_multivariateGaussian (hS : S.PosSemidef) {i : ι} : |
| 237 | + MeasurePreserving (fun x ↦ x i) (multivariateGaussian μ S) |
| 238 | + (gaussianReal (μ i) (S i i).toNNReal) where |
| 239 | + measurable := by fun_prop |
| 240 | + map_eq := by |
| 241 | + rw [← EuclideanSpace.coe_proj, IsGaussian.map_eq_gaussianReal, |
| 242 | + ContinuousLinearMap.integral_comp_id_comm] |
| 243 | + · simp [variance_eval_multivariateGaussian hS] |
| 244 | + exact IsGaussian.integrable_id |
| 245 | + |
| 246 | +lemma charFun_multivariateGaussian (hS : S.PosSemidef) (x : EuclideanSpace ℝ ι) : |
| 247 | + charFun (multivariateGaussian μ S) x = |
| 248 | + exp (⟪x, μ⟫ * I - x ⬝ᵥ S *ᵥ x / 2) := by |
| 249 | + simp [IsGaussian.charFun_eq', covarianceBilin_multivariateGaussian hS] |
| 250 | + |
| 251 | +set_option backward.isDefEq.respectTransparency false in |
| 252 | +/-- If one restricts a multivariate Gaussian measure indexed by a finite set `I` to |
| 253 | +coordinates indexed by `J ⊆ I`, one obtains the multivariate Gaussian measure whose |
| 254 | +covariance matrix is given by the corresponding submatrix. -/ |
| 255 | +lemma measurePreserving_restrict₂_multivariateGaussian {ι : Type*} [DecidableEq ι] {I J : Finset ι} |
| 256 | + {μ : EuclideanSpace ℝ I} {S : Matrix I I ℝ} (hS : S.PosSemidef) (hJI : J ⊆ I) : |
| 257 | + MeasurePreserving (EuclideanSpace.restrict₂ hJI) (multivariateGaussian μ S) |
| 258 | + (multivariateGaussian (μ.restrict₂ hJI) |
| 259 | + (S.submatrix (fun i : J ↦ ⟨i.1, hJI i.2⟩) (fun i : J ↦ ⟨i.1, hJI i.2⟩))) where |
| 260 | + measurable := by fun_prop |
| 261 | + map_eq := by |
| 262 | + apply IsGaussian.ext |
| 263 | + · simp only [id_eq, integral_id_multivariateGaussian] |
| 264 | + rw [ContinuousLinearMap.integral_id_map, integral_id_multivariateGaussian] |
| 265 | + exact IsGaussian.integrable_id |
| 266 | + rw [← ContinuousLinearMap.toBilinForm_inj] |
| 267 | + refine LinearMap.BilinForm.ext_basis (EuclideanSpace.basisFun J ℝ).toBasis fun i j ↦ ?_ |
| 268 | + rw [ContinuousLinearMap.toBilinForm_apply, ContinuousLinearMap.toBilinForm_apply, |
| 269 | + covarianceBilin_apply_eq_cov, covariance_map] |
| 270 | + · have (i : J) : (fun u ↦ ⟪(EuclideanSpace.basisFun J ℝ).toBasis i, u⟫) ∘ |
| 271 | + EuclideanSpace.restrict₂ hJI = fun u ↦ u ⟨i.1, hJI i.2⟩ := by ext; simp [PiLp.inner_apply] |
| 272 | + simp_rw [this, covariance_eval_multivariateGaussian hS, |
| 273 | + covarianceBilin_multivariateGaussian (hS.submatrix _)] |
| 274 | + simp |
| 275 | + any_goals exact Measurable.aestronglyMeasurable (by fun_prop) |
| 276 | + · fun_prop |
| 277 | + · exact IsGaussian.memLp_two_id |
| 278 | + |
| 279 | +end multivariateGaussian |
| 280 | + |
| 281 | +end ProbabilityTheory |
0 commit comments