Skip to content

Commit ed00059

Browse files
committed
Implement GEO commands
1 parent 6d729f6 commit ed00059

File tree

3 files changed

+175
-82
lines changed

3 files changed

+175
-82
lines changed

fakeredis/_msgs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,4 @@
6363
FLAG_NO_SCRIPT = 's' # Command not allowed in scripts
6464
FLAG_LEAVE_EMPTY_VAL = 'v'
6565
FLAG_TRANSACTION = 't'
66+
GEO_UNSUPPORTED_UNIT = 'unsupported unit provided. please use M, KM, FT, MI'

fakeredis/commands_mixins/geo_mixin.py

Lines changed: 109 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,94 @@
11
import sys
22
from collections import namedtuple
3-
from typing import List, Optional, Any
3+
from typing import List, Any
44

55
from fakeredis import _msgs as msgs
66
from fakeredis._command_args_parsing import extract_args
7-
from fakeredis._commands import command, Key, Float
7+
from fakeredis._commands import command, Key, Float, CommandItem
88
from fakeredis._helpers import SimpleError
99
from fakeredis._zset import ZSet
1010
from fakeredis.geo import geohash
1111
from fakeredis.geo.haversine import distance
1212

13+
UNIT_TO_M = {'km': 0.001, 'mi': 0.000621371, 'ft': 3.28084, 'm': 1}
14+
1315

1416
def translate_meters_to_unit(unit_arg: bytes) -> float:
15-
unit_str = unit_arg.decode().lower()
16-
if unit_str == 'km':
17-
unit = 0.001
18-
elif unit_str == 'mi':
19-
unit = 0.000621371
20-
elif unit_str == 'ft':
21-
unit = 3.28084
22-
else: # meter
23-
unit = 1
17+
"""number of meters in a unit.
18+
:param unit_arg: unit name (km, mi, ft, m)
19+
:returns: number of meters in unit
20+
"""
21+
unit = UNIT_TO_M.get(unit_arg.decode().lower())
22+
if unit is None:
23+
raise SimpleError(msgs.GEO_UNSUPPORTED_UNIT)
2424
return unit
2525

2626

2727
GeoResult = namedtuple('GeoResult', 'name long lat hash distance')
2828

2929

30+
def _parse_results(
31+
items: List[GeoResult],
32+
withcoord: bool, withdist: bool) -> List[Any]:
33+
"""Parse list of GeoResults to redis response
34+
:param withcoord: include coordinates in response
35+
:param withdist: include distance in response
36+
:returns: Parsed list
37+
"""
38+
res = list()
39+
for item in items:
40+
new_item = [item.name, ]
41+
if withdist:
42+
new_item.append(Float.encode(item.distance, False))
43+
if withcoord:
44+
new_item.append([Float.encode(item.long, False),
45+
Float.encode(item.lat, False)])
46+
if len(new_item) == 1:
47+
new_item = new_item[0]
48+
res.append(new_item)
49+
return res
50+
51+
52+
def _find_near(
53+
zset: ZSet,
54+
lat: float, long: float, radius: float,
55+
conv: float, count: int, count_any: bool, desc: bool) -> List[GeoResult]:
56+
"""Find items within area (lat,long)+radius
57+
:param zset: list of items to check
58+
:param lat: latitude
59+
:param long: longitude
60+
:param radius: radius in whatever units
61+
:param conv: conversion of radius to meters
62+
:param count: number of results to give
63+
:param count_any: should we return any results that match? (vs. sorted)
64+
:param desc: should results be sorted descending order?
65+
:returns: List of GeoResults
66+
"""
67+
results = list()
68+
for name, _hash in zset.items():
69+
p_lat, p_long, _, _ = geohash.decode(_hash)
70+
dist = distance((p_lat, p_long), (lat, long)) * conv
71+
if dist < radius:
72+
results.append(GeoResult(name, p_long, p_lat, _hash, dist))
73+
if count_any and len(results) >= count:
74+
break
75+
results = sorted(results, key=lambda x: x.distance, reverse=desc)
76+
if count:
77+
results = results[:count]
78+
return results
79+
80+
3081
class GeoCommandsMixin:
3182
# TODO
32-
# GEORADIUS, GEORADIUS_RO,
33-
# GEORADIUSBYMEMBER, GEORADIUSBYMEMBER_RO,
3483
# GEOSEARCH, GEOSEARCHSTORE
84+
def _store_geo_results(self, item_name: bytes, geo_results: List[GeoResult], scoredist: bool) -> int:
85+
db_item = CommandItem(item_name, self._db, item=self._db.get(item_name), default=ZSet())
86+
db_item.value = ZSet()
87+
for item in geo_results:
88+
val = item.distance if scoredist else item.hash
89+
db_item.value.add(item.name, val)
90+
db_item.writeback()
91+
return len(geo_results)
3592

3693
@command(name='GEOADD', fixed=(Key(ZSet),), repeat=(bytes,))
3794
def geoadd(self, key, *args):
@@ -83,42 +140,51 @@ def geodist(self, key, m1, m2, *args):
83140
unit = translate_meters_to_unit(args[0]) if len(args) == 1 else 1
84141
return res * unit
85142

86-
def _parse_results(
87-
self, items: List[GeoResult],
88-
withcoord: bool, withdist: bool, withhash: bool,
89-
count: Optional[int], desc: bool) -> List[Any]:
90-
items = sorted(items, key=lambda x: x.distance, reverse=desc)
91-
if count:
92-
items = items[:count]
93-
res = list()
94-
for item in items:
95-
new_item = [item.name, ]
96-
if withdist:
97-
new_item.append(self._encodefloat(item.distance, False))
98-
if withcoord:
99-
new_item.append([self._encodefloat(item.long, False),
100-
self._encodefloat(item.lat, False)])
101-
if len(new_item) == 1:
102-
new_item = new_item[0]
103-
res.append(new_item)
104-
return res
143+
def _search(
144+
self, key, long, lat, radius, conv,
145+
withcoord, withdist, withhash, count, count_any, desc, store, storedist):
146+
zset = key.value
147+
geo_results = _find_near(zset, lat, long, radius, conv, count, count_any, desc)
148+
149+
if store:
150+
self._store_geo_results(store, geo_results, scoredist=False)
151+
return len(geo_results)
152+
if storedist:
153+
self._store_geo_results(storedist, geo_results, scoredist=True)
154+
return len(geo_results)
155+
ret = _parse_results(geo_results, withcoord, withdist)
156+
return ret
157+
158+
@command(name='GEORADIUS_RO', fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,))
159+
def georadius_ro(self, key, long, lat, radius, *args):
160+
(withcoord, withdist, withhash, count, count_any, desc), left_args = extract_args(
161+
args, ('withcoord', 'withdist', 'withhash', '+count', 'any', 'desc',),
162+
error_on_unexpected=False, left_from_first_unexpected=False)
163+
count = count or sys.maxsize
164+
conv = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1
165+
return self._search(
166+
key, long, lat, radius, conv,
167+
withcoord, withdist, withhash, count, count_any, desc, False, False)
105168

106169
@command(name='GEORADIUS', fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,))
107170
def georadius(self, key, long, lat, radius, *args):
108-
zset = key.value
109-
results = list()
110171
(withcoord, withdist, withhash, count, count_any, desc, store, storedist), left_args = extract_args(
111172
args, ('withcoord', 'withdist', 'withhash', '+count', 'any', 'desc', '*store', '*storedist'),
112173
error_on_unexpected=False, left_from_first_unexpected=False)
113-
unit = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1
114174
count = count or sys.maxsize
175+
conv = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1
176+
return self._search(
177+
key, long, lat, radius, conv,
178+
withcoord, withdist, withhash, count, count_any, desc, store, storedist)
115179

116-
for name, _hash in zset.items():
117-
p_lat, p_long, _, _ = geohash.decode(_hash)
118-
dist = distance((p_lat, p_long), (lat, long)) * unit
119-
if dist < radius:
120-
results.append(GeoResult(name, p_long, p_lat, _hash, dist))
121-
if count_any and len(results) >= count:
122-
break
180+
@command(name='GEORADIUSBYMEMBER', fixed=(Key(ZSet), bytes, Float), repeat=(bytes,))
181+
def georadiusbymember(self, key, member_name, radius, *args):
182+
member_score = key.value.get(member_name)
183+
lat, long, _, _ = geohash.decode(member_score)
184+
return self.georadius(key, long, lat, radius, *args)
123185

124-
return self._parse_results(results, withcoord, withdist, withhash, count, desc)
186+
@command(name='GEORADIUSBYMEMBER_RO', fixed=(Key(ZSet), bytes, Float), repeat=(bytes,))
187+
def georadiusbymember_ro(self, key, member_name, radius, *args):
188+
member_score = key.value.get(member_name)
189+
lat, long, _, _ = geohash.decode(member_score)
190+
return self.georadius_ro(key, long, lat, radius, *args)

test/test_mixins/test_geo_commands.py

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
from typing import Dict, Any
2+
13
import pytest
24
import redis
35

6+
from test import testtools
7+
48

59
def test_geoadd(r: redis.Redis):
610
values = ((2.1909389952632, 41.433791470673, "place1") +
@@ -10,13 +14,20 @@ def test_geoadd(r: redis.Redis):
1014

1115
values = (2.1909389952632, 41.433791470673, "place1")
1216
assert r.geoadd("a", values) == 1
17+
1318
values = ((2.1909389952632, 31.433791470673, "place1") +
1419
(2.1873744593677, 41.406342043777, "place2",))
1520
assert r.geoadd("a", values, ch=True) == 2
1621
assert r.zrange("a", 0, -1) == [b"place1", b"place2"]
1722

18-
with pytest.raises(redis.RedisError):
23+
with pytest.raises(redis.DataError):
1924
r.geoadd("barcelona", (1, 2))
25+
with pytest.raises(redis.DataError):
26+
r.geoadd("t", values, ch=True, nx=True, xx=True)
27+
with pytest.raises(redis.ResponseError):
28+
testtools.raw_command(r, "geoadd", "barcelona", "1", "2")
29+
with pytest.raises(redis.ResponseError):
30+
testtools.raw_command(r, "geoadd", "barcelona", "nx", "xx", *values, )
2031

2132

2233
def test_geoadd_xx(r: redis.Redis):
@@ -91,56 +102,38 @@ def test_geodist_missing_one_member(r: redis.Redis):
91102
assert r.geodist("barcelona", "place1", "missing_member", "km") is None
92103

93104

94-
def test_georadius(r: redis.Redis):
95-
values = ((2.1909389952632, 41.433791470673, "place1") +
96-
(2.1873744593677, 41.406342043777, b"\x80place2"))
97-
98-
r.geoadd("barcelona", values)
99-
assert r.georadius("barcelona", 2.191, 41.433, 1000) == [b"place1"]
100-
assert r.georadius("barcelona", 2.187, 41.406, 1000) == [b"\x80place2"]
101-
102-
103-
def test_georadius_no_values(r: redis.Redis):
104-
values = ((2.1909389952632, 41.433791470673, "place1") +
105-
(2.1873744593677, 41.406342043777, "place2",))
106-
107-
r.geoadd("barcelona", values)
108-
assert r.georadius("barcelona", 1, 2, 1000) == []
109-
110-
111-
def test_georadius_units(r: redis.Redis):
105+
@pytest.mark.parametrize(
106+
"long,lat,radius,extra,expected", [
107+
(2.191, 41.433, 1000, {}, [b"place1"]),
108+
(2.187, 41.406, 1000, {}, [b"place2"]),
109+
(1, 2, 1000, {}, []),
110+
(2.191, 41.433, 1, {"unit": "km"}, [b"place1"]),
111+
(2.191, 41.433, 3000, {"count": 1}, [b"place1"]),
112+
])
113+
def test_georadius(
114+
r: redis.Redis, long: float, lat: float, radius: float,
115+
extra: Dict[str, Any],
116+
expected):
112117
values = ((2.1909389952632, 41.433791470673, "place1") +
113-
(2.1873744593677, 41.406342043777, "place2",))
114-
118+
(2.1873744593677, 41.406342043777, b"place2"))
115119
r.geoadd("barcelona", values)
116-
assert r.georadius("barcelona", 2.191, 41.433, 1, unit="km") == [b"place1"]
120+
assert r.georadius("barcelona", long, lat, radius, **extra) == expected
117121

118122

119123
def test_georadius_with(r: redis.Redis):
120124
values = ((2.1909389952632, 41.433791470673, "place1") +
121125
(2.1873744593677, 41.406342043777, "place2",))
122126

123127
r.geoadd("barcelona", values)
124-
125-
# test a bunch of combinations to test the parse response
126-
# function.
128+
# test a bunch of combinations to test the parse response function.
127129
res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True, )
128-
assert res == [pytest.approx([
129-
b"place1",
130-
0.0881,
131-
pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001)
132-
], 0.001)]
130+
assert res == [pytest.approx([b"place1", 0.0881, pytest.approx((2.1909, 41.4337), 0.0001)], 0.001)]
133131

134132
res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True)
135-
assert res == [pytest.approx([
136-
b"place1",
137-
0.0881,
138-
pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001)
139-
], 0.001)]
133+
assert res == [pytest.approx([b"place1", 0.0881, pytest.approx((2.1909, 41.4337), 0.0001)], 0.001)]
140134

141-
assert r.georadius(
142-
"barcelona", 2.191, 41.433, 1, unit="km", withcoord=True
143-
) == [[b"place1", pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001)]]
135+
res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withcoord=True)
136+
assert res == [[b"place1", pytest.approx((2.1909, 41.4337), 0.0001)]]
144137

145138
# test no values.
146139
assert (r.georadius("barcelona", 2, 1, 1, unit="km", withdist=True, withcoord=True, ) == [])
@@ -151,6 +144,39 @@ def test_georadius_count(r: redis.Redis):
151144
(2.1873744593677, 41.406342043777, "place2",))
152145

153146
r.geoadd("barcelona", values)
154-
assert r.georadius("barcelona", 2.191, 41.433, 3000, count=1) == [b"place1"]
147+
148+
assert r.georadius("barcelona", 2.191, 41.433, 3000, count=1, store='barcelona') == 1
149+
assert r.georadius("barcelona", 2.191, 41.433, 3000, store_dist='extract') == 1
150+
assert r.zcard("extract") == 1
155151
res = r.georadius("barcelona", 2.191, 41.433, 3000, count=1, any=True)
156152
assert (res == [b"place2"]) or res == [b'place1']
153+
154+
values = ((13.361389, 38.115556, "Palermo") +
155+
(15.087269, 37.502669, "Catania",))
156+
157+
r.geoadd("Sicily", values)
158+
assert testtools.raw_command(
159+
r, "GEORADIUS", "Sicily", "15", "37", "200", "km",
160+
"STOREDIST", "neardist", "STORE", "near") == 2
161+
assert r.zcard("near") == 2
162+
assert r.zcard("neardist") == 0
163+
164+
165+
def test_georadius_errors(r: redis.Redis):
166+
values = ((13.361389, 38.115556, "Palermo") +
167+
(15.087269, 37.502669, "Catania",))
168+
169+
r.geoadd("Sicily", values)
170+
171+
with pytest.raises(redis.DataError): # Unsupported unit
172+
r.georadius("barcelona", 2.191, 41.433, 3000, unit='dsf')
173+
with pytest.raises(redis.ResponseError): # Unsupported unit
174+
testtools.raw_command(
175+
r, "GEORADIUS", "Sicily", "15", "37", "200", "ddds",
176+
"STOREDIST", "neardist", "STORE", "near")
177+
178+
bad_values = (13.361389, 38.115556, "Palermo", 15.087269, "Catania",)
179+
with pytest.raises(redis.DataError):
180+
r.geoadd('newgroup', bad_values)
181+
with pytest.raises(redis.ResponseError):
182+
testtools.raw_command(r, 'geoadd', 'newgroup', *bad_values)

0 commit comments

Comments
 (0)