diff --git a/db/db.go b/db/db.go index 3bf3d0b..fa7bb96 100644 --- a/db/db.go +++ b/db/db.go @@ -42,7 +42,7 @@ func (db *Db) Keyspace(keyspace string) (*gocql.KeyspaceMetadata, error) { // Keyspaces Retrieves all the keyspace names func (db *Db) Keyspaces() ([]string, error) { - iter := db.session.ExecuteIterSimple("SELECT keyspace_name FROM system_schema.keyspaces", gocql.One) + iter := db.session.ExecuteIter("SELECT keyspace_name FROM system_schema.keyspaces", nil) var keyspaces []string diff --git a/db/db_session.go b/db/db_session.go index 0556dfc..2e52e80 100644 --- a/db/db_session.go +++ b/db/db_session.go @@ -2,26 +2,33 @@ package db import "github.com/gocql/gocql" -func (db *Db) Execute(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator { - return db.session.ExecuteIter(query, consistency, values...) +type QueryOptions struct { + UserOrRole string + Consistency gocql.Consistency } -func (db *Db) ExecuteNoResult(query string, consistency gocql.Consistency, values ...interface{}) error { - return db.session.Execute(query, consistency, values) +func NewQueryOptions() *QueryOptions { + return &QueryOptions{ + Consistency: gocql.LocalOne, + } } -type DbSession interface { - // Execute executes a prepared statement without returning row results - Execute(query string, consistency gocql.Consistency, values ...interface{}) error +func (q *QueryOptions) WithUserOrRole(userOrRole string) *QueryOptions { + q.UserOrRole = userOrRole + return q +} - // Execute executes a simple statement without returning row results - ExecuteSimple(query string, consistency gocql.Consistency, values ...interface{}) error +func (q *QueryOptions) WithConsistency(userOrRole string) *QueryOptions { + q.UserOrRole = userOrRole + return q +} - // ExecuteIter executes a prepared statement and returns iterator to the result set - ExecuteIter(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator +type DbSession interface { + // Execute executes a statement without returning row results + Execute(query string, options *QueryOptions, values ...interface{}) error - // ExecuteIterSimple executes a simple statement and returns iterator to the result set - ExecuteIterSimple(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator + // ExecuteIterSimple executes a statement and returns iterator to the result set + ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator //TODO: Extract metadata methods from interface into another interface KeyspaceMetadata(keyspaceName string) (*gocql.KeyspaceMetadata, error) @@ -40,21 +47,32 @@ type GoCqlSession struct { ref *gocql.Session } -func (session *GoCqlSession) Execute(query string, consistency gocql.Consistency, values ...interface{}) error { - return session.ref.Query(query).Bind(values...).Consistency(consistency).Exec() +func (db *Db) Execute(query string, options *QueryOptions, values ...interface{}) ResultIterator { + return db.session.ExecuteIter(query, options, values...) } -func (session *GoCqlSession) ExecuteSimple(query string, consistency gocql.Consistency, values ...interface{}) error { - return session.ref.Query(query, values...).Consistency(consistency).Exec() +func (db *Db) ExecuteNoResult(query string, options* QueryOptions, values ...interface{}) error { + return db.session.Execute(query, options, values) } -func (session *GoCqlSession) ExecuteIter(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator { - return session.ref.Query(query).Bind(values...).Consistency(consistency).Iter() +func (session *GoCqlSession) Execute(query string, options *QueryOptions, values ...interface{}) error { + return session.ExecuteIter(query, options, values...).Close() } -func (session *GoCqlSession) ExecuteIterSimple(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator { - return session.ref.Query(query, values...).Consistency(consistency).Iter() +func (session *GoCqlSession) ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator { + q := session.ref.Query(query, values) + if options != nil { + q.Consistency(options.Consistency) + if options.UserOrRole != "" { + q.CustomPayload(map[string][]byte { + "ProxyExecute": []byte(options.UserOrRole), + }) + } + } + return session.ref.Query(query, values).Iter() } + func (session *GoCqlSession) KeyspaceMetadata(keyspaceName string) (*gocql.KeyspaceMetadata, error) { return session.ref.KeyspaceMetadata(keyspaceName) } + diff --git a/db/keyspace.go b/db/keyspace.go index 375f086..d411b69 100644 --- a/db/keyspace.go +++ b/db/keyspace.go @@ -2,10 +2,9 @@ package db import ( "fmt" - "github.com/gocql/gocql" ) -func (db *Db) CreateKeyspace(name string, dcReplicas map[string]int) (bool, error) { +func (db *Db) CreateKeyspace(name string, dcReplicas map[string]int, options *QueryOptions) (bool, error) { // TODO: Escape keyspace datacenter names? dcs := "" for name, replicas := range dcReplicas { @@ -18,15 +17,15 @@ func (db *Db) CreateKeyspace(name string, dcReplicas map[string]int) (bool, erro query := fmt.Sprintf("CREATE KEYSPACE %s WITH REPLICATION = { 'class': 'NetworkTopologyStrategy', %s }", name, dcs) - err := db.session.ExecuteSimple(query, gocql.Any) + err := db.session.Execute(query, options) return err == nil, err } -func (db *Db) DropKeyspace(name string) (bool, error) { +func (db *Db) DropKeyspace(name string, options *QueryOptions) (bool, error) { // TODO: Escape keyspace name? query := fmt.Sprintf("DROP KEYSPACE %s", name) - err := db.session.ExecuteSimple(query, gocql.Any) + err := db.session.Execute(query, options) return err == nil, err } diff --git a/db/query_generators.go b/db/query_generators.go index e0123e9..6f2576d 100644 --- a/db/query_generators.go +++ b/db/query_generators.go @@ -19,6 +19,24 @@ type SelectInfo struct { OrderBy []ColumnOrder } +type InsertInfo struct { + Keyspace string + Table string + Columns []string + QueryParams []interface{} + IfNotExists bool + TTL int +} + +type DeleteInfo struct { + Keyspace string + Table string + Columns []string + QueryParams []interface{} + IfCondition map[string]interface{} + IfExists bool +} + type ColumnOrder struct { Column string Order string @@ -74,7 +92,7 @@ func mapScan(scanner gocql.Scanner, columns []gocql.ColumnInfo) (map[string]inte return mapped, nil } -func (db *Db) Select(info *SelectInfo) (*types.QueryResult, error) { +func (db *Db) Select(info *SelectInfo, options *QueryOptions) (*types.QueryResult, error) { values := make([]interface{}, 0, len(info.Columns)) whereClause := "" for i := 0; i < len(info.Columns); i++ { @@ -104,7 +122,7 @@ func (db *Db) Select(info *SelectInfo) (*types.QueryResult, error) { } } - iter := db.session.ExecuteIter(query, gocql.LocalOne, values...) + iter := db.session.ExecuteIter(query, options, values...) pageState := hex.EncodeToString(iter.PageState()) columns := iter.Columns() @@ -130,38 +148,35 @@ func (db *Db) Select(info *SelectInfo) (*types.QueryResult, error) { }, nil } -func (db *Db) Insert(ksName string, tableName string, columnNames []string, - queryParams []interface{}, ifNotExists bool, ttl int) (*types.ModificationResult, error) { +func (db *Db) Insert(info *InsertInfo, options *QueryOptions) (*types.ModificationResult, error) { placeholders := "?" - for i := 1; i < len(columnNames); i++ { + for i := 1; i < len(info.Columns); i++ { placeholders += ", ?" } query := fmt.Sprintf( "INSERT INTO %s.%s (%s) VALUES (%s)", - ksName, tableName, strings.Join(columnNames, ","), placeholders) + info.Keyspace, info.Table, strings.Join(info.Columns, ","), placeholders) - if ifNotExists { + if info.IfNotExists { query += " IF NOT EXISTS" } - if ttl >= 0 { + if info.TTL >= 0 { query += " USING TTL ?" - queryParams = append(queryParams, ttl) + info.QueryParams = append(info.QueryParams, info.TTL) } - err := db.session.Execute(query, gocql.LocalOne, queryParams...) + err := db.session.Execute(query, options, info.QueryParams...) return &types.ModificationResult{Applied: err == nil}, err } -func (db *Db) Delete(ksName string, tableName string, columnNames []string, queryParams []interface{}, - ifCondition map[string]interface{}, ifExists bool) (*types.ModificationResult, error) { - - whereClause := buildWhereClause(columnNames) - query := fmt.Sprintf("DELETE FROM %s.%s WHERE %s", ksName, tableName, whereClause) - err := db.session.Execute(query, gocql.LocalOne, queryParams...) +func (db *Db) Delete(info *DeleteInfo, options *QueryOptions) (*types.ModificationResult, error) { + whereClause := buildWhereClause(info.Columns) + query := fmt.Sprintf("DELETE FROM %s.%s WHERE %s", info.Keyspace, info.Table, whereClause) + err := db.session.Execute(query, options, info.QueryParams...) return &types.ModificationResult{Applied: err == nil}, err } diff --git a/db/query_generators_test.go b/db/query_generators_test.go index d70e213..52b4da2 100644 --- a/db/query_generators_test.go +++ b/db/query_generators_test.go @@ -27,7 +27,11 @@ func TestDeleteGeneration(t *testing.T) { } for _, item := range items { - _, err := db.Delete("ks1", "tbl1", item.columnNames, item.queryParams, nil, false) + _, err := db.Delete(&DeleteInfo{ + Keyspace: "ks1", + Table: "tbl1", + Columns: item.columnNames, + QueryParams: item.queryParams}, nil) assert.Nil(t, err) sessionMock.AssertCalled(t, "Execute", item.query, consistency, item.queryParams) } @@ -47,11 +51,11 @@ func TestSelectGeneration(t *testing.T) { //items := []struct { // columnNames []string // values []types.OperatorAndValue - // options *types.QueryOptions + // options *types.ExecuteOptions // orderBy []ColumnOrder // query string //}{ - // {[]string{"a"}, []types.OperatorAndValue{{"=", 1}}, &types.QueryOptions{}, nil, + // {[]string{"a"}, []types.OperatorAndValue{{"=", 1}}, &types.ExecuteOptions{}, nil, // "SELECT * FROM ks1.tbl1 WHERE a = ?"}, //} // @@ -81,21 +85,12 @@ type SessionMock struct { mock.Mock } -func (o *SessionMock) Execute(query string, consistency gocql.Consistency, values ...interface{}) error { +func (o *SessionMock) Execute(query string, options *QueryOptions, values ...interface{}) error { args := o.Called(query, consistency, values) return args.Error(0) } -func (o *SessionMock) ExecuteSimple(query string, consistency gocql.Consistency, values ...interface{}) error { - args := o.Called(query, consistency, values) - return args.Error(0) -} - -func (o *SessionMock) ExecuteIter(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator { - return nil -} - -func (o *SessionMock) ExecuteIterSimple(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator { +func (o *SessionMock) ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator { return nil } diff --git a/db/table.go b/db/table.go index 890ed9b..47882ef 100644 --- a/db/table.go +++ b/db/table.go @@ -5,15 +5,26 @@ import ( "github.com/gocql/gocql" ) -func (db *Db) CreateTable( - ksName string, name string, partitionKeys []*gocql.ColumnMetadata, - clusteringKeys []*gocql.ColumnMetadata, values []*gocql.ColumnMetadata) error { +type CreateTableInfo struct { + Keyspace string + Table string + PartitionKeys []*gocql.ColumnMetadata + ClusteringKeys []*gocql.ColumnMetadata + Values []*gocql.ColumnMetadata +} + +type DropTableInfo struct { + Keyspace string + Table string +} + +func (db *Db) CreateTable(info* CreateTableInfo, options *QueryOptions) (bool, error) { columns := "" primaryKeys := "" clusteringOrder := "" - for _, c := range partitionKeys { + for _, c := range info.PartitionKeys { columns += fmt.Sprintf("%s %s, ", c.Name, c.Type) if len(primaryKeys) > 0 { primaryKeys += ", " @@ -21,10 +32,10 @@ func (db *Db) CreateTable( primaryKeys += c.Name } - if clusteringKeys != nil { + if info.ClusteringKeys != nil { primaryKeys = fmt.Sprintf("(%s)", primaryKeys) - for _, c := range clusteringKeys { + for _, c := range info.ClusteringKeys { columns += fmt.Sprintf("%s %s, ", c.Name, c.Type) primaryKeys += fmt.Sprintf(", %s", c.Name) if len(clusteringOrder) > 0 { @@ -38,25 +49,25 @@ func (db *Db) CreateTable( } } - if values != nil { - for _, c := range values { + if info.Values != nil { + for _, c := range info.Values { columns += fmt.Sprintf("%s %s, ", c.Name, c.Type) } } - query := fmt.Sprintf("CREATE TABLE %s.%s (%sPRIMARY KEY (%s))", ksName, name, columns, primaryKeys) + query := fmt.Sprintf("CREATE TABLE %s.%s (%sPRIMARY KEY (%s))", info.Keyspace, info.Table, columns, primaryKeys) if clusteringOrder != "" { query += fmt.Sprintf(" WITH CLUSTERING ORDER BY (%s)", clusteringOrder) } - return db.session.ExecuteSimple(query, gocql.Any) + err := db.session.Execute(query, options) + return err == nil, err } -func (db *Db) DropTable(ksName string, tableName string) (bool, error) { +func (db *Db) DropTable(info* DropTableInfo, options *QueryOptions) (bool, error) { // TODO: Escape keyspace/table name? - query := fmt.Sprintf("DROP TABLE %s.%s", ksName, tableName) - err := db.session.ExecuteSimple(query, gocql.Any) - + query := fmt.Sprintf("DROP TABLE %s.%s", info.Table, info.Keyspace) + err := db.session.Execute(query, options) return err == nil, err } diff --git a/graphql/keyspace.go b/graphql/keyspace.go index cd2e803..735cc70 100644 --- a/graphql/keyspace.go +++ b/graphql/keyspace.go @@ -46,11 +46,11 @@ var keyspaceType = graphql.NewObject(graphql.ObjectConfig{ }, }) -func BuildKeyspaceSchema(db *db.Db) (graphql.Schema, error) { +func BuildKeyspaceSchema(dbClient *db.Db) (graphql.Schema, error) { return graphql.NewSchema( graphql.SchemaConfig{ - Query: buildKeyspaceQuery(db), - Mutation: buildKeyspaceMutation(db), + Query: buildKeyspaceQuery(dbClient), + Mutation: buildKeyspaceMutation(dbClient), }) } @@ -83,7 +83,7 @@ func buildKeyspaceValue(keyspace *gocql.KeyspaceMetadata) ksValue { return ksValue{keyspace.Name, dcs} } -func buildKeyspaceQuery(db *db.Db) *graphql.Object { +func buildKeyspaceQuery(dbClient *db.Db) *graphql.Object { return graphql.NewObject(graphql.ObjectConfig{ Name: "KeyspaceQuery", Fields: graphql.Fields{ @@ -96,7 +96,7 @@ func buildKeyspaceQuery(db *db.Db) *graphql.Object { }, Resolve: func(params graphql.ResolveParams) (interface{}, error) { ksName := params.Args["name"].(string) - keyspace, err := db.Keyspace(ksName) + keyspace, err := dbClient.Keyspace(ksName) if err != nil { return nil, err } @@ -107,14 +107,14 @@ func buildKeyspaceQuery(db *db.Db) *graphql.Object { "keyspaces": &graphql.Field{ Type: graphql.NewList(keyspaceType), Resolve: func(params graphql.ResolveParams) (interface{}, error) { - ksNames, err := db.Keyspaces() + ksNames, err := dbClient.Keyspaces() if err != nil { return nil, err } ksValues := make([]ksValue, 0) for _, ksName := range ksNames { - keyspace, err := db.Keyspace(ksName) + keyspace, err := dbClient.Keyspace(ksName) if err != nil { return nil, err } @@ -128,7 +128,7 @@ func buildKeyspaceQuery(db *db.Db) *graphql.Object { }) } -func buildKeyspaceMutation(db *db.Db) *graphql.Object { +func buildKeyspaceMutation(dbClient *db.Db) *graphql.Object { return graphql.NewObject(graphql.ObjectConfig{ Name: "KeyspaceMutation", Fields: graphql.Fields{ @@ -152,7 +152,11 @@ func buildKeyspaceMutation(db *db.Db) *graphql.Object { dcReplicas[dcReplica["name"].(string)] = dcReplica["replicas"].(int) } - return db.CreateKeyspace(ksName, dcReplicas) + userOrRole, err := checkAuthUserOrRole(params) + if err != nil { + return nil, err + } + return dbClient.CreateKeyspace(ksName, dcReplicas, db.NewQueryOptions().WithUserOrRole(userOrRole)) }, }, "dropKeyspace": &graphql.Field{ @@ -165,7 +169,11 @@ func buildKeyspaceMutation(db *db.Db) *graphql.Object { Resolve: func(params graphql.ResolveParams) (interface{}, error) { ksName := params.Args["name"].(string) - return db.DropKeyspace(ksName) + userOrRole, err := checkAuthUserOrRole(params) + if err != nil { + return nil, err + } + return dbClient.DropKeyspace(ksName, db.NewQueryOptions().WithUserOrRole(userOrRole)) }, }, }, diff --git a/graphql/routes.go b/graphql/routes.go index b9d0db1..c42efaa 100644 --- a/graphql/routes.go +++ b/graphql/routes.go @@ -1,6 +1,7 @@ package graphql import ( + "context" "encoding/json" "fmt" "github.com/graphql-go/graphql" @@ -16,7 +17,7 @@ var systemKeyspaces = []string{ "solr_admin", } -type executeQueryFunc func(query string) *graphql.Result +type executeQueryFunc func(query string, ctx context.Context) *graphql.Result type Route struct { Method string @@ -65,8 +66,8 @@ func RoutesKeyspaceManagement(pattern string, db *db.Db) ([]Route, error) { if err != nil { return nil, fmt.Errorf("unable to build graphql schema for keyspace management: %s", err) } - return routesForSchema(pattern, func(query string) *graphql.Result { - return executeQuery(query, schema) + return routesForSchema(pattern, func(query string, ctx context.Context) *graphql.Result { + return executeQuery(query, ctx, schema) }), nil } @@ -76,8 +77,8 @@ func RoutesKeyspace(pattern string, ksName string, db *db.Db, updateInterval tim return nil, fmt.Errorf("unable to build graphql schema for keyspace '%s': %s", ksName, err) } go updater.Start() - return routesForSchema(pattern, func(query string) *graphql.Result { - return executeQuery(query, *updater.Schema()) + return routesForSchema(pattern, func(query string, ctx context.Context) *graphql.Result { + return executeQuery(query, ctx, *updater.Schema()) }), nil } @@ -96,7 +97,7 @@ func routesForSchema(pattern string, execute executeQueryFunc) []Route { Method: http.MethodGet, Pattern: pattern, HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - result:= execute(r.URL.Query().Get("query")) + result:= execute(r.URL.Query().Get("query"), r.Context()) json.NewEncoder(w).Encode(result) }, }, @@ -116,17 +117,18 @@ func routesForSchema(pattern string, execute executeQueryFunc) []Route { return } - result := execute(body.Query) + result := execute(body.Query, r.Context()) json.NewEncoder(w).Encode(result) }, }, } } -func executeQuery(query string, schema graphql.Schema) *graphql.Result { +func executeQuery(query string, ctx context.Context, schema graphql.Schema) *graphql.Result { result := graphql.Do(graphql.Params{ Schema: schema, RequestString: query, + Context: ctx, }) if len(result.Errors) > 0 { fmt.Printf("wrong result, unexpected errors: %v", result.Errors) diff --git a/graphql/schema.go b/graphql/schema.go index eec65c2..852a309 100644 --- a/graphql/schema.go +++ b/graphql/schema.go @@ -17,6 +17,8 @@ const insertPrefix = "insert" const deletePrefix = "delete" const updatePrefix = "update" +const AuthUserOrRole = "userOrRole" + func buildType(typeInfo gocql.TypeInfo) graphql.Output { switch typeInfo.Type() { case gocql.TypeInt, gocql.TypeTinyInt, gocql.TypeSmallInt: @@ -148,8 +150,8 @@ func buildMutation(schema *KeyspaceGraphQLSchema, tables map[string]*gocql.Table } // Build GraphQL schema for tables in the provided keyspace metadata -func BuildSchema(keyspaceName string, db *db.Db) (graphql.Schema, error) { - keyspace, err := db.Keyspace(keyspaceName) +func BuildSchema(keyspaceName string, dbClient *db.Db) (graphql.Schema, error) { + keyspace, err := dbClient.Keyspace(keyspaceName) if err != nil { return graphql.Schema{}, err } @@ -161,8 +163,8 @@ func BuildSchema(keyspaceName string, db *db.Db) (graphql.Schema, error) { return graphql.NewSchema( graphql.SchemaConfig{ - Query: buildQuery(keyspaceSchema, keyspace.Tables, queryFieldResolver(keyspace, db)), - Mutation: buildMutation(keyspaceSchema, keyspace.Tables, mutationFieldResolver(keyspace, db)), + Query: buildQuery(keyspaceSchema, keyspace.Tables, queryFieldResolver(keyspace, dbClient)), + Mutation: buildMutation(keyspaceSchema, keyspace.Tables, mutationFieldResolver(keyspace, dbClient)), }, ) } @@ -229,6 +231,10 @@ func queryFieldResolver(keyspace *gocql.KeyspaceMetadata, dbClient *db.Db) graph orderBy = params.Args["orderBy"].([]interface{}) } + userOrRole, err := checkAuthUserOrRole(params) + if err != nil { + return nil, err + } return dbClient.Select(&db.SelectInfo{ Keyspace: keyspace.Name, Table: table.Name, @@ -236,19 +242,19 @@ func queryFieldResolver(keyspace *gocql.KeyspaceMetadata, dbClient *db.Db) graph Values: queryParams, OrderBy: parseColumnOrder(orderBy), Options: &options, - }) + }, db.NewQueryOptions().WithUserOrRole(userOrRole)) } } } -func mutationFieldResolver(keyspace *gocql.KeyspaceMetadata, db *db.Db) graphql.FieldResolveFn { +func mutationFieldResolver(keyspace *gocql.KeyspaceMetadata, dbClient *db.Db) graphql.FieldResolveFn { return func(params graphql.ResolveParams) (interface{}, error) { fieldName := params.Info.FieldName switch fieldName { case "createTable": - return createTable(db, keyspace.Name, params.Args) + return createTable(dbClient, keyspace.Name, params) case "dropTable": - return dropTable(db, keyspace.Name, params.Args) + return dropTable(dbClient, keyspace.Name, params) default: operation, typeName := mutationPrefix(fieldName) if table, ok := keyspace.Tables[strcase.ToSnake(typeName)]; ok { @@ -267,6 +273,11 @@ func mutationFieldResolver(keyspace *gocql.KeyspaceMetadata, db *db.Db) graphql. options = params.Args["options"].(map[string]interface{}) } + userOrRole, err := checkAuthUserOrRole(params) + if err != nil { + return nil, err + } + queryOptions := db.NewQueryOptions().WithUserOrRole(userOrRole) switch operation { case insertPrefix: ttl := -1 @@ -274,14 +285,26 @@ func mutationFieldResolver(keyspace *gocql.KeyspaceMetadata, db *db.Db) graphql. ttl = options["ttl"].(int) } ifNotExists := params.Args["ifNotExists"] == true - return db.Insert(keyspace.Name, table.Name, columnNames, queryParams, ifNotExists, ttl) + return dbClient.Insert(&db.InsertInfo{ + Keyspace: keyspace.Name, + Table: table.Name, + Columns: columnNames, + QueryParams: queryParams, + IfNotExists: ifNotExists, + TTL: ttl, + }, queryOptions) case deletePrefix: var ifCondition map[string]interface{} if params.Args["ifCondition"] != nil { ifCondition = params.Args["ifCondition"].(map[string]interface{}) } - return db.Delete(keyspace.Name, table.Name, columnNames, - queryParams, ifCondition, params.Args["ifExists"] == true) + return dbClient.Delete(&db.DeleteInfo{ + Keyspace: keyspace.Name, + Table: table.Name, + Columns: columnNames, + QueryParams: queryParams, + IfCondition: ifCondition, + IfExists: params.Args["ifExists"] == true}, queryOptions) } return false, fmt.Errorf("operation '%s' not supported", operation) @@ -319,3 +342,12 @@ func parseColumnOrder(values []interface{}) []db.ColumnOrder { return result } + +func checkAuthUserOrRole(params graphql.ResolveParams) (string, error) { + // TODO: Return an error if we're expecting a user/role, but one isn't provided + value := params.Context.Value(AuthUserOrRole) + if value == nil { + return "", nil + } + return value.(string), nil +} diff --git a/graphql/table.go b/graphql/table.go index 254a580..b6c9e9e 100644 --- a/graphql/table.go +++ b/graphql/table.go @@ -247,9 +247,10 @@ func decodeClusteringInfo(columns []interface{}) ([]*gocql.ColumnMetadata, error return columnValues, nil } -func createTable(db *db.Db, ksName string, args map[string]interface{}) (interface{}, error) { +func createTable(dbClient *db.Db, ksName string, params graphql.ResolveParams) (interface{}, error) { var values []*gocql.ColumnMetadata = nil var clusteringKeys []*gocql.ColumnMetadata = nil + args := params.Args name := args["name"].(string) partitionKeys, err := decodeColumns(args["partitionKeys"].([]interface{})) @@ -270,15 +271,27 @@ func createTable(db *db.Db, ksName string, args map[string]interface{}) (interfa } } - if err := db.CreateTable(ksName, name, partitionKeys, clusteringKeys, values); err != nil { - return false, err + userOrRole, err := checkAuthUserOrRole(params) + if err != nil { + return nil, err } - - return true, nil + return dbClient.CreateTable(&db.CreateTableInfo{ + Keyspace: ksName, + Table: name, + PartitionKeys: partitionKeys, + ClusteringKeys: clusteringKeys, + Values: values}, db.NewQueryOptions().WithUserOrRole(userOrRole)) } -func dropTable(db *db.Db, ksName string, args map[string]interface{}) (interface{}, error) { - return db.DropTable(ksName, strcase.ToSnake(args["name"].(string))) +func dropTable(dbClient *db.Db, ksName string, params graphql.ResolveParams) (interface{}, error) { + name := strcase.ToSnake(params.Args["name"].(string)) + userOrRole, err := checkAuthUserOrRole(params) + if err != nil { + return nil, err + } + return dbClient.DropTable(&db.DropTableInfo{ + Keyspace: ksName, + Table: name}, db.NewQueryOptions().WithUserOrRole(userOrRole)) } func toColumnKind(kind gocql.ColumnKind) int { diff --git a/graphql/updater.go b/graphql/updater.go index 5320326..e795229 100644 --- a/graphql/updater.go +++ b/graphql/updater.go @@ -12,14 +12,14 @@ import ( ) type SchemaUpdater struct { - ctx context.Context - cancel context.CancelFunc - mutex sync.Mutex + ctx context.Context + cancel context.CancelFunc + mutex sync.Mutex updateInterval time.Duration - schema *graphql.Schema - ksName string - db *db.Db - schemaVersion gocql.UUID + schema *graphql.Schema + ksName string + dbClient *db.Db + schemaVersion gocql.UUID } func (su *SchemaUpdater) Schema() *graphql.Schema { @@ -29,19 +29,19 @@ func (su *SchemaUpdater) Schema() *graphql.Schema { return su.schema } -func NewUpdater(ksName string, db *db.Db, updateInterval time.Duration) (*SchemaUpdater, error) { - schema, err := BuildSchema(ksName, db) +func NewUpdater(ksName string, dbClient *db.Db, updateInterval time.Duration) (*SchemaUpdater, error) { + schema, err := BuildSchema(ksName, dbClient) if err != nil { return nil, err } updater := &SchemaUpdater{ - ctx: nil, - cancel: nil, - mutex: sync.Mutex{}, + ctx: nil, + cancel: nil, + mutex: sync.Mutex{}, updateInterval: updateInterval, - schema: &schema, - ksName: ksName, - db: db, + schema: &schema, + ksName: ksName, + dbClient: dbClient, } return updater, nil } @@ -49,7 +49,7 @@ func NewUpdater(ksName string, db *db.Db, updateInterval time.Duration) (*Schema func (su *SchemaUpdater) Start() { su.ctx, su.cancel = context.WithCancel(context.Background()) for { - iter := su.db.Execute("SELECT schema_version FROM system.local", gocql.LocalOne) + iter := su.dbClient.Execute("SELECT schema_version FROM system.local", nil) shouldUpdate := false row := make(map[string]interface{}) @@ -69,7 +69,7 @@ func (su *SchemaUpdater) Start() { } if shouldUpdate { - schema, err := BuildSchema(su.ksName, su.db) + schema, err := BuildSchema(su.ksName, su.dbClient) if err != nil { // TODO: Log error fmt.Fprintf(os.Stderr, "error trying to build graphql schema for keyspace '%s': %s", su.ksName, err)