Gaussian Distributions

tools to work with Gaussian distributions
from fastcore.test import *
import altair as alt

Normal Parameters

Normal

import torch

source

ListNormal.detach

 ListNormal.detach ()

Detach both mean and cov at once

ln = ListNormal(torch.rand(10), torch.rand(10))
ln[5]
Normal(mean=tensor(0.2285), std=tensor(0.9300))

Multivariate Normal


ListMultiNormal.detach

 ListMultiNormal.detach ()

Detach both mean and cov at once

ListMNormal(torch.rand(2,10), torch.rand(2,10,10))[1]
MultiNormal(mean=tensor([0.0076, 0.7078, 0.1849, 0.8875, 0.0598, 0.5685, 0.9130, 0.4167, 0.7761,
        0.8239]), cov=tensor([[0.5052, 0.1173, 0.9735, 0.2103, 0.6431, 0.2104, 0.4656, 0.0827, 0.1011,
         0.5391],
        [0.3806, 0.4621, 0.5505, 0.0553, 0.1300, 0.3820, 0.2128, 0.1168, 0.1066,
         0.3580],
        [0.3908, 0.8341, 0.5620, 0.8415, 0.5727, 0.3466, 0.1083, 0.6935, 0.1600,
         0.1613],
        [0.4332, 0.6150, 0.0203, 0.4674, 0.6139, 0.7792, 0.4534, 0.4927, 0.2141,
         0.2313],
        [0.5573, 0.6323, 0.8267, 0.7420, 0.7631, 0.2128, 0.1779, 0.7642, 0.0040,
         0.8567],
        [0.2587, 0.3077, 0.2753, 0.3718, 0.1282, 0.9497, 0.7606, 0.3450, 0.7562,
         0.8713],
        [0.4206, 0.1561, 0.3089, 0.5148, 0.9883, 0.3910, 0.0613, 0.1449, 0.1414,
         0.2631],
        [0.9216, 0.2119, 0.9350, 0.4848, 0.3025, 0.9027, 0.3520, 0.2612, 0.6861,
         0.0352],
        [0.3255, 0.7473, 0.8152, 0.7568, 0.3317, 0.2959, 0.8844, 0.4175, 0.2749,
         0.8595],
        [0.8556, 0.1053, 0.9823, 0.1468, 0.9170, 0.9112, 0.8187, 0.6592, 0.1459,
         0.9400]]))

Positive Definite

The covariance matrices need to be positive definite Those are utilities functions to check is a matrix is positive definite and to make any matrix positive definite

Other libraries

Most libraries that implement Kalman Filters use manually specified parameters, which often don’t have the issue of the positive definite constraint (eg. pykalman)

From statsmodels statespace models: >Cholesky decomposition […] requires that the matrix be positive definite. While this should generally be true, it may not be in every case. source

which seems to mean that they take into account the fact that during the filter calculations may not be positive definite

A = torch.rand(2,3,3) # batched random matrix used for testing

Symmetry


source

is_symmetric

 is_symmetric (value, atol=1e-05)
is_symmetric(A)
tensor([False, False])

source

symmetric_upto_batched

 symmetric_upto_batched (value, start=-8)

source

symmetric_upto

 symmetric_upto (value, start=-8)
symmetric_upto_batched(A)
tensor([0, 0])

is posdef

Default pytorch check (uses symmetry + cholesky decomposition)


source

is_posdef

 is_posdef (cov)
is_posdef(A)
tensor([False, False])

check if it is pos definite using eigenvalues. Positive definite matrix have all positive eigenvalues

torch.linalg.eigvalsh(A)
tensor([[0.1633, 0.4556, 1.0072],
        [0.0982, 0.3418, 1.7830]])

source

is_posdef_eigv

 is_posdef_eigv (cov)
is_posdef_eigv(A)
(tensor([True, True]),
 tensor([[0.1633, 0.4556, 1.0072],
         [0.0982, 0.3418, 1.7830]]))

Note that is_posdef and is_posdef_eigv can return different values, in general is_posdef_eigv is more tollerant

Pytorch constraint

transform any matrix \(A\) into a positive definite matrix (\(PD\)) using the following formula

\(PD = AA^T + aI\)

where \(AA^T\) is a positive semi-definite matrix and \(a\) is a small positive number that is added on the diagonal to ensure that the resulting matrix is positive definite (not semi-definite)

the inverse transformation uses cholesky decomposition

Another approach would be to multiple to lower triangular matrix, but they’d require a positive diagonal, which is harderd to obtain see https://en.wikipedia.org/wiki/Definite_matrix#Cholesky_decomposition

The API inspired by gpytorch constraints

from meteo_imp.utils import *

source

inv_softplus

 inv_softplus (x)

source

batch_diag_embed

 batch_diag_embed (x)

source

batch_diag_scatter

 batch_diag_scatter (input, src)

source

batch_diagonal

 batch_diagonal (x)

source

PosDef

 PosDef (min_diag:float=1e-05)

Positive Definite Constraint for PyTorch parameters

Type Default Details
min_diag float 1e-05 min value for diagonal to ensure num stability
constraint = PosDef()

posdef = constraint.transform(A)
A = torch.randn(2, 3,3)
triang = constraint.transform_triangular(A)
p_diag = constraint.transform_pos_diag(triang)
cho_fact = constraint.transform_cho_factor(A)
posdef = constraint.transform(A)
show_as_row(A, triang, p_diag, cho_fact, posdef)

A

tensor([[[-0.0246,  0.4967,  0.1831],
         [-1.5187,  0.5129, -0.9096],
         [ 1.9180,  0.8267, -0.9243]],

        [[-0.1214, -0.5275,  0.6791],
         [ 0.0395,  1.2408,  1.2185],
         [-0.9288, -0.0625, -0.1523]]])

triang

tensor([[[-0.0246,  0.0000,  0.0000],
         [-1.5187,  0.5129,  0.0000],
         [ 1.9180,  0.8267, -0.9243]],

        [[-0.1214,  0.0000,  0.0000],
         [ 0.0395,  1.2408,  0.0000],
         [-0.9288, -0.0625, -0.1523]]])

p_diag

tensor([[[ 0.6809,  0.0000,  0.0000],
         [-1.5187,  0.9821,  0.0000],
         [ 1.9180,  0.8267,  0.3342]],

        [[ 0.6343,  0.0000,  0.0000],
         [ 0.0395,  1.4948,  0.0000],
         [-0.9288, -0.0625,  0.6199]]])

cho_fact

tensor([[[ 0.6809,  0.0000,  0.0000],
         [-1.5187,  0.9821,  0.0000],
         [ 1.9180,  0.8267,  0.3342]],

        [[ 0.6343,  0.0000,  0.0000],
         [ 0.0395,  1.4948,  0.0000],
         [-0.9288, -0.0625,  0.6199]]])

posdef

tensor([[[ 0.4637, -1.0342,  1.3060],
         [-1.0342,  3.2711, -2.1010],
         [ 1.3060, -2.1010,  4.4736]],

        [[ 0.4024,  0.0250, -0.5892],
         [ 0.0250,  2.2361, -0.1301],
         [-0.5892, -0.1301,  1.2509]]])
show_as_row(is_posdef(torch.stack([posdef,A])), is_posdef_eigv(torch.stack([posdef,A])), is_symmetric(torch.stack([posdef,A])))

#0

tensor([[ True,  True],
        [False, False]])

#1

(tensor([[ True,  True],
        [False, False]]),
 tensor([[[ 4.5415e-03,  1.6879e+00,  6.5159e+00],
         [ 9.9948e-02,  1.5298e+00,  2.2596e+00]],

        [[-3.1540e+00,  7.2122e-01,  1.9967e+00],
         [-1.0659e+00,  7.8085e-01,  1.2522e+00]]]))

#2

tensor([[ True,  True],
        [False, False]])
test_eq(is_posdef(posdef).all(), True)
test_close(posdef, constraint.transform(constraint.inverse_transform(posdef)))
symmetric_upto(posdef[0])
-8
is_posdef_eigv(to_posdef(torch.rand(1000, 1000)))[0]
tensor(False)

Fuzzer

run_fuzzer = True # temporly disable for performance reasons
def random_posdef(bs=10,n=100,n_range=(0,1), **kwargs):
    A = torch.rand(bs,n,n, **kwargs)  * (n_range[1]-n_range[0]) + n_range[0]
    return PosDef().transform(A)
# fuzzer
def fuzz_posdef(bs=10,n=100,n_range=(0,1), **kwargs):
    posdef = random_posdef(bs, n, **kwargs)
    return pd.DataFrame(
        {'n': [n], 'range': str(n_range), 'n_samples': bs,
         'posdef': is_posdef(posdef).sum().item() / bs,
         'sym': is_symmetric(posdef).sum().item() / bs, 
         'posdef_eigv': is_posdef_eigv(posdef)[0].sum().item() / bs
    })
fuzz_posdef()
n range n_samples posdef sym posdef_eigv
0 100 (0, 1) 10 0.9 1.0 0.6
n_min, n_max = -1, 1
A = torch.rand(2,100,100)  * (n_max-n_min) + n_min
is_posdef(to_posdef(A))
tensor([False, False])
ma = torch.tensor([[1., 7],
                   [-3, 4]])
is_posdef(to_posdef(ma))
tensor(True)
fuzz_posdef(device='cuda')
n range n_samples posdef sym posdef_eigv
0 100 (0, 1) 10 1.0 1.0 1.0
# %time fuzz_posdef(bs=100, device='cuda')
rate_posdef = pd.concat([fuzz_posdef(n=n, bs=100, n_range=n_range, device='cuda') 
               for n in [10, 100]
               for n_range in [(-1,1),(0,1)]])
import altair as alt
from altair import datum
rate_posdef.head()
n range n_samples posdef sym posdef_eigv
0 10 (-1, 1) 100 1.00 1.0 1.0
0 10 (0, 1) 100 1.00 1.0 1.0
0 100 (-1, 1) 100 0.98 1.0 1.0
0 100 (0, 1) 100 0.97 1.0 1.0
def _plot_var(df, var, x='n:N', row='range', y_domain=(0,1), height=70, width=50):
    bar = alt.Chart(df).mark_bar().encode(
        x = alt.X('n:N'),
        y = alt.Y(var, scale=alt.Scale(domain=y_domain)),
        color = 'n:N',
    ).properties(height=height, width=width, ) 
    
    text = alt.Chart(df).mark_text(dy=10, color='white').encode(
        x = alt.X('n:N'),
        y = alt.Y(var),
        text = alt.Text(var, format=".2f")
    )
    
    return (bar + text).facet(
        row=row).properties(title=var, )
def _plot_var_box(df, var, x='n:N', row='range', column='noise:N', height=70, width=50, title=''):
    box = alt.Chart(df).mark_boxplot().encode(
        x = alt.X(x),
        y = alt.Y(var),
        color = x,
    ).properties(height=height, width=width) 

    # text = alt.Chart(df).mark_text(dy=10, color='white').encode(
    #     x = alt.X('n:N'),
    #     y = alt.Y(var),
    #     text = alt.Text(var, format=".2f")
    # )
    
    return (box).facet(
        column=column,
        row=row).properties(title=title)
from IPython import display
import vl_convert as vlc
from functools import partial

Generation of Random positive definite matrices

def plot_posdef_simulation(n_s, range_s, bs=100, **kwargs):
    if not run_fuzzer: return
    rate_posdef = pd.concat([fuzz_posdef(n=n, bs=bs, n_range=range, device='cuda', **kwargs) 
               for n in n_s for range in range_s])
    
    print(rate_posdef)
    vl_spec = alt.hconcat(*[_plot_var(rate_posdef, var) for var in ['posdef', 'posdef_eigv']]).to_json()
    # workaround for bug in vegalite see https://github.com/altair-viz/altair/issues/2742
    svg = vlc.vegalite_to_svg(vl_spec, vl_version='v5.3')
    display.display(display.HTML(svg))
plot_posdef_simulation(n_s = [10, 100], range_s = [(-1, 1)], bs=1000)
     n    range  n_samples  posdef  sym  posdef_eigv
0   10  (-1, 1)       1000   1.000  1.0        1.000
0  100  (-1, 1)       1000   0.972  1.0        0.998
range0.00.51.0posdef(-1, 1)10100n1.000.97posdefrange0.00.51.0posdef_eigv(-1, 1)10100n1.001.00posdef_eigv10100n

Let’s go big by using a matrix 1000x1000

plot_posdef_simulation(n_s = [1000], range_s = [(10, 20)], bs=100)
      n     range  n_samples  posdef  sym  posdef_eigv
0  1000  (10, 20)        100     0.0  1.0          0.0
range0.00.51.0posdef(10, 20)1000n0.00posdefrange0.00.51.0posdef_eigv(10, 20)1000n0.00posdef_eigv1000n

for a standard noise on the diagonal less than half of the random matrices that are 1000 in size are positive definite.

Let’s have a look at one of such matrices

posdef = random_posdef(100, 1000)
not_pd = posdef[torch.argwhere(~is_posdef_eigv(posdef)[0])[0]]

This should be positive definite but actually it’s not …

not_pd
tensor([[[4.8334e-01, 5.3647e-01, 5.7627e-01,  ..., 4.4425e-01,
          3.0689e-01, 2.0315e-01],
         [5.3647e-01, 2.2643e+00, 1.1408e+00,  ..., 1.1138e+00,
          4.2245e-01, 3.3727e-01],
         [5.7627e-01, 1.1408e+00, 1.8559e+00,  ..., 1.3060e+00,
          1.3603e+00, 7.1770e-01],
         ...,
         [4.4425e-01, 1.1138e+00, 1.3060e+00,  ..., 3.4489e+02,
          2.5251e+02, 2.4902e+02],
         [3.0689e-01, 4.2245e-01, 1.3603e+00,  ..., 2.5251e+02,
          3.3394e+02, 2.4621e+02],
         [2.0315e-01, 3.3727e-01, 7.1770e-01,  ..., 2.4902e+02,
          2.4621e+02, 3.2393e+02]]])

trying with float64 (for memory constraint on the GPU only using a 700x700 matrix)

plot_posdef_simulation(n_s = [700], range_s = [(-.1, 1)], bs=100)
     n      range  n_samples  posdef  sym  posdef_eigv
0  700  (-0.1, 1)        100     0.0  1.0          0.0
range0.00.51.0posdef(-0.1, 1)700n0.00posdefrange0.00.51.0posdef_eigv(-0.1, 1)700n0.00posdef_eigv700n
plot_posdef_simulation(n_s = [700], range_s = [(-.1, 1)], bs=100, dtype=torch.float64)
     n      range  n_samples  posdef  sym  posdef_eigv
0  700  (-0.1, 1)        100     0.0  1.0          0.0
range0.00.51.0posdef(-0.1, 1)700n0.00posdefrange0.00.51.0posdef_eigv(-0.1, 1)700n0.00posdef_eigv700n

All matrices now are positive definite

Multiplication

check is multiplication of matrices is not breaking the positive definite constraint

If \(A\) and \(B\) are both positive definite matrices \(ABA\) is also positive definite https://en.wikipedia.org/wiki/Definite_matrix#Multiplication

def fuzz_op(op, # operation that takes 2 pos def matrices and return one pos def matrix
            fn_check = is_posdef,
                  n=100, # size of matrix
                  max_t=1000, # number of multiplications
                  noise=1e-5, # noise to add on diagonal
                  bs=10, # batch size
                  n_range=(0,1), # range of random numbers
                  **kwargs):
    pd1 = random_posdef(bs, n, noise, n_range, **kwargs)
    pd2 = random_posdef(bs, n, noise, n_range,**kwargs)
    stop_times = torch.zeros(bs, **kwargs)
    
    for t in torch.arange(max_t):
        pd1 = op(pd1, pd2)
        check = fn_check(pd1)
        stop_times[torch.logical_and(stop_times == 0, ~check)] = t
        if not check.any(): break
         
    stop_times[stop_times == 0] = t
    return pd.DataFrame(
        {'n': [n], 'noise': f"{noise:.0e}", 'range': str(n_range), 'n_samples': bs, 'last_t': t.item(),
         'mean_stop': stop_times.mean().item(),
         'std_stop': stop_times.std().item(),
         'stop_times': [stop_times.cpu().numpy()]})
fuzz_multiply = partial(fuzz_op, lambda pd1, pd2: pd2 @ pd1 @ pd2)
fuzz_multiply_eigv = partial(fuzz_multiply, fn_check = lambda pd1: is_posdef_eigv(pd1)[0])
def plot_multiply_simulation(n_s, noise_s, max_mult=1000, bs=100, **kwargs):
    mult = pd.concat([fuzz_multiply(n=n, noise=noise, bs=bs, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s]).explode('stop_times')
    
    mult_eigv = pd.concat([fuzz_multiply(n=n, noise=noise, bs=bs, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s]).explode('stop_times')
    
    vl_spec = alt.hconcat(*[_plot_var_box(df, 'stop_times') for df in [mult, mult_eigv]]).to_json()
    # workaround for bug in vegalite see https://github.com/altair-viz/altair/issues/2742
    svg = vlc.vegalite_to_svg(vl_spec, vl_version='v5.3')
    display.display(display.HTML(svg))
    return (mult, mult_eigv)
plot_multiply_simulation(n_s=[2,3,10, 100], noise_s=[1e-3, 1e-4, 1e-5], bs=100);
TypeError: random_posdef() takes from 0 to 3 positional arguments but 4 were given

Addition

check is multiplication of matrices is not breaking the positive definite constraint

If \(A\) and \(B\) are both positive definite matrices \(A+B\) is also positive definite https://en.wikipedia.org/wiki/Definite_matrix#Addition

pd1 = random_posdef(10, 100)
pd2 = random_posdef(10, 100)
is_posdef(pd1 + pd2).all()
fuzz_add = partial(fuzz_op, lambda pd1, pd2: pd1 + pd2)
def plot_add_simulation(n_s, noise_s, max_ts=[1000], bs=100, **kwargs):
    add = pd.concat([fuzz_add(n=n, noise=noise, bs=bs, max_t=max_t, device='cuda', **kwargs) 
               for n in n_s for noise in noise_s for max_t in max_ts]).explode('stop_times')
    
    vl_spec = _plot_var_box(add, var='stop_times', height=150, width=150).to_json()

    svg = vlc.vegalite_to_svg(vl_spec, vl_version='v5.3')
    display.display(display.HTML(svg))
cache_disk("add_plot")(lambda: plot_add_simulation(n_s=[50, 100, 150], noise_s=[1e-3, 1e-4, 1e-5], bs=100, max_ts=[1e5]))()

Numpy posdef

import numpy as np
arr = np.random.rand(2,3,3)
arr.shape
arr.transpose(0,2,1) == np.moveaxis(arr, -1, -2)
def to_posdef_np(x, noise=1e-5):
    return x @ np.moveaxis(x, -1, -2) + (noise * np.eye(x.shape[-1], dtype=arr.dtype))
to_posdef_np(arr)
# fuzzer
def fuzz_posdef_np(n=100, noise=1e-5, bs=10, range=(0,1), dtype=np.float32):
    A = np.random.rand(bs,n,n).astype(dtype)  * (range[1]-range[0]) + range[0]
    posdef = torch.from_numpy(to_posdef_np(A, noise))
    return pd.DataFrame(
        {'n': [n], 'noise': f"{noise:.0e}", 'range': str(range), 'n_samples': bs,
         'posdef': is_posdef(posdef).sum().item() / bs,
         'sym': is_symmetric(posdef).sum().item() / bs, 
         'posdef_eigv': is_posdef_eigv(posdef)[0].sum().item() / bs
    })
fuzz_posdef_np(n=1000, dtype=np.float32)

Checker Positive Definite

This is to help finding matrices that aren’t positive definite and debug the issues. Returns a detailed dataframe row with info about the matrix and optionally logs everything to a global object


source

CheckPosDef

 CheckPosDef (do_check:bool=False, use_log:bool=True, warning:bool=True)

Initialize self. See help(type(self)) for accurate signature.

Type Default Details
do_check bool False set to True to actually check matrix
use_log bool True keep internal log
warning bool True show a warning if a matrix is not pos def
CheckPosDef(True).check(A)
CheckPosDef(True).check(A[0])
checker = CheckPosDef(True)

checker.check(A, my_arg="my arg") # this will be another col in the log
checker.log
checker.add_args(show="only once")
checker.check(posdef)
checker.check(A)
checker.log
B = torch.rand(2,3,3) # a batch of matrices
is_symmetric(B).shape
checker.check(B)
test_close(B[0] @ A, (B @ A)[0]) # example batched matrix multiplication

Diagonal Positive Definite Contraint

this is a simpler contraint that make the matrix diagonal and positive definite, by forcing it to have positive numbers on the diagonal.

given a vector matrix \(a\) it is transformed into a diagonal positive definite matrix using:

\(A_{diag\ pos\ def} = a^2 I\)

the inverse transformation is the square root of the diagonal

from meteo_imp.utils import *

source

to_diagposdef

 to_diagposdef (x)

source

DiagPosDef

 DiagPosDef ()

Diagonal Positive Definite Constraint for PyTorch parameters

DiagPosDef().transform(torch.rand(3))
DiagPosDef().transform(torch.rand(2, 3))
v = -1.2 * torch.ones(2,3)
DiagPosDef().inverse_transform(DiagPosDef().transform(v))
to_diagposdef(torch.rand(3,3))
dpd_const = DiagPosDef()
a = torch.rand(3)
dpd_const.transform(a)
test_close(a, dpd_const.inverse_transform(dpd_const.transform(a)))

Conditional Predictions

Therefore we need to compute the conditional distribution of a normal 1

\[ X = \left[\begin{array}{c} x \\ o \end{array} \right] \]

\[ p(X) = N\left(\left[ \begin{array}{c} \mu_x \\ \mu_o \end{array} \right], \left[\begin{array}{cc} \Sigma_{xx} & \Sigma_{xo} \\ \Sigma_{ox} & \Sigma_{oo} \end{array} \right]\right)\]

where \(x\) is a vector of variable that need to predicted and \(o\) is a vector of the variables that have been observed

then the conditional distribution is:

\[p(x|o) = N(\mu_x + \Sigma_{xo}\Sigma_{oo}^{-1}(o - \mu_o), \Sigma_{xx} - \Sigma_{xo}\Sigma_{oo}^{-1}\Sigma_{ox})\]


source

conditional_guassian

 conditional_guassian (μ:torch.Tensor, Σ:torch.Tensor, obs:torch.Tensor,
                       mask:torch.Tensor)
Type Details
μ Tensor mean with shape [n_vars]
Σ Tensor cov with shape [n_vars, n_vars]
obs Tensor Observations with shape [n_obs], where n_obs = sum(idx)
mask Tensor Boolean tensor specifying for each variable is observed (True) or not (False). Shape [n_vars]
Returns ListMultiNormal Distribution conditioned on observations. shape [n_vars - n_obs]
# example distribution with only 2 variables
μ = torch.tensor([.5, 1.])
Σ = torch.tensor([[1., .5], [.5 ,1.]])


mask = torch.tensor([True, False]) # second variable is the observed one

obs = torch.tensor([5.]) # value of second variable

gauss_cond = conditional_guassian(μ, Σ, obs, mask)

# hardcoded values to test that the code is working, see also for alternative implementation https://python.quantecon.org/multivariate_normal.html
test_close(3.25, gauss_cond.mean.item())
test_close(.75, gauss_cond.cov.item())

Batches

cannot have proper batch support, or at least not in a straigthforward way as the shape of the output would be different for the different batches.

so using a for-loop to temporarly fix the situation


source

cond_gaussian_batched

 cond_gaussian_batched (dist:__main__.ListMultiNormal, obs, mask)
Type Details
dist ListMultiNormal
obs this needs to have the same shape of the mask !!!
mask
Returns List lists of distributions for element in the batch
reset_seed(10)
mean = torch.rand(2,3) # batch
cov = to_posdef(torch.rand(2,3,3))
mask = torch.rand(2,3) > .3
obs = torch.rand(2,3)
conditional_gaussian_batched(mean, cov, obs, mask)
mask.shape, obs.shape
assert mean.shape == mask.shape
assert mean.dim() == 2
obs.shape
mean_x = mean[~mask]
mean_o = mean[mask]
mask
mean_x
cov.shape
cov[~mask]
cov
cov[0][~mask[0], ~mask[0]]
cov[0][mask[0],:][:, mask[0]].shape

Performance

analysis of the performance of inverting a positive definite matrix

Use cholesky decomposition and cholesky_solve to improve performance of matrix inversion

see the Probabilist machine learning course from uni Tübigen, specifically the code from the Gaussian Regression Notebook for details

This is the direct implementation of the equations

def _conditional_guassian_base(
                         μ: Tensor, # mean with shape `[n_vars]`
                         Σ: Tensor, # cov with shape `[n_vars, n_vars] `
                         obs: Tensor, # Observations with shape `[n_vars]`
                         idx: Tensor # Boolean tensor specifying for each variable is observed (True) or not (False). Shape `[n_vars]`
                        ) -> ListNormal: # Distribution conditioned on observations
    μ_x = μ[~idx]
    μ_o = μ[idx]
    
    Σ_xx = Σ[~idx,:][:, ~idx]
    Σ_xo = Σ[~idx,:][:, idx]
    Σ_ox = Σ[idx,:][:, ~idx]
    Σ_oo = Σ[idx,:][:, idx]
    
    Σ_oo_inv = torch.linalg.inv(Σ_oo)
    
    mean = μ_x + Σ_xo@Σ_oo_inv@(obs - μ_o)
    cov = Σ_xx - Σ_xo@Σ_oo_inv@Σ_ox
    
    return ListNormal(mean, cov)

faster version

n_var = 5
mean = torch.rand(n_var, dtype=torch.float64)
cov = to_posdef(torch.rand(n_var, n_var, dtype=torch.float64))
dist = MultivariateNormal(mean, cov)
idx = torch.rand(n_var, dtype=torch.float64) > .5
obs = torch.rand(n_var, dtype=torch.float64)[idx]
torch.linalg.inv(cov)
(torch.linalg.inv(cov) - cholesky_inverse(torch.linalg.cholesky(cov))).max()
test_close(torch.linalg.inv(cov), cholesky_inverse(torch.linalg.cholesky(cov)), eps=1e-2)
reset_seed()
A = to_posdef(torch.rand(1000, 1000, dtype=torch.float64)) + torch.eye(1000) * 1e-3 # noise to ensure is positive definite
is_symmetric(A)
is_posdef(A)

The second version is way faster

test_close(conditional_guassian(mean, cov, obs, idx).mean, _conditional_guassian_base(mean, cov, obs, idx).mean)
B = to_posdef(torch.rand(n_var, n_var, dtype=torch.float64))
B @ torch.inverse(cov)
torch.cholesky_solve(cholesky(cov), B)

Helper

cov2std

x = torch.stack([torch.eye(3)*i for i in  range(1,4)])
x
torch.diagonal(x, dim1=1, dim2=2)

source

cov2std

 cov2std (x)

convert cov of array of covariances to array of stddev

Export