diff --git a/cmd_geo.go b/cmd_geo.go index 6702e96..8767b8a 100644 --- a/cmd_geo.go +++ b/cmd_geo.go @@ -3,6 +3,7 @@ package miniredis import ( + "errors" "fmt" "sort" "strconv" @@ -20,6 +21,7 @@ func commandsGeo(m *Miniredis) { m.srv.Register("GEORADIUS_RO", m.cmdGeoradius) m.srv.Register("GEORADIUSBYMEMBER", m.cmdGeoradiusbymember) m.srv.Register("GEORADIUSBYMEMBER_RO", m.cmdGeoradiusbymember) + m.srv.Register("GEOSEARCH", m.cmdGeosearch) } // GEOADD @@ -607,3 +609,186 @@ func parseUnit(u string) float64 { return 0 } } + +type tuple struct { + a float64 + b float64 +} + +type geosearchOpts struct { + key string + withFromMember bool + fromMember string + withFromLonLat bool + fromLonLat tuple + withByRadius bool + byRadius float64 + withByBox bool + byBox tuple + direction direction // unsorted + count int + withAny bool + withCoord bool + withDist bool + withHash bool +} + +func geosearchParse(cmd string, args []string) (*geosearchOpts, error) { + var opts geosearchOpts + + opts.key, args = args[0], args[1:] + + fmt.Printf("args: %v\n", args) + + switch strings.ToUpper(args[0]) { + case "FROMMEMBER": + if len(args) < 2 { + return nil, errors.New(errWrongNumber(cmd)) + } + opts.withFromMember = true + opts.fromMember = args[1] + args = args[2:] + case "FROMLONLAT": + if len(args) < 3 { + return nil, errors.New(errWrongNumber(cmd)) + } + opts.withFromLonLat = true + if err := optFloat(args[1], &opts.fromLonLat.a); err != nil { + return nil, err + } + if err := optFloat(args[2], &opts.fromLonLat.b); err != nil { + return nil, err + } + args = args[3:] + default: + return nil, errors.New(errWrongNumber(cmd)) + } + + if len(args) < 3 { + return nil, errors.New(errWrongNumber(cmd)) + } + switch strings.ToUpper(args[0]) { + case "BYRADIUS": + if len(args) < 3 { + return nil, errors.New(errWrongNumber(cmd)) + } + opts.withByRadius = true + if err := optFloat(args[1], &opts.byRadius); err != nil { + return nil, err + } + toMeter := parseUnit(args[2]) + if toMeter == 0 { + return nil, errors.New(errWrongNumber(cmd)) + } + opts.byRadius *= toMeter + args = args[3:] + case "BYBOX": + if len(args) < 4 { + return nil, errors.New(errWrongNumber(cmd)) + } + opts.withByBox = true + if err := optFloat(args[1], &opts.byBox.a); err != nil { + return nil, err + } + if err := optFloat(args[2], &opts.byBox.b); err != nil { + return nil, err + } + toMeter := parseUnit(args[3]) + if toMeter == 0 { + return nil, errors.New(errWrongNumber(cmd)) + } + opts.byBox.a *= toMeter + opts.byBox.b *= toMeter + args = args[4:] + default: + return nil, errors.New(errWrongNumber(cmd)) + } + + // FIXME: ASC|DESC + // FIXME: COUNT n ANY + // FIXME: WITHCOORD + // FIXME: WITHDIST + // FIXME: WITHHASH + + return &opts, nil +} + +// GEOSEARCH +func (m *Miniredis) cmdGeosearch(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c, cmd) { + return + } + + opts, err := geosearchParse(cmd, args) + if err != nil { + setDirty(c) + c.WriteError(err.Error()) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + members := db.ssetElements(opts.key) + + if !opts.withFromLonLat { + panic("wip") + } + if !opts.withByRadius { + panic("wip") + } + matches := withinRadius(members, opts.fromLonLat.a, opts.fromLonLat.b, opts.byRadius) + + /* + // deal with ASC/DESC + if opts.direction != unsorted { + sort.Slice(matches, func(i, j int) bool { + if opts.direction == desc { + return matches[i].Distance > matches[j].Distance + } + return matches[i].Distance < matches[j].Distance + }) + } + + // deal with COUNT + if opts.count > 0 && len(matches) > opts.count { + matches = matches[:opts.count] + } + */ + + c.WriteLen(len(matches)) + for _, member := range matches { + // if !opts.withDist && !opts.withCoord { + c.WriteBulk(member.Name) + continue + // } + + /* + len := 1 + if opts.withDist { + len++ + } + if opts.withCoord { + len++ + } + c.WriteLen(len) + c.WriteBulk(member.Name) + if opts.withDist { + c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/toMeter)) + } + if opts.withCoord { + c.WriteLen(2) + c.WriteBulk(fmt.Sprintf("%f", member.Longitude)) + c.WriteBulk(fmt.Sprintf("%f", member.Latitude)) + } + */ + } + }) +} diff --git a/integration/geo_test.go b/integration/geo_test.go index ca8b863..0bf6e6f 100644 --- a/integration/geo_test.go +++ b/integration/geo_test.go @@ -937,3 +937,63 @@ func TestGeo(t *testing.T) { c.DoLoosely("ZRANGE", "resbymemd", "0", "-1", "WITHSCORES") }) } + +func TestGeosearch(t *testing.T) { + skip(t) + t.Run("basic", func(t *testing.T) { + testRaw(t, func(c *client) { + c.Do("GEOADD", + "stations", + "-73.99106999861966", "40.73005400028978", "Astor Pl", + "-74.00019299927328", "40.71880300107709", "Canal St", + "-73.98384899986625", "40.76172799961419", "50th St", + ) + c.Do("GEOSEARCH", "stations", "FROMLONLAT", "-73.9718893", "40.7728773", "BYRADIUS", "4", "km") + c.Do("GEOSEARCH", "stations", "FROMLONLAT", "-73.9718893", "40.7728773", "BYRADIUS", "4", "KM") // case of KM + c.Do("GEOSEARCH", "stations", "FROMLONLAT", "1.0", "1.0", "BYRADIUS", "1", "km") + // c.Do("GEOSEARCH", "stations", "FROMLONLAT", "-73.9718893", "40.7728773", "BYRADIUS", "4", "ft", "WITHDIST") + // c.Do("GEORADIUS", "stations", "FROMLONLAT", "-73.9718893", "40.7728773", "BYRADIUS", "4", "m", "WITHDIST") + + /* + // redis has more precision in the coords + c.Do("GEORADIUS", "stations", "-73.9718893", "40.7728773", "4", "m", "WITHCOORD") + c.DoRounded(3, "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "WITHDIST", "WITHCOORD") + c.DoRounded(3, "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "WITHCOORD", "WITHDIST") + c.DoRounded(3, "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "WITHCOORD", "WITHCOORD", "WITHCOORD") + c.DoRounded(3, "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "WITHDIST", "WITHDIST", "WITHDIST") + // FIXME: the distances don't quite match for miles or km + c.DoRounded(3, "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "mi", "WITHDIST") + c.DoRounded(3, "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "WITHDIST") + + // Sorting + c.Do("GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "DESC") + c.Do("GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "ASC") + c.Do("GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "ASC", "DESC", "ASC") + + // COUNT + c.DoRounded(3, "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "ASC", "COUNT", "1") + c.DoRounded(3, "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "ASC", "COUNT", "2") + c.DoRounded(3, "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "ASC", "COUNT", "999") + c.Error("syntax error", "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "COUNT") + c.Error("COUNT must", "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "COUNT", "0") + c.Error("COUNT must", "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "COUNT", "-12") + c.Error("not an integer", "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "COUNT", "foobar") + + // non-existing key + c.Do("GEORADIUS", "foo", "-73.9718893", "40.7728773", "4", "km") + + // no error in redis, for some reason + // c.Do("GEORADIUS", "foo", "-73.9718893", "40.7728773", "4", "km", "FOOBAR") + c.Error("syntax error", "GEORADIUS", "stations", "-73.9718893", "40.7728773", "400", "km", "ASC", "FOOBAR") + + // GEORADIUS_RO + c.Do("GEORADIUS_RO", "stations", "-73.9718893", "40.7728773", "4", "km") + c.Do("GEORADIUS_RO", "stations", "1.0", "1.0", "1", "km") + c.Error("syntax error", "GEORADIUS_RO", "stations", "-73.9718893", "40.7728773", "4", "km", "STORE", "bar") + c.Error("syntax error", "GEORADIUS_RO", "stations", "-73.9718893", "40.7728773", "4", "km", "STOREDIST", "bar") + c.Error("syntax error", "GEORADIUS_RO", "stations", "-73.9718893", "40.7728773", "4", "km", "STORE") + c.Error("syntax error", "GEORADIUS_RO", "stations", "-73.9718893", "40.7728773", "4", "km", "STOREDIST") + */ + }) + }) +} diff --git a/opts.go b/opts.go index 666ace7..ee87916 100644 --- a/opts.go +++ b/opts.go @@ -1,6 +1,7 @@ package miniredis import ( + "errors" "math" "strconv" "time" @@ -47,3 +48,12 @@ func optDuration(c *server.Peer, src string, dest *time.Duration) bool { *dest = time.Duration(n*1_000_000) * time.Microsecond return true } + +func optFloat(src string, dest *float64) error { + n, err := strconv.ParseFloat(src, 64) + if err != nil { + return errors.New(msgInvalidInt) // FIXME + } + *dest = n + return nil +}