|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import sqlite3 |
| 4 | +from datetime import datetime, timedelta, timezone |
| 5 | + |
| 6 | +import pandas as pd |
| 7 | +import plotly.express as px |
| 8 | +import scoredata |
| 9 | +from plotly_streaming import render_plotly_streaming |
| 10 | +from shinywidgets import output_widget |
| 11 | + |
| 12 | +import shiny.experimental as x |
| 13 | +from shiny import App, Inputs, Outputs, Session, reactive, render, ui |
| 14 | + |
| 15 | +THRESHOLD_MID = 0.85 |
| 16 | +THRESHOLD_MID_COLOR = "rgb(0, 137, 26)" |
| 17 | +THRESHOLD_LOW = 0.5 |
| 18 | +THRESHOLD_LOW_COLOR = "rgb(193, 0, 0)" |
| 19 | + |
| 20 | +# Start a background thread that writes fake data to the SQLite database every second |
| 21 | +scoredata.begin() |
| 22 | + |
| 23 | +con = sqlite3.connect(scoredata.SQLITE_DB_URI, uri=True) |
| 24 | + |
| 25 | + |
| 26 | +def last_modified(con): |
| 27 | + """ |
| 28 | + Fast-executing call to get the timestamp of the most recent row in the database. |
| 29 | + We will poll against this in absence of a way to receive a push notification when |
| 30 | + our SQLite database changes. |
| 31 | + """ |
| 32 | + return con.execute("select max(timestamp) from accuracy_scores").fetchone()[0] |
| 33 | + |
| 34 | + |
| 35 | +@reactive.poll(lambda: last_modified(con)) |
| 36 | +def df(): |
| 37 | + """ |
| 38 | + @reactive.poll calls a cheap query (`last_modified()`) every 1 second to check if |
| 39 | + the expensive query (`df()`) should be run and downstream calculations should be |
| 40 | + updated. |
| 41 | +
|
| 42 | + By declaring this reactive object at the top-level of the script instead of in the |
| 43 | + server function, all sessions are sharing the same object, so the expensive query is |
| 44 | + only run once no matter how many users are connected. |
| 45 | + """ |
| 46 | + tbl = pd.read_sql( |
| 47 | + "select * from accuracy_scores order by timestamp desc, model desc limit ?", |
| 48 | + con, |
| 49 | + params=[150], |
| 50 | + ) |
| 51 | + # Convert timestamp to datetime object, which SQLite doesn't support natively |
| 52 | + tbl["timestamp"] = pd.to_datetime(tbl["timestamp"], utc=True) |
| 53 | + # Create a short label for readability |
| 54 | + tbl["time"] = tbl["timestamp"].dt.strftime("%H:%M:%S") |
| 55 | + # Reverse order of rows |
| 56 | + tbl = tbl.iloc[::-1] |
| 57 | + |
| 58 | + return tbl |
| 59 | + |
| 60 | + |
| 61 | +def read_time_period(from_time, to_time): |
| 62 | + tbl = pd.read_sql( |
| 63 | + "select * from accuracy_scores where timestamp between ? and ? order by timestamp, model", |
| 64 | + con, |
| 65 | + params=[from_time, to_time], |
| 66 | + ) |
| 67 | + # Treat timestamp as a continuous variable |
| 68 | + tbl["timestamp"] = pd.to_datetime(tbl["timestamp"], utc=True) |
| 69 | + tbl["time"] = tbl["timestamp"].dt.strftime("%H:%M:%S") |
| 70 | + |
| 71 | + return tbl |
| 72 | + |
| 73 | + |
| 74 | +model_names = ["model_1", "model_2", "model_3", "model_4"] |
| 75 | +model_colors = { |
| 76 | + name: color |
| 77 | + for name, color in zip(model_names, px.colors.qualitative.D3[0 : len(model_names)]) |
| 78 | +} |
| 79 | + |
| 80 | + |
| 81 | +def app_ui(req): |
| 82 | + end_time = datetime.now(timezone.utc) |
| 83 | + start_time = end_time - timedelta(minutes=1) |
| 84 | + |
| 85 | + return x.ui.page_sidebar( |
| 86 | + x.ui.sidebar( |
| 87 | + ui.input_checkbox_group( |
| 88 | + "models", "Models", model_names, selected=model_names |
| 89 | + ), |
| 90 | + ui.input_radio_buttons( |
| 91 | + "timeframe", |
| 92 | + "Timeframe", |
| 93 | + ["Latest", "Specific timeframe"], |
| 94 | + selected="Latest", |
| 95 | + ), |
| 96 | + ui.panel_conditional( |
| 97 | + "input.timeframe === 'Latest'", |
| 98 | + ui.input_selectize( |
| 99 | + "refresh", |
| 100 | + "Refresh interval", |
| 101 | + { |
| 102 | + 0: "Realtime", |
| 103 | + 5: "5 seconds", |
| 104 | + 15: "15 seconds", |
| 105 | + 30: "30 seconds", |
| 106 | + 60 * 5: "5 minutes", |
| 107 | + 60 * 15: "15 minutes", |
| 108 | + }, |
| 109 | + ), |
| 110 | + ), |
| 111 | + ui.panel_conditional( |
| 112 | + "input.timeframe !== 'Latest'", |
| 113 | + ui.input_slider( |
| 114 | + "timerange", |
| 115 | + "Time range", |
| 116 | + min=start_time, |
| 117 | + max=end_time, |
| 118 | + value=[start_time, end_time], |
| 119 | + step=timedelta(seconds=1), |
| 120 | + time_format="%H:%M:%S", |
| 121 | + ), |
| 122 | + ), |
| 123 | + ), |
| 124 | + ui.div( |
| 125 | + ui.h1("Model monitoring dashboard"), |
| 126 | + ui.p( |
| 127 | + x.ui.output_ui("value_boxes"), |
| 128 | + ), |
| 129 | + x.ui.card(output_widget("plot_timeseries")), |
| 130 | + x.ui.card(output_widget("plot_dist")), |
| 131 | + style="max-width: 800px;", |
| 132 | + ), |
| 133 | + fillable=False, |
| 134 | + ) |
| 135 | + |
| 136 | + |
| 137 | +def server(input: Inputs, output: Outputs, session: Session): |
| 138 | + @reactive.Calc |
| 139 | + def recent_df(): |
| 140 | + """ |
| 141 | + Returns the most recent rows from the database, at the refresh interval |
| 142 | + requested by the user. If the refresh interview is 0, go at maximum speed. |
| 143 | + """ |
| 144 | + refresh = int(input.refresh()) |
| 145 | + if refresh == 0: |
| 146 | + return df() |
| 147 | + else: |
| 148 | + # This approach works well if you know that input.refresh() is likely to be |
| 149 | + # a longer interval than the underlying changing data source (df()). If not, |
| 150 | + # then this can cause downstream reactives to be invalidated when they |
| 151 | + # didn't need to be. |
| 152 | + reactive.invalidate_later(refresh) |
| 153 | + with reactive.isolate(): |
| 154 | + return df() |
| 155 | + |
| 156 | + @reactive.Calc |
| 157 | + def timeframe_df(): |
| 158 | + """ |
| 159 | + Returns rows from the database within the specified time range. Notice that we |
| 160 | + implement the business logic as a separate function (read_time_period), so it's |
| 161 | + easier to reason about and test. |
| 162 | + """ |
| 163 | + start, end = input.timerange() |
| 164 | + return read_time_period(start, end) |
| 165 | + |
| 166 | + @reactive.Calc |
| 167 | + def filtered_df(): |
| 168 | + """ |
| 169 | + Return the data frame that should be displayed in the app, based on the user's |
| 170 | + input. This will be either the latest rows, or a specific time range. Also |
| 171 | + filter out rows for models that the user has deselected. |
| 172 | + """ |
| 173 | + data = recent_df() if input.timeframe() == "Latest" else timeframe_df() |
| 174 | + |
| 175 | + # Filter the rows so we only include the desired models |
| 176 | + return data[data["model"].isin(input.models())] |
| 177 | + |
| 178 | + @reactive.Calc |
| 179 | + def filtered_model_names(): |
| 180 | + return filtered_df()["model"].unique() |
| 181 | + |
| 182 | + @output |
| 183 | + @render.ui |
| 184 | + def value_boxes(): |
| 185 | + data = filtered_df() |
| 186 | + models = data["model"].unique().tolist() |
| 187 | + scores_by_model = { |
| 188 | + x: data[data["model"] == x].iloc[-1]["score"] for x in models |
| 189 | + } |
| 190 | + # Round scores to 2 decimal places |
| 191 | + scores_by_model = {x: round(y, 2) for x, y in scores_by_model.items()} |
| 192 | + |
| 193 | + return x.ui.layout_column_wrap( |
| 194 | + "135px", |
| 195 | + *[ |
| 196 | + # For each model, return a value_box with the score, colored based on |
| 197 | + # how high the score is. |
| 198 | + x.ui.value_box( |
| 199 | + model, |
| 200 | + ui.h2(score), |
| 201 | + theme_color="success" |
| 202 | + if score > THRESHOLD_MID |
| 203 | + else "warning" |
| 204 | + if score > THRESHOLD_LOW |
| 205 | + else "danger", |
| 206 | + ) |
| 207 | + for model, score in scores_by_model.items() |
| 208 | + ], |
| 209 | + fixed_width=True, |
| 210 | + ) |
| 211 | + |
| 212 | + @output |
| 213 | + @render_plotly_streaming(recreate_key=filtered_model_names, update="data") |
| 214 | + def plot_timeseries(): |
| 215 | + """ |
| 216 | + Returns a Plotly Figure visualization. Streams new data to the Plotly widget in |
| 217 | + the browser whenever filtered_df() updates, and completely recreates the figure |
| 218 | + when filtered_model_names() changes (see recreate_key=... above). |
| 219 | + """ |
| 220 | + fig = px.line( |
| 221 | + filtered_df(), |
| 222 | + x="time", |
| 223 | + y="score", |
| 224 | + labels=dict(score="accuracy"), |
| 225 | + color="model", |
| 226 | + color_discrete_map=model_colors, |
| 227 | + # The default for render_mode is "auto", which switches between |
| 228 | + # type="scatter" and type="scattergl" depending on the number of data |
| 229 | + # points. Switching that value breaks streaming updates, as the type |
| 230 | + # property is read-only. Setting it to "webgl" keeps the type consistent. |
| 231 | + render_mode="webgl", |
| 232 | + template="simple_white", |
| 233 | + ) |
| 234 | + |
| 235 | + fig.add_hline( |
| 236 | + THRESHOLD_LOW, |
| 237 | + line_dash="dash", |
| 238 | + line=dict(color=THRESHOLD_LOW_COLOR, width=2), |
| 239 | + opacity=0.3, |
| 240 | + ) |
| 241 | + fig.add_hline( |
| 242 | + THRESHOLD_MID, |
| 243 | + line_dash="dash", |
| 244 | + line=dict(color=THRESHOLD_MID_COLOR, width=2), |
| 245 | + opacity=0.3, |
| 246 | + ) |
| 247 | + |
| 248 | + fig.update_yaxes(range=[0, 1], fixedrange=True) |
| 249 | + fig.update_xaxes(fixedrange=True, tickangle=60) |
| 250 | + |
| 251 | + return fig |
| 252 | + |
| 253 | + @output |
| 254 | + @render_plotly_streaming(recreate_key=filtered_model_names, update="data") |
| 255 | + def plot_dist(): |
| 256 | + fig = px.histogram( |
| 257 | + filtered_df(), |
| 258 | + facet_row="model", |
| 259 | + nbins=20, |
| 260 | + x="score", |
| 261 | + labels=dict(score="accuracy"), |
| 262 | + color="model", |
| 263 | + color_discrete_map=model_colors, |
| 264 | + template="simple_white", |
| 265 | + ) |
| 266 | + |
| 267 | + fig.add_vline( |
| 268 | + THRESHOLD_LOW, |
| 269 | + line_dash="dash", |
| 270 | + line=dict(color=THRESHOLD_LOW_COLOR, width=2), |
| 271 | + opacity=0.3, |
| 272 | + ) |
| 273 | + fig.add_vline( |
| 274 | + THRESHOLD_MID, |
| 275 | + line_dash="dash", |
| 276 | + line=dict(color=THRESHOLD_MID_COLOR, width=2), |
| 277 | + opacity=0.3, |
| 278 | + ) |
| 279 | + |
| 280 | + # From https://plotly.com/python/facet-plots/#customizing-subplot-figure-titles |
| 281 | + fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1])) |
| 282 | + |
| 283 | + fig.update_yaxes(matches=None) |
| 284 | + fig.update_xaxes(range=[0, 1], fixedrange=True) |
| 285 | + fig.layout.height = 500 |
| 286 | + |
| 287 | + return fig |
| 288 | + |
| 289 | + @reactive.Effect |
| 290 | + def update_time_range(): |
| 291 | + """ |
| 292 | + Every 5 seconds, update the custom time range slider's min and max values to |
| 293 | + reflect the current min and max values in the database. |
| 294 | + """ |
| 295 | + |
| 296 | + reactive.invalidate_later(15) |
| 297 | + min_time, max_time = pd.to_datetime( |
| 298 | + con.execute( |
| 299 | + "select min(timestamp), max(timestamp) from accuracy_scores" |
| 300 | + ).fetchone(), |
| 301 | + utc=True, |
| 302 | + ) |
| 303 | + ui.update_slider( |
| 304 | + "timerange", |
| 305 | + min=min_time.replace(tzinfo=timezone.utc), |
| 306 | + max=max_time.replace(tzinfo=timezone.utc), |
| 307 | + ) |
| 308 | + |
| 309 | + |
| 310 | +app = App(app_ui, server) |
0 commit comments