Skip to content

Commit 40d81a6

Browse files
committed
Add executeAs to the DB layer
1 parent c64503d commit 40d81a6

11 files changed

+217
-124
lines changed

db/db.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func (db *Db) Keyspace(keyspace string) (*gocql.KeyspaceMetadata, error) {
4242

4343
// Keyspaces Retrieves all the keyspace names
4444
func (db *Db) Keyspaces() ([]string, error) {
45-
iter := db.session.ExecuteIterSimple("SELECT keyspace_name FROM system_schema.keyspaces", gocql.One)
45+
iter := db.session.ExecuteIter("SELECT keyspace_name FROM system_schema.keyspaces", nil)
4646

4747
var keyspaces []string
4848

db/db_session.go

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,33 @@ package db
22

33
import "github.com/gocql/gocql"
44

5-
func (db *Db) Execute(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator {
6-
return db.session.ExecuteIter(query, consistency, values...)
5+
type QueryOptions struct {
6+
UserOrRole string
7+
Consistency gocql.Consistency
78
}
89

9-
func (db *Db) ExecuteNoResult(query string, consistency gocql.Consistency, values ...interface{}) error {
10-
return db.session.Execute(query, consistency, values)
10+
func NewQueryOptions() *QueryOptions {
11+
return &QueryOptions{
12+
Consistency: gocql.LocalOne,
13+
}
1114
}
1215

13-
type DbSession interface {
14-
// Execute executes a prepared statement without returning row results
15-
Execute(query string, consistency gocql.Consistency, values ...interface{}) error
16+
func (q *QueryOptions) WithUserOrRole(userOrRole string) *QueryOptions {
17+
q.UserOrRole = userOrRole
18+
return q
19+
}
1620

17-
// Execute executes a simple statement without returning row results
18-
ExecuteSimple(query string, consistency gocql.Consistency, values ...interface{}) error
21+
func (q *QueryOptions) WithConsistency(userOrRole string) *QueryOptions {
22+
q.UserOrRole = userOrRole
23+
return q
24+
}
1925

20-
// ExecuteIter executes a prepared statement and returns iterator to the result set
21-
ExecuteIter(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator
26+
type DbSession interface {
27+
// Execute executes a statement without returning row results
28+
Execute(query string, options *QueryOptions, values ...interface{}) error
2229

23-
// ExecuteIterSimple executes a simple statement and returns iterator to the result set
24-
ExecuteIterSimple(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator
30+
// ExecuteIterSimple executes a statement and returns iterator to the result set
31+
ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator
2532

2633
//TODO: Extract metadata methods from interface into another interface
2734
KeyspaceMetadata(keyspaceName string) (*gocql.KeyspaceMetadata, error)
@@ -40,21 +47,32 @@ type GoCqlSession struct {
4047
ref *gocql.Session
4148
}
4249

43-
func (session *GoCqlSession) Execute(query string, consistency gocql.Consistency, values ...interface{}) error {
44-
return session.ref.Query(query).Bind(values...).Consistency(consistency).Exec()
50+
func (db *Db) Execute(query string, options *QueryOptions, values ...interface{}) ResultIterator {
51+
return db.session.ExecuteIter(query, options, values...)
4552
}
4653

47-
func (session *GoCqlSession) ExecuteSimple(query string, consistency gocql.Consistency, values ...interface{}) error {
48-
return session.ref.Query(query, values...).Consistency(consistency).Exec()
54+
func (db *Db) ExecuteNoResult(query string, options* QueryOptions, values ...interface{}) error {
55+
return db.session.Execute(query, options, values)
4956
}
5057

51-
func (session *GoCqlSession) ExecuteIter(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator {
52-
return session.ref.Query(query).Bind(values...).Consistency(consistency).Iter()
58+
func (session *GoCqlSession) Execute(query string, options *QueryOptions, values ...interface{}) error {
59+
return session.ExecuteIter(query, options, values...).Close()
5360
}
5461

55-
func (session *GoCqlSession) ExecuteIterSimple(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator {
56-
return session.ref.Query(query, values...).Consistency(consistency).Iter()
62+
func (session *GoCqlSession) ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator {
63+
q := session.ref.Query(query, values...)
64+
if options != nil {
65+
q.Consistency(options.Consistency)
66+
if options.UserOrRole != "" {
67+
q.CustomPayload(map[string][]byte {
68+
"ProxyExecute": []byte(options.UserOrRole),
69+
})
70+
}
71+
}
72+
return q.Iter()
5773
}
74+
5875
func (session *GoCqlSession) KeyspaceMetadata(keyspaceName string) (*gocql.KeyspaceMetadata, error) {
5976
return session.ref.KeyspaceMetadata(keyspaceName)
6077
}
78+

db/keyspace.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@ package db
22

33
import (
44
"fmt"
5-
"github.com/gocql/gocql"
65
)
76

8-
func (db *Db) CreateKeyspace(name string, dcReplicas map[string]int) (bool, error) {
7+
func (db *Db) CreateKeyspace(name string, dcReplicas map[string]int, options *QueryOptions) (bool, error) {
98
// TODO: Escape keyspace datacenter names?
109
dcs := ""
1110
for name, replicas := range dcReplicas {
@@ -18,15 +17,15 @@ func (db *Db) CreateKeyspace(name string, dcReplicas map[string]int) (bool, erro
1817

1918
query := fmt.Sprintf("CREATE KEYSPACE %s WITH REPLICATION = { 'class': 'NetworkTopologyStrategy', %s }", name, dcs)
2019

21-
err := db.session.ExecuteSimple(query, gocql.Any)
20+
err := db.session.Execute(query, options)
2221

2322
return err == nil, err
2423
}
2524

26-
func (db *Db) DropKeyspace(name string) (bool, error) {
25+
func (db *Db) DropKeyspace(name string, options *QueryOptions) (bool, error) {
2726
// TODO: Escape keyspace name?
2827
query := fmt.Sprintf("DROP KEYSPACE %s", name)
29-
err := db.session.ExecuteSimple(query, gocql.Any)
28+
err := db.session.Execute(query, options)
3029

3130
return err == nil, err
3231
}

db/query_generators.go

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,24 @@ type SelectInfo struct {
1919
OrderBy []ColumnOrder
2020
}
2121

22+
type InsertInfo struct {
23+
Keyspace string
24+
Table string
25+
Columns []string
26+
QueryParams []interface{}
27+
IfNotExists bool
28+
TTL int
29+
}
30+
31+
type DeleteInfo struct {
32+
Keyspace string
33+
Table string
34+
Columns []string
35+
QueryParams []interface{}
36+
IfCondition map[string]interface{}
37+
IfExists bool
38+
}
39+
2240
type ColumnOrder struct {
2341
Column string
2442
Order string
@@ -74,7 +92,7 @@ func mapScan(scanner gocql.Scanner, columns []gocql.ColumnInfo) (map[string]inte
7492
return mapped, nil
7593
}
7694

77-
func (db *Db) Select(info *SelectInfo) (*types.QueryResult, error) {
95+
func (db *Db) Select(info *SelectInfo, options *QueryOptions) (*types.QueryResult, error) {
7896
values := make([]interface{}, 0, len(info.Columns))
7997
whereClause := ""
8098
for i := 0; i < len(info.Columns); i++ {
@@ -104,7 +122,7 @@ func (db *Db) Select(info *SelectInfo) (*types.QueryResult, error) {
104122
}
105123
}
106124

107-
iter := db.session.ExecuteIter(query, gocql.LocalOne, values...)
125+
iter := db.session.ExecuteIter(query, options, values...)
108126

109127
pageState := hex.EncodeToString(iter.PageState())
110128
columns := iter.Columns()
@@ -130,38 +148,35 @@ func (db *Db) Select(info *SelectInfo) (*types.QueryResult, error) {
130148
}, nil
131149
}
132150

133-
func (db *Db) Insert(ksName string, tableName string, columnNames []string,
134-
queryParams []interface{}, ifNotExists bool, ttl int) (*types.ModificationResult, error) {
151+
func (db *Db) Insert(info *InsertInfo, options *QueryOptions) (*types.ModificationResult, error) {
135152

136153
placeholders := "?"
137-
for i := 1; i < len(columnNames); i++ {
154+
for i := 1; i < len(info.Columns); i++ {
138155
placeholders += ", ?"
139156
}
140157

141158
query := fmt.Sprintf(
142159
"INSERT INTO %s.%s (%s) VALUES (%s)",
143-
ksName, tableName, strings.Join(columnNames, ","), placeholders)
160+
info.Keyspace, info.Table, strings.Join(info.Columns, ","), placeholders)
144161

145-
if ifNotExists {
162+
if info.IfNotExists {
146163
query += " IF NOT EXISTS"
147164
}
148165

149-
if ttl >= 0 {
166+
if info.TTL >= 0 {
150167
query += " USING TTL ?"
151-
queryParams = append(queryParams, ttl)
168+
info.QueryParams = append(info.QueryParams, info.TTL)
152169
}
153170

154-
err := db.session.Execute(query, gocql.LocalOne, queryParams...)
171+
err := db.session.Execute(query, options, info.QueryParams...)
155172

156173
return &types.ModificationResult{Applied: err == nil}, err
157174
}
158175

159-
func (db *Db) Delete(ksName string, tableName string, columnNames []string, queryParams []interface{},
160-
ifCondition map[string]interface{}, ifExists bool) (*types.ModificationResult, error) {
161-
162-
whereClause := buildWhereClause(columnNames)
163-
query := fmt.Sprintf("DELETE FROM %s.%s WHERE %s", ksName, tableName, whereClause)
164-
err := db.session.Execute(query, gocql.LocalOne, queryParams...)
176+
func (db *Db) Delete(info *DeleteInfo, options *QueryOptions) (*types.ModificationResult, error) {
177+
whereClause := buildWhereClause(info.Columns)
178+
query := fmt.Sprintf("DELETE FROM %s.%s WHERE %s", info.Keyspace, info.Table, whereClause)
179+
err := db.session.Execute(query, options, info.QueryParams...)
165180
return &types.ModificationResult{Applied: err == nil}, err
166181
}
167182

db/query_generators_test.go

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ func TestDeleteGeneration(t *testing.T) {
2727
}
2828

2929
for _, item := range items {
30-
_, err := db.Delete("ks1", "tbl1", item.columnNames, item.queryParams, nil, false)
30+
_, err := db.Delete(&DeleteInfo{
31+
Keyspace: "ks1",
32+
Table: "tbl1",
33+
Columns: item.columnNames,
34+
QueryParams: item.queryParams}, nil)
3135
assert.Nil(t, err)
3236
sessionMock.AssertCalled(t, "Execute", item.query, consistency, item.queryParams)
3337
}
@@ -47,11 +51,11 @@ func TestSelectGeneration(t *testing.T) {
4751
//items := []struct {
4852
// columnNames []string
4953
// values []types.OperatorAndValue
50-
// options *types.QueryOptions
54+
// options *types.ExecuteOptions
5155
// orderBy []ColumnOrder
5256
// query string
5357
//}{
54-
// {[]string{"a"}, []types.OperatorAndValue{{"=", 1}}, &types.QueryOptions{}, nil,
58+
// {[]string{"a"}, []types.OperatorAndValue{{"=", 1}}, &types.ExecuteOptions{}, nil,
5559
// "SELECT * FROM ks1.tbl1 WHERE a = ?"},
5660
//}
5761
//
@@ -81,21 +85,12 @@ type SessionMock struct {
8185
mock.Mock
8286
}
8387

84-
func (o *SessionMock) Execute(query string, consistency gocql.Consistency, values ...interface{}) error {
88+
func (o *SessionMock) Execute(query string, options *QueryOptions, values ...interface{}) error {
8589
args := o.Called(query, consistency, values)
8690
return args.Error(0)
8791
}
8892

89-
func (o *SessionMock) ExecuteSimple(query string, consistency gocql.Consistency, values ...interface{}) error {
90-
args := o.Called(query, consistency, values)
91-
return args.Error(0)
92-
}
93-
94-
func (o *SessionMock) ExecuteIter(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator {
95-
return nil
96-
}
97-
98-
func (o *SessionMock) ExecuteIterSimple(query string, consistency gocql.Consistency, values ...interface{}) ResultIterator {
93+
func (o *SessionMock) ExecuteIter(query string, options *QueryOptions, values ...interface{}) ResultIterator {
9994
return nil
10095
}
10196

db/table.go

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,37 @@ import (
55
"github.com/gocql/gocql"
66
)
77

8-
func (db *Db) CreateTable(
9-
ksName string, name string, partitionKeys []*gocql.ColumnMetadata,
10-
clusteringKeys []*gocql.ColumnMetadata, values []*gocql.ColumnMetadata) error {
8+
type CreateTableInfo struct {
9+
Keyspace string
10+
Table string
11+
PartitionKeys []*gocql.ColumnMetadata
12+
ClusteringKeys []*gocql.ColumnMetadata
13+
Values []*gocql.ColumnMetadata
14+
}
15+
16+
type DropTableInfo struct {
17+
Keyspace string
18+
Table string
19+
}
20+
21+
func (db *Db) CreateTable(info* CreateTableInfo, options *QueryOptions) (bool, error) {
1122

1223
columns := ""
1324
primaryKeys := ""
1425
clusteringOrder := ""
1526

16-
for _, c := range partitionKeys {
27+
for _, c := range info.PartitionKeys {
1728
columns += fmt.Sprintf("%s %s, ", c.Name, c.Type)
1829
if len(primaryKeys) > 0 {
1930
primaryKeys += ", "
2031
}
2132
primaryKeys += c.Name
2233
}
2334

24-
if clusteringKeys != nil {
35+
if info.ClusteringKeys != nil {
2536
primaryKeys = fmt.Sprintf("(%s)", primaryKeys)
2637

27-
for _, c := range clusteringKeys {
38+
for _, c := range info.ClusteringKeys {
2839
columns += fmt.Sprintf("%s %s, ", c.Name, c.Type)
2940
primaryKeys += fmt.Sprintf(", %s", c.Name)
3041
if len(clusteringOrder) > 0 {
@@ -38,25 +49,25 @@ func (db *Db) CreateTable(
3849
}
3950
}
4051

41-
if values != nil {
42-
for _, c := range values {
52+
if info.Values != nil {
53+
for _, c := range info.Values {
4354
columns += fmt.Sprintf("%s %s, ", c.Name, c.Type)
4455
}
4556
}
4657

47-
query := fmt.Sprintf("CREATE TABLE %s.%s (%sPRIMARY KEY (%s))", ksName, name, columns, primaryKeys)
58+
query := fmt.Sprintf("CREATE TABLE %s.%s (%sPRIMARY KEY (%s))", info.Keyspace, info.Table, columns, primaryKeys)
4859

4960
if clusteringOrder != "" {
5061
query += fmt.Sprintf(" WITH CLUSTERING ORDER BY (%s)", clusteringOrder)
5162
}
5263

53-
return db.session.ExecuteSimple(query, gocql.Any)
64+
err := db.session.Execute(query, options)
65+
return err == nil, err
5466
}
5567

56-
func (db *Db) DropTable(ksName string, tableName string) (bool, error) {
68+
func (db *Db) DropTable(info* DropTableInfo, options *QueryOptions) (bool, error) {
5769
// TODO: Escape keyspace/table name?
58-
query := fmt.Sprintf("DROP TABLE %s.%s", ksName, tableName)
59-
err := db.session.ExecuteSimple(query, gocql.Any)
60-
70+
query := fmt.Sprintf("DROP TABLE %s.%s", info.Table, info.Keyspace)
71+
err := db.session.Execute(query, options)
6172
return err == nil, err
6273
}

0 commit comments

Comments
 (0)