diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index 64e82d66e2..b58590e80d 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -206,6 +206,8 @@ def read_sql( """ if isinstance(query, list) and len(query) == 1: query = query[0] + query = remove_ending_semicolon(query) + if isinstance(conn, dict): assert partition_on is None and isinstance( @@ -214,6 +216,9 @@ def read_sql( assert ( protocol is None ), "Federated query does not support specifying protocol for now" + + query = remove_ending_semicolon(query) + result = _read_sql2(query, conn) df = reconstruct_arrow(result) if return_type == "pandas": @@ -232,6 +237,9 @@ def read_sql( return df if isinstance(query, str): + + query = remove_ending_semicolon(query) + if partition_on is None: queries = [query] partition_query = None @@ -245,7 +253,7 @@ def read_sql( } queries = None elif isinstance(query, list): - queries = query + queries = [remove_ending_semicolon(subquery) for subquery in query] partition_query = None if partition_on is not None: @@ -377,3 +385,11 @@ def reconstruct_pandas(df_infos: Dict[str, Any]): ) df = pd.DataFrame(block_manager) return df + + +def remove_ending_semicolon(query: str) -> str: + if query[-1] == ';': + query= list(query) + query.pop(-1) + query = "".join(query) + return query diff --git a/connectorx-python/connectorx/tests/test_postgres.py b/connectorx-python/connectorx/tests/test_postgres.py index 4f636fb020..620c5333b5 100644 --- a/connectorx-python/connectorx/tests/test_postgres.py +++ b/connectorx-python/connectorx/tests/test_postgres.py @@ -1138,4 +1138,42 @@ def test_postgres_name_type(postgres_url: str) -> None: "test_name": pd.Series(["0", "21", "someName", "101203203-1212323-22131235"]), }, ) + assert_frame_equal(df, expected, check_names=True) + + + +def test_postgres_semicolon_support_str_query(postgres_url: str) -> None: + query = "SELECT test_name FROM test_types;" + df = read_sql(postgres_url, query) + expected = pd.DataFrame( + data={ + "test_name": pd.Series(["0", "21", "someName", "101203203-1212323-22131235"]), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_postgres_semicolon_list_queries(postgres_url: str) -> None: + queries = [ + "SELECT * FROM test_table WHERE test_int < 2;", + "SELECT * FROM test_table WHERE test_int >= 2;", + ] + + df = read_sql(postgres_url, query=queries) + + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([0, 1, 2, 3, 4, 1314], dtype="Int64"), + "test_nullint": pd.Series([5, 3, None, 7, 9, 2], dtype="Int64"), + "test_str": pd.Series( + ["a", "str1", "str2", "b", "c", None], dtype="object" + ), + "test_float": pd.Series([3.1, None, 2.2, 3, 7.8, -10], dtype="float64"), + "test_bool": pd.Series( + [None, True, False, False, None, True], dtype="boolean" + ), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) assert_frame_equal(df, expected, check_names=True) \ No newline at end of file