diff --git a/geoalchemy2/admin/dialects/mysql.py b/geoalchemy2/admin/dialects/mysql.py index 82a5391b..7513353d 100644 --- a/geoalchemy2/admin/dialects/mysql.py +++ b/geoalchemy2/admin/dialects/mysql.py @@ -157,9 +157,7 @@ def after_drop(table, bind, **kw): return -_MYSQL_FUNCTIONS = { - "ST_AsEWKB": "ST_AsBinary", -} +_MYSQL_FUNCTIONS = {"ST_AsEWKB": "ST_AsBinary", "ST_SetSRID": "ST_SRID"} def _compiles_mysql(cls, fn): diff --git a/tests/schema_fixtures.py b/tests/schema_fixtures.py index b868868d..8410ca7e 100644 --- a/tests/schema_fixtures.py +++ b/tests/schema_fixtures.py @@ -11,6 +11,7 @@ from geoalchemy2 import Geography from geoalchemy2 import Geometry from geoalchemy2 import Raster +from geoalchemy2.elements import WKTElement @pytest.fixture @@ -126,7 +127,22 @@ def column_expression(self, col): ) def bind_expression(self, bindvalue): - return func.ST_Transform(self.impl.bind_expression(bindvalue), self.db_srid) + return func.ST_Transform(func.ST_GeomFromText(bindvalue, self.app_srid), self.db_srid) + + def bind_processor(self, dialect): + """Specific bind_processor that automatically process spatial elements. + + Here we only use WKT representations. + """ + + def process(bindvalue): + bindvalue = WKTElement(bindvalue) + bindvalue = bindvalue.as_wkt() + if bindvalue.srid <= 0: + bindvalue.srid = self.srid + return bindvalue.desc + + return process @pytest.fixture diff --git a/tests/test_functional.py b/tests/test_functional.py index 0ac14aa2..c3d7193e 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -775,19 +775,18 @@ def test_WKBElement(self, session, Lake, setup_tables, dialect_name): assert srid == 4326 @test_only_with_dialects("postgresql", "mysql", "sqlite-spatialite3", "sqlite-spatialite4") - def test_transform(self, session, LocalPoint, setup_tables): - if session.bind.dialect.name == "mysql": - # Explicitly skip MySQL dialect to show that there is an issue - pytest.skip( - reason=( - "The SRID is not properly retrieved so an exception is raised. TODO: This " - "should be fixed later" - ) - ) + def test_transform(self, session, LocalPoint, setup_tables, dialect_name): # Create new point instance p = LocalPoint() - p.geom = "SRID=4326;POINT(5 45)" # Insert geometry with wrong SRID - p.managed_geom = "SRID=4326;POINT(5 45)" # Insert geometry with wrong SRID + if dialect_name in ["mysql", "mariadb"]: + expected_x = 45 + expected_y = 5 + else: + expected_x = 5 + expected_y = 45 + ewkt = f"SRID=4326;POINT({expected_x} {expected_y})" + p.geom = ewkt # Insert geometry with wrong SRID + p.managed_geom = ewkt # Insert geometry with wrong SRID # Insert point session.add(p) @@ -800,11 +799,11 @@ def test_transform(self, session, LocalPoint, setup_tables): assert pt.geom.srid == 4326 assert pt.managed_geom.srid == 4326 pt_wkb = to_shape(pt.geom) - assert round(pt_wkb.x, 5) == 5 - assert round(pt_wkb.y, 5) == 45 + assert round(pt_wkb.x, 5) == expected_x + assert round(pt_wkb.y, 5) == expected_y pt_wkb = to_shape(pt.managed_geom) - assert round(pt_wkb.x, 5) == 5 - assert round(pt_wkb.y, 5) == 45 + assert round(pt_wkb.x, 5) == expected_x + assert round(pt_wkb.y, 5) == expected_y # Check that the data is correct in DB using raw query q = text( @@ -818,7 +817,10 @@ def test_transform(self, session, LocalPoint, setup_tables): for i in [res_q.geom, res_q.managed_geom]: x, y = re.match(r"POINT\((\d+\.\d*) (\d+\.\d*)\)", i).groups() assert round(float(x), 3) == 857581.899 - assert round(float(y), 3) == 6435414.748 + if dialect_name in ["mysql", "mariadb"]: + assert round(float(y), 3) == 6434180.796 + else: + assert round(float(y), 3) == 6435414.748 class TestUpdateORM: