|
18 | 18 |
|
19 | 19 | from unittest import mock
|
20 | 20 |
|
| 21 | +import pandas as pd |
21 | 22 | import pytest
|
22 | 23 |
|
23 | 24 | from neo4j import (
|
|
30 | 31 | Version,
|
31 | 32 | )
|
32 | 33 | from neo4j._async_compat.util import AsyncUtil
|
33 |
| -from neo4j.data import DataHydrator |
| 34 | +from neo4j.data import ( |
| 35 | + DataHydrator, |
| 36 | + Node, |
| 37 | +) |
34 | 38 | from neo4j.exceptions import ResultNotSingleError
|
| 39 | +from neo4j.packstream import Structure |
35 | 40 |
|
36 | 41 | from ...._async_compat import mark_async_test
|
37 | 42 |
|
38 | 43 |
|
39 | 44 | class Records:
|
40 | 45 | def __init__(self, fields, records):
|
41 |
| - assert all(len(fields) == len(r) for r in records) |
42 |
| - self.fields = fields |
43 |
| - # self.records = [{"record_values": r} for r in records] |
44 |
| - self.records = records |
| 46 | + self.fields = tuple(fields) |
| 47 | + self.records = tuple(records) |
| 48 | + assert all(len(self.fields) == len(r) for r in self.records) |
45 | 49 |
|
46 | 50 | def __len__(self):
|
47 | 51 | return self.records.__len__()
|
@@ -469,3 +473,46 @@ async def test_data(num_records):
|
469 | 473 | assert await result.data("hello", "world") == expected_data
|
470 | 474 | for record in records:
|
471 | 475 | assert record.data.called_once_with("hello", "world")
|
| 476 | + |
| 477 | + |
| 478 | +@pytest.mark.parametrize( |
| 479 | + ("keys", "values", "types", "instances"), |
| 480 | + ( |
| 481 | + (["i"], zip(range(5)), ["int64"], None), |
| 482 | + (["x"], zip((n - .5) / 5 for n in range(5)), ["float64"], None), |
| 483 | + (["s"], zip(("foo", "bar", "baz", "foobar")), ["object"], None), |
| 484 | + (["l"], zip(([1, 2], [3, 4])), ["object"], None), |
| 485 | + ( |
| 486 | + ["n"], |
| 487 | + zip(( |
| 488 | + Structure(b"N", 0, ["LABEL_A"], {"a": 1, "b": 2}), |
| 489 | + Structure(b"N", 2, ["LABEL_B"], {"a": 1, "c": 1.2}), |
| 490 | + Structure(b"N", 1, ["LABEL_A", "LABEL_B"], {"a": [1, "a"]}), |
| 491 | + )), |
| 492 | + ["object"], |
| 493 | + [Node] |
| 494 | + ), |
| 495 | + ) |
| 496 | +) |
| 497 | +@mark_async_test |
| 498 | +async def test_to_df(keys, values, types, instances): |
| 499 | + values = list(values) |
| 500 | + connection = AsyncConnectionStub(records=Records(keys, values)) |
| 501 | + result = AsyncResult(connection, DataHydrator(), 1, noop, noop) |
| 502 | + await result._run("CYPHER", {}, None, None, "r", None) |
| 503 | + df = await result.to_df() |
| 504 | + |
| 505 | + assert isinstance(df, pd.DataFrame) |
| 506 | + assert df.keys().to_list() == keys |
| 507 | + assert len(df) == len(values) |
| 508 | + assert df.dtypes.to_list() == types |
| 509 | + |
| 510 | + expected_df = pd.DataFrame( |
| 511 | + {k: [v[i] for v in values] for i, k in enumerate(keys)} |
| 512 | + ) |
| 513 | + |
| 514 | + if instances: |
| 515 | + for i, k in enumerate(keys): |
| 516 | + assert all(isinstance(v, instances[i]) for v in df[k]) |
| 517 | + else: |
| 518 | + assert df.equals(expected_df) |
0 commit comments