Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ var reservedClaims = []string{
type Client struct {
*baseClient
TenantManager *TenantManager
signer cryptoSigner
clock internal.Clock
}

// NewClient creates a new instance of the Firebase Auth Client.
Expand Down Expand Up @@ -116,11 +114,11 @@ func NewClient(ctx context.Context, conf *internal.AuthConfig) (*Client, error)
httpClient: hc,
idTokenVerifier: idTokenVerifier,
cookieVerifier: cookieVerifier,
signer: signer,
clock: internal.SystemClock,
}
return &Client{
baseClient: base,
signer: signer,
clock: internal.SystemClock,
TenantManager: newTenantManager(hc, conf, base),
}, nil
}
Expand All @@ -144,13 +142,13 @@ func NewClient(ctx context.Context, conf *internal.AuthConfig) (*Client, error)
// conjunction with the IAM service to sign tokens remotely.
//
// CustomToken returns an error the SDK fails to discover a viable mechanism for signing tokens.
func (c *Client) CustomToken(ctx context.Context, uid string) (string, error) {
func (c *baseClient) CustomToken(ctx context.Context, uid string) (string, error) {
return c.CustomTokenWithClaims(ctx, uid, nil)
}

// CustomTokenWithClaims is similar to CustomToken, but in addition to the user ID, it also encodes
// all the key-value pairs in the provided map as claims in the resulting JWT.
func (c *Client) CustomTokenWithClaims(ctx context.Context, uid string, devClaims map[string]interface{}) (string, error) {
func (c *baseClient) CustomTokenWithClaims(ctx context.Context, uid string, devClaims map[string]interface{}) (string, error) {
iss, err := c.signer.Email(ctx)
if err != nil {
return "", err
Expand All @@ -176,13 +174,14 @@ func (c *Client) CustomTokenWithClaims(ctx context.Context, uid string, devClaim
info := &jwtInfo{
header: jwtHeader{Algorithm: "RS256", Type: "JWT"},
payload: &customToken{
Iss: iss,
Sub: iss,
Aud: firebaseAudience,
UID: uid,
Iat: now,
Exp: now + oneHourInSeconds,
Claims: devClaims,
Iss: iss,
Sub: iss,
Aud: firebaseAudience,
UID: uid,
Iat: now,
Exp: now + oneHourInSeconds,
TenantID: c.tenantID,
Claims: devClaims,
},
}
return info.Token(ctx, c.signer)
Expand Down Expand Up @@ -235,6 +234,8 @@ type baseClient struct {
httpClient *internal.HTTPClient
idTokenVerifier *tokenVerifier
cookieVerifier *tokenVerifier
signer cryptoSigner
clock internal.Clock
}

func (c *baseClient) withTenantID(tenantID string) *baseClient {
Expand Down
92 changes: 67 additions & 25 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,20 +282,26 @@ func TestNewClientExplicitNoAuth(t *testing.T) {

func TestCustomToken(t *testing.T) {
client := &Client{
signer: testSigner,
clock: testClock,
baseClient: &baseClient{
signer: testSigner,
clock: testClock,
},
}
token, err := client.CustomToken(context.Background(), "user1")
if err != nil {
t.Fatal(err)
}
verifyCustomToken(context.Background(), token, nil, t)
if err := verifyCustomToken(context.Background(), token, nil, ""); err != nil {
t.Fatal(err)
}
}

func TestCustomTokenWithClaims(t *testing.T) {
client := &Client{
signer: testSigner,
clock: testClock,
baseClient: &baseClient{
signer: testSigner,
clock: testClock,
},
}
claims := map[string]interface{}{
"foo": "bar",
Expand All @@ -306,19 +312,46 @@ func TestCustomTokenWithClaims(t *testing.T) {
if err != nil {
t.Fatal(err)
}
verifyCustomToken(context.Background(), token, claims, t)
if err := verifyCustomToken(context.Background(), token, claims, ""); err != nil {
t.Fatal(err)
}
}

func TestCustomTokenWithNilClaims(t *testing.T) {
client := &Client{
signer: testSigner,
clock: testClock,
baseClient: &baseClient{
signer: testSigner,
clock: testClock,
},
}
token, err := client.CustomTokenWithClaims(context.Background(), "user1", nil)
if err != nil {
t.Fatal(err)
}
verifyCustomToken(context.Background(), token, nil, t)
if err := verifyCustomToken(context.Background(), token, nil, ""); err != nil {
t.Fatal(err)
}
}

func TestCustomTokenForTenant(t *testing.T) {
client := &Client{
baseClient: &baseClient{
tenantID: "tenantID",
signer: testSigner,
clock: testClock,
},
}
claims := map[string]interface{}{
"foo": "bar",
"premium": true,
}
token, err := client.CustomTokenWithClaims(context.Background(), "user1", claims)
if err != nil {
t.Fatal(err)
}
if err := verifyCustomToken(context.Background(), token, claims, "tenantID"); err != nil {
t.Fatal(err)
}
}

func TestCustomTokenError(t *testing.T) {
Expand All @@ -333,7 +366,7 @@ func TestCustomTokenError(t *testing.T) {
{"ReservedClaims", "uid", map[string]interface{}{"sub": "1234", "aud": "foo"}},
}

client := &Client{
client := &baseClient{
signer: testSigner,
clock: testClock,
}
Expand Down Expand Up @@ -628,9 +661,9 @@ func TestCustomTokenVerification(t *testing.T) {
client := &Client{
baseClient: &baseClient{
idTokenVerifier: testIDTokenVerifier,
signer: testSigner,
clock: testClock,
},
signer: testSigner,
clock: testClock,
}
token, err := client.CustomToken(context.Background(), "user1")
if err != nil {
Expand Down Expand Up @@ -1137,52 +1170,61 @@ func checkBaseClient(client *Client, wantProjectID string) error {
return nil
}

func verifyCustomToken(ctx context.Context, token string, expected map[string]interface{}, t *testing.T) {
func verifyCustomToken(
ctx context.Context, token string, expected map[string]interface{}, tenantID string) error {

if err := testIDTokenVerifier.verifySignature(ctx, token); err != nil {
t.Fatal(err)
return err
}

var (
header jwtHeader
payload customToken
)
segments := strings.Split(token, ".")
if err := decode(segments[0], &header); err != nil {
t.Fatal(err)
return err
}
if err := decode(segments[1], &payload); err != nil {
t.Fatal(err)
return err
}

email, err := testSigner.Email(ctx)
if err != nil {
t.Fatal(err)
return err
}

if header.Algorithm != "RS256" {
t.Errorf("Algorithm: %q; want: 'RS256'", header.Algorithm)
return fmt.Errorf("Algorithm: %q; want: 'RS256'", header.Algorithm)
} else if header.Type != "JWT" {
t.Errorf("Type: %q; want: 'JWT'", header.Type)
return fmt.Errorf("Type: %q; want: 'JWT'", header.Type)
} else if payload.Aud != firebaseAudience {
t.Errorf("Audience: %q; want: %q", payload.Aud, firebaseAudience)
return fmt.Errorf("Audience: %q; want: %q", payload.Aud, firebaseAudience)
} else if payload.Iss != email {
t.Errorf("Issuer: %q; want: %q", payload.Iss, email)
return fmt.Errorf("Issuer: %q; want: %q", payload.Iss, email)
} else if payload.Sub != email {
t.Errorf("Subject: %q; want: %q", payload.Sub, email)
return fmt.Errorf("Subject: %q; want: %q", payload.Sub, email)
}

now := testClock.Now().Unix()
if payload.Exp != now+3600 {
t.Errorf("Exp: %d; want: %d", payload.Exp, now+3600)
return fmt.Errorf("Exp: %d; want: %d", payload.Exp, now+3600)
}
if payload.Iat != now {
t.Errorf("Iat: %d; want: %d", payload.Iat, now)
return fmt.Errorf("Iat: %d; want: %d", payload.Iat, now)
}

for k, v := range expected {
if payload.Claims[k] != v {
t.Errorf("Claim[%q]: %v; want: %v", k, payload.Claims[k], v)
return fmt.Errorf("Claim[%q]: %v; want: %v", k, payload.Claims[k], v)
}
}

if payload.TenantID != tenantID {
return fmt.Errorf("Tenant ID: %q; want: %q", payload.TenantID, tenantID)
}

return nil
}

func logFatal(err error) {
Expand Down
15 changes: 8 additions & 7 deletions auth/token_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ type jwtHeader struct {
}

type customToken struct {
Iss string `json:"iss"`
Aud string `json:"aud"`
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
Sub string `json:"sub,omitempty"`
UID string `json:"uid,omitempty"`
Claims map[string]interface{} `json:"claims,omitempty"`
Iss string `json:"iss"`
Aud string `json:"aud"`
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
Sub string `json:"sub,omitempty"`
UID string `json:"uid,omitempty"`
TenantID string `json:"tenant_id,omitempty"`
Claims map[string]interface{} `json:"claims,omitempty"`
}

type jwtInfo struct {
Expand Down
13 changes: 11 additions & 2 deletions integration/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,19 @@ func verifyCustomToken(t *testing.T, ct, uid string) *auth.Token {
}

func signInWithCustomToken(token string) (string, error) {
req, err := json.Marshal(map[string]interface{}{
return signInWithCustomTokenForTenant(token, "")
}

func signInWithCustomTokenForTenant(token string, tenantID string) (string, error) {
payload := map[string]interface{}{
"token": token,
"returnSecureToken": true,
})
}
if tenantID != "" {
payload["tenantId"] = tenantID
}

req, err := json.Marshal(payload)
if err != nil {
return "", err
}
Expand Down
38 changes: 38 additions & 0 deletions integration/auth/tenant_mgt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ func TestTenantManager(t *testing.T) {
}
})

t.Run("CustomTokens", func(t *testing.T) {
testTenantAwareCustomToken(t, id)
})

t.Run("UserManagement", func(t *testing.T) {
testTenantAwareUserManagement(t, id)
})
Expand Down Expand Up @@ -154,6 +158,40 @@ func TestTenantManager(t *testing.T) {
})
}

func testTenantAwareCustomToken(t *testing.T, id string) {
tenantClient, err := client.TenantManager.AuthForTenant(id)
if err != nil {
t.Fatalf("AuthForTenant() = %v", err)
}

uid := randomUID()
ct, err := tenantClient.CustomToken(context.Background(), uid)
if err != nil {
t.Fatal(err)
}

idToken, err := signInWithCustomTokenForTenant(ct, id)
if err != nil {
t.Fatal(err)
}

defer func() {
tenantClient.DeleteUser(context.Background(), uid)
}()

vt, err := tenantClient.VerifyIDToken(context.Background(), idToken)
if err != nil {
t.Fatal(err)
}

if vt.UID != uid {
t.Errorf("UID = %q; want UID = %q", vt.UID, uid)
}
if vt.Firebase.Tenant != id {
t.Errorf("Tenant = %q; want = %q", vt.Firebase.Tenant, id)
}
}

func testTenantAwareUserManagement(t *testing.T, id string) {
tenantClient, err := client.TenantManager.AuthForTenant(id)
if err != nil {
Expand Down