Skip to content

ResultSet interface to simplify mocking and testing #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,14 @@ func (db *Db) Keyspace(keyspace string) (*gocql.KeyspaceMetadata, error) {

// Keyspaces Retrieves all the keyspace names
func (db *Db) Keyspaces() ([]string, error) {
iter := db.session.ExecuteIter("SELECT keyspace_name FROM system_schema.keyspaces", nil)
iter, err := db.session.ExecuteIter("SELECT keyspace_name FROM system_schema.keyspaces", nil)
if err != nil {
return nil, err
}

var keyspaces []string

var name string
for iter.Scan(&name) {
keyspaces = append(keyspaces, name)
}
if err := iter.Close(); err != nil {
return nil, err
for _, row := range iter.Values() {
keyspaces = append(keyspaces, *row["keyspace_name"].(*string))
}

return keyspaces, nil
Expand Down
70 changes: 53 additions & 17 deletions db/db_session.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package db

import "github.com/gocql/gocql"
import (
"encoding/hex"
"github.com/gocql/gocql"
)

type QueryOptions struct {
UserOrRole string
UserOrRole string
Consistency gocql.Consistency
}

Expand All @@ -28,51 +31,84 @@ type DbSession interface {
Execute(query string, options *QueryOptions, values ...interface{}) error

// ExecuteIterSimple executes a statement and returns iterator to the result set
ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator
ExecuteIter(query string, options *QueryOptions, values ...interface{}) (ResultSet, error)

//TODO: Extract metadata methods from interface into another interface
KeyspaceMetadata(keyspaceName string) (*gocql.KeyspaceMetadata, error)
}

type ResultIterator interface {
Close() error
Columns() []gocql.ColumnInfo
Scanner() gocql.Scanner
PageState() []byte
Scan(dest ...interface{}) bool
MapScan(m map[string]interface{}) bool
type ResultSet interface {
PageState() string
Values() []map[string]interface{}
}

func (r *goCqlResultIterator) PageState() string {
return hex.EncodeToString(r.pageState)
}

func (r *goCqlResultIterator) Values() []map[string]interface{} {
return r.values
}

type goCqlResultIterator struct {
pageState []byte
values []map[string]interface{}
}

func newResultIterator(iter *gocql.Iter) (*goCqlResultIterator, error) {
columns := iter.Columns()
scanner := iter.Scanner()

items := make([]map[string]interface{}, 0)

for scanner.Next() {
row, err := mapScan(scanner, columns)
if err != nil {
return nil, err
}
items = append(items, row)
}

if err := iter.Close(); err != nil {
return nil, err
}

return &goCqlResultIterator{
pageState: iter.PageState(),
values: items,
}, nil
}

type GoCqlSession struct {
ref *gocql.Session
}

func (db *Db) Execute(query string, options *QueryOptions, values ...interface{}) ResultIterator {
func (db *Db) Execute(query string, options *QueryOptions, values ...interface{}) (ResultSet, error) {
return db.session.ExecuteIter(query, options, values...)
}

func (db *Db) ExecuteNoResult(query string, options* QueryOptions, values ...interface{}) error {
func (db *Db) ExecuteNoResult(query string, options *QueryOptions, values ...interface{}) error {
return db.session.Execute(query, options, values)
}

func (session *GoCqlSession) Execute(query string, options *QueryOptions, values ...interface{}) error {
return session.ExecuteIter(query, options, values...).Close()
_, err := session.ExecuteIter(query, options, values...)
return err
}

func (session *GoCqlSession) ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator {
func (session *GoCqlSession) ExecuteIter(query string, options *QueryOptions, values ...interface{}) (ResultSet, error) {
q := session.ref.Query(query, values...)
if options != nil {
q.Consistency(options.Consistency)
if options.UserOrRole != "" {
q.CustomPayload(map[string][]byte {
q.CustomPayload(map[string][]byte{
"ProxyExecute": []byte(options.UserOrRole),
})
}
}
return q.Iter()
return newResultIterator(q.Iter())
}

func (session *GoCqlSession) KeyspaceMetadata(keyspaceName string) (*gocql.KeyspaceMetadata, error) {
return session.ref.KeyspaceMetadata(keyspaceName)
}

31 changes: 3 additions & 28 deletions db/query_generators.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package db

import (
"encoding/hex"
"fmt"
"github.com/gocql/gocql"
"github.com/iancoleman/strcase"
"github.com/riptano/data-endpoints/types"
"reflect"
"strings"
Expand Down Expand Up @@ -86,13 +84,13 @@ func mapScan(scanner gocql.Scanner, columns []gocql.ColumnInfo) (map[string]inte
value = reflect.Indirect(reflect.ValueOf(value)).Interface()
}

mapped[strcase.ToLowerCamel(column.Name)] = value
mapped[column.Name] = value
}

return mapped, nil
}

func (db *Db) Select(info *SelectInfo, options *QueryOptions) (*types.QueryResult, error) {
func (db *Db) Select(info *SelectInfo, options *QueryOptions) (ResultSet, error) {
values := make([]interface{}, 0, len(info.Columns))
whereClause := ""
for i := 0; i < len(info.Columns); i++ {
Expand Down Expand Up @@ -122,30 +120,7 @@ func (db *Db) Select(info *SelectInfo, options *QueryOptions) (*types.QueryResul
}
}

iter := db.session.ExecuteIter(query, options, values...)

pageState := hex.EncodeToString(iter.PageState())
columns := iter.Columns()
scanner := iter.Scanner()

items := make([]map[string]interface{}, 0)

for scanner.Next() {
row, err := mapScan(scanner, columns)
if err != nil {
return nil, err
}
items = append(items, row)
}

if err := iter.Close(); err != nil {
return nil, err
}

return &types.QueryResult{
PageState: pageState,
Values: items,
}, nil
return db.session.ExecuteIter(query, options, values...)
}

func (db *Db) Insert(info *InsertInfo, options *QueryOptions) (*types.ModificationResult, error) {
Expand Down
55 changes: 10 additions & 45 deletions db/query_generators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,9 @@ func TestInsertGeneration(t *testing.T) {

func TestSelectGeneration(t *testing.T) {
resultMock := &ResultMock{}
scannerMock := &ScannerMock{}
resultMock.
On("PageState").Return([]byte{}).
On("Columns").Return([]gocql.ColumnInfo{}).
On("Scanner").Return(scannerMock).
On("Close").Return(nil)
scannerMock.On("Next").Return(false)
On("PageState").Return("").
On("Values").Return([]map[string]interface{}{}, nil)

items := []struct {
columnNames []string
Expand All @@ -113,7 +109,7 @@ func TestSelectGeneration(t *testing.T) {
db := &Db{
session: &sessionMock,
}
sessionMock.On("ExecuteIter", mock.Anything, mock.Anything, mock.Anything).Return(resultMock)
sessionMock.On("ExecuteIter", mock.Anything, mock.Anything, mock.Anything).Return(resultMock, nil)
queryParams := make([]interface{}, 0)

for _, v := range item.values {
Expand Down Expand Up @@ -147,9 +143,9 @@ func (o *SessionMock) Execute(query string, options *QueryOptions, values ...int
return args.Error(0)
}

func (o *SessionMock) ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator {
func (o *SessionMock) ExecuteIter(query string, options *QueryOptions, values ...interface{}) (ResultSet, error) {
args := o.Called(query, options, values)
return args.Get(0).(ResultIterator)
return args.Get(0).(ResultSet), args.Error(1)
}

func (o *SessionMock) KeyspaceMetadata(keyspaceName string) (*gocql.KeyspaceMetadata, error) {
Expand All @@ -161,42 +157,11 @@ type ResultMock struct {
mock.Mock
}

type ScannerMock struct {
mock.Mock
}

func (o ScannerMock) Next() bool {
return o.Called().Bool(0)
}

func (o ScannerMock) Scan(dest ...interface{}) error {
return o.Called(dest).Error(0)
}

func (o ScannerMock) Err() error {
return o.Called().Error(0)
}

func (o ResultMock) Close() error {
return o.Called().Error(0)
}

func (o ResultMock) Columns() []gocql.ColumnInfo {
return o.Called().Get(0).([]gocql.ColumnInfo)
}

func (o ResultMock) Scanner() gocql.Scanner {
return o.Called().Get(0).(gocql.Scanner)
}

func (o ResultMock) PageState() []byte {
return o.Called().Get(0).([]byte)
}

func (o ResultMock) Scan(dest ...interface{}) bool {
return o.Called(dest).Bool(0)
func (o ResultMock) PageState() string {
return o.Called().String(0)
}

func (o ResultMock) MapScan(m map[string]interface{}) bool {
return o.Called(m).Bool(0)
func (o ResultMock) Values() []map[string]interface{} {
args := o.Called()
return args.Get(0).([]map[string]interface{})
}
26 changes: 25 additions & 1 deletion graphql/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,18 +235,42 @@ func queryFieldResolver(keyspace *gocql.KeyspaceMetadata, dbClient *db.Db) graph
if err != nil {
return nil, err
}
return dbClient.Select(&db.SelectInfo{

result, err := dbClient.Select(&db.SelectInfo{
Keyspace: keyspace.Name,
Table: table.Name,
Columns: columnNames,
Values: queryParams,
OrderBy: parseColumnOrder(orderBy),
Options: &options,
}, db.NewQueryOptions().WithUserOrRole(userOrRole))

if err != nil {
return nil, err
}

return &types.QueryResult{
PageState: result.PageState(),
Values: adaptResultValues(result.Values()),
}, nil
}
}
}

func adaptResultValues(values []map[string]interface{}) []map[string]interface{} {
result := make([]map[string]interface{}, 0, len(values))
// TODO: Use naming conventions
for _, item := range values {
resultItem := make(map[string]interface{})
for k, v := range item {
resultItem[strcase.ToLowerCamel(k)] = v
}
result = append(result, resultItem)
}

return result
}

func mutationFieldResolver(keyspace *gocql.KeyspaceMetadata, dbClient *db.Db) graphql.FieldResolveFn {
return func(params graphql.ResolveParams) (interface{}, error) {
fieldName := params.Info.FieldName
Expand Down
18 changes: 8 additions & 10 deletions graphql/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ func NewUpdater(ksName string, dbClient *db.Db, updateInterval time.Duration) (*
func (su *SchemaUpdater) Start() {
su.ctx, su.cancel = context.WithCancel(context.Background())
for {
iter := su.dbClient.Execute("SELECT schema_version FROM system.local", nil)
result, err := su.dbClient.Execute("SELECT schema_version FROM system.local", nil)

if err != nil {
// TODO: Log error
fmt.Fprintf(os.Stderr, "error attempting to determine schema version: %s", err)
}

shouldUpdate := false
row := make(map[string]interface{})
for iter.MapScan(row) {
for _, row := range result.Values() {
if schemaVersion, ok := row["schema_version"].(gocql.UUID); ok {
if schemaVersion != su.schemaVersion {
shouldUpdate = true
Expand All @@ -62,12 +66,6 @@ func (su *SchemaUpdater) Start() {
}
}

err := iter.Close()
if err != nil {
// TODO: Log error
fmt.Fprintf(os.Stderr, "error attempting to determine schema version: %s", err)
}

if shouldUpdate {
schema, err := BuildSchema(su.ksName, su.dbClient)
if err != nil {
Expand Down Expand Up @@ -97,4 +95,4 @@ func (su *SchemaUpdater) sleep() bool {
case <-su.ctx.Done():
return false
}
}
}