Skip to content

Commit 9e635ca

Browse files
loganluluzhao
authored andcommitted
Add xray.AWSSession to install handlers on session (#97)
An application has to call xray.AWS for each AWS client it constructs. This creates opportunities for blind spots if someone forgets to configure a new client. The xray.AWSSession installs the same handlers at the Session level. Clients inherit handlers from the session they're created with. As long as the application systematically reuses the same session to create clients, it only needs to install X-Ray handlers once.
1 parent bf07cf7 commit 9e635ca

File tree

2 files changed

+106
-48
lines changed

2 files changed

+106
-48
lines changed

xray/aws.go

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121

2222
"github.com/aws/aws-sdk-go/aws/client"
2323
"github.com/aws/aws-sdk-go/aws/request"
24+
"github.com/aws/aws-sdk-go/aws/session"
2425
"github.com/aws/aws-xray-sdk-go/internal/logger"
2526
"github.com/aws/aws-xray-sdk-go/resources"
2627
)
@@ -137,33 +138,46 @@ var xRayAfterRetryHandler = request.NamedHandler{
137138
},
138139
}
139140

140-
func pushHandlers(c *client.Client) {
141-
c.Handlers.Validate.PushFrontNamed(xRayBeforeValidateHandler)
142-
c.Handlers.Build.PushBackNamed(xRayAfterBuildHandler)
143-
c.Handlers.Sign.PushFrontNamed(xRayBeforeSignHandler)
144-
c.Handlers.Send.PushBackNamed(xRayAfterSendHandler)
145-
c.Handlers.Unmarshal.PushFrontNamed(xRayBeforeUnmarshalHandler)
146-
c.Handlers.Unmarshal.PushBackNamed(xRayAfterUnmarshalHandler)
147-
c.Handlers.Retry.PushFrontNamed(xRayBeforeRetryHandler)
148-
c.Handlers.AfterRetry.PushBackNamed(xRayAfterRetryHandler)
141+
func pushHandlers(handlers *request.Handlers, completionWhitelistFilename string) {
142+
handlers.Validate.PushFrontNamed(xRayBeforeValidateHandler)
143+
handlers.Build.PushBackNamed(xRayAfterBuildHandler)
144+
handlers.Sign.PushFrontNamed(xRayBeforeSignHandler)
145+
handlers.Send.PushBackNamed(xRayAfterSendHandler)
146+
handlers.Unmarshal.PushFrontNamed(xRayBeforeUnmarshalHandler)
147+
handlers.Unmarshal.PushBackNamed(xRayAfterUnmarshalHandler)
148+
handlers.Retry.PushFrontNamed(xRayBeforeRetryHandler)
149+
handlers.AfterRetry.PushBackNamed(xRayAfterRetryHandler)
150+
handlers.Complete.PushFrontNamed(xrayCompleteHandler(completionWhitelistFilename))
149151
}
150152

151153
// AWS adds X-Ray tracing to an AWS client.
152154
func AWS(c *client.Client) {
153155
if c == nil {
154156
panic("Please initialize the provided AWS client before passing to the AWS() method.")
155157
}
156-
pushHandlers(c)
157-
c.Handlers.Complete.PushFrontNamed(xrayCompleteHandler(""))
158+
pushHandlers(&c.Handlers, "")
158159
}
159160

160161
// AWSWithWhitelist allows a custom parameter whitelist JSON file to be defined.
161162
func AWSWithWhitelist(c *client.Client, filename string) {
162163
if c == nil {
163164
panic("Please initialize the provided AWS client before passing to the AWSWithWhitelist() method.")
164165
}
165-
pushHandlers(c)
166-
c.Handlers.Complete.PushFrontNamed(xrayCompleteHandler(filename))
166+
pushHandlers(&c.Handlers, filename)
167+
}
168+
169+
// AWSSession adds X-Ray tracing to an AWS session. Clients created under this
170+
// session will inherit X-Ray tracing.
171+
func AWSSession(s *session.Session) *session.Session {
172+
pushHandlers(&s.Handlers, "")
173+
return s
174+
}
175+
176+
// AWSSessionWithWhitelist allows a custom parameter whitelist JSON file to be
177+
// defined.
178+
func AWSSessionWithWhitelist(s *session.Session, filename string) *session.Session {
179+
pushHandlers(&s.Handlers, filename)
180+
return s
167181
}
168182

169183
func xrayCompleteHandler(filename string) request.NamedHandler {

xray/aws_test.go

Lines changed: 79 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,86 @@ import (
1414
"github.com/stretchr/testify/assert"
1515
)
1616

17-
func TestClientSuccessfulConnection(t *testing.T) {
18-
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
19-
b := []byte(`{}`)
20-
w.WriteHeader(http.StatusOK)
21-
w.Write(b)
22-
}))
23-
24-
svc := lambda.New(session.Must(session.NewSession(&aws.Config{
25-
Endpoint: aws.String(ts.URL),
26-
Region: aws.String("fake-moon-1"),
27-
Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")})))
17+
func TestAWS(t *testing.T) {
18+
// Runs a suite of tests against two different methods of registering
19+
// handlers on an AWS client.
20+
21+
type test func(*testing.T, *lambda.Lambda)
22+
tests := []struct {
23+
name string
24+
test test
25+
failConn bool
26+
}{
27+
{"failed connection", testClientFailedConnection, true},
28+
{"successful connection", testClientSuccessfulConnection, false},
29+
{"without segment", testClientWithoutSegment, false},
30+
}
2831

29-
ctx, root := BeginSegment(context.Background(), "Test")
32+
onClient := func(s *session.Session) *lambda.Lambda {
33+
svc := lambda.New(s)
34+
AWS(svc.Client)
35+
return svc
36+
}
3037

31-
AWS(svc.Client)
38+
onSession := func(s *session.Session) *lambda.Lambda {
39+
return lambda.New(AWSSession(s))
40+
}
41+
42+
const whitelist = "../resources/AWSWhitelist.json"
43+
44+
onClientWithWhitelist := func(s *session.Session) *lambda.Lambda {
45+
svc := lambda.New(s)
46+
AWSWithWhitelist(svc.Client, whitelist)
47+
return svc
48+
}
49+
50+
onSessionWithWhitelist := func(s *session.Session) *lambda.Lambda {
51+
return lambda.New(AWSSessionWithWhitelist(s, whitelist))
52+
}
53+
54+
type constructor func(*session.Session) *lambda.Lambda
55+
constructors := []struct {
56+
name string
57+
constructor constructor
58+
}{
59+
{"AWS()", onClient},
60+
{"AWSSession()", onSession},
61+
{"AWSWithWhitelist()", onClientWithWhitelist},
62+
{"AWSSessionWithWhitelist()", onSessionWithWhitelist},
63+
}
64+
65+
// Run all combinations of constructors + tests.
66+
for _, cons := range constructors {
67+
t.Run(cons.name, func(t *testing.T) {
68+
for _, test := range tests {
69+
t.Run(test.name, func(t *testing.T) {
70+
test.test(t, cons.constructor(fakeSession(t, test.failConn)))
71+
})
72+
}
73+
})
74+
}
75+
}
3276

77+
func fakeSession(t *testing.T, failConn bool) *session.Session {
78+
cfg := &aws.Config{
79+
Region: aws.String("fake-moon-1"),
80+
Credentials: credentials.NewStaticCredentials("akid", "secret", "noop"),
81+
}
82+
if !failConn {
83+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
84+
b := []byte(`{}`)
85+
w.WriteHeader(http.StatusOK)
86+
w.Write(b)
87+
}))
88+
cfg.Endpoint = aws.String(ts.URL)
89+
}
90+
s, err := session.NewSession(cfg)
91+
assert.NoError(t, err)
92+
return s
93+
}
94+
95+
func testClientSuccessfulConnection(t *testing.T, svc *lambda.Lambda) {
96+
ctx, root := BeginSegment(context.Background(), "Test")
3397
_, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{})
3498
root.Close(nil)
3599
assert.NoError(t, err)
@@ -76,15 +140,8 @@ func TestClientSuccessfulConnection(t *testing.T) {
76140
}
77141
}
78142

79-
func TestClientFailedConnection(t *testing.T) {
80-
svc := lambda.New(session.Must(session.NewSession(&aws.Config{
81-
Region: aws.String("fake-moon-1"),
82-
Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")})))
83-
143+
func testClientFailedConnection(t *testing.T, svc *lambda.Lambda) {
84144
ctx, root := BeginSegment(context.Background(), "Test")
85-
86-
AWS(svc.Client)
87-
88145
_, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{})
89146
root.Close(nil)
90147
assert.Error(t, err)
@@ -116,24 +173,11 @@ func TestClientFailedConnection(t *testing.T) {
116173
assert.NotEmpty(t, connectSubseg.Subsegments)
117174
}
118175

119-
func TestClientWithoutSegment(t *testing.T) {
176+
func testClientWithoutSegment(t *testing.T, svc *lambda.Lambda) {
120177
Configure(Config{ContextMissingStrategy: &TestContextMissingStrategy{}})
121178
defer ResetConfig()
122-
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
123-
b := []byte(`{}`)
124-
w.WriteHeader(http.StatusOK)
125-
w.Write(b)
126-
}))
127-
128-
svc := lambda.New(session.Must(session.NewSession(&aws.Config{
129-
Endpoint: aws.String(ts.URL),
130-
Region: aws.String("fake-moon-1"),
131-
Credentials: credentials.NewStaticCredentials("akid", "secret", "noop")})))
132179

133180
ctx := context.Background()
134-
135-
AWS(svc.Client)
136-
137181
_, err := svc.ListFunctionsWithContext(ctx, &lambda.ListFunctionsInput{})
138182
assert.NoError(t, err)
139183
}

0 commit comments

Comments
 (0)