From 10abfb8c3cc070bda494b0a65e916d263022da97 Mon Sep 17 00:00:00 2001 From: Jorge Bay Gondra Date: Thu, 19 Mar 2020 14:04:15 +0100 Subject: [PATCH] ResultSet interface to simplify mocking and testing --- db/db.go | 14 ++++---- db/db_session.go | 70 ++++++++++++++++++++++++++++--------- db/query_generators.go | 31 ++-------------- db/query_generators_test.go | 55 ++++++----------------------- graphql/schema.go | 26 +++++++++++++- graphql/updater.go | 18 +++++----- 6 files changed, 105 insertions(+), 109 deletions(-) diff --git a/db/db.go b/db/db.go index fa7bb96..d1b2c1e 100644 --- a/db/db.go +++ b/db/db.go @@ -42,16 +42,14 @@ func (db *Db) Keyspace(keyspace string) (*gocql.KeyspaceMetadata, error) { // Keyspaces Retrieves all the keyspace names func (db *Db) Keyspaces() ([]string, error) { - iter := db.session.ExecuteIter("SELECT keyspace_name FROM system_schema.keyspaces", nil) + iter, err := db.session.ExecuteIter("SELECT keyspace_name FROM system_schema.keyspaces", nil) + if err != nil { + return nil, err + } var keyspaces []string - - var name string - for iter.Scan(&name) { - keyspaces = append(keyspaces, name) - } - if err := iter.Close(); err != nil { - return nil, err + for _, row := range iter.Values() { + keyspaces = append(keyspaces, *row["keyspace_name"].(*string)) } return keyspaces, nil diff --git a/db/db_session.go b/db/db_session.go index 28fc07a..0eb3aee 100644 --- a/db/db_session.go +++ b/db/db_session.go @@ -1,9 +1,12 @@ package db -import "github.com/gocql/gocql" +import ( + "encoding/hex" + "github.com/gocql/gocql" +) type QueryOptions struct { - UserOrRole string + UserOrRole string Consistency gocql.Consistency } @@ -28,51 +31,84 @@ type DbSession interface { Execute(query string, options *QueryOptions, values ...interface{}) error // ExecuteIterSimple executes a statement and returns iterator to the result set - ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator + ExecuteIter(query string, options *QueryOptions, values ...interface{}) (ResultSet, error) //TODO: Extract metadata methods from interface into another interface KeyspaceMetadata(keyspaceName string) (*gocql.KeyspaceMetadata, error) } -type ResultIterator interface { - Close() error - Columns() []gocql.ColumnInfo - Scanner() gocql.Scanner - PageState() []byte - Scan(dest ...interface{}) bool - MapScan(m map[string]interface{}) bool +type ResultSet interface { + PageState() string + Values() []map[string]interface{} +} + +func (r *goCqlResultIterator) PageState() string { + return hex.EncodeToString(r.pageState) +} + +func (r *goCqlResultIterator) Values() []map[string]interface{} { + return r.values +} + +type goCqlResultIterator struct { + pageState []byte + values []map[string]interface{} +} + +func newResultIterator(iter *gocql.Iter) (*goCqlResultIterator, error) { + columns := iter.Columns() + scanner := iter.Scanner() + + items := make([]map[string]interface{}, 0) + + for scanner.Next() { + row, err := mapScan(scanner, columns) + if err != nil { + return nil, err + } + items = append(items, row) + } + + if err := iter.Close(); err != nil { + return nil, err + } + + return &goCqlResultIterator{ + pageState: iter.PageState(), + values: items, + }, nil } type GoCqlSession struct { ref *gocql.Session } -func (db *Db) Execute(query string, options *QueryOptions, values ...interface{}) ResultIterator { +func (db *Db) Execute(query string, options *QueryOptions, values ...interface{}) (ResultSet, error) { return db.session.ExecuteIter(query, options, values...) } -func (db *Db) ExecuteNoResult(query string, options* QueryOptions, values ...interface{}) error { +func (db *Db) ExecuteNoResult(query string, options *QueryOptions, values ...interface{}) error { return db.session.Execute(query, options, values) } func (session *GoCqlSession) Execute(query string, options *QueryOptions, values ...interface{}) error { - return session.ExecuteIter(query, options, values...).Close() + _, err := session.ExecuteIter(query, options, values...) + return err } -func (session *GoCqlSession) ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator { +func (session *GoCqlSession) ExecuteIter(query string, options *QueryOptions, values ...interface{}) (ResultSet, error) { q := session.ref.Query(query, values...) if options != nil { q.Consistency(options.Consistency) if options.UserOrRole != "" { - q.CustomPayload(map[string][]byte { + q.CustomPayload(map[string][]byte{ "ProxyExecute": []byte(options.UserOrRole), }) } } - return q.Iter() + return newResultIterator(q.Iter()) } func (session *GoCqlSession) KeyspaceMetadata(keyspaceName string) (*gocql.KeyspaceMetadata, error) { return session.ref.KeyspaceMetadata(keyspaceName) } - diff --git a/db/query_generators.go b/db/query_generators.go index de18a30..d4fabd9 100644 --- a/db/query_generators.go +++ b/db/query_generators.go @@ -1,10 +1,8 @@ package db import ( - "encoding/hex" "fmt" "github.com/gocql/gocql" - "github.com/iancoleman/strcase" "github.com/riptano/data-endpoints/types" "reflect" "strings" @@ -86,13 +84,13 @@ func mapScan(scanner gocql.Scanner, columns []gocql.ColumnInfo) (map[string]inte value = reflect.Indirect(reflect.ValueOf(value)).Interface() } - mapped[strcase.ToLowerCamel(column.Name)] = value + mapped[column.Name] = value } return mapped, nil } -func (db *Db) Select(info *SelectInfo, options *QueryOptions) (*types.QueryResult, error) { +func (db *Db) Select(info *SelectInfo, options *QueryOptions) (ResultSet, error) { values := make([]interface{}, 0, len(info.Columns)) whereClause := "" for i := 0; i < len(info.Columns); i++ { @@ -122,30 +120,7 @@ func (db *Db) Select(info *SelectInfo, options *QueryOptions) (*types.QueryResul } } - iter := db.session.ExecuteIter(query, options, values...) - - pageState := hex.EncodeToString(iter.PageState()) - columns := iter.Columns() - scanner := iter.Scanner() - - items := make([]map[string]interface{}, 0) - - for scanner.Next() { - row, err := mapScan(scanner, columns) - if err != nil { - return nil, err - } - items = append(items, row) - } - - if err := iter.Close(); err != nil { - return nil, err - } - - return &types.QueryResult{ - PageState: pageState, - Values: items, - }, nil + return db.session.ExecuteIter(query, options, values...) } func (db *Db) Insert(info *InsertInfo, options *QueryOptions) (*types.ModificationResult, error) { diff --git a/db/query_generators_test.go b/db/query_generators_test.go index 7791543..55f51c9 100644 --- a/db/query_generators_test.go +++ b/db/query_generators_test.go @@ -83,13 +83,9 @@ func TestInsertGeneration(t *testing.T) { func TestSelectGeneration(t *testing.T) { resultMock := &ResultMock{} - scannerMock := &ScannerMock{} resultMock. - On("PageState").Return([]byte{}). - On("Columns").Return([]gocql.ColumnInfo{}). - On("Scanner").Return(scannerMock). - On("Close").Return(nil) - scannerMock.On("Next").Return(false) + On("PageState").Return(""). + On("Values").Return([]map[string]interface{}{}, nil) items := []struct { columnNames []string @@ -113,7 +109,7 @@ func TestSelectGeneration(t *testing.T) { db := &Db{ session: &sessionMock, } - sessionMock.On("ExecuteIter", mock.Anything, mock.Anything, mock.Anything).Return(resultMock) + sessionMock.On("ExecuteIter", mock.Anything, mock.Anything, mock.Anything).Return(resultMock, nil) queryParams := make([]interface{}, 0) for _, v := range item.values { @@ -147,9 +143,9 @@ func (o *SessionMock) Execute(query string, options *QueryOptions, values ...int return args.Error(0) } -func (o *SessionMock) ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator { +func (o *SessionMock) ExecuteIter(query string, options *QueryOptions, values ...interface{}) (ResultSet, error) { args := o.Called(query, options, values) - return args.Get(0).(ResultIterator) + return args.Get(0).(ResultSet), args.Error(1) } func (o *SessionMock) KeyspaceMetadata(keyspaceName string) (*gocql.KeyspaceMetadata, error) { @@ -161,42 +157,11 @@ type ResultMock struct { mock.Mock } -type ScannerMock struct { - mock.Mock -} - -func (o ScannerMock) Next() bool { - return o.Called().Bool(0) -} - -func (o ScannerMock) Scan(dest ...interface{}) error { - return o.Called(dest).Error(0) -} - -func (o ScannerMock) Err() error { - return o.Called().Error(0) -} - -func (o ResultMock) Close() error { - return o.Called().Error(0) -} - -func (o ResultMock) Columns() []gocql.ColumnInfo { - return o.Called().Get(0).([]gocql.ColumnInfo) -} - -func (o ResultMock) Scanner() gocql.Scanner { - return o.Called().Get(0).(gocql.Scanner) -} - -func (o ResultMock) PageState() []byte { - return o.Called().Get(0).([]byte) -} - -func (o ResultMock) Scan(dest ...interface{}) bool { - return o.Called(dest).Bool(0) +func (o ResultMock) PageState() string { + return o.Called().String(0) } -func (o ResultMock) MapScan(m map[string]interface{}) bool { - return o.Called(m).Bool(0) +func (o ResultMock) Values() []map[string]interface{} { + args := o.Called() + return args.Get(0).([]map[string]interface{}) } diff --git a/graphql/schema.go b/graphql/schema.go index 852a309..ac232ab 100644 --- a/graphql/schema.go +++ b/graphql/schema.go @@ -235,7 +235,8 @@ func queryFieldResolver(keyspace *gocql.KeyspaceMetadata, dbClient *db.Db) graph if err != nil { return nil, err } - return dbClient.Select(&db.SelectInfo{ + + result, err := dbClient.Select(&db.SelectInfo{ Keyspace: keyspace.Name, Table: table.Name, Columns: columnNames, @@ -243,10 +244,33 @@ func queryFieldResolver(keyspace *gocql.KeyspaceMetadata, dbClient *db.Db) graph OrderBy: parseColumnOrder(orderBy), Options: &options, }, db.NewQueryOptions().WithUserOrRole(userOrRole)) + + if err != nil { + return nil, err + } + + return &types.QueryResult{ + PageState: result.PageState(), + Values: adaptResultValues(result.Values()), + }, nil } } } +func adaptResultValues(values []map[string]interface{}) []map[string]interface{} { + result := make([]map[string]interface{}, 0, len(values)) + // TODO: Use naming conventions + for _, item := range values { + resultItem := make(map[string]interface{}) + for k, v := range item { + resultItem[strcase.ToLowerCamel(k)] = v + } + result = append(result, resultItem) + } + + return result +} + func mutationFieldResolver(keyspace *gocql.KeyspaceMetadata, dbClient *db.Db) graphql.FieldResolveFn { return func(params graphql.ResolveParams) (interface{}, error) { fieldName := params.Info.FieldName diff --git a/graphql/updater.go b/graphql/updater.go index e795229..3a5be94 100644 --- a/graphql/updater.go +++ b/graphql/updater.go @@ -49,11 +49,15 @@ func NewUpdater(ksName string, dbClient *db.Db, updateInterval time.Duration) (* func (su *SchemaUpdater) Start() { su.ctx, su.cancel = context.WithCancel(context.Background()) for { - iter := su.dbClient.Execute("SELECT schema_version FROM system.local", nil) + result, err := su.dbClient.Execute("SELECT schema_version FROM system.local", nil) + + if err != nil { + // TODO: Log error + fmt.Fprintf(os.Stderr, "error attempting to determine schema version: %s", err) + } shouldUpdate := false - row := make(map[string]interface{}) - for iter.MapScan(row) { + for _, row := range result.Values() { if schemaVersion, ok := row["schema_version"].(gocql.UUID); ok { if schemaVersion != su.schemaVersion { shouldUpdate = true @@ -62,12 +66,6 @@ func (su *SchemaUpdater) Start() { } } - err := iter.Close() - if err != nil { - // TODO: Log error - fmt.Fprintf(os.Stderr, "error attempting to determine schema version: %s", err) - } - if shouldUpdate { schema, err := BuildSchema(su.ksName, su.dbClient) if err != nil { @@ -97,4 +95,4 @@ func (su *SchemaUpdater) sleep() bool { case <-su.ctx.Done(): return false } -} \ No newline at end of file +}