Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

[to#811]Fix special characters in PG url and Mysql connection reconnect #812

Merged
merged 8 commits into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion data_diff/databases/mysql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, ClassVar, Dict, Type
from typing import Any, ClassVar, Dict, Type, Union

import attrs

Expand All @@ -20,6 +20,7 @@
import_helper,
ConnectError,
BaseDialect,
ThreadLocalInterpreter,
)
from data_diff.databases.base import (
MD5_HEXDIGITS,
Expand Down Expand Up @@ -148,3 +149,11 @@ def create_connection(self):
elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR:
raise ConnectError("Database does not exist") from e
raise ConnectError(*e.args) from e

def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]):
"This method runs in a worker thread"
if self._init_error:
raise self._init_error
if not self.thread_local.conn.is_connected():
self.thread_local.conn.ping(reconnect=True, attempts=3, delay=5)
return self._query_conn(self.thread_local.conn, sql_code)
3 changes: 2 additions & 1 deletion data_diff/databases/postgresql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, ClassVar, Dict, List, Type

from urllib.parse import unquote
import attrs

from data_diff.abcs.database_types import (
Expand Down Expand Up @@ -168,6 +168,7 @@ def create_connection(self):

pg = import_postgresql()
try:
self._args["password"] = unquote(self._args["password"])
self._conn = pg.connect(
**self._args, keepalives=1, keepalives_idle=5, keepalives_interval=2, keepalives_count=2
)
Expand Down
42 changes: 41 additions & 1 deletion tests/test_postgresql.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import unittest

from urllib.parse import quote
from data_diff.queries.api import table, commit
from data_diff import TableSegment, HashDiffer
from data_diff import databases as db
from tests.common import get_conn, random_table_suffix
from tests.common import get_conn, random_table_suffix, connect
from data_diff import connect_to_table


class TestUUID(unittest.TestCase):
Expand Down Expand Up @@ -113,3 +115,41 @@ def test_100_fields(self):
id_ = diff[0][1][0]
result = (id_,) + tuple("1" for x in range(100))
self.assertEqual(diff, [("-", result)])


class TestSpecialCharacterPassword(unittest.TestCase):
def setUp(self) -> None:
self.connection = get_conn(db.PostgreSQL)

table_suffix = random_table_suffix()

self.table_name = f"table{table_suffix}"
self.table = table(self.table_name)

def test_special_char_password(self):
password = "passw!!!@rd"
# Setup user with special character '@' in password
self.connection.query("DROP USER IF EXISTS test;", None)
self.connection.query(f"CREATE USER test WITH PASSWORD '{password}';", None)

password_quoted = quote(password)
db_config = {
"driver": "postgresql",
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "test",
"password": password_quoted,
}

# verify pythonic connection method
connect_to_table(
db_config,
self.table_name,
)

# verify connection method with URL string unquoted after it's verified
db_url = f"postgresql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['dbname']}"

connection_verified = connect(db_url)
assert connection_verified._args.get("password") == password