Skip to content

Commit ca122cf

Browse files
authored
Model score example (#650)
1 parent c0cfe9d commit ca122cf

File tree

9 files changed

+18652
-2
lines changed

9 files changed

+18652
-2
lines changed

examples/model-score/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
data

examples/model-score/app.py

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
from __future__ import annotations
2+
3+
import sqlite3
4+
from datetime import datetime, timedelta, timezone
5+
6+
import pandas as pd
7+
import plotly.express as px
8+
import scoredata
9+
from plotly_streaming import render_plotly_streaming
10+
from shinywidgets import output_widget
11+
12+
import shiny.experimental as x
13+
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
14+
15+
THRESHOLD_MID = 0.85
16+
THRESHOLD_MID_COLOR = "rgb(0, 137, 26)"
17+
THRESHOLD_LOW = 0.5
18+
THRESHOLD_LOW_COLOR = "rgb(193, 0, 0)"
19+
20+
# Start a background thread that writes fake data to the SQLite database every second
21+
scoredata.begin()
22+
23+
con = sqlite3.connect(scoredata.SQLITE_DB_URI, uri=True)
24+
25+
26+
def last_modified(con):
27+
"""
28+
Fast-executing call to get the timestamp of the most recent row in the database.
29+
We will poll against this in absence of a way to receive a push notification when
30+
our SQLite database changes.
31+
"""
32+
return con.execute("select max(timestamp) from accuracy_scores").fetchone()[0]
33+
34+
35+
@reactive.poll(lambda: last_modified(con))
36+
def df():
37+
"""
38+
@reactive.poll calls a cheap query (`last_modified()`) every 1 second to check if
39+
the expensive query (`df()`) should be run and downstream calculations should be
40+
updated.
41+
42+
By declaring this reactive object at the top-level of the script instead of in the
43+
server function, all sessions are sharing the same object, so the expensive query is
44+
only run once no matter how many users are connected.
45+
"""
46+
tbl = pd.read_sql(
47+
"select * from accuracy_scores order by timestamp desc, model desc limit ?",
48+
con,
49+
params=[150],
50+
)
51+
# Convert timestamp to datetime object, which SQLite doesn't support natively
52+
tbl["timestamp"] = pd.to_datetime(tbl["timestamp"], utc=True)
53+
# Create a short label for readability
54+
tbl["time"] = tbl["timestamp"].dt.strftime("%H:%M:%S")
55+
# Reverse order of rows
56+
tbl = tbl.iloc[::-1]
57+
58+
return tbl
59+
60+
61+
def read_time_period(from_time, to_time):
62+
tbl = pd.read_sql(
63+
"select * from accuracy_scores where timestamp between ? and ? order by timestamp, model",
64+
con,
65+
params=[from_time, to_time],
66+
)
67+
# Treat timestamp as a continuous variable
68+
tbl["timestamp"] = pd.to_datetime(tbl["timestamp"], utc=True)
69+
tbl["time"] = tbl["timestamp"].dt.strftime("%H:%M:%S")
70+
71+
return tbl
72+
73+
74+
model_names = ["model_1", "model_2", "model_3", "model_4"]
75+
model_colors = {
76+
name: color
77+
for name, color in zip(model_names, px.colors.qualitative.D3[0 : len(model_names)])
78+
}
79+
80+
81+
def app_ui(req):
82+
end_time = datetime.now(timezone.utc)
83+
start_time = end_time - timedelta(minutes=1)
84+
85+
return x.ui.page_sidebar(
86+
x.ui.sidebar(
87+
ui.input_checkbox_group(
88+
"models", "Models", model_names, selected=model_names
89+
),
90+
ui.input_radio_buttons(
91+
"timeframe",
92+
"Timeframe",
93+
["Latest", "Specific timeframe"],
94+
selected="Latest",
95+
),
96+
ui.panel_conditional(
97+
"input.timeframe === 'Latest'",
98+
ui.input_selectize(
99+
"refresh",
100+
"Refresh interval",
101+
{
102+
0: "Realtime",
103+
5: "5 seconds",
104+
15: "15 seconds",
105+
30: "30 seconds",
106+
60 * 5: "5 minutes",
107+
60 * 15: "15 minutes",
108+
},
109+
),
110+
),
111+
ui.panel_conditional(
112+
"input.timeframe !== 'Latest'",
113+
ui.input_slider(
114+
"timerange",
115+
"Time range",
116+
min=start_time,
117+
max=end_time,
118+
value=[start_time, end_time],
119+
step=timedelta(seconds=1),
120+
time_format="%H:%M:%S",
121+
),
122+
),
123+
),
124+
ui.div(
125+
ui.h1("Model monitoring dashboard"),
126+
ui.p(
127+
x.ui.output_ui("value_boxes"),
128+
),
129+
x.ui.card(output_widget("plot_timeseries")),
130+
x.ui.card(output_widget("plot_dist")),
131+
style="max-width: 800px;",
132+
),
133+
fillable=False,
134+
)
135+
136+
137+
def server(input: Inputs, output: Outputs, session: Session):
138+
@reactive.Calc
139+
def recent_df():
140+
"""
141+
Returns the most recent rows from the database, at the refresh interval
142+
requested by the user. If the refresh interview is 0, go at maximum speed.
143+
"""
144+
refresh = int(input.refresh())
145+
if refresh == 0:
146+
return df()
147+
else:
148+
# This approach works well if you know that input.refresh() is likely to be
149+
# a longer interval than the underlying changing data source (df()). If not,
150+
# then this can cause downstream reactives to be invalidated when they
151+
# didn't need to be.
152+
reactive.invalidate_later(refresh)
153+
with reactive.isolate():
154+
return df()
155+
156+
@reactive.Calc
157+
def timeframe_df():
158+
"""
159+
Returns rows from the database within the specified time range. Notice that we
160+
implement the business logic as a separate function (read_time_period), so it's
161+
easier to reason about and test.
162+
"""
163+
start, end = input.timerange()
164+
return read_time_period(start, end)
165+
166+
@reactive.Calc
167+
def filtered_df():
168+
"""
169+
Return the data frame that should be displayed in the app, based on the user's
170+
input. This will be either the latest rows, or a specific time range. Also
171+
filter out rows for models that the user has deselected.
172+
"""
173+
data = recent_df() if input.timeframe() == "Latest" else timeframe_df()
174+
175+
# Filter the rows so we only include the desired models
176+
return data[data["model"].isin(input.models())]
177+
178+
@reactive.Calc
179+
def filtered_model_names():
180+
return filtered_df()["model"].unique()
181+
182+
@output
183+
@render.ui
184+
def value_boxes():
185+
data = filtered_df()
186+
models = data["model"].unique().tolist()
187+
scores_by_model = {
188+
x: data[data["model"] == x].iloc[-1]["score"] for x in models
189+
}
190+
# Round scores to 2 decimal places
191+
scores_by_model = {x: round(y, 2) for x, y in scores_by_model.items()}
192+
193+
return x.ui.layout_column_wrap(
194+
"135px",
195+
*[
196+
# For each model, return a value_box with the score, colored based on
197+
# how high the score is.
198+
x.ui.value_box(
199+
model,
200+
ui.h2(score),
201+
theme_color="success"
202+
if score > THRESHOLD_MID
203+
else "warning"
204+
if score > THRESHOLD_LOW
205+
else "danger",
206+
)
207+
for model, score in scores_by_model.items()
208+
],
209+
fixed_width=True,
210+
)
211+
212+
@output
213+
@render_plotly_streaming(recreate_key=filtered_model_names, update="data")
214+
def plot_timeseries():
215+
"""
216+
Returns a Plotly Figure visualization. Streams new data to the Plotly widget in
217+
the browser whenever filtered_df() updates, and completely recreates the figure
218+
when filtered_model_names() changes (see recreate_key=... above).
219+
"""
220+
fig = px.line(
221+
filtered_df(),
222+
x="time",
223+
y="score",
224+
labels=dict(score="accuracy"),
225+
color="model",
226+
color_discrete_map=model_colors,
227+
# The default for render_mode is "auto", which switches between
228+
# type="scatter" and type="scattergl" depending on the number of data
229+
# points. Switching that value breaks streaming updates, as the type
230+
# property is read-only. Setting it to "webgl" keeps the type consistent.
231+
render_mode="webgl",
232+
template="simple_white",
233+
)
234+
235+
fig.add_hline(
236+
THRESHOLD_LOW,
237+
line_dash="dash",
238+
line=dict(color=THRESHOLD_LOW_COLOR, width=2),
239+
opacity=0.3,
240+
)
241+
fig.add_hline(
242+
THRESHOLD_MID,
243+
line_dash="dash",
244+
line=dict(color=THRESHOLD_MID_COLOR, width=2),
245+
opacity=0.3,
246+
)
247+
248+
fig.update_yaxes(range=[0, 1], fixedrange=True)
249+
fig.update_xaxes(fixedrange=True, tickangle=60)
250+
251+
return fig
252+
253+
@output
254+
@render_plotly_streaming(recreate_key=filtered_model_names, update="data")
255+
def plot_dist():
256+
fig = px.histogram(
257+
filtered_df(),
258+
facet_row="model",
259+
nbins=20,
260+
x="score",
261+
labels=dict(score="accuracy"),
262+
color="model",
263+
color_discrete_map=model_colors,
264+
template="simple_white",
265+
)
266+
267+
fig.add_vline(
268+
THRESHOLD_LOW,
269+
line_dash="dash",
270+
line=dict(color=THRESHOLD_LOW_COLOR, width=2),
271+
opacity=0.3,
272+
)
273+
fig.add_vline(
274+
THRESHOLD_MID,
275+
line_dash="dash",
276+
line=dict(color=THRESHOLD_MID_COLOR, width=2),
277+
opacity=0.3,
278+
)
279+
280+
# From https://plotly.com/python/facet-plots/#customizing-subplot-figure-titles
281+
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
282+
283+
fig.update_yaxes(matches=None)
284+
fig.update_xaxes(range=[0, 1], fixedrange=True)
285+
fig.layout.height = 500
286+
287+
return fig
288+
289+
@reactive.Effect
290+
def update_time_range():
291+
"""
292+
Every 5 seconds, update the custom time range slider's min and max values to
293+
reflect the current min and max values in the database.
294+
"""
295+
296+
reactive.invalidate_later(15)
297+
min_time, max_time = pd.to_datetime(
298+
con.execute(
299+
"select min(timestamp), max(timestamp) from accuracy_scores"
300+
).fetchone(),
301+
utc=True,
302+
)
303+
ui.update_slider(
304+
"timerange",
305+
min=min_time.replace(tzinfo=timezone.utc),
306+
max=max_time.replace(tzinfo=timezone.utc),
307+
)
308+
309+
310+
app = App(app_ui, server)

0 commit comments

Comments
 (0)