Skip to content

Commit bd33da8

Browse files
authored
Replace ARRAY_CONSTRUCT with Array (#19)
* f * gitignore * f * vbump + lint
1 parent d78ecac commit bd33da8

7 files changed

Lines changed: 2695 additions & 766 deletions

File tree

.gitignore

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,12 @@
11
.vscode
2+
3+
# Pycharm
4+
.idea
5+
.idea/
6+
*.DS_Store
7+
.python-version
8+
9+
# Byte-compiled / optimized / DLL files
10+
__pycache__/
11+
*.py[cod]
12+
*$py.class

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Yellowbox Snowglobe Changelog
2-
## NEXT
2+
## 0.2.6
3+
### Added
4+
* Added support for ARRAY_CONSTRUCT for PostgreSQL
35
### Internal
46
* removed cahce action from cicd
57
* removed autofix from lint script

poetry.lock

Lines changed: 2594 additions & 759 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "yellowbox-snowglobe"
3-
version = "0.2.5"
3+
version = "0.2.6"
44
description = ""
55
authors = ["Biocatch LTD <serverteam@biocatch.com>"]
66
license = "MIT"
@@ -9,7 +9,7 @@ license = "MIT"
99
python = "^3.8"
1010
yellowbox = { version = ">=0.7.0", extras = ["postgresql", "webserver"] }
1111

12-
[tool.poetry.dev-dependencies]
12+
[tool.poetry.group.dev.dependencies]
1313
pytest = "*"
1414
pytest-asyncio = "*"
1515
pytest-cov = "*"

tests/blackbox/test_simple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_connect(docker_client):
1818
assert results == [(2, "two"), (3, "three")]
1919

2020

21-
@mark.asyncio()
21+
@mark.asyncio
2222
async def test_connect_async(docker_client):
2323
async with SnowGlobeService.arun(docker_client) as service:
2424
with connector.connect(**service.local_connection_kwargs()) as conn:
@@ -32,7 +32,7 @@ async def test_connect_async(docker_client):
3232
assert results == [(2, "two"), (3, "three")]
3333

3434

35-
@mark.asyncio()
35+
@mark.asyncio
3636
async def test_simultaneous_connections(docker_client):
3737
async with SnowGlobeService.arun(docker_client) as service:
3838
with connector.connect(**service.local_connection_kwargs()) as conn:

tests/blackbox/test_syntax.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_lateral_flatten(connection):
1919
)
2020
res = (
2121
connection.cursor()
22-
.execute("select x, root.value from bar, lateral flatten(d) as root" " where root.value*root.value = x;")
22+
.execute("select x, root.value from bar, lateral flatten(d) as root where root.value*root.value = x;")
2323
.fetchall()
2424
)
2525
assert res == [(4, 2), (9, 3)]
@@ -56,7 +56,7 @@ def test_bools(connection):
5656
def test_null_bools_and_dates(connection):
5757
connection.cursor().execute("create table bar (x timestamp, y boolean)")
5858
connection.cursor().execute(
59-
"insert into bar values " "('2014-01-01 16:00:00', true)," " (null, false), " "('2023-01-08 17:00:00', null)"
59+
"insert into bar values ('2014-01-01 16:00:00', true), (null, false), ('2023-01-08 17:00:00', null)"
6060
)
6161
res = connection.cursor().execute("select * from bar;").fetchall()
6262
assert res == [(datetime(2014, 1, 1, 16, 0), True), (None, False), (datetime(2023, 1, 8, 17, 0), None)]
@@ -82,3 +82,31 @@ def test_json(connection, db, query, expected):
8282
connection.cursor().execute("""insert into bar values (1, '{"a":"1", "b":"1"}'), (2, '{"a":"2", "b":"2"}')""")
8383
res = connection.cursor().execute(query).fetchall()
8484
assert res == expected
85+
86+
87+
@mark.parametrize(
88+
("query", "expected"),
89+
[
90+
(
91+
"select ARRAY_CONSTRUCT(NULL, x) from bar;",
92+
[
93+
([None, "hello"],),
94+
],
95+
),
96+
(
97+
"select ARRAY_CONSTRUCT(ARRAY_CONSTRUCT(NULL, x)) from bar;",
98+
[
99+
(
100+
[
101+
[None, "hello"],
102+
],
103+
),
104+
],
105+
),
106+
],
107+
)
108+
def test_array(connection, db, query, expected):
109+
connection.cursor().execute("create table bar (x text)")
110+
connection.cursor().execute("""insert into bar (x) values ('hello')""")
111+
res = connection.cursor().execute(query).fetchall()
112+
assert res == expected

yellowbox_snowglobe/snow_to_post.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,57 @@ def split_sql_to_statements(query: str) -> Iterator[str]:
5757
yield last_bit
5858

5959

60+
def find_matching_paren(text: str, start_pos: int) -> int:
61+
"""
62+
Find the position of the closing parenthesis ')' corresponding to the opening parenthesis '(' at start_pos.
63+
Handle nested parentheses by counting open/closed parentheses.
64+
"""
65+
if start_pos >= len(text) or text[start_pos] != "(":
66+
return -1
67+
depth = 1
68+
pos = start_pos + 1
69+
while pos < len(text) and depth > 0:
70+
if text[pos] == "(":
71+
depth += 1
72+
elif text[pos] == ")":
73+
depth -= 1
74+
pos += 1
75+
return pos - 1 if depth == 0 else -1
76+
77+
78+
def replace_array_construct(text: str) -> str:
79+
"""
80+
Replaces all occurrences of ARRAY_CONSTRUCT(...) with Array[...] while correctly handling nested parentheses.
81+
"""
82+
result = []
83+
i = 0
84+
while i < len(text):
85+
# Search for "ARRAY_CONSTRUCT(" (case-insensitive)
86+
array_match = re.search(r"(?i)\bARRAY_CONSTRUCT\(", text[i:])
87+
if not array_match:
88+
result.append(text[i:])
89+
break
90+
91+
array_start = i + array_match.start()
92+
paren_start = i + array_match.end() - 1 # Position of the '('
93+
94+
# Find the corresponding closing parenthesis
95+
paren_end = find_matching_paren(text, paren_start)
96+
if paren_end == -1:
97+
# Parenthesis not closed, leave as is
98+
result.append(text[i:])
99+
break
100+
101+
# Add text before ARRAY_CONSTRUCT(
102+
result.append(text[i:array_start])
103+
# Add Array[ with the content between parentheses
104+
content = text[paren_start + 1 : paren_end]
105+
result.append(f"Array[{content}]")
106+
i = paren_end + 1
107+
108+
return "".join(result)
109+
110+
60111
"""
61112
A Rule is replacement rule that converts a snowflake-dialect query to a postgresql query.
62113
for example there's a rule that will turn "a..b" into "a.public.b"
@@ -146,6 +197,8 @@ class Rule:
146197
def repl_part(part: Union[str, TextLiteral], rules: Iterable[Rule]) -> str:
147198
if isinstance(part, TextLiteral):
148199
return part.value
200+
# Replace ARRAY_CONSTRUCT() with Array[] before applying the other rules
201+
part = replace_array_construct(part)
149202
ret_parts = []
150203
while part:
151204
best_match = None

0 commit comments

Comments
 (0)