Skip to content

Improve data serialization #483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 75 additions & 33 deletions ipydatagrid/datagrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pandas as pd
from bqplot.traits import array_from_json, array_to_json
from ipywidgets import CallbackDispatcher, DOMWidget, widget_serialization
from traitlets import (
Bool,
Expand Down Expand Up @@ -76,7 +77,6 @@ def _cell_in_rect(cell, rect):


class SelectionHelper:

"""A Helper Class for processing selections. Provides an iterator
to traverse selected cells.
"""
Expand Down Expand Up @@ -164,14 +164,13 @@ def _get_num_rows(self):
return self._num_rows


# modified from ipywidgets original
def _data_to_json(x):
def _data_to_json(x, _):
if isinstance(x, dict):
return {str(k): _data_to_json(v) for k, v in x.items()}
return {str(k): _data_to_json(v, _) for k, v in x.items()}
if isinstance(x, np.ndarray):
return _data_to_json(x.tolist())
return _data_to_json(x.tolist(), _)
if isinstance(x, (list, tuple)):
return [_data_to_json(v) for v in x]
return [_data_to_json(v, _) for v in x]
if isinstance(x, int):
return x
if isinstance(x, float):
Expand All @@ -193,9 +192,55 @@ def _data_to_json(x):
return str(x)


def _data_serialization_impl(data, _):
if not data:
return {}

serialized_data = {}
for column, value in data["data"].items():
arr = value.to_numpy()
if arr.size == 0:
serialized_data[str(column)] = {
"value": [],
"dtype": str(arr.dtype),
"shape": arr.shape,
"type": None,
}
continue
try:
serialized_data[str(column)] = array_to_json(arr)
except ValueError:
# Column is most likely heterogeneous, sending the column raw
serialized_data[str(column)] = {
"value": _data_to_json(arr, _),
"type": "raw",
}

return {
"data": serialized_data,
"schema": data["schema"],
"fields": _data_to_json(data["fields"], _),
}


def _data_deserialization_impl(data, _): # noqa: U101
if not data:
return {}

deserialized_data = {}
for column, value in data["data"].items():
deserialized_data[column] = array_from_json(value.to_numpy())

return {
"data": deserialized_data,
"schema": data["schema"],
"fields": data["fields"],
}


_data_serialization = {
"from_json": widget_serialization["from_json"],
"to_json": lambda x, _: _data_to_json(x), # noqa: U101
"from_json": _data_deserialization_impl,
"to_json": _data_serialization_impl,
}


Expand All @@ -212,7 +257,6 @@ def _widgets_dict_to_json(x, obj):


class DataGrid(DOMWidget):

"""A Grid Widget with filter, sort and selection capabilities.

Attributes
Expand Down Expand Up @@ -360,7 +404,7 @@ class DataGrid(DOMWidget):
).tag(sync=True)
selections = List(Dict()).tag(sync=True)
editable = Bool(False).tag(sync=True)
column_widths = Dict({}).tag(sync=True, **_data_serialization)
column_widths = Dict({}).tag(sync=True, to_json=_data_to_json)
grid_style = Dict(allow_none=True).tag(
sync=True, **_widgets_dict_serialization
)
Expand All @@ -383,17 +427,15 @@ def __init__(self, dataframe, index_name=None, **kwargs):
def __handle_custom_msg(self, _, content, buffers): # noqa: U101,U100
if content["event_type"] == "cell-changed":
row = content["row"]
column = self._column_index_to_name(
self._data, content["column_index"]
)
column = content["column"]
value = content["value"]
# update data on kernel
self._data["data"][row][column] = value
self._data["data"].loc[row, column] = value
# notify python listeners
self._cell_change_handlers(
{
"row": row,
"column": column,
"column": content["column"],
"column_index": content["column_index"],
"value": value,
}
Expand All @@ -414,7 +456,7 @@ def __handle_custom_msg(self, _, content, buffers): # noqa: U101,U100
@property
def data(self):
trimmed_primary_key = self._data["schema"]["primaryKey"][:-1]
if self._data["data"]:
if "data" in self._data:
df = pd.DataFrame(self._data["data"])
else:
df = pd.DataFrame(
Expand Down Expand Up @@ -460,7 +502,7 @@ def generate_data_object(dataframe, guid_key="ipydguuid", index_name="key"):

schema = pd.io.json.build_table_schema(dataframe)
reset_index_dataframe = dataframe.reset_index()
data = reset_index_dataframe.to_dict(orient="records")
data = reset_index_dataframe

# Check for multiple primary keys
key = reset_index_dataframe.columns[: dataframe.index.nlevels].tolist()
Expand Down Expand Up @@ -522,7 +564,7 @@ def get_cell_value(self, column_name, primary_key_value):
if isinstance(column_name, list):
column_name = tuple(column_name)

return [self._data["data"][row][column_name] for row in row_indices]
return [self._data["data"][column_name][row] for row in row_indices]

def set_cell_value(self, column_name, primary_key_value, new_value):
"""Sets the value for a single cell by column name and primary key.
Expand All @@ -541,9 +583,9 @@ def set_cell_value(self, column_name, primary_key_value, new_value):
# Iterate over all indices
outcome = True
for row_index in row_indices:
has_column = column_name in self._data["data"][row_index]
has_column = column_name in self._data["data"]
if has_column and row_index is not None:
self._data["data"][row_index][column_name] = new_value
self._data["data"].loc[row_index, column_name] = new_value
self._notify_cell_change(row_index, column_name, new_value)
else:
outcome = False
Expand All @@ -565,7 +607,9 @@ def set_row_value(self, primary_key_value, new_value):
column_index = 0
column = DataGrid._column_index_to_name(self._data, column_index)
while column is not None:
self._data["data"][row_index][column] = new_value[column_index]
self._data["data"].loc[row_index, column] = new_value[
column_index
]

column_index = column_index + 1
column = DataGrid._column_index_to_name(
Expand All @@ -577,17 +621,17 @@ def set_row_value(self, primary_key_value, new_value):

def get_cell_value_by_index(self, column_name, row_index):
"""Gets the value for a single cell by column name and row index."""
return self._data["data"][row_index][column_name]
return self._data["data"][column_name][row_index]

def set_cell_value_by_index(self, column_name, row_index, new_value):
"""Sets the value for a single cell by column name and row index.

Note: This method returns a boolean to indicate if the operation
was successful.
"""
has_column = column_name in self._data["data"][row_index]
if has_column and 0 <= row_index < len(self._data["data"]):
self._data["data"][row_index][column_name] = new_value
has_column = column_name in self._data["data"]
if has_column and 0 <= row_index < len(self._data["data"][column_name]):
self._data["data"].loc[row_index, column_name] = new_value
self._notify_cell_change(row_index, column_name, new_value)
return True
return False
Expand Down Expand Up @@ -634,7 +678,7 @@ def get_visible_data(self):
"""Returns a dataframe of the current View."""
data = deepcopy(self._data)
if self._visible_rows:
data["data"] = [data["data"][i] for i in self._visible_rows]
data["data"] = data["data"].reindex(self._visible_rows)

at = self._data["schema"]["primaryKey"]
return_df = pd.DataFrame(data["data"]).set_index(at)
Expand Down Expand Up @@ -852,20 +896,18 @@ def _get_row_index_of_primary_key(self, value):
"as the primary key."
)

row_indices = [
at
for at, row in enumerate(self._data["data"])
if all(row[key[j]] == value[j] for j in range(len(key)))
]
return row_indices
df = self._data["data"]
return pd.RangeIndex(len(df))[
(df[key] == value).all(axis="columns")
].to_list()

@staticmethod
def _get_cell_value_by_numerical_index(data, column_index, row_index):
"""Gets the value for a single cell by column index and row index."""
column = DataGrid._column_index_to_name(data, column_index)
if column is None:
return None
return data["data"][row_index][column]
return data["data"].loc[row_index, column]

def _set_renderer_defaults(self):
# Set sensible default values for renderers that are not completely
Expand Down
3 changes: 2 additions & 1 deletion jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ module.exports = {
'^.+\\.tsx?$': 'ts-jest',
'^.+\\.js$': 'babel-jest',
},
transformIgnorePatterns: ['node_modules/?!(@jupyter-widgets)'],
transformIgnorePatterns: ['node_modules/?!(@jupyter-widgets)', 'node_modules/bqplot'],
testPathIgnorePatterns: ['ui-tests-ipw7/', 'ui-tests-ipw8/'],
setupFiles: ['./tests/js/setupFile.js'],
testEnvironment: 'jsdom',
moduleNameMapper: {
'\\.(css|less)$': '<rootDir>/__mocks__/styleMock.js',
"raw-loader!.*": "jest-raw-loader",
},
};
Loading