Skip to content

Commit 40fc7ef

Browse files
committed
Add more tests for token client
1 parent 59dc349 commit 40fc7ef

File tree

2 files changed

+48
-6
lines changed

2 files changed

+48
-6
lines changed

client.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ var (
3636
HTTPClientTimeout = 30 * time.Second
3737
)
3838

39+
// DialTLS is the default dial function for creating TLS connections for
40+
// non-proxied HTTPS requests.
41+
var DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
42+
return tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg)
43+
}
44+
3945
// Client represents a connection with the APNs
4046
type Client struct {
4147
Host string
@@ -64,9 +70,7 @@ func NewClient(certificate tls.Certificate) *Client {
6470
}
6571
transport := &http2.Transport{
6672
TLSClientConfig: tlsConfig,
67-
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
68-
return tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg)
69-
},
73+
DialTLS: DialTLS,
7074
}
7175
return &Client{
7276
HTTPClient: &http.Client{
@@ -88,9 +92,7 @@ func NewClient(certificate tls.Certificate) *Client {
8892
// connection and disconnection as a denial-of-service attack.
8993
func NewTokenClient(token *token.Token) *Client {
9094
transport := &http2.Transport{
91-
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
92-
return tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg)
93-
},
95+
DialTLS: DialTLS,
9496
}
9597
return &Client{
9698
Token: token,

client_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package apns2_test
22

33
import (
4+
"crypto/ecdsa"
5+
"crypto/elliptic"
6+
"crypto/rand"
47
"crypto/tls"
58
"fmt"
69
"io/ioutil"
@@ -15,6 +18,7 @@ import (
1518

1619
apns "github.com/sideshow/apns2"
1720
"github.com/sideshow/apns2/certificate"
21+
"github.com/sideshow/apns2/token"
1822
"github.com/stretchr/testify/assert"
1923
)
2024

@@ -27,6 +31,12 @@ func mockNotification() *apns.Notification {
2731
return n
2832
}
2933

34+
func mockToken() *token.Token {
35+
pubkeyCurve := elliptic.P256()
36+
authKey, _ := ecdsa.GenerateKey(pubkeyCurve, rand.Reader)
37+
return &token.Token{AuthKey: authKey}
38+
}
39+
3040
func mockCert() tls.Certificate {
3141
return tls.Certificate{}
3242
}
@@ -42,16 +52,31 @@ func TestClientDefaultHost(t *testing.T) {
4252
assert.Equal(t, "https://api.development.push.apple.com", client.Host)
4353
}
4454

55+
func TestTokenDefaultHost(t *testing.T) {
56+
client := apns.NewTokenClient(mockToken()).Development()
57+
assert.Equal(t, "https://api.development.push.apple.com", client.Host)
58+
}
59+
4560
func TestClientDevelopmentHost(t *testing.T) {
4661
client := apns.NewClient(mockCert()).Development()
4762
assert.Equal(t, "https://api.development.push.apple.com", client.Host)
4863
}
4964

65+
func TestTokenClientDevelopmentHost(t *testing.T) {
66+
client := apns.NewTokenClient(mockToken()).Development()
67+
assert.Equal(t, "https://api.development.push.apple.com", client.Host)
68+
}
69+
5070
func TestClientProductionHost(t *testing.T) {
5171
client := apns.NewClient(mockCert()).Production()
5272
assert.Equal(t, "https://api.push.apple.com", client.Host)
5373
}
5474

75+
func TestTokenClientProductionHost(t *testing.T) {
76+
client := apns.NewTokenClient(mockToken()).Production()
77+
assert.Equal(t, "https://api.push.apple.com", client.Host)
78+
}
79+
5580
func TestClientBadUrlError(t *testing.T) {
5681
n := mockNotification()
5782
res, err := mockClient("badurl://badurl.com").Push(n)
@@ -150,6 +175,21 @@ func TestHeaders(t *testing.T) {
150175
assert.NoError(t, err)
151176
}
152177

178+
func TestAuthorizationHeader(t *testing.T) {
179+
n := mockNotification()
180+
token := mockToken()
181+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
182+
assert.Equal(t, "application/json; charset=utf-8", r.Header.Get("Content-Type"))
183+
assert.Equal(t, fmt.Sprintf("bearer %v", token.Bearer), r.Header.Get("authorization"))
184+
}))
185+
defer server.Close()
186+
187+
client := mockClient(server.URL)
188+
client.Token = token
189+
_, err := client.Push(n)
190+
assert.NoError(t, err)
191+
}
192+
153193
func TestPayload(t *testing.T) {
154194
n := mockNotification()
155195
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

0 commit comments

Comments
 (0)