diff --git a/bench/ndarray/matmul.ipynb b/bench/ndarray/matmul.ipynb new file mode 100644 index 000000000..331159786 --- /dev/null +++ b/bench/ndarray/matmul.ipynb @@ -0,0 +1,4941 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": "" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-02-19T12:37:40.724835Z", + "start_time": "2025-02-19T12:37:40.720790Z" + } + }, + "cell_type": "code", + "source": [ + "import numpy as np\n", + "import blosc2\n", + "import time\n", + "import plotly.express as px\n", + "import pandas as pd" + ], + "outputs": [], + "execution_count": 8 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-02-19T12:37:43.543794Z", + "start_time": "2025-02-19T12:37:43.539841Z" + } + }, + "cell_type": "code", + "source": [ + "N_tams = [1_000, 2_000, 5_000] #, 10_000]\n", + "sizes_auto = []\n", + "bandwidths_auto = []\n", + "sizes_1000 = []\n", + "bandwidths_1000 = []\n", + "chunksizes = [None, (1_000, 1_000)]\n", + "cparams = blosc2.CParams(codec=blosc2.Codec.LZ4, clevel=1)\n" + ], + "outputs": [], + "execution_count": 9 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-02-19T12:37:57.442611Z", + "start_time": "2025-02-19T12:37:45.378449Z" + } + }, + "cell_type": "code", + "source": [ + "for N in N_tams:\n", + " shape_a = (N, N)\n", + " shape_b = (N, N)\n", + "\n", + " # Generate matrices\n", + " matrix_a_np = np.linspace(0, 10, np.prod(shape_a)).reshape(shape_a)\n", + " matrix_b_np = np.linspace(0, 10, np.prod(shape_b)).reshape(shape_b)\n", + "\n", + " for chunk in chunksizes:\n", + " # Convert NumPy to Blosc2\n", + " matrix_a_blosc2 = blosc2.asarray(matrix_a_np, cparams=cparams, chunks=chunk)\n", + " matrix_b_blosc2 = blosc2.asarray(matrix_b_np, cparams=cparams, chunks=chunk)\n", + "\n", + " # Blosc2 multiplication\n", + " t0 = time.perf_counter()\n", + " result_blosc2 = blosc2.matmul(matrix_a_blosc2, matrix_b_blosc2)\n", + " blosc2_time = time.perf_counter() - t0\n", + "\n", + " # Gather the data\n", + " size_mb = (np.prod(shape_a) * 8) / 2**20\n", + " bandwidth = size_mb / blosc2_time\n", + " if chunk is None:\n", + " sizes_auto.append(size_mb)\n", + " bandwidths_auto.append(bandwidth)\n", + " else:\n", + " sizes_1000.append(size_mb)\n", + " bandwidths_1000.append(bandwidth)\n", + "\n", + " print(f\"N={N}, Chunks = {chunk}, Bandwidth = {bandwidth:.2f} MB/s\")" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "N=1000, Chunks = None, Bandwidth = 145.82 MB/s\n", + "N=1000, Chunks = (1000, 1000), Bandwidth = 196.22 MB/s\n", + "N=2000, Chunks = None, Bandwidth = 94.87 MB/s\n", + "N=2000, Chunks = (1000, 1000), Bandwidth = 112.36 MB/s\n", + "N=5000, Chunks = None, Bandwidth = 34.51 MB/s\n", + "N=5000, Chunks = (1000, 1000), Bandwidth = 36.33 MB/s\n" + ] + } + ], + "execution_count": 10 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-02-19T12:37:57.547006Z", + "start_time": "2025-02-19T12:37:57.450620Z" + } + }, + "cell_type": "code", + "source": [ + "df = pd.DataFrame({\n", + " \"Matrix Size (MB)\": sizes_auto + sizes_1000,\n", + " \"Bandwidth (MB/s)\": bandwidths_auto + bandwidths_1000,\n", + " \"Chunk Size\": [\"Auto\" for _ in sizes_auto] + [\"1000x1000\" for _ in sizes_1000]\n", + "})\n", + "\n", + "fig = px.line(df,\n", + " x=\"Matrix Size (MB)\",\n", + " y=\"Bandwidth (MB/s)\",\n", + " color=\"Chunk Size\",\n", + " title=\"Bandwidth of Blosc2 Matrix-Matrix Multiplication\",\n", + " labels={\"value\": \"MB/s\", \"variable\": \"Metric\"})\n", + "\n", + "fig.show()" + ], + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "data": [ + { + "hovertemplate": "Chunk Size=Auto
Matrix Size (MB)=%{x}
Bandwidth (MB/s)=%{y}", + "legendgroup": "Auto", + "line": { + "color": "#636efa", + "dash": "solid" + }, + "marker": { + "symbol": "circle" + }, + "mode": "lines", + "name": "Auto", + "orientation": "v", + "showlegend": true, + "x": { + "dtype": "f8", + "bdata": "AAAAAICEHkAAAAAAgIQ+QAAAAACE12dA" + }, + "xaxis": "x", + "y": { + "dtype": "f8", + "bdata": "PqH1ckA6YkB/qygGxrdXQDGdXyFkQUFA" + }, + "yaxis": "y", + "type": "scatter" + }, + { + "hovertemplate": "Chunk Size=1000x1000
Matrix Size (MB)=%{x}
Bandwidth (MB/s)=%{y}", + "legendgroup": "1000x1000", + "line": { + "color": "#EF553B", + "dash": "solid" + }, + "marker": { + "symbol": "circle" + }, + "mode": "lines", + "name": "1000x1000", + "orientation": "v", + "showlegend": true, + "x": { + "dtype": "f8", + "bdata": "AAAAAICEHkAAAAAAgIQ+QAAAAACE12dA" + }, + "xaxis": "x", + "y": { + "dtype": "f8", + "bdata": "1b3Jh+KGaEC16njDOhdcQMADa9++KkJA" + }, + "yaxis": "y", + "type": "scatter" + } + ], + "layout": { + "template": { + "data": { + "histogram2dcontour": [ + { + "type": "histogram2dcontour", + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0.0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1.0, + "#f0f921" + ] + ] + } + ], + "choropleth": [ + { + "type": "choropleth", + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + ], + "histogram2d": [ + { + "type": "histogram2d", + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0.0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1.0, + "#f0f921" + ] + ] + } + ], + "heatmap": [ + { + "type": "heatmap", + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0.0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1.0, + "#f0f921" + ] + ] + } + ], + "contourcarpet": [ + { + "type": "contourcarpet", + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + ], + "contour": [ + { + "type": "contour", + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0.0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1.0, + "#f0f921" + ] + ] + } + ], + "surface": [ + { + "type": "surface", + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0.0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1.0, + "#f0f921" + ] + ] + } + ], + "mesh3d": [ + { + "type": "mesh3d", + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "parcoords": [ + { + "type": "parcoords", + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + } + ], + "scatterpolargl": [ + { + "type": "scatterpolargl", + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + } + ], + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "scattergeo": [ + { + "type": "scattergeo", + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + } + ], + "scatterpolar": [ + { + "type": "scatterpolar", + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "scattergl": [ + { + "type": "scattergl", + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + } + ], + "scatter3d": [ + { + "type": "scatter3d", + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + } + ], + "scattermap": [ + { + "type": "scattermap", + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + } + ], + "scattermapbox": [ + { + "type": "scattermapbox", + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + } + ], + "scatterternary": [ + { + "type": "scatterternary", + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + } + ], + "scattercarpet": [ + { + "type": "scattercarpet", + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + } + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ] + }, + "layout": { + "autotypenumbers": "strict", + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "hovermode": "closest", + "hoverlabel": { + "align": "left" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "bgcolor": "#E5ECF6", + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "ternary": { + "bgcolor": "#E5ECF6", + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "sequential": [ + [ + 0.0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1.0, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0.0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1.0, + "#f0f921" + ] + ], + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ] + }, + "xaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "automargin": true, + "zerolinewidth": 2 + }, + "yaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "automargin": true, + "zerolinewidth": 2 + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white", + "gridwidth": 2 + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white", + "gridwidth": 2 + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white", + "gridwidth": 2 + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "geo": { + "bgcolor": "white", + "landcolor": "#E5ECF6", + "subunitcolor": "white", + "showland": true, + "showlakes": true, + "lakecolor": "white" + }, + "title": { + "x": 0.05 + }, + "mapbox": { + "style": "light" + } + } + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0.0, + 1.0 + ], + "title": { + "text": "Matrix Size (MB)" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0.0, + 1.0 + ], + "title": { + "text": "Bandwidth (MB/s)" + } + }, + "legend": { + "title": { + "text": "Chunk Size" + }, + "tracegroupgap": 0 + }, + "title": { + "text": "Bandwidth of Blosc2 Matrix-Matrix Multiplication" + } + }, + "config": { + "plotlyServerURL": "https://plot.ly" + } + }, + "text/html": [ + "
\n", + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 11 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-02-19T12:37:06.556837Z", + "start_time": "2025-02-19T12:37:06.553836Z" + } + }, + "cell_type": "code", + "source": "", + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/reference/array_operations.rst b/doc/reference/array_operations.rst index 9ebaea351..5df6447d8 100644 --- a/doc/reference/array_operations.rst +++ b/doc/reference/array_operations.rst @@ -6,3 +6,4 @@ Operations with arrays lazy_functions reduction_functions + linear_algebra diff --git a/doc/reference/linear_algebra.rst b/doc/reference/linear_algebra.rst new file mode 100644 index 000000000..fbbda988e --- /dev/null +++ b/doc/reference/linear_algebra.rst @@ -0,0 +1,14 @@ +.. _linear_algebra: + +Linear Algebra +-------------- + +The next functions can be used for computing linear algebra operations with :ref:`NDArray `. + +.. currentmodule:: blosc2 + +.. autosummary:: + :toctree: autofiles/operations_with_arrays/ + :nosignatures: + + matmul diff --git a/src/blosc2/__init__.py b/src/blosc2/__init__.py index a28c62888..4fbd82268 100644 --- a/src/blosc2/__init__.py +++ b/src/blosc2/__init__.py @@ -236,6 +236,7 @@ class Tuner(Enum): ones, full, save, + matmul, ) from .c2array import c2context, C2Array, URLPath diff --git a/src/blosc2/ndarray.py b/src/blosc2/ndarray.py index fbfddd59c..98c055377 100644 --- a/src/blosc2/ndarray.py +++ b/src/blosc2/ndarray.py @@ -1845,14 +1845,14 @@ def slice(self, key: int | slice | Sequence[slice], **kwargs: Any) -> NDArray: return ndslice - def squeeze(self) -> None: + def squeeze(self) -> NDArray: """Remove single-dimensional entries from the shape of the array. This method modifies the array in-place, removing any dimensions with size 1. Returns ------- - out: None + out: NDArray Examples -------- @@ -1868,6 +1868,7 @@ def squeeze(self) -> None: (23, 11) """ super().squeeze() + return self def indices(self, order: str | list[str] | None = None, **kwargs: Any) -> NDArray: """ @@ -3643,6 +3644,113 @@ def sort(array: NDArray, order: str | list[str] | None = None, **kwargs: Any) -> return larr.sort(order).compute(**kwargs) +def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray: + """ + Computes the matrix product between two Blosc2 NDArrays. + + Parameters + ---------- + x1: `NDArray` + The first input array. + x2: `NDArray` + The second input array. + kwargs: Any, optional + Keyword arguments that are supported by the :func:`empty` constructor. + + Returns + ------- + out: :ref:`NDArray` + The matrix product of the inputs. This is a scalar only when both x1, + x2 are 1-d vectors. + + Raises + ------ + ValueError + If the last dimension of ``x1`` is not the same size as + the second-to-last dimension of ``x2``. + + If a scalar value is passed in. + + References + ---------- + `numpy.matmul `_ + + Examples + -------- + For 2-D arrays it is the matrix product: + + >>> import numpy as np + >>> import blosc2 + >>> a = np.array([[1, 2], + ... [3, 4]]) + >>> nd_a = blosc2.asarray(a) + >>> b = np.array([[2, 3], + ... [2, 1]]) + >>> nd_b = blosc2.asarray(b) + >>> blosc2.matmul(nd_a, nd_b) + array([[ 6, 5], + [14, 13]]) + + For 2-D mixed with 1-D, the result is the usual. + + >>> a = np.array([[1, 3], + ... [0, 1]]) + >>> nd_a = blosc2.asarray(a) + >>> v = np.array([1, 2]) + >>> nd_v = blosc2.asarray(v) + >>> blosc2.matmul(nd_a, nd_v) + array([7, 2]) + >>> blosc2.matmul(nd_v, nd_a) + array([1, 5]) + + """ + + # Validate arguments are not scalars + if np.isscalar(x1) or np.isscalar(x2): + raise ValueError("Arguments can't be scalars.") + + # Validate arguments are dimension 1 or 2 + if x1.ndim > 2 or x2.ndim > 2: + raise ValueError("Multiplication of arrays with dimension greater than 2 is not supported yet.") + + # Promote 1D arrays to 2D if necessary + x1_is_vector = False + x2_is_vector = False + if x1.ndim == 1: + x1 = x1.reshape((1, x1.shape[0])) # (N,) -> (1, N) + x1_is_vector = True + if x2.ndim == 1: + x2 = x2.reshape((x2.shape[0], 1)) # (M,) -> (M, 1) + x2_is_vector = True + + # Validate matrix multiplication compatibility + if x1.shape[-1] != x2.shape[-2]: + raise ValueError("Shapes are not aligned for matrix multiplication.") + + n, k = x1.shape[-2:] + m = x2.shape[-1] + + p1, q1 = x1.chunks[-2:] + q2 = x2.chunks[-1] + + result = blosc2.zeros((n, m), dtype=np.result_type(x1, x2), **kwargs) + + for row in range(0, n, p1): + row_end = (row + p1) if (row + p1) < n else n + for col in range(0, m, q2): + col_end = (col + q2) if (col + q2) < m else m + for aux in range(0, k, q1): + aux_end = (aux + q1) if (aux + q1) < k else k + bx1 = x1[row:row_end, aux:aux_end] + bx2 = x2[aux:aux_end, col:col_end] + result[row:row_end, col:col_end] += np.matmul(bx1, bx2) + + if x1_is_vector and x2_is_vector: + return result[0][0] + + return result.squeeze() + + # Class for dealing with fields in an NDArray # This will allow to access fields by name in the dtype of the NDArray class NDField(Operand): diff --git a/tests/ndarray/test_matmul.py b/tests/ndarray/test_matmul.py new file mode 100644 index 000000000..121be1c9d --- /dev/null +++ b/tests/ndarray/test_matmul.py @@ -0,0 +1,157 @@ +import numpy as np +import pytest + +import blosc2 + + +@pytest.mark.parametrize( + ("ashape", "achunks", "ablocks"), + { + ((12, 10), (7, 5), (3, 3)), + ((10,), (9,), (7,)), + }, +) +@pytest.mark.parametrize( + ("bshape", "bchunks", "bblocks"), + { + ((10,), (4,), (2,)), + ((10, 5), (3, 4), (1, 3)), + ((10, 12), (2, 4), (1, 2)), + }, +) +@pytest.mark.parametrize( + "dtype", + {np.float32, np.float64, np.complex64, np.complex128}, +) +def test_matmul(ashape, achunks, ablocks, bshape, bchunks, bblocks, dtype): + a = blosc2.linspace(0, 10, dtype=dtype, shape=ashape, chunks=achunks, blocks=ablocks) + b = blosc2.linspace(0, 10, dtype=dtype, shape=bshape, chunks=bchunks, blocks=bblocks) + c = blosc2.matmul(a, b) + + na = a[:] + nb = b[:] + nc = np.matmul(na, nb) + + np.testing.assert_allclose(c, nc, rtol=1e-6) + + +@pytest.mark.parametrize( + ("ashape", "achunks", "ablocks"), + { + ((12, 11), (7, 5), (3, 1)), + ((0, 0), (0, 0), (0, 0)), + ((10,), (4,), (2,)), + }, +) +@pytest.mark.parametrize( + ("bshape", "bchunks", "bblocks"), + { + ((1, 5), (1, 4), (1, 3)), + ((4, 6), (2, 4), (1, 3)), + ((5,), (4,), (2,)), + }, +) +def test_shapes(ashape, achunks, ablocks, bshape, bchunks, bblocks): + a = blosc2.linspace(0, 10, shape=ashape, chunks=achunks, blocks=ablocks) + b = blosc2.linspace(0, 10, shape=bshape, chunks=bchunks, blocks=bblocks) + + with pytest.raises(ValueError): + blosc2.matmul(a, b) + + with pytest.raises(ValueError): + blosc2.matmul(b, a) + + +@pytest.mark.parametrize( + "scalar", + { + 5, # int + 5.3, # float + 1 + 2j, # complex + np.int32(5), # NumPy int32 + np.int64(5), # NumPy int64 + np.float32(5.3), # NumPy float32 + np.float64(5.3), # NumPy float64 + np.complex64(1 + 2j), # NumPy complex64 + np.complex128(1 + 2j), # NumPy complex128 + }, +) +def test_scalars(scalar): + vector = blosc2.asarray(np.array([1, 2, 3])) + + with pytest.raises(ValueError): + blosc2.matmul(scalar, vector) + + with pytest.raises(ValueError): + blosc2.matmul(vector, scalar) + + with pytest.raises(ValueError): + blosc2.matmul(scalar, scalar) + + +@pytest.mark.parametrize( + "ashape", + [ + (12, 10, 10), + (3, 3, 3), + ], +) +@pytest.mark.parametrize( + "bshape", + [ + (10, 10, 10, 11), + (3, 2), + ], +) +def test_dims(ashape, bshape): + a = blosc2.linspace(0, 10, shape=ashape) + b = blosc2.linspace(0, 1, shape=bshape) + + with pytest.raises(ValueError): + blosc2.matmul(a, b) + + with pytest.raises(ValueError): + blosc2.matmul(b, a) + + +@pytest.mark.parametrize( + ("ashape", "achunks", "ablocks", "adtype"), + { + ((7, 10), (7, 5), (3, 5), np.float32), + ((10,), (9,), (7,), np.complex64), + }, +) +@pytest.mark.parametrize( + ("bshape", "bchunks", "bblocks", "bdtype"), + { + ((10,), (4,), (2,), np.float64), + ((10, 6), (9, 4), (2, 3), np.complex128), + ((10, 12), (2, 4), (1, 2), np.complex128), + }, +) +def test_special_cases(ashape, achunks, ablocks, adtype, bshape, bchunks, bblocks, bdtype): + a = blosc2.linspace(0, 10, dtype=adtype, shape=ashape, chunks=achunks, blocks=ablocks) + b = blosc2.linspace(0, 10, dtype=bdtype, shape=bshape, chunks=bchunks, blocks=bblocks) + c = blosc2.matmul(a, b) + + na = a[:] + nb = b[:] + nc = np.matmul(na, nb) + + np.testing.assert_allclose(c, nc, rtol=1e-6) + + +def test_disk(): + a = blosc2.linspace(0, 1, shape=(3, 4), urlpath="a_test.b2nd", mode="w") + b = blosc2.linspace(0, 1, shape=(4, 2), urlpath="b_test.b2nd", mode="w") + c = blosc2.matmul(a, b, urlpath="c_test.b2nd", mode="w") + + na = a[:] + nb = b[:] + nc = np.matmul(na, nb) + + np.testing.assert_allclose(c, nc, rtol=1e-6) + + blosc2.remove_urlpath("a_test.b2nd") + blosc2.remove_urlpath("b_test.b2nd") + blosc2.remove_urlpath("c_test.b2nd")