Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,46 @@ Default: nil

The `queryTimeout` parameter sets a timeout for the query. If the query takes longer than the timeout, it will be cancelled. If it is not set the default context timeout will be used.


#### `roles`

```
Type: string
Format: roles=catalog1=ROLE{role1},catalog2=ROLE{role2}
Valid values: A comma-separated list of catalog-to-role assignments, where each assignment maps a catalog to a role.
Default: empty
```
The roles parameter defines authorization roles to assume for one or more catalogs during the Trino session.

You can assign roles either as a map of catalog-to-role pairs or a string. When a string is used, it applies the role to the `system` catalog by default.

##### Example

``` go
c := &Config{
ServerURI: "https://foobar@localhost:8090",
SessionProperties: map[string]string{"query_priority": "1"},
Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
}

dsn, err := c.FormatDSN()
// Result: https://foobar@localhost:8090?roles=catalog1%3DROLE%7B%22role1%22%7D%2Ccatalog2%3DROLE%7B%22role2%22%7D&session_properties=query_priority%3A1&source=trino-go-client
```

**Example using a string (applies to system catalog)**

``` go
c := &Config{
ServerURI: "https://foobar@localhost:8090",
SessionProperties: map[string]string{"query_priority": "1"},
Roles: "admin", // equivalent to map[string]string{"system": "admin"}
}

dsn, err := c.FormatDSN()
// Result: https://foobar@localhost:8090?roles=system%3DROLE%7B%22admin%22%7D&session_properties=query_priority%3A1&source=trino-go-client
```


#### Examples

```
Expand Down
79 changes: 79 additions & 0 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
"github.com/golang-jwt/jwt/v5"
dt "github.com/ory/dockertest/v3"
docker "github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/require"
)

const (
Expand Down Expand Up @@ -1024,6 +1025,84 @@ func TestIntegrationNoResults(t *testing.T) {
t.Fatal(err)
}
}
func TestRoleHeaderSupport(t *testing.T) {
tests := []struct {
name string
config Config
rawDSN string
expectError bool
errorSubstr string
}{
{
name: "Valid roles via Config",
config: Config{
ServerURI: *integrationServerFlag,
Roles: map[string]string{"tpch": "role1", "memory": "role2"},
},
expectError: false,
},
{
name: "Valid single role via DSN",
rawDSN: *integrationServerFlag + "?roles=tpch%3DROLE%7Brole1%7D",
expectError: false,
},
{
name: "Non-existent catalog role",
config: Config{
ServerURI: *integrationServerFlag,
Roles: map[string]string{"not-exist-catalog": "role1"},
},
expectError: true,
errorSubstr: "USER_ERROR: Catalog 'not-exist-catalog' not found",
},
{
name: "Invalid role format with colon",
rawDSN: *integrationServerFlag + "?roles=not-exist-catalog%3Arole1",
expectError: true,
errorSubstr: "Invalid X-Trino-Role header",
},
{
name: "Invalid role format missing ROLE{}",
rawDSN: *integrationServerFlag + "?roles=catolog%3Drole1",
expectError: true,
errorSubstr: "Invalid X-Trino-Role header",
},
{
name: "Invalid role format missing ROLE{}",
rawDSN: *integrationServerFlag + "?roles=catolog%3Drole1",
expectError: true,
errorSubstr: "Invalid X-Trino-Role header",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var dns string
var err error

if tt.rawDSN != "" {
dns = tt.rawDSN
} else {
dns, err = tt.config.FormatDSN()
if err != nil {
t.Fatal(err)
}
}

db := integrationOpen(t, dns)
_, err = db.Query("SELECT 1")

if tt.expectError {
require.Error(t, err)
if tt.errorSubstr != "" {
require.Contains(t, err.Error(), tt.errorSubstr)
}
} else {
require.NoError(t, err)
}
})
}
}

func TestIntegrationQueryParametersSelect(t *testing.T) {
scenarios := []struct {
Expand Down
26 changes: 26 additions & 0 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ const (
trinoSetSessionHeader = trinoHeaderPrefix + `Set-Session`
trinoClearSessionHeader = trinoHeaderPrefix + `Clear-Session`
trinoSetRoleHeader = trinoHeaderPrefix + `Set-Role`
trinoRoleHeader = trinoHeaderPrefix + `Role`
trinoExtraCredentialHeader = trinoHeaderPrefix + `Extra-Credential`

trinoProgressCallbackParam = trinoHeaderPrefix + `Progress-Callback`
Expand Down Expand Up @@ -153,6 +154,9 @@ const (

mapKeySeparator = ":"
mapEntrySeparator = ";"
mapCommaSeparator = ","
mapRolesSeparator = "="
sistemRole = "system"
)

var (
Expand Down Expand Up @@ -194,6 +198,7 @@ type Config struct {
AccessToken string // An access token (JWT) for authentication (optional)
ForwardAuthorizationHeader bool // Allow forwarding the `accessToken` named query parameter in the authorization header, overwriting the `AccessToken` option, if set (optional)
QueryTimeout *time.Duration // Configurable timeout for query (optional)
Roles interface{} // Roles (optional)
}

// FormatDSN returns a DSN string from the configuration.
Expand All @@ -214,6 +219,20 @@ func (c *Config) FormatDSN() (string, error) {
credkv = append(credkv, k+mapKeySeparator+v)
}
}

var roles []string
if c.Roles != nil {
if v, ok := c.Roles.(string); ok {
roles = append(roles, sistemRole+mapRolesSeparator+fmt.Sprintf("ROLE{%q}", v))
} else if v, ok := c.Roles.(map[string]string); ok {
for k, v := range v {
roles = append(roles, k+mapRolesSeparator+fmt.Sprintf("ROLE{%q}", v))
}
} else {
return "", fmt.Errorf("Invalid roles type %T", c.Roles)
}
}

source := c.Source
if source == "" {
source = "trino-go-client"
Expand Down Expand Up @@ -284,6 +303,7 @@ func (c *Config) FormatDSN() (string, error) {
"extra_credentials": strings.Join(credkv, mapEntrySeparator),
"custom_client": c.CustomClientName,
accessTokenConfig: c.AccessToken,
"roles": strings.Join(roles, mapCommaSeparator),
} {
if v != "" {
query[k] = []string{v}
Expand All @@ -307,6 +327,7 @@ type Conn struct {
useExplicitPrepare bool
forwardAuthorizationHeader bool
queryTimeout *time.Duration
Roles string
}

var (
Expand Down Expand Up @@ -400,6 +421,7 @@ func newConn(dsn string) (*Conn, error) {
useExplicitPrepare: useExplicitPrepare,
forwardAuthorizationHeader: forwardAuthorizationHeader,
queryTimeout: queryTimeout,
Roles: query.Get("roles"),
}

var user string
Expand Down Expand Up @@ -931,6 +953,10 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
// Ensure the server returns timestamps preserving their precision, without truncating them to timestamp(3).
hs.Add("X-Trino-Client-Capabilities", "PARAMETRIC_DATETIME")

if st.conn.Roles != "" {
hs.Add(trinoRoleHeader, st.conn.Roles)
}

if len(args) > 0 {
var ss []string
for _, arg := range args {
Expand Down
65 changes: 65 additions & 0 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,46 @@ func TestKerberosConfig(t *testing.T) {
assert.Equal(t, want, dsn)
}

func TestFormatDSNWithRoles(t *testing.T) {
tests := []struct {
name string
config *Config
wantDSN string
expectError bool
}{
{
name: "Multiple catalog roles",
config: &Config{
ServerURI: "https://foobar@localhost:8090",
SessionProperties: map[string]string{"query_priority": "1"},
Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
},
wantDSN: "https://foobar@localhost:8090?roles=catalog1%3DROLE%7B%22role1%22%7D%2Ccatalog2%3DROLE%7B%22role2%22%7D&session_properties=query_priority%3A1&source=trino-go-client",
},
{
name: "Default system role as string",
config: &Config{
ServerURI: "https://foobar@localhost:8090",
SessionProperties: map[string]string{"query_priority": "1"},
Roles: "role1",
},
wantDSN: "https://foobar@localhost:8090?roles=system%3DROLE%7B%22role1%22%7D&session_properties=query_priority%3A1&source=trino-go-client",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dsn, err := tt.config.FormatDSN()
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.wantDSN, dsn)
}
})
}
}

func TestInvalidKerberosConfig(t *testing.T) {
c := &Config{
ServerURI: "http://foobar@localhost:8090",
Expand Down Expand Up @@ -1098,6 +1138,31 @@ func TestQueryCancellation(t *testing.T) {
assert.EqualError(t, err, ErrQueryCancelled.Error(), "unexpected error")
}

func TestTrinoRoleHeaderSent(t *testing.T) {
var receivedHeader string

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeader = r.Header.Get(trinoRoleHeader)
}))
t.Cleanup(ts.Close)

c := &Config{
ServerURI: ts.URL,
SessionProperties: map[string]string{"query_priority": "1"},
Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
}

dsn, err := c.FormatDSN()
require.NoError(t, err)
db, err := sql.Open("trino", dsn)
require.NoError(t, err)

_, _ = db.Query("SHOW TABLES")
require.NoError(t, err)

assert.Equal(t, `catalog1=ROLE{"role1"},catalog2=ROLE{"role2"}`, receivedHeader, "expected X-Trino-Role header to be set")
}

func TestQueryFailure(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
Expand Down
Loading