Skip to content

Commit c58b567

Browse files
[Feature] Add Chart To econometrics.correlation_matrix (#6750)
* add chart to correlation matrix * lint * handling for data.results without chart --------- Co-authored-by: Igor Radovanovic <[email protected]>
1 parent 737e5d4 commit c58b567

File tree

9 files changed

+301
-8
lines changed

9 files changed

+301
-8
lines changed

openbb_platform/extensions/econometrics/integration/test_econometrics_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def get_data(menu: Literal["equity", "crypto"]):
7171
@pytest.mark.parametrize(
7272
"params, data_type",
7373
[
74-
({"data": ""}, "equity"),
75-
({"data": ""}, "crypto"),
74+
({"data": "", "method": "pearson"}, "equity"),
75+
({"data": "", "method": "pearson"}, "crypto"),
7676
],
7777
)
7878
@pytest.mark.integration

openbb_platform/extensions/econometrics/integration/test_econometrics_python.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def get_data(menu: Literal["equity", "crypto"]):
6666
@parametrize(
6767
"params, data_type",
6868
[
69-
({"data": ""}, "equity"),
70-
({"data": ""}, "crypto"),
69+
({"data": "", "method": "pearson"}, "equity"),
70+
({"data": "", "method": "pearson"}, "crypto"),
7171
],
7272
)
7373
@pytest.mark.integration

openbb_platform/extensions/econometrics/openbb_econometrics/econometrics_router.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
APIEx(parameters={"data": APIEx.mock_data("timeseries")}),
2626
],
2727
)
28-
def correlation_matrix(data: List[Data]) -> OBBject[List[Data]]:
28+
def correlation_matrix(
29+
data: List[Data], method: Literal["pearson", "kendall", "spearman"] = "pearson"
30+
) -> OBBject[List[Data]]:
2931
"""Get the correlation matrix of an input dataset.
3032
3133
The correlation matrix provides a view of how different variables in your dataset relate to one another.
@@ -37,6 +39,11 @@ def correlation_matrix(data: List[Data]) -> OBBject[List[Data]]:
3739
----------
3840
data : List[Data]
3941
Input dataset.
42+
method : Literal["pearson", "kendall", "spearman"]
43+
Method to use for correlation calculation. Default is "pearson".
44+
pearson : standard correlation coefficient
45+
kendall : Kendall Tau correlation coefficient
46+
spearman : Spearman rank correlation
4047
4148
Returns
4249
-------
@@ -49,9 +56,14 @@ def correlation_matrix(data: List[Data]) -> OBBject[List[Data]]:
4956

5057
df = basemodel_to_df(data)
5158
# remove non float columns from the dataframe to perform the correlation
52-
df = df.select_dtypes(include=["float64"])
5359

54-
corr = df.corr()
60+
if "symbol" in df.columns and len(df.symbol.unique()) > 1 and "close" in df.columns:
61+
df = df.pivot(
62+
columns="symbol",
63+
values="close",
64+
)
65+
66+
corr = df.corr(method=method, numeric_only=True)
5567

5668
# replace nan values with None to allow for json serialization
5769
corr = corr.replace(np.NaN, None)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Views for the Econometrics Extension."""
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
if TYPE_CHECKING:
6+
from openbb_charting.core.openbb_figure import (
7+
OpenBBFigure,
8+
)
9+
10+
11+
class EconometricsViews:
12+
"""Econometrics Views."""
13+
14+
@staticmethod
15+
def econometrics_correlation_matrix( # noqa: PLR0912
16+
**kwargs,
17+
) -> tuple["OpenBBFigure", dict[str, Any]]:
18+
"""Correlation Matrix Chart.
19+
20+
Parameters
21+
----------
22+
data : Union[list[Data], DataFrame]
23+
Input dataset.
24+
method : Literal["pearson", "kendall", "spearman"]
25+
Method to use for correlation calculation. Default is "pearson".
26+
pearson : standard correlation coefficient
27+
kendall : Kendall Tau correlation coefficient
28+
spearman : Spearman rank correlation
29+
colorscale : str
30+
Plotly colorscale to use for the heatmap. Default is "RdBu".
31+
title : str
32+
Title of the chart. Default is "Asset Correlation Matrix".
33+
layout_kwargs : Dict[str, Any]
34+
Additional keyword arguments to apply with figure.update_layout(), by default None.
35+
"""
36+
# pylint: disable=import-outside-toplevel
37+
from openbb_charting.charts.correlation_matrix import correlation_matrix
38+
39+
return correlation_matrix(**kwargs)

openbb_platform/extensions/econometrics/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,6 @@ build-backend = "poetry.core.masonry.api"
2121

2222
[tool.poetry.plugins."openbb_core_extension"]
2323
econometrics = "openbb_econometrics.econometrics_router:router"
24+
25+
[tool.poetry.plugins."openbb_charting_extension"]
26+
econometrics = "openbb_econometrics.econometrics_views:EconometricsViews"

openbb_platform/obbject_extensions/charting/integration/test_charting_api.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,3 +897,44 @@ def test_charting_economy_survey_bls_series(params, headers):
897897
assert chart
898898
assert not fig
899899
assert list(chart.keys()) == ["content", "format"]
900+
901+
902+
@parametrize(
903+
"params",
904+
[
905+
(
906+
{
907+
"data": "",
908+
"method": "pearson",
909+
"chart": True,
910+
}
911+
)
912+
],
913+
)
914+
@pytest.mark.integration
915+
def test_charting_econometrics_correlation_matrix(params, headers):
916+
"""Test chart econometrics correlation matrix."""
917+
# pylint:disable=import-outside-toplevel
918+
from pandas import DataFrame
919+
920+
url = "http://0.0.0.0:8000/api/v1/equity/price/historical?symbol=AAPL,MSFT,GOOG&provider=yfinance"
921+
result = requests.get(url, headers=headers, timeout=10)
922+
df = DataFrame(result.json()["results"])
923+
df = df.pivot(index="date", columns="symbol", values="close").reset_index()
924+
body = df.to_dict(orient="records")
925+
926+
params = {p: v for p, v in params.items() if v}
927+
928+
query_str = get_querystring(params, [])
929+
url = f"http://0.0.0.0:8000/api/v1/econometrics/correlation_matrix?{query_str}"
930+
result = requests.post(url, headers=headers, timeout=10, data=json.dumps(body))
931+
932+
assert isinstance(result, requests.Response)
933+
assert result.status_code == 200
934+
935+
chart = result.json()["chart"]
936+
fig = chart.pop("fig", {})
937+
938+
assert chart
939+
assert not fig
940+
assert list(chart.keys()) == ["content", "format"]

openbb_platform/obbject_extensions/charting/integration/test_charting_python.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,3 +728,34 @@ def test_charting_economy_survey_bls_series(params, obb):
728728
assert len(result.results) > 0
729729
assert result.chart.content
730730
assert isinstance(result.chart.fig, OpenBBFigure)
731+
732+
733+
@parametrize(
734+
"params",
735+
[
736+
(
737+
{
738+
"data": "",
739+
"method": "pearson",
740+
"chart": True,
741+
}
742+
)
743+
],
744+
)
745+
@pytest.mark.integration
746+
def test_charting_econometrics_correlation_matrix(params, obb):
747+
"""Test chart econometrics correlation matrix."""
748+
749+
symbols = "XRT,XLB,XLI,XLH,XLC,XLY,XLU,XLK".split(",")
750+
params["data"] = (
751+
obb.equity.price.historical(symbol=symbols, provider="yfinance")
752+
.to_df()
753+
.pivot(columns="symbol", values="close")
754+
.filter(items=symbols, axis=1)
755+
)
756+
result = obb.econometrics.correlation_matrix(**params)
757+
assert result
758+
assert isinstance(result, OBBject)
759+
assert len(result.results) > 0
760+
assert result.chart.content
761+
assert isinstance(result.chart.fig, OpenBBFigure)

openbb_platform/obbject_extensions/charting/openbb_charting/charting.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Charting Class implementation."""
22

3-
# pylint: disable=too-many-arguments,unused-argument
3+
# pylint: disable=too-many-arguments,unused-argument,too-many-positional-arguments
44

55
from typing import (
66
TYPE_CHECKING,
@@ -57,6 +57,8 @@ class Charting:
5757
Create a line chart from external data.
5858
create_bar_chart
5959
Create a bar chart, on a single x-axis with one or more values for the y-axis, from external data.
60+
create_correlation_matrix
61+
Create a correlation matrix from external data.
6062
toggle_chart_style
6163
Toggle the chart style, of an existing chart, between light and dark mode.
6264
"""
@@ -367,6 +369,54 @@ def create_bar_chart(
367369

368370
return fig
369371

372+
def create_correlation_matrix(
373+
self,
374+
data: Union[
375+
list[Data],
376+
"DataFrame",
377+
],
378+
method: Literal["pearson", "kendall", "spearman"] = "pearson",
379+
colorscale: str = "RdBu",
380+
title: str = "Asset Correlation Matrix",
381+
layout_kwargs: Optional[Dict[str, Any]] = None,
382+
):
383+
"""Create a correlation matrix from external data.
384+
385+
Parameters
386+
----------
387+
data : Union[list[Data], DataFrame]
388+
Input dataset.
389+
method : Literal["pearson", "kendall", "spearman"]
390+
Method to use for correlation calculation. Default is "pearson".
391+
pearson : standard correlation coefficient
392+
kendall : Kendall Tau correlation coefficient
393+
spearman : Spearman rank correlation
394+
colorscale : str
395+
Plotly colorscale to use for the heatmap. Default is "RdBu".
396+
title : str
397+
Title of the chart. Default is "Asset Correlation Matrix".
398+
layout_kwargs : Dict[str, Any]
399+
Additional keyword arguments to apply with figure.update_layout(), by default None.
400+
401+
Returns
402+
-------
403+
OpenBBFigure
404+
The OpenBBFigure object.
405+
"""
406+
# pylint: disable=import-outside-toplevel
407+
from openbb_charting.charts.correlation_matrix import correlation_matrix
408+
409+
kwargs = {
410+
"data": data,
411+
"method": method,
412+
"colorscale": colorscale,
413+
"title": title,
414+
"layout_kwargs": layout_kwargs,
415+
}
416+
fig, _ = correlation_matrix(**kwargs)
417+
fig = self._set_chart_style(fig)
418+
return fig
419+
370420
# pylint: disable=inconsistent-return-statements
371421
def show(self, render: bool = True, **kwargs):
372422
"""Display chart and save it to the OBBject."""
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Correlation Matrix Chart."""
2+
3+
from typing import TYPE_CHECKING, Any, Union
4+
5+
if TYPE_CHECKING:
6+
from plotly.graph_objs import Figure # noqa
7+
from openbb_charting.core.openbb_figure import OpenBBFigure # noqa
8+
9+
10+
def correlation_matrix( # noqa: PLR0912
11+
**kwargs,
12+
) -> tuple[Union["OpenBBFigure", "Figure"], dict[str, Any]]:
13+
"""Correlation Matrix Chart."""
14+
# pylint: disable=import-outside-toplevel
15+
from numpy import ones_like, triu # noqa
16+
from openbb_core.app.utils import basemodel_to_df # noqa
17+
from openbb_charting.core.openbb_figure import OpenBBFigure
18+
from openbb_charting.core.chart_style import ChartStyle
19+
from plotly.graph_objs import Figure, Heatmap, Layout
20+
from pandas import DataFrame
21+
22+
if "data" in kwargs and isinstance(kwargs["data"], DataFrame):
23+
corr = kwargs["data"]
24+
elif "data" in kwargs and isinstance(kwargs["data"], list):
25+
corr = basemodel_to_df(kwargs["data"], index=kwargs.get("index", "date")) # type: ignore
26+
else:
27+
corr = basemodel_to_df(
28+
kwargs["obbject_item"], index=kwargs.get("index", "date") # type: ignore
29+
)
30+
if (
31+
"symbol" in corr.columns
32+
and len(corr.symbol.unique()) > 1
33+
and "close" in corr.columns
34+
):
35+
corr = corr.pivot(
36+
columns="symbol",
37+
values="close",
38+
)
39+
40+
method = kwargs.get("method") or "pearson"
41+
corr = corr.corr(method=method, numeric_only=True)
42+
43+
X = corr.columns.to_list()
44+
x_replace = X[-1]
45+
Y = X.copy()
46+
y_replace = Y[0]
47+
X = [x if x != x_replace else "" for x in X]
48+
Y = [y if y != y_replace else "" for y in Y]
49+
mask = triu(ones_like(corr, dtype=bool))
50+
df = corr.mask(mask)
51+
title = kwargs.get("title") or "Asset Correlation Matrix"
52+
text_color = "white" if ChartStyle().plt_style == "dark" else "black"
53+
colorscale = kwargs.get("colorscale") or "RdBu"
54+
heatmap = Heatmap(
55+
z=df,
56+
x=X,
57+
y=Y,
58+
xgap=1,
59+
ygap=1,
60+
colorscale=colorscale,
61+
colorbar=dict(
62+
orientation="v",
63+
x=0.9,
64+
y=0.45,
65+
xanchor="left",
66+
yanchor="middle",
67+
len=0.75,
68+
bgcolor="rgba(0,0,0,0)" if text_color == "white" else "rgba(255,255,255,0)",
69+
),
70+
text=df.fillna(""),
71+
texttemplate="%{text:.4f}",
72+
hoverinfo="skip",
73+
)
74+
layout = Layout(
75+
title=title,
76+
title_x=0.5,
77+
title_y=0.95,
78+
xaxis=dict(
79+
showgrid=False,
80+
showline=False,
81+
ticklen=0,
82+
tickfont=dict(size=16),
83+
ticklabelstandoff=10,
84+
domain=[0.05, 1],
85+
),
86+
yaxis=dict(
87+
showgrid=False,
88+
side="left",
89+
autorange="reversed",
90+
showline=False,
91+
ticklen=0,
92+
tickfont=dict(size=16),
93+
ticklabelstandoff=15,
94+
domain=[0.05, 1],
95+
),
96+
margin=dict(r=20, t=0, b=50),
97+
dragmode=False,
98+
)
99+
fig = Figure(data=[heatmap], layout=layout)
100+
figure = OpenBBFigure(fig=fig)
101+
figure.update_layout(
102+
font=dict(color=text_color),
103+
paper_bgcolor=(
104+
"rgba(0,0,0,0)" if text_color == "white" else "rgba(255,255,255,0)"
105+
),
106+
plot_bgcolor=(
107+
"rgba(0,0,0,0)" if text_color == "white" else "rgba(255,255,255,0)"
108+
),
109+
)
110+
layout_kwargs = kwargs.get("layout_kwargs", {})
111+
112+
if layout_kwargs:
113+
figure.update_layout(**layout_kwargs)
114+
115+
content = figure.show(external=True).to_plotly_json() # type: ignore
116+
117+
return figure, content

0 commit comments

Comments
 (0)