Expectation propagation for symbol detection in large-scale MIMO communications Pablo M. Olmos olmos@tsc.uc3m.es Joint work with Javier Céspedes (UC3M) Matilde Sánchez-Fernández (UC3M) and Fernando Pérez-Cruz (Bell Labs) Today Probabilistic symbol detection in a MIMO system, combined with low-density parity-check (LDPC) channel coding. Approximate Inference. State of the art techniques: Soft MMSE, Gaussian tree approximations (GTA), message passing (BP) ... Our proposal: approximate inference via Expectation Propagation (EP). Simulation results. Index 1 Problem setup 2 Low-complexity probabilistic detection Soft MMSE GTA Message Passing 3 Approximate Inference with EP MIMO detection with an EP approximation 4 Simulation results MIMO symbol detection Real-valued channel model. H is assumed known at the receiver. xi , i = 1, . . . , n are independent amplitude-modulated symbols xi ∈ A = {±1, ±2, . . . , ± M } → M symbols 2 Probabilistic detection: Observed y, we are interested in computing p(xi |y), i = 1, . . . , n. MIMO symbol detection p(xi |y) ∝ p(xj ) = XX x1 x2 ... XX xi−1 xi+1 1 xj ∈ A 0 otherwise ... X xd N (y; Hx, σ I) d Y j=1 p(xj ) → M n operations!!! MIMO symbol detection p(xi |y) ∝ p(xj ) = XX x1 x2 ... XX xi−1 xi+1 1 xj ∈ A 0 otherwise ... X xd N (y; Hx, σ I) d Y j=1 p(xj ) → M n operations!!! Approximate Inference in MIMO detection Today → approximate inference methods at O(n3 ) complexity. Approximate Inference in MIMO detection Today → approximate inference methods at O(n3 ) complexity. Large-scale! E.g., n = 128 and 64-QAM modulation. Approximate Inference in MIMO detection Today → approximate inference methods at O(n3 ) complexity. Large-scale! E.g., n = 128 and 64-QAM modulation. Alternatives for small n, M: Sphere-decoding, MCMC methods, ... Index 1 Problem setup 2 Low-complexity probabilistic detection Soft MMSE GTA Message Passing 3 Approximate Inference with EP MIMO detection with an EP approximation 4 Simulation results Index 1 Problem setup 2 Low-complexity probabilistic detection Soft MMSE GTA Message Passing 3 Approximate Inference with EP MIMO detection with an EP approximation 4 Simulation results Soft MMSE d Y p(x|y) ∝ N (y; Hx, σ I) p(xj ), p(xj ) = j=1 1 0 xj ∈ A otherwise Replace the discrete priors by independent Gaussian priors with the same mean/variance: d Y q(x) ∝ N (y; Hx, σ I) q(xj ), j=1 q(xj ) = N (0, Es ) Soft MMSE d Y p(x|y) ∝ N (y; Hx, σ I) p(xj ), p(xj ) = j=1 1 0 xj ∈ A otherwise Replace the discrete priors by independent Gaussian priors with the same mean/variance: d Y q(x) ∝ N (y; Hx, σ I) q(xj ), j=1 q(xj ) = N (0, Es ) Soft MMSE d Y p(x|y) ∝ N (y; Hx, σ I) p(xj ), p(xj ) = j=1 1 0 xj ∈ A otherwise Replace the discrete priors by independent Gaussian priors with the same mean/variance: q(x) ∝ N (y; Hx, σ I) d Y j=1 N (xj : 0, Es ) Soft MMSE q(x) = N (x : µMMSE , ΣMMSE ) −1 σ T ΣMMSE = H H + I Es µMMSE = ΣMMSE HT y q(xi = a) p(xi = a|y) ≈ X , q(xi = c) c∈A a∈A Soft MMSE q(x) = N (x : µMMSE , ΣMMSE ) −1 σ T ΣMMSE = H H + I Es µMMSE = ΣMMSE HT y q(xi = a) p(xi = a|y) ≈ X , q(xi = c) c∈A a∈A Index 1 Problem setup 2 Low-complexity probabilistic detection Soft MMSE GTA Message Passing 3 Approximate Inference with EP MIMO detection with an EP approximation 4 Simulation results The Gaussian tree approximation d Y p(x|y) ∝ N (y; Hx, σ I) p(xj ), j=1 p(xj ) = 1 0 xj ∈ A otherwise Step 1: Ignore the discrete prior (or replace it by a constant term) q(x) ∝ N (y; Hx, σ I) = N (x; (H> H)−1 H> y, σ (H> H)−1 ) The Gaussian tree approximation d Y p(x|y) ∝ N (y; Hx, σ I) p(xj ), j=1 p(xj ) = 1 0 xj ∈ A otherwise Step 1: Ignore the discrete prior (or replace it by a constant term) q(x) ∝ N (y; Hx, σ I) = N (x; (H> H)−1 H> y, σ (H> H)−1 ) The Gaussian tree approximation Step 2: Construct a Gaussian tree approximation to q(x) qtree (x) = n Y j=1 qtree (xj |π(j)) ≈ q(x), Both q(x) and qtree (x) are Gaussian and have the same pairwise marginals. The Gaussian tree approximation Step 3: Include the discrete priors and compute marginals using discrete message passing. o qtree (x) ∝ qtree (x) d Y j=1 p(xj ) Index 1 Problem setup 2 Low-complexity probabilistic detection Soft MMSE GTA Message Passing 3 Approximate Inference with EP MIMO detection with an EP approximation 4 Simulation results Message Passing and Belief Propagation BP → accurate marginal estimates in sparse factor graphs (such as LDPC code graphs). This is not the case in MIMO detection d Y 1 xj ∈ A p(x|y) ∝ N (y; Hx, σ I) p(xj ), p(xj ) = 0 otherwise j=1 Message Passing and Belief Propagation Approximate Message Passing (AMP): BP messages are approximated using Gaussian pdfs. Compressed Sensing applications: AMP achieves remarkable accuracy in a certain scenarios, despite the high-density in the factor graph. AMP for MIMO detection Excellent performance/accuracy for QPSK modulations and large n. Larger QAM constellations: performance severely degraded. Reduce the loading factor: n/d < 1. Index 1 Problem setup 2 Low-complexity probabilistic detection Soft MMSE GTA Message Passing 3 Approximate Inference with EP MIMO detection with an EP approximation 4 Simulation results Expectation Propagation General purpose framework for approximating a probability distribution p(x) by a distribution q(x) that belongs to an exponential family F. Exponential Family: A family F of distributions with densities q(x|θ) = h(x) exp(θ T φ(x) − Φ(θ)), Z Φ(θ) = log exp(θ T φ(x))dh(x) θ∈Θ Family F of Gaussian distributions with diagonal covariance matrix: T φ(x) = x1 , x2 , . . . , xn , x12 , x22 , . . . , xn2 The Moment Matching criterion: Find θ ∗ such that Eq(x|θ∗ ) [φ(x)] ≈ Ep(x) [φ(x)] EP Iterative algorithms Index 1 Problem setup 2 Low-complexity probabilistic detection Soft MMSE GTA Message Passing 3 Approximate Inference with EP MIMO detection with an EP approximation 4 Simulation results EP for MIMO symbol detection d Y p(x|y) ∝ N (y; Hx, σ I) p(xj ), j=1 p(xj ) = 1 0 xj ∈ A otherwise EP for MIMO symbol detection d Y p(x|y) ∝ N (y; Hx, σ I) p(xj ), p(xj ) = j=1 1 0 xj ∈ A otherwise A family of approximating distributions: q(x|γ, Λ) ∝ N (y; Hx, σ I) where γ ∈ Rn and Λ ∈ Rn+ . n Y j=1 1 2 eγj xj − 2 Λj xj , EP for MIMO symbol detection d Y p(x|y) ∝ N (y; Hx, σ I) p(xj ), p(xj ) = j=1 1 0 xj ∈ A otherwise A family of approximating distributions: q(x|γ, Λ) ∝ N (y; Hx, σ I) n Y 1 2 eγj xj − 2 Λj xj , j=1 where γ ∈ Rn and Λ ∈ Rn+ . q(x|γ, Λ) = N (x : µ, Σ) −1 Σ = σ−1 H> H + diag (Λ) µ = Σ σ−1 H> y + γ EP for MIMO symbol detection q(x|γ, Λ) ∝ N (y; Hx, σ I) n Y j=1 Σ = σ−1 H> H + diag (Λ) µ = Σ σ−1 H> y + γ 1 2 eγj xj − 2 Λj xj = N (x : µ, Σ) −1 Note that q(xm |γ, Λ) = N (xm : µm , Σmm ) Z n Y 2 2 1 γm xm − 12 Λm xm ∝e N (y; Hx, σ I) eγj xj − 2 Λj xj dx∼m j=1 j6=m O(n3 ) z }| { Given the current value of γ, Λ → µ, Σ → q(xj |γ, Λ), m = 1, . . . , n. Parallel update of all (γm , Λm ) pairs, m = 1, . . . , n: 1 Define q̂(xm ) ∝ q(xm |γ, Λ) 2 p(xm ) 1 2 eγm xm − 2 Λm xm = 6= 0 0 xm ∈ A otherwise Compute µ̂m = Eq̂(xm ) [xm ], σ̂j = Eq̂(xm ) [xm2 ] − µ̂2m . 3 0 Find (γm , Λ0m ) such that q̂(xm ) ∝ q(xm |γ, Λ) has mean and variance equal to µ̂m , σ̂m . 0 γm xm − 12 Λ0m xm2 1 2 eγm xm − 2 Λm xm O(n3 ) z }| { Given the current value of γ, Λ → µ, Σ → q(xj |γ, Λ), m = 1, . . . , n. Parallel update of all (γm , Λm ) pairs, m = 1, . . . , n: 1 Define q̂(xm ) ∝ q(xm |γ, Λ) 2 p(xm ) 1 2 eγm xm − 2 Λm xm = 6= 0 0 xm ∈ A otherwise Compute µ̂m = Eq̂(xm ) [xm ], σ̂j = Eq̂(xm ) [xm2 ] − µ̂2m . 3 0 Find (γm , Λ0m ) such that q̂(xm ) ∝ q(xm |γ, Λ) has mean and variance equal to µ̂m , σ̂m . 0 γm xm − 12 Λ0m xm2 1 2 eγm xm − 2 Λm xm O(n3 ) z }| { Given the current value of γ, Λ → µ, Σ → q(xj |γ, Λ), m = 1, . . . , n. Parallel update of all (γm , Λm ) pairs, m = 1, . . . , n: 1 Define q̂(xm ) ∝ q(xm |γ, Λ) 2 p(xm ) 1 2 eγm xm − 2 Λm xm = 6= 0 0 xm ∈ A otherwise Compute µ̂m = Eq̂(xm ) [xm ], σ̂j = Eq̂(xm ) [xm2 ] − µ̂2m . 3 0 Find (γm , Λ0m ) such that q̂(xm ) ∝ q(xm |γ, Λ) has mean and variance equal to µ̂m , σ̂m . 0 γm xm − 12 Λ0m xm2 1 2 eγm xm − 2 Λm xm 1 2 q(xm |γ, Λ) ∝ eγm xm − 2 Λm xm q̂(xm ) = q(xm |γ, Λ) p(xm ) e 2 γm xm − 12 Λm xm Z N (y; Hx, σ I) ∝ p(xm ) Z n Y 1 2 eγj xj − 2 Λj xj dx∼m j=1 j6=m N (y; Hx, σ I) n Y j=1 j6=m 1 2 eγj xj − 2 Λj xj dx∼m O(n3 ) z }| { Given the current value of γ, Λ → µ, Σ → q(xj |γ, Λ), m = 1, . . . , n. Parallel update of all (γm , Λm ) pairs, m = 1, . . . , n: 1 Define q̂(xm ) ∝ q(xm |γ, Λ) 2 p(xm ) 1 2 eγm xm − 2 Λm xm = 6= 0 0 xm ∈ A otherwise Compute µ̂m = Eq̂(xm ) [xm ], σ̂m = Eq̂(xm ) [xm2 ] − µ̂2m . 3 0 Find (γm , Λ0m ) such that q̂(xm ) ∝ q(xm |γ, Λ) has mean and variance equal to µ̂m , σ̂m . 0 γm xm − 12 Λ0m xm2 1 2 eγm xm − 2 Λm xm O(n3 ) z }| { Given the current value of γ, Λ → µ, Σ → q(xj |γ, Λ), m = 1, . . . , n. Parallel update of all (γm , Λm ) pairs, m = 1, . . . , n: 1 Define q̂(xm ) ∝ q(xm |γ, Λ) 2 p(xm ) 1 2 eγm xm − 2 Λm xm = 6= 0 0 xm ∈ A otherwise Compute µ̂m = Eq̂(xm ) [xm ], σ̂m = Eq̂(xm ) [xm2 ] − µ̂2m . 3 0 Find (γm , Λ0m ) such that q(xm |γ, Λ) 0 γm xm − 12 Λ0m xm2 1 2 eγm xm − 2 Λm xm has mean and variance equal to µ̂m , σ̂m . O(n3 ) z }| { Given the current value of γ, Λ → µ, Σ → q(xj |γ, Λ), m = 1, . . . , n. Parallel update of all (γm , Λm ) pairs, m = 1, . . . , n: 1 Define p(xm ) 6 0 = q̂(xm ) ∝ q(xm |γ, Λ) γ x − 1 Λ x 2 = m m m 0 m 2 e 2 xm ∈ A otherwise Compute µ̂m = Eq̂(xm ) [xm ], σ̂m = Eq̂(xm ) [xm2 ] − µ̂2m . 3 0 Find (γm , Λ0m ) such that q(xm |γ, Λ) 0 γm xm − 12 Λ0m xm2 1 2 eγm xm − 2 Λm xm has mean and variance equal to µ̂m , σ̂m . The EP iterative algorithm MMSE initialization γj = 0, Λj = Es−1 , j = 1, . . . , n. Cost per iteration dominated by the matrix inversion: −1 Σ = σ−1 H> H + diag (Λ) Smooth the parameter update to improve stability. For some β ∈ [0, 1] γj ← βγj0 + (1 − β)γj Λj ← βΛ0j + (1 − β)Λj Quick convergence to a stationary point. Number I of required iterations that does not scale neither with n nor M. Symbol marginals: q(xi = a|γ, Λ) p(xi = a|y) ≈ X , q(xi = c|γ, Λ) c∈A a∈A Index 1 Problem setup 2 Low-complexity probabilistic detection Soft MMSE GTA Message Passing 3 Approximate Inference with EP MIMO detection with an EP approximation 4 Simulation results Binary LDPC channel code, block length N. Several channel uses are required to transmit a complete LDPC codeword. Channel coefficients are sampled from an independent complex zero-mean unit-variance Gaussian distribution. R(Hc ) −I(Hc ) H= I(Hc ) R(Hc ) SNR defined as: Eb SNR(dB) = 10 log10 n log2 (M)R , σ QPSK, 5 × 5 scenario, (3, 6)-regular LDPC code, N = 5120 bits 10−1 BER 10−2 10−3 Optimal detector EP GTA APM Soft MMSE −4 10 10−5 10−6 6 8 10 SNR 12 14 QPSK, 5 × 5 scenario, (3, 6)-regular LDPC code, N = 5120 bits Calibration curves: q(xi ) vs. p(xi |y) GTA pGTA (ui |y) pEP (ui |y) EP p(ui |y) p(ui |y) AMP pMMSE (ui |y) pCHEMP (ui |y) Soft MMSE p(ui |y) p(ui |y) SNR = 13.5dB 128 × 128 scenario, (3, 6)-regular LDPC code, N = 5120 bits 100 QPSK 16-QAM 64-QAM 10−1 BER 10−2 10−3 EP AMP GTA MMSE 10−4 10−5 10−6 5 10 15 20 SNR 25 30 32 × 32 scenario, 16-QAM, (3, 6)-regular LDPC code with N = 15k bits and N = 32k bits 100 10−1 BER 10−2 10−3 10−4 EP GTA MMSE 10−5 10−6 14 15 16 17 SNR Dashed line → N = 15k Solid line → N = 32k 18 19 20 32 × 32 scenario, 16-QAM, capacity-achieving LDPC codes, N = 32k bits 100 EP GTA MMSE 10−1 BER 10−2 10−3 10−4 10−5 10−6 14 15 16 17 18 19 20 SNR Dashed line → Rate-1/2 Irregular LDPC code optimized for the BIAWNG threshold using density evolution. Solid line → Rate-0.48 Convolutional LDPC code with L = 50 positions. System performance is limited by the detector accuracy. Room for improvement. More involved EP approximating families can lead to further gains. Extension to partial CSI scenarios, where we only know the statistics of H. E.g., hi,j ∼ N (ĥi,j , σh ) with known ĥi,j , σh . Questions?