|
1 | 1 | import sys
|
2 | 2 | from collections import namedtuple
|
3 |
| -from typing import List, Optional, Any |
| 3 | +from typing import List, Any |
4 | 4 |
|
5 | 5 | from fakeredis import _msgs as msgs
|
6 | 6 | from fakeredis._command_args_parsing import extract_args
|
7 |
| -from fakeredis._commands import command, Key, Float |
| 7 | +from fakeredis._commands import command, Key, Float, CommandItem |
8 | 8 | from fakeredis._helpers import SimpleError
|
9 | 9 | from fakeredis._zset import ZSet
|
10 | 10 | from fakeredis.geo import geohash
|
11 | 11 | from fakeredis.geo.haversine import distance
|
12 | 12 |
|
| 13 | +UNIT_TO_M = {'km': 0.001, 'mi': 0.000621371, 'ft': 3.28084, 'm': 1} |
| 14 | + |
13 | 15 |
|
14 | 16 | def translate_meters_to_unit(unit_arg: bytes) -> float:
|
15 |
| - unit_str = unit_arg.decode().lower() |
16 |
| - if unit_str == 'km': |
17 |
| - unit = 0.001 |
18 |
| - elif unit_str == 'mi': |
19 |
| - unit = 0.000621371 |
20 |
| - elif unit_str == 'ft': |
21 |
| - unit = 3.28084 |
22 |
| - else: # meter |
23 |
| - unit = 1 |
| 17 | + """number of meters in a unit. |
| 18 | + :param unit_arg: unit name (km, mi, ft, m) |
| 19 | + :returns: number of meters in unit |
| 20 | + """ |
| 21 | + unit = UNIT_TO_M.get(unit_arg.decode().lower()) |
| 22 | + if unit is None: |
| 23 | + raise SimpleError(msgs.GEO_UNSUPPORTED_UNIT) |
24 | 24 | return unit
|
25 | 25 |
|
26 | 26 |
|
27 | 27 | GeoResult = namedtuple('GeoResult', 'name long lat hash distance')
|
28 | 28 |
|
29 | 29 |
|
| 30 | +def _parse_results( |
| 31 | + items: List[GeoResult], |
| 32 | + withcoord: bool, withdist: bool) -> List[Any]: |
| 33 | + """Parse list of GeoResults to redis response |
| 34 | + :param withcoord: include coordinates in response |
| 35 | + :param withdist: include distance in response |
| 36 | + :returns: Parsed list |
| 37 | + """ |
| 38 | + res = list() |
| 39 | + for item in items: |
| 40 | + new_item = [item.name, ] |
| 41 | + if withdist: |
| 42 | + new_item.append(Float.encode(item.distance, False)) |
| 43 | + if withcoord: |
| 44 | + new_item.append([Float.encode(item.long, False), |
| 45 | + Float.encode(item.lat, False)]) |
| 46 | + if len(new_item) == 1: |
| 47 | + new_item = new_item[0] |
| 48 | + res.append(new_item) |
| 49 | + return res |
| 50 | + |
| 51 | + |
| 52 | +def _find_near( |
| 53 | + zset: ZSet, |
| 54 | + lat: float, long: float, radius: float, |
| 55 | + conv: float, count: int, count_any: bool, desc: bool) -> List[GeoResult]: |
| 56 | + """Find items within area (lat,long)+radius |
| 57 | + :param zset: list of items to check |
| 58 | + :param lat: latitude |
| 59 | + :param long: longitude |
| 60 | + :param radius: radius in whatever units |
| 61 | + :param conv: conversion of radius to meters |
| 62 | + :param count: number of results to give |
| 63 | + :param count_any: should we return any results that match? (vs. sorted) |
| 64 | + :param desc: should results be sorted descending order? |
| 65 | + :returns: List of GeoResults |
| 66 | + """ |
| 67 | + results = list() |
| 68 | + for name, _hash in zset.items(): |
| 69 | + p_lat, p_long, _, _ = geohash.decode(_hash) |
| 70 | + dist = distance((p_lat, p_long), (lat, long)) * conv |
| 71 | + if dist < radius: |
| 72 | + results.append(GeoResult(name, p_long, p_lat, _hash, dist)) |
| 73 | + if count_any and len(results) >= count: |
| 74 | + break |
| 75 | + results = sorted(results, key=lambda x: x.distance, reverse=desc) |
| 76 | + if count: |
| 77 | + results = results[:count] |
| 78 | + return results |
| 79 | + |
| 80 | + |
30 | 81 | class GeoCommandsMixin:
|
31 | 82 | # TODO
|
32 |
| - # GEORADIUS, GEORADIUS_RO, |
33 |
| - # GEORADIUSBYMEMBER, GEORADIUSBYMEMBER_RO, |
34 | 83 | # GEOSEARCH, GEOSEARCHSTORE
|
| 84 | + def _store_geo_results(self, item_name: bytes, geo_results: List[GeoResult], scoredist: bool) -> int: |
| 85 | + db_item = CommandItem(item_name, self._db, item=self._db.get(item_name), default=ZSet()) |
| 86 | + db_item.value = ZSet() |
| 87 | + for item in geo_results: |
| 88 | + val = item.distance if scoredist else item.hash |
| 89 | + db_item.value.add(item.name, val) |
| 90 | + db_item.writeback() |
| 91 | + return len(geo_results) |
35 | 92 |
|
36 | 93 | @command(name='GEOADD', fixed=(Key(ZSet),), repeat=(bytes,))
|
37 | 94 | def geoadd(self, key, *args):
|
@@ -83,42 +140,51 @@ def geodist(self, key, m1, m2, *args):
|
83 | 140 | unit = translate_meters_to_unit(args[0]) if len(args) == 1 else 1
|
84 | 141 | return res * unit
|
85 | 142 |
|
86 |
| - def _parse_results( |
87 |
| - self, items: List[GeoResult], |
88 |
| - withcoord: bool, withdist: bool, withhash: bool, |
89 |
| - count: Optional[int], desc: bool) -> List[Any]: |
90 |
| - items = sorted(items, key=lambda x: x.distance, reverse=desc) |
91 |
| - if count: |
92 |
| - items = items[:count] |
93 |
| - res = list() |
94 |
| - for item in items: |
95 |
| - new_item = [item.name, ] |
96 |
| - if withdist: |
97 |
| - new_item.append(self._encodefloat(item.distance, False)) |
98 |
| - if withcoord: |
99 |
| - new_item.append([self._encodefloat(item.long, False), |
100 |
| - self._encodefloat(item.lat, False)]) |
101 |
| - if len(new_item) == 1: |
102 |
| - new_item = new_item[0] |
103 |
| - res.append(new_item) |
104 |
| - return res |
| 143 | + def _search( |
| 144 | + self, key, long, lat, radius, conv, |
| 145 | + withcoord, withdist, withhash, count, count_any, desc, store, storedist): |
| 146 | + zset = key.value |
| 147 | + geo_results = _find_near(zset, lat, long, radius, conv, count, count_any, desc) |
| 148 | + |
| 149 | + if store: |
| 150 | + self._store_geo_results(store, geo_results, scoredist=False) |
| 151 | + return len(geo_results) |
| 152 | + if storedist: |
| 153 | + self._store_geo_results(storedist, geo_results, scoredist=True) |
| 154 | + return len(geo_results) |
| 155 | + ret = _parse_results(geo_results, withcoord, withdist) |
| 156 | + return ret |
| 157 | + |
| 158 | + @command(name='GEORADIUS_RO', fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,)) |
| 159 | + def georadius_ro(self, key, long, lat, radius, *args): |
| 160 | + (withcoord, withdist, withhash, count, count_any, desc), left_args = extract_args( |
| 161 | + args, ('withcoord', 'withdist', 'withhash', '+count', 'any', 'desc',), |
| 162 | + error_on_unexpected=False, left_from_first_unexpected=False) |
| 163 | + count = count or sys.maxsize |
| 164 | + conv = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1 |
| 165 | + return self._search( |
| 166 | + key, long, lat, radius, conv, |
| 167 | + withcoord, withdist, withhash, count, count_any, desc, False, False) |
105 | 168 |
|
106 | 169 | @command(name='GEORADIUS', fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,))
|
107 | 170 | def georadius(self, key, long, lat, radius, *args):
|
108 |
| - zset = key.value |
109 |
| - results = list() |
110 | 171 | (withcoord, withdist, withhash, count, count_any, desc, store, storedist), left_args = extract_args(
|
111 | 172 | args, ('withcoord', 'withdist', 'withhash', '+count', 'any', 'desc', '*store', '*storedist'),
|
112 | 173 | error_on_unexpected=False, left_from_first_unexpected=False)
|
113 |
| - unit = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1 |
114 | 174 | count = count or sys.maxsize
|
| 175 | + conv = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1 |
| 176 | + return self._search( |
| 177 | + key, long, lat, radius, conv, |
| 178 | + withcoord, withdist, withhash, count, count_any, desc, store, storedist) |
115 | 179 |
|
116 |
| - for name, _hash in zset.items(): |
117 |
| - p_lat, p_long, _, _ = geohash.decode(_hash) |
118 |
| - dist = distance((p_lat, p_long), (lat, long)) * unit |
119 |
| - if dist < radius: |
120 |
| - results.append(GeoResult(name, p_long, p_lat, _hash, dist)) |
121 |
| - if count_any and len(results) >= count: |
122 |
| - break |
| 180 | + @command(name='GEORADIUSBYMEMBER', fixed=(Key(ZSet), bytes, Float), repeat=(bytes,)) |
| 181 | + def georadiusbymember(self, key, member_name, radius, *args): |
| 182 | + member_score = key.value.get(member_name) |
| 183 | + lat, long, _, _ = geohash.decode(member_score) |
| 184 | + return self.georadius(key, long, lat, radius, *args) |
123 | 185 |
|
124 |
| - return self._parse_results(results, withcoord, withdist, withhash, count, desc) |
| 186 | + @command(name='GEORADIUSBYMEMBER_RO', fixed=(Key(ZSet), bytes, Float), repeat=(bytes,)) |
| 187 | + def georadiusbymember_ro(self, key, member_name, radius, *args): |
| 188 | + member_score = key.value.get(member_name) |
| 189 | + lat, long, _, _ = geohash.decode(member_score) |
| 190 | + return self.georadius_ro(key, long, lat, radius, *args) |
0 commit comments