%load_ext autoreload
%autoreload 2
Plotting for results
This notebook produces all results plots. It generates some gap in the data, fill with a method (filter, MDS …), compute metrics and then makes all relevant plots
import altair as alt
from meteo_imp.kalman.results import *
from meteo_imp.data import *
from meteo_imp.utils import *
import pandas as pd
import numpy as np
from pyprojroot import here
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import SVG, Image
from meteo_imp.kalman.results import _plot_timeseries, _get_labels
from functools import partial
from contextlib import redirect_stderr
import io
import polars as pl
from fastai.vision.data import get_grid
import cairosvg
the generation of a proper pdf is complex as altair_render
doesn’t support XOffset
, so plotsare first renderder to svg using vl-convert
and then to pdf using cairosvg
. However this last methods doesn’t support negative numbers …
Due to the high number of samples also cannot use the browser render in the notebook so using vl-convert
to a png for the visualization in the notebook
import vl_convert as vlc
from pyprojroot import here
= here("manuscript/Master Thesis - Evaluation of Kalman filter for meteorological time series imputation for Eddy Covariance applications - Simone Massaro/images/")
base_path_img = here("manuscript/Master Thesis - Evaluation of Kalman filter for meteorological time series imputation for Eddy Covariance applications - Simone Massaro/tables/")
base_path_tbl
=True), base_path_tbl.mkdir(exist_ok=True)
base_path_img.mkdir(exist_ok
def save_show_plot(plot,
path,=False # use altair render for pdf?
altair_render
):= plot.to_json()
plt_json if not altair_render:
= vlc.vegalite_to_svg(vl_spec=plt_json)
svg_data with open(base_path_img / (path + ".svg"), 'w') as f:
f.write(svg_data)
=open(base_path_img / (path + ".svg")), write_to=str(base_path_img / (path + ".pdf")))
cairosvg.svg2pdf(file_objelse:
#save svg version anyway
= vlc.vegalite_to_svg(vl_spec=plt_json)
svg_data with open(base_path_img / (path + ".svg"), 'w') as f:
f.write(svg_data)#convert to pdf using altair
with redirect_stderr(io.StringIO()):
/ (path + ".pdf"))
plot.save(base_path_img
# render to image for displaying in notebook
= vlc.vegalite_to_png(vl_spec=plot.to_json(), scale=1)
png_data return Image(png_data)
reset_seed()= 500 n_rep
= pd.read_parquet(hai_big_path).reindex(columns=var_type.categories)
hai = pd.read_parquet(hai_era_big_path) hai_era
# it is safe to do so as the plots are rendered using vl-convert and then showed as images alt.data_transformers.disable_max_rows()
DataTransformerRegistry.enable('default')
Correlation
import matplotlib.pyplot as plt
import statsmodels.api as sm
def auto_corr_df(data, nlags=96):
= {}
autocorr for col in data.columns:
= sm.tsa.acf(data[col], nlags=nlags)
autocorr[col] return pd.DataFrame(autocorr)
= auto_corr_df(hai).reset_index(names="gap_len").melt(id_vars="gap_len")
auto_corr = auto_corr.gap_len / 2 auto_corr.gap_len
auto_corr
gap_len | variable | value | |
---|---|---|---|
0 | 0.0 | TA | 1.000000 |
1 | 0.5 | TA | 0.998595 |
2 | 1.0 | TA | 0.995814 |
3 | 1.5 | TA | 0.992141 |
4 | 2.0 | TA | 0.987630 |
... | ... | ... | ... |
868 | 46.0 | TS | 0.959680 |
869 | 46.5 | TS | 0.961116 |
870 | 47.0 | TS | 0.962085 |
871 | 47.5 | TS | 0.962551 |
872 | 48.0 | TS | 0.962480 |
873 rows × 3 columns
= (alt.Chart(auto_corr).mark_line().encode(
p = alt.X('gap_len', title="Gap length [h]", axis = alt.Axis(values= [12, 24, 36, 48])),
x = alt.Y("value", title="correlation"),
y =alt.Color("variable", scale=meteo_scale, title="Variable"),
color=alt.Facet('variable', columns=3, sort = meteo_scale.domain, title=None,
facet = alt.Header(labelFontWeight="bold", labelFontSize=20))
header
)=120, width=250)
.properties(height='independent', x = 'independent')
.resolve_scale(y
.pipe(plot_formatter))
"temporal_autocorrelation") save_show_plot(p,
= get_grid(1,1,1, figsize=(10,8))
axes set(font_scale=1.25)
sns.=True, vmin=-1, vmax=1, center=0,
sns.heatmap(hai.corr(), annot=sns.diverging_palette(20, 220, n=200), ax=axes[0], square=True, cbar=True)
cmap# axes[0].set(xlabel="Variable", ylabel="Variable", title="Inter-variable Correlation");
# size_old = plt.rcParams["axes.labelsize"]
# w_old = plt.rcParams["axes.labelweight"]
# plt.rcParams["axes.labelsize"] = 30
# plt.rcParams["axes.labelweight"] = 'bold'
plt.tight_layout()= 'bold')
plt.xticks(weight = 'bold')
plt.yticks(weight
with matplotlib.rc_context({"axes.labelsize": 30}):
/ "correlation.pdf")
plt.savefig(base_path_img
plt.show()
# plt.rcParams["axes.labelsize"] = size_old
# plt.rcParams["axes.labelweight"] = w_old
Comparison Imputation methods
= here("analysis/results/trained_models") base_path
def l_model(x, base_path=base_path): return torch.load(base_path / x)
= pd.DataFrame.from_records([
models_var 'var': 'TA', 'model': l_model("TA_specialized_gap_6-336_v3_0.pickle",base_path)},
{'var': 'SW_IN', 'model': l_model("SW_IN_specialized_gap_6-336_v2_0.pickle",base_path)},
{'var': 'LW_IN', 'model': l_model("LW_IN_specialized_gap_6-336_v1.pickle",base_path)},
{'var': 'VPD', 'model': l_model("VPD_specialized_gap_6-336_v2_0.pickle",base_path)},
{'var': 'WS', 'model': l_model("WS_specialized_gap_6-336_v1.pickle",base_path)},
{'var': 'PA', 'model': l_model("PA_specialized_gap_6-336_v3_0.pickle",base_path)},
{'var': 'P', 'model': l_model("1_gap_varying_6-336_v3.pickle",base_path)},
{'var': 'TS', 'model': l_model("TS_specialized_gap_6-336_v2_0.pickle",base_path)},
{'var': 'SWC', 'model': l_model("SWC_specialized_gap_6-336_v2_1.pickle",base_path)},
{ ])
@cache_disk(cache_dir / "the_results")
def get_the_results(n_rep=20):
reset_seed()= ImpComparison(models = models_var, df = hai, control = hai_era, block_len = 446, time_series=False)
comp_Av = comp_Av.compare(gap_len = [12,24, 48, 336], var=list(hai.columns), n_rep=n_rep)
results_Av return results_Av
= get_the_results(n_rep) results_Av
State of the art
the first plot is a time series using only state-of-the-art methods
reset_seed()= ImpComparison(models = models_var, df = hai, control = hai_era, block_len = 48+100, time_series=True, rmse=False)
comp_ts = comp_ts.compare(gap_len = [48], var=list(hai.columns), n_rep=1) results_ts
= results_ts.query("method != 'Kalman Filter'")
res_ts = pd.concat([unnest_predictions(row, ctx_len=72) for _,row in res_ts.iterrows()]) res_ts_plot
= alt.Scale(domain=["ERA-I", "MDS"], range=list(sns.color_palette('Dark2', 3).as_hex())[1:]) scale_sota
= (facet_wrap(res_ts_plot, partial(_plot_timeseries, scale_color=scale_sota, err_band = False), col="var",
p = _get_labels(res_ts_plot, 'mean', None),
y_labels
)
.pipe(plot_formatter)
)"timeseries_sota", altair_render=True) save_show_plot(p,
Percentage improvement
results_Av.method.unique()
['Kalman Filter', 'ERA-I', 'MDS']
Categories (3, object): ['Kalman Filter' < 'ERA-I' < 'MDS']
= results_Av.query('var != "P"').groupby(['method']).agg({'rmse_stand': 'mean'}).T all_res
all_res
method | Kalman Filter | ERA-I | MDS |
---|---|---|---|
rmse_stand | 0.204628 | 0.307361 | 0.482837 |
percentage of improvement across all variables
"ERA-I"] - all_res["Kalman Filter"]) / all_res["ERA-I"] * 100 (all_res[
rmse_stand 33.42398
dtype: float64
"MDS"] - all_res["Kalman Filter"]) / all_res["MDS"] * 100 (all_res[
rmse_stand 57.619542
dtype: float64
= results_Av.groupby(['method', 'var']).agg({'rmse_stand': 'mean'}) res_var
= res_var.reset_index().pivot(columns='method', values='rmse_stand', index='var') res_var
'ERA': (res_var["ERA-I"] - res_var["Kalman Filter"]) / res_var["ERA-I"] * 100, 'MDS': (res_var["MDS"] - res_var["Kalman Filter"]) / res_var["MDS"] * 100 }) pd.DataFrame({
ERA | MDS | |
---|---|---|
var | ||
TA | 54.540802 | 77.713711 |
SW_IN | 12.004508 | 35.516142 |
LW_IN | 5.166063 | 52.289627 |
VPD | 44.402821 | 65.407769 |
WS | 21.064305 | 40.321732 |
PA | 28.784191 | 90.751559 |
P | -18.544370 | -22.084360 |
SWC | NaN | 41.543006 |
TS | NaN | 25.772326 |
= results_Av.groupby(['method', 'var', 'gap_len']).agg({'rmse_stand': 'mean'}) res_var2
= res_var2.reset_index().pivot(columns='method', values='rmse_stand', index=['var', 'gap_len']) res_var2
'ERA': (res_var2["ERA-I"] - res_var2["Kalman Filter"]) / res_var2["ERA-I"] * 100, 'MDS': (res_var2["MDS"] - res_var2["Kalman Filter"]) / res_var2["MDS"] * 100 }) pd.DataFrame({
ERA | MDS | ||
---|---|---|---|
var | gap_len | ||
TA | 6 | 69.897582 | 85.052698 |
12 | 58.766166 | 79.376385 | |
24 | 51.538443 | 75.395970 | |
168 | 41.823614 | 73.000401 | |
SW_IN | 6 | 9.519984 | 29.746651 |
12 | 11.165399 | 30.639223 | |
24 | 14.232051 | 34.811941 | |
168 | 12.305658 | 42.651906 | |
LW_IN | 6 | 21.023524 | 59.136518 |
12 | 9.110040 | 52.211404 | |
24 | -3.553292 | 50.720632 | |
168 | -4.260023 | 48.223005 | |
VPD | 6 | 66.980942 | 79.449579 |
12 | 47.785633 | 69.081018 | |
24 | 33.663749 | 56.728120 | |
168 | 32.272332 | 57.702579 | |
WS | 6 | 32.402977 | 45.724043 |
12 | 25.209162 | 43.275430 | |
24 | 15.543672 | 37.142502 | |
168 | 12.735569 | 36.436106 | |
PA | 6 | 39.823585 | 91.511486 |
12 | 30.995845 | 90.532461 | |
24 | 24.727301 | 89.319180 | |
168 | 20.691181 | 91.421434 | |
P | 6 | -18.485009 | -13.917879 |
12 | -28.935358 | -37.127331 | |
24 | -24.423076 | -29.998707 | |
168 | -7.725322 | -11.587796 | |
SWC | 6 | NaN | 61.302664 |
12 | NaN | 47.976950 | |
24 | NaN | 42.535719 | |
168 | NaN | 23.301469 | |
TS | 6 | NaN | 64.264901 |
12 | NaN | 46.699870 | |
24 | NaN | 27.050291 | |
168 | NaN | -15.268479 |
Main plot
from itertools import product
import altair as alt
= the_plot(results_Av)
p "the_plot") save_show_plot(p,
= the_plot_stand(results_Av)
p "the_plot_stand") save_show_plot(p,
Table
= the_table(results_Av)
t / "the_table.tex", label="tbl:the_table",
the_table_latex(t, base_path_tbl ="\\CapTheTable")
caption t
Kalman Filter | ERA-I | MDS | |||||
---|---|---|---|---|---|---|---|
RMSE | mean | std | mean | std | mean | std | |
Variable | Gap | ||||||
TA | 6 h | 0.405453 | 0.258301 | 1.346910 | 0.997843 | 2.712546 | 1.896914 |
12 h | 0.606836 | 0.400849 | 1.471695 | 0.900611 | 2.942435 | 1.748131 | |
1 day (24 h) | 0.741275 | 0.368468 | 1.529614 | 0.800256 | 3.012819 | 1.611311 | |
1 week (168 h) | 1.020608 | 0.444591 | 1.754334 | 0.643160 | 3.780087 | 1.315472 | |
SW_IN | 6 h | 44.636609 | 40.464629 | 49.333113 | 66.241975 | 63.536627 | 85.401585 |
12 h | 48.155186 | 33.868178 | 54.207691 | 49.769296 | 69.427115 | 68.936352 | |
1 day (24 h) | 56.564277 | 30.042752 | 65.950367 | 40.930505 | 86.770917 | 59.603564 | |
1 week (168 h) | 61.582820 | 25.740161 | 70.224393 | 34.883199 | 107.384249 | 53.606111 | |
LW_IN | 6 h | 10.902409 | 7.736087 | 13.804628 | 12.987987 | 26.680077 | 15.022366 |
12 h | 13.421656 | 7.734502 | 14.766929 | 12.584725 | 28.085478 | 13.457335 | |
1 day (24 h) | 14.593819 | 7.840046 | 14.093052 | 12.227900 | 29.614461 | 12.416763 | |
1 week (168 h) | 17.062880 | 6.425136 | 16.365697 | 11.129569 | 32.954558 | 8.833972 | |
VPD | 6 h | 0.428187 | 0.363168 | 1.296787 | 1.547397 | 2.083592 | 2.149288 |
12 h | 0.660623 | 0.504761 | 1.265213 | 1.288794 | 2.136626 | 2.095549 | |
1 day (24 h) | 0.827563 | 0.501975 | 1.247527 | 1.032319 | 1.912472 | 1.605013 | |
1 week (168 h) | 1.125680 | 0.633392 | 1.662069 | 1.127314 | 2.661345 | 1.965431 | |
WS | 6 h | 0.616774 | 0.316972 | 0.912428 | 0.508295 | 1.136367 | 0.783146 |
12 h | 0.715412 | 0.350974 | 0.956550 | 0.524247 | 1.261203 | 0.796744 | |
1 day (24 h) | 0.801851 | 0.343378 | 0.949427 | 0.446912 | 1.275665 | 0.608630 | |
1 week (168 h) | 0.950211 | 0.363124 | 1.088887 | 0.348541 | 1.494891 | 0.615371 | |
PA | 6 h | 0.045046 | 0.034294 | 0.074856 | 0.061726 | 0.530665 | 0.441476 |
12 h | 0.053359 | 0.041613 | 0.077328 | 0.058476 | 0.563603 | 0.427426 | |
1 day (24 h) | 0.059481 | 0.038666 | 0.079021 | 0.051491 | 0.556899 | 0.404451 | |
1 week (168 h) | 0.066325 | 0.047544 | 0.083628 | 0.053654 | 0.773143 | 0.384029 | |
P | 6 h | 0.134093 | 0.274033 | 0.113173 | 0.315504 | 0.117710 | 0.305539 |
12 h | 0.178871 | 0.295419 | 0.138729 | 0.297227 | 0.130442 | 0.281377 | |
1 day (24 h) | 0.206231 | 0.253588 | 0.165750 | 0.288432 | 0.158641 | 0.265257 | |
1 week (168 h) | 0.239885 | 0.173820 | 0.222682 | 0.201782 | 0.214975 | 0.197499 | |
SWC | 6 h | 0.508379 | 0.487342 | NaN | NaN | 1.313730 | 1.556829 |
12 h | 0.664855 | 0.471849 | NaN | NaN | 1.278001 | 1.323011 | |
1 day (24 h) | 0.779066 | 0.640996 | NaN | NaN | 1.355740 | 1.472185 | |
1 week (168 h) | 1.493784 | 0.947799 | NaN | NaN | 1.947605 | 1.488284 | |
TS | 6 h | 0.341080 | 0.431992 | NaN | NaN | 0.954469 | 0.889126 |
12 h | 0.534363 | 0.783787 | NaN | NaN | 1.002555 | 0.876784 | |
1 day (24 h) | 0.786670 | 0.851931 | NaN | NaN | 1.078373 | 0.856964 | |
1 week (168 h) | 1.659875 | 1.077782 | NaN | NaN | 1.440008 | 0.764040 |
= the_table(results_Av, 'rmse_stand', y_name="Stand. RMSE")
t / "the_table_stand.tex", stand = True, label="tbl:the_table_stand",
the_table_latex(t, base_path_tbl = "\\CapTheTableStand")
caption t
Kalman Filter | ERA-I | MDS | |||||
---|---|---|---|---|---|---|---|
Stand. RMSE | mean | std | mean | std | mean | std | |
Variable | Gap | ||||||
TA | 6 h | 0.051164 | 0.032595 | 0.169965 | 0.125917 | 0.342294 | 0.239370 |
12 h | 0.076576 | 0.050583 | 0.185712 | 0.113647 | 0.371303 | 0.220595 | |
1 day (24 h) | 0.093541 | 0.046497 | 0.193021 | 0.100984 | 0.380185 | 0.203330 | |
1 week (168 h) | 0.128790 | 0.056103 | 0.221378 | 0.081160 | 0.477006 | 0.165998 | |
SW_IN | 6 h | 0.218804 | 0.198354 | 0.241826 | 0.324711 | 0.311450 | 0.418630 |
12 h | 0.236052 | 0.166018 | 0.265721 | 0.243964 | 0.340325 | 0.337919 | |
1 day (24 h) | 0.277272 | 0.147267 | 0.323282 | 0.200637 | 0.425342 | 0.292171 | |
1 week (168 h) | 0.301873 | 0.126176 | 0.344233 | 0.170994 | 0.526387 | 0.262772 | |
LW_IN | 6 h | 0.259855 | 0.184387 | 0.329028 | 0.309564 | 0.635910 | 0.358053 |
12 h | 0.319900 | 0.184349 | 0.351964 | 0.299952 | 0.669407 | 0.320751 | |
1 day (24 h) | 0.347838 | 0.186865 | 0.335903 | 0.291448 | 0.705850 | 0.295949 | |
1 week (168 h) | 0.406688 | 0.153141 | 0.390071 | 0.265269 | 0.785460 | 0.210555 | |
VPD | 6 h | 0.098019 | 0.083135 | 0.296855 | 0.354224 | 0.476967 | 0.492006 |
12 h | 0.151227 | 0.115548 | 0.289627 | 0.295025 | 0.489108 | 0.479704 | |
1 day (24 h) | 0.189442 | 0.114910 | 0.285579 | 0.236314 | 0.437795 | 0.367413 | |
1 week (168 h) | 0.257686 | 0.144994 | 0.380474 | 0.258060 | 0.609224 | 0.449918 | |
WS | 6 h | 0.379454 | 0.195008 | 0.561347 | 0.312715 | 0.699120 | 0.481810 |
12 h | 0.440138 | 0.215927 | 0.588492 | 0.322529 | 0.775922 | 0.490176 | |
1 day (24 h) | 0.493318 | 0.211254 | 0.584110 | 0.274951 | 0.784819 | 0.374443 | |
1 week (168 h) | 0.584592 | 0.223403 | 0.669909 | 0.214431 | 0.919692 | 0.378591 | |
PA | 6 h | 0.052675 | 0.040103 | 0.087534 | 0.072180 | 0.620545 | 0.516250 |
12 h | 0.062397 | 0.048661 | 0.090425 | 0.068381 | 0.659061 | 0.499820 | |
1 day (24 h) | 0.069556 | 0.045215 | 0.092405 | 0.060212 | 0.651223 | 0.472953 | |
1 week (168 h) | 0.077558 | 0.055597 | 0.097793 | 0.062741 | 0.904092 | 0.449073 | |
P | 6 h | 0.478431 | 0.977725 | 0.403790 | 1.125691 | 0.419979 | 1.090136 |
12 h | 0.638197 | 1.054031 | 0.494974 | 1.060481 | 0.465404 | 1.003928 | |
1 day (24 h) | 0.735816 | 0.904779 | 0.591382 | 1.029100 | 0.566018 | 0.946414 | |
1 week (168 h) | 0.855891 | 0.620173 | 0.794512 | 0.719941 | 0.767011 | 0.704660 | |
SWC | 6 h | 0.057037 | 0.054677 | NaN | NaN | 0.147393 | 0.174667 |
12 h | 0.074593 | 0.052939 | NaN | NaN | 0.143384 | 0.148434 | |
1 day (24 h) | 0.087407 | 0.071916 | NaN | NaN | 0.152106 | 0.165171 | |
1 week (168 h) | 0.167594 | 0.106338 | NaN | NaN | 0.218510 | 0.166977 | |
TS | 6 h | 0.060276 | 0.076342 | NaN | NaN | 0.168674 | 0.157127 |
12 h | 0.094433 | 0.138512 | NaN | NaN | 0.177172 | 0.154946 | |
1 day (24 h) | 0.139021 | 0.150554 | NaN | NaN | 0.190571 | 0.151443 | |
1 week (168 h) | 0.293335 | 0.190466 | NaN | NaN | 0.254479 | 0.135022 |
Timeseries
@cache_disk(cache_dir / "the_results_ts")
def get_the_results_ts():
reset_seed()= ImpComparison(models = models_var, df = hai, control = hai_era, block_len = 446, time_series=True, rmse=False)
comp_Av = comp_Av.compare(gap_len = [12,24, 336], var=list(hai.columns), n_rep=4)
results_Av return results_Av
= get_the_results_ts() results_ts
= plot_timeseries(results_ts.query("var in ['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS']"), idx_rep=0)
ts "timeseries_1", altair_render=True) save_show_plot(ts,
%time ts = plot_timeseries(results_ts.query("var in ['PA', 'P', 'TS', 'SWC']"), idx_rep=0)
%time save_show_plot(ts, "timeseries_2", altair_render=True)
CPU times: user 2.82 s, sys: 765 µs, total: 2.82 s
Wall time: 2.84 s
CPU times: user 9.53 s, sys: 127 ms, total: 9.66 s
Wall time: 12.8 s
from tqdm.auto import tqdm
# @cache_disk(cache_dir / "ts_plot", rm_cache=True)
def plot_additional_ts():
for idx in tqdm(results_ts.idx_rep.unique()):
if idx == 0: continue # skip first plot as is done above
= plot_timeseries(results_ts.query("var in ['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS']"), idx_rep=idx)
ts1 f"timeseries_1_{idx}", altair_render=True)
save_show_plot(ts1, = plot_timeseries(results_ts.query("var in ['PA', 'P', 'TS', 'SWC']"), idx_rep=idx)
ts2 f"timeseries_2_{idx}", altair_render=True) save_show_plot(ts2,
plot_additional_ts()
Kalman Filter analysis
Gap len
@cache_disk(cache_dir / "gap_len")
def get_g_len(n_rep=n_rep):
reset_seed()return KalmanImpComparison(models_var, hai, hai_era, block_len=48*7+100).compare(gap_len = [2,6,12,24,48,48*2, 48*3, 48*7], var=list(hai.columns), n_rep=n_rep)
= get_g_len(n_rep) gap_len
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Input In [46], in <cell line: 1>() ----> 1 gap_len = get_g_len(n_rep) File ~/Documents/uni/Thesis/meteo_imp/meteo_imp/utils.py:44, in cache_disk.<locals>.decorator.<locals>.new_func(*args) 42 def new_func(*args): 43 if tuple(args) not in cache: ---> 44 cache[tuple(args)] = original_func(*args) 45 save_data() 46 return cache[args] Input In [45], in get_g_len(n_rep) 1 @cache_disk(cache_dir / "gap_len") 2 def get_g_len(n_rep=n_rep): 3 reset_seed() ----> 4 return KalmanImpComparison(models_var, hai, hai_era, block_len=48*7+100).compare(gap_len = [2,6,12,24,48,48*2, 48*3, 48*7], var=list(hai.columns), n_rep=n_rep) File ~/Documents/uni/Thesis/meteo_imp/meteo_imp/kalman/results.py:458, in KalmanImpComparison.compare(self, n_rep, gap_len, var) 456 out = [] 457 for arg_set in tqdm(arg_sets): --> 458 out.append(self._compare_single(**arg_set, n_rep=n_rep)) 459 return prep_df(pd.concat(out)) File ~/Documents/uni/Thesis/meteo_imp/meteo_imp/kalman/results.py:439, in KalmanImpComparison._compare_single(self, n_rep, gap_len, var) 437 pred, targ, metric = imp.imp.preds_all_metrics(items = [items[i]], dls=dls, metrics=metrics_fn) 438 pred, targ = pred[0], targ[0] --> 439 pred = pred.mean.iloc[:, [var_idx]] 440 targ = MeteoImpDf(targ.data.iloc[:, [var_idx]], targ.mask.iloc[:, [var_idx]], targ.control.iloc[:, [var_idx]]) 441 out = { 442 'var': var, 443 'loss': metric['loss'][0].item(), (...) 446 'idx_rep': i, 447 } | imp.drop(index=["model", "imp"]).to_dict() File ~/.local/lib/python3.10/site-packages/pandas/core/indexing.py:140, in IndexingMixin.iloc(self) 135 class IndexingMixin: 136 """ 137 Mixin for adding .loc/.iloc/.at/.iat to Dataframes and Series. 138 """ --> 140 @property 141 def iloc(self) -> _iLocIndexer: 142 """ 143 Purely integer-location based indexing for selection by position. 144 (...) 275 2 1000 3000 276 """ 277 return _iLocIndexer("iloc", self) KeyboardInterrupt:
= plot_gap_len(gap_len, hai, hai_era)
p "gap_len") save_show_plot(p,
= table_gap_len(gap_len)
t / "gap_len.tex", label="gap_len",
table_gap_len_latex(t, base_path_tbl ="\\CapGapLen")
caption t
= gap_len.groupby('gap_len').agg({'rmse_stand': 'mean'})
g_len_agg 0])/g_len_agg.iloc[-1] (g_len_agg.iloc[
= gap_len.groupby(['gap_len', 'var']).agg({'rmse_stand': 'mean'})
g_len_agg 1.])/g_len_agg.loc[168.] (g_len_agg.loc[
g_len_agg
= gap_len.groupby('gap_len').agg({'rmse_stand': 'std'})
g_len_agg_std 0])/g_len_agg_std.iloc[-1] (g_len_agg_std.iloc[
'gap_len', 'var']).agg({'rmse_stand': 'std'})
(gap_len.groupby(["var")
.unstack(0, 1)
.droplevel(=True, layout=(3,3), figsize=(10,10))) .plot(subplots
# with open(base_path_tbl / "gap_len.tex") as f:
# print(f.readlines())
Control
= pd.DataFrame({'model': [ l_model("1_gap_varying_336_no_control_v1.pickle"), l_model("1_gap_varying_6-336_v3.pickle")],
models_nc 'type': [ 'No Control', 'Use Control' ]})
@cache_disk(cache_dir / "use_control")
def get_control(n_rep=n_rep):
reset_seed()
= KalmanImpComparison(models_nc, hai, hai_era, block_len=100+48*7)
kcomp_control
= kcomp_control.compare(n_rep =n_rep, gap_len = [12, 24, 48, 48*7], var = list(hai.columns))
k_results_control
return k_results_control
from time import sleep
= get_control(n_rep) k_results_control
k_results_control
= plot_compare(k_results_control, 'type', y = 'rmse', scale_domain=["Use Control", "No Control"])
p "use_control")
save_show_plot(p, p
from functools import partial
= table_compare(k_results_control, 'type')
t / "control.tex", label="tbl:control",
table_compare_latex(t, base_path_tbl ="\\CapControl")
caption t
Gap in Other variables
= pd.DataFrame.from_records([
models_gap_single 'Gap': 'All variables', 'gap_single_var': False, 'model': l_model("all_varying_gap_varying_len_6-30_v3.pickle")},
{'Gap': 'Only one var', 'gap_single_var': True, 'model': l_model("all_varying_gap_varying_len_6-30_v3.pickle")},
{ ])
@cache_disk(cache_dir / "gap_single")
def get_gap_single(n_rep):
= KalmanImpComparison(models_gap_single, hai, hai_era, block_len=130)
kcomp_single
return kcomp_single.compare(n_rep =n_rep, gap_len = [6, 12, 24, 30], var = list(hai.columns))
= get_gap_single(n_rep) res_single
= plot_compare(res_single, "Gap", y = 'rmse', scale_domain=["Only one var", "All variables"])
p "gap_single_var") save_show_plot(p,
= table_compare(res_single, 'Gap')
t / "gap_single_var.tex", caption="\\CapGapSingle", label="tbl:gap_single_var")
table_compare_latex(t, base_path_tbl t
= res_single.groupby(['Gap', 'var', 'gap_len']).agg({'rmse_stand': 'mean'}).reset_index().pivot(columns = 'Gap', values='rmse_stand', index=['var', 'gap_len']) res_singl_perc
'Only one var': (res_singl_perc["All variables"] - res_singl_perc["Only one var"]) / res_singl_perc["All variables"] * 100}) pd.DataFrame({
= res_single.groupby(['Gap', 'var']).agg({'rmse_stand': 'mean'}).reset_index().pivot(columns = 'Gap', values='rmse_stand', index=['var']) res_singl_perc
'Only one var': (res_singl_perc["All variables"] - res_singl_perc["Only one var"]) / res_singl_perc["All variables"] * 100}) pd.DataFrame({
Generic vs Specialized
= models_var.copy() models_generic
= l_model("1_gap_varying_6-336_v3.pickle")
models_generic.model 'type'] = 'Generic' models_generic[
models_generic
'type'] = 'Fine-tuned one var' models_var[
= pd.concat([models_generic, models_var]) models_gen_vs_spec
models_gen_vs_spec
@cache_disk(cache_dir / "generic")
def get_generic(n_rep=n_rep):
reset_seed()
= KalmanImpComparison(models_gen_vs_spec, hai, hai_era, block_len=100+48*7)
comp_generic
return comp_generic.compare(n_rep =n_rep, gap_len = [12, 24, 48, 48*7], var = list(hai.columns))
= get_generic(n_rep) k_results_generic
= 300 plot_formatter.legend_symbol_size
= plot_compare(k_results_generic, 'type', y = 'rmse', scale_domain=["Fine-tuned one var", "Generic"])
p "generic")
save_show_plot(p, p
= table_compare(k_results_generic, 'type')
t / "generic.tex", label='tbl:generic', caption="\\CapGeneric")
table_compare_latex(t, base_path_tbl t
= k_results_generic.groupby(['type', 'var']).agg({'rmse_stand': 'mean'}).reset_index().pivot(columns = 'type', values='rmse_stand', index=['var']) res_singl_perc
"Generic"] - res_singl_perc["Fine-tuned one var"]) / res_singl_perc["Generic"] * 100 (res_singl_perc[
Training
= pd.DataFrame.from_records([
models_train # {'Train': 'All variables', 'model': l_model("All_gap_all_30_v1.pickle") },
'Train': 'Only one var', 'model': l_model("1_gap_varying_6-336_v3.pickle") },
{'Train': 'Multi vars', 'model': l_model("all_varying_gap_varying_len_6-30_v3.pickle") },
{'Train': 'Random params', 'model': l_model("rand_all_varying_gap_varying_len_6-30_v4.pickle") }
{ ])
models_train
@cache_disk(cache_dir / "train")
def get_train(n_rep):
reset_seed()= KalmanImpComparison(models_train, hai, hai_era, block_len=130)
kcomp
return kcomp.compare(n_rep =n_rep, gap_len = [6, 12, 24, 30], var = list(hai.columns))
= get_train(n_rep) res_train
= res_train.groupby(['Train', 'gap_len']).agg({'rmse_stand': 'mean'}).reset_index() res_train_agg
res_train_agg
= plot_compare(res_train, "Train", y='rmse', scale_domain=["Multi vars", "Only one var", "Random params"])
p "train_compare") save_show_plot(p,
= table_compare3(res_train, 'Train')
t / "train.tex", label="tbl:train_compare", caption="\\CapTrain")
table_compare3_latex(t, base_path_tbl t
Extra results
Standard deviations
= hai.std().to_frame(name='std')
hai_std = "Variable"
hai_std.index.name = hai_std.reset_index().assign(unit=[f"\\si{{{unit}}}" for unit in units_big.values()]) hai_std
hai_std
= hai_std.style.hide(axis="index").format(precision=3).to_latex(hrules=True, caption="\\CapStd", label="tbl:hai_std", position_float="centering")
latex
with open(base_path_tbl / "hai_std.tex", 'w') as f:
f.write(latex)
Gap distribution
= here("../fluxnet/gap_stat") out_dir
= pl.read_parquet(out_dir / "../site_info.parquet").select([
site_info "start").cast(pl.Utf8).str.strptime(pl.Datetime, "%Y%m%d%H%M"),
pl.col("end").cast(pl.Utf8).str.strptime(pl.Datetime, "%Y%m%d%H%M"),
pl.col("site").cast(pl.Categorical).sort()
pl.col( ])
def duration_n_obs(duration):
"converts a duration into a n of fluxnet observations"
return abs(int(duration.total_seconds() / (30 * 60)))
= out_dir.ls()
files # need to sort to match the site_info
files.sort() = []
sites for i, path in enumerate(files):
sites.append(pl.scan_parquet(path).with_columns(["site"]).alias("site"),
pl.lit(site_info[i, "start"] - site_info[i, "end"])).alias("total_obs"),
pl.lit(duration_n_obs(site_info[i, "TIMESTAMP_END").cast(pl.Utf8).str.strptime(pl.Datetime, "%Y%m%d%H%M").alias("end"),
pl.col("TIMESTAMP_END"))
]).drop(
= pl.concat(sites) gap_stat
0]) pl.read_parquet(files[
gap_stat.head().collect()
def plot_var_dist(var, small=False, ax=None):
if ax is None: ax = get_grid(1)[0]
= gap_stat.filter(
ta_gaps "variable") == var)
(pl.col(filter(
)."gap_len") < 200 if small else True
pl.col("gap_len") / (24 *2 * 7)).collect().to_pandas().hist("gap_len", bins=50, ax=ax)
).with_column(pl.col(f"{var} - { 'gaps < 200' if small else 'all gaps'}")
ax.set_title(if not small: ax.set_yscale('log')
"gap length (weeks)")
ax.set_xlabel(f"{'Log' if not small else ''} n gaps")
ax.set_ylabel(# plt.xscale('log')
'TA_F_QC') plot_var_dist(
= dict(zip(scale_meteo.domain, list(sns.color_palette('Set2', n_colors=len(hai.columns)).as_hex()))) color_map
= {
qc_map 'TA': 'TA_F_QC',
'SW_IN': 'SW_IN_F_QC',
'LW_IN': 'LW_IN_F_QC',
'VPD': 'VPD_F_QC',
'WS': 'WS_F_QC',
'PA': 'PA_F_QC',
'P': 'P_F_QC',
'TS': 'TS_F_MDS_1_QC',
'SWC': 'SWC_F_MDS_1_QC',
}
def pl_in(col, values):
= False
expr for val in values:
|= pl.col(col) == val
expr return expr
filter(pl_in('variable', qc_map.values())
gap_stat.
).with_columns(["gap_len") < 48*7).then(True).otherwise(False).alias("short"),
pl.when(pl.col("total"),
pl.count().alias("total len"),
pl.count().alias("short").agg([
]).groupby("gap_len").count() / pl.col("total")).alias("frac_num"),
(pl.col("gap_len").sum() / pl.col("total len")).alias("frac_len")
(pl.col( ]).collect()
filter(pl_in('variable', qc_map.values())
gap_stat.
).with_columns(["gap_len") < 48*7).then(True).otherwise(False).alias("short"),
pl.when(pl.col("total"),
pl.count().alias("total len"),
pl.count().alias("short").agg([
]).groupby("gap_len").count() / pl.col("total")).alias("frac_num"),
(pl.col("gap_len").sum() / pl.col("total len")).alias("frac_len")
(pl.col( ]).collect()
= gap_stat.filter(
frac_miss 'variable', qc_map.values())
pl_in("site", "variable"]).agg([
).groupby(["gap_len").mean().alias("mean"),
pl.col("gap_len").sum() / pl.col("total_obs").first()).alias("frac_gap")
(pl.col( ])
'variable').agg([
frac_miss.groupby("frac_gap").max().alias("max"),
pl.col("frac_gap").min().alias("min"),
pl.col("frac_gap").std().alias("std"),
pl.col("frac_gap").mean().alias("mean"),
pl.col( ]).collect()
"frac_gap", reverse=True).collect() frac_miss.sort(
filter((pl.col("site") == "US-LWW")) site_info.
filter((pl.col("site") == "US-LWW") & (pl.col("variable") == "LW_IN_F_QC" )).collect() gap_stat.
import matplotlib
'font.size': 22})
matplotlib.rcParams.update({"whitegrid") sns.set_style(
def plot_var_dist(var, ax=None, small=True):
if ax is None: ax = get_grid(1)[0]
= color_map[var]
color = qc_map[var]
var_qc = gap_stat.filter(
ta_gaps "variable") == var_qc)
(pl.col(filter(
)."gap_len") < (24 * 2 *7) if small else True
pl.col("gap_len") / (2 if small else 48 * 7)
).with_column(pl.col("gap_len", bins=50, ax=ax, edgecolor="white", color=color)
).collect().to_pandas().hist(f"{var} - { 'gap length < 1 week' if small else 'all gaps'}")
ax.set_title(f"Gap length ({ 'hour' if small else 'week'})")
ax.set_xlabel(f"Log n gaps")
ax.set_ylabel('log')
ax.set_yscale( plt.tight_layout()
vars = gap_stat.select(pl.col("variable").unique()).collect()
vars.filter(pl.col("variable").str.contains("TA"))
for ax, var in zip(get_grid(9,3,3, figsize=(15,12), sharey=False), list(var_type.categories)):
=ax)
plot_var_dist(var, ax/ "gap_len_dist_small.pdf") plt.savefig(base_path_img
for ax, var in zip(get_grid(9,3,3, figsize=(15,12), sharey=False), list(var_type.categories)):
=ax, small=False)
plot_var_dist(var, ax/ "gap_len_dist.pdf") plt.savefig(base_path_img
Square Root Filter
Numerical Stability
from meteo_imp.kalman.performance import *
= cache_disk(cache_dir / "fuzz_sr")(fuzz_filter_SR)(100, 110) # this is already handling the random seed err
= plot_err_sr_filter(err)
p "numerical_stability") save_show_plot(p,
Performance
@cache_disk(cache_dir / "perf_sr")
def get_perf_sr():
reset_seed()return perf_comb_params('filter', use_sr_filter=[True, False], rep=range(100),
= 100,
n_obs =9,
n_dim_obs=18,
n_dim_state=14,
n_dim_contr=0,
p_missing=20 ,
bs= 'local_slope'
init_method )
= (get_perf_sr()
perf1 'use_sr_filter')
.groupby("time").mean())
.agg(pl.col(
.with_column("use_sr_filter"))
pl.when(pl.col("Square Root Filter"))
.then(pl.lit("Standard Filter"))
.otherwise(pl.lit("Filter type")
.alias( ))
perf1
0, 'time'] - perf1[1, 'time']) / perf1[1, 'time'] * 100 (perf1[
= alt.Chart(perf1.to_pandas()).mark_bar(size = 50).encode(
plot_perf_sr =alt.X('Filter type', axis=alt.Axis(labelAngle=0)),
x=alt.Y('time', scale=alt.Scale(zero=False), title="time [s]"),
y=alt.Color('Filter type',
color= alt.Scale(scheme = 'accent'))
scale =300) ).properties(width
"perf_sr") save_show_plot(plot_perf_sr,