Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions cmd/optimizely/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (

"github.com/optimizely/agent/config"
"github.com/optimizely/agent/pkg/metrics"
"github.com/optimizely/agent/pkg/middleware"
"github.com/optimizely/agent/pkg/optimizely"
"github.com/optimizely/agent/pkg/routers"
"github.com/optimizely/agent/pkg/server"
Expand Down Expand Up @@ -157,17 +158,18 @@ func getStdOutTraceProvider(conf config.TracingExporterConfig) (*sdktrace.Tracer
return sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exp),
sdktrace.WithResource(res),
sdktrace.WithIDGenerator(middleware.NewTraceIDGenerator()),
), nil
}

func getOTELTraceClient(conf config.TracingExporterConfig) (otlptrace.Client, error) {
switch conf.Services.Remote.Protocal {
case config.TracingRemoteProtocalHTTP:
switch conf.Services.Remote.Protocol {
case config.TracingRemoteProtocolHTTP:
return otlptracehttp.NewClient(
otlptracehttp.WithInsecure(),
otlptracehttp.WithEndpoint(conf.Services.Remote.Endpoint),
), nil
case config.TracingRemoteProtocalGRPC:
case config.TracingRemoteProtocolGRPC:
return otlptracegrpc.NewClient(
otlptracegrpc.WithInsecure(),
otlptracegrpc.WithEndpoint(conf.Services.Remote.Endpoint),
Expand Down Expand Up @@ -204,6 +206,7 @@ func getRemoteTraceProvider(conf config.TracingExporterConfig) (*sdktrace.Tracer
sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(conf.Services.Remote.SampleRate))),
sdktrace.WithResource(res),
sdktrace.WithSpanProcessor(bsp),
sdktrace.WithIDGenerator(middleware.NewTraceIDGenerator()),
), nil
}

Expand Down
6 changes: 3 additions & 3 deletions cmd/optimizely/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ func Test_initTracing(t *testing.T) {
Services: config.TracingServiceConfig{
Remote: config.TracingRemoteConfig{
Endpoint: "localhost:1234",
Protocal: "http",
Protocol: "http",
},
},
},
Expand All @@ -474,7 +474,7 @@ func Test_initTracing(t *testing.T) {
Services: config.TracingServiceConfig{
Remote: config.TracingRemoteConfig{
Endpoint: "localhost:1234",
Protocal: "grpc",
Protocol: "grpc",
},
},
},
Expand All @@ -489,7 +489,7 @@ func Test_initTracing(t *testing.T) {
Services: config.TracingServiceConfig{
Remote: config.TracingRemoteConfig{
Endpoint: "localhost:1234",
Protocal: "udp/invalid",
Protocol: "udp/invalid",
},
},
},
Expand Down
4 changes: 2 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ tracing:
remote:
## remote collector endpoint
endpoint: "localhost:4317"
## supported protocals are "http" and "grpc"
protocal: "grpc"
## supported protocols are "http" and "grpc"
protocol: "grpc"
## "sampleRate" refers to the rate at which traces are collected and recorded.
## sampleRate >= 1 will always sample.
## sampleRate < 0 are treated as zero i.e. never sample.
Expand Down
8 changes: 4 additions & 4 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,11 @@ const (
TracingServiceTypeRemote TracingServiceType = "remote"
)

type TracingRemoteProtocal string
type TracingRemoteProtocol string

const (
TracingRemoteProtocalGRPC TracingRemoteProtocal = "grpc"
TracingRemoteProtocalHTTP TracingRemoteProtocal = "http"
TracingRemoteProtocolGRPC TracingRemoteProtocol = "grpc"
TracingRemoteProtocolHTTP TracingRemoteProtocol = "http"
)

type TracingExporterConfig struct {
Expand All @@ -242,7 +242,7 @@ type TracingStdOutConfig struct {

type TracingRemoteConfig struct {
Endpoint string `json:"endpoint"`
Protocal TracingRemoteProtocal `json:"protocal"`
Protocol TracingRemoteProtocol `json:"protocol"`
SampleRate float64 `json:"sampleRate"`
}

Expand Down
5 changes: 5 additions & 0 deletions pkg/middleware/cached.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ const OptlyUPSHeader = "X-Optimizely-UPS-Name"
// OptlyODPCacheHeader is the header key for an ad-hoc ODP Cache name
const OptlyODPCacheHeader = "X-Optimizely-ODP-Cache-Name"

// OptlyTraceIDHeader is the header key for trace-id in distributed tracing.
// The value set in HTTP Header must be a hex compliant with the W3C trace-context specification.
// See more at https://www.w3.org/TR/trace-context/#trace-id
const OptlyTraceIDHeader = "X-Optimizely-Trace-ID"

// CachedOptlyMiddleware implements OptlyMiddleware backed by a cache
type CachedOptlyMiddleware struct {
Cache optimizely.Cache
Expand Down
58 changes: 57 additions & 1 deletion pkg/middleware/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,65 @@
package middleware

import (
"context"
crand "crypto/rand"
"encoding/binary"
"math/rand"
"net/http"
"sync"

"github.com/rs/zerolog/log"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"go.opentelemetry.io/otel/trace"
)

type traceIDGenerator struct {
sync.Mutex
randSource *rand.Rand
}

func NewTraceIDGenerator() *traceIDGenerator {
var rngSeed int64
_ = binary.Read(crand.Reader, binary.LittleEndian, &rngSeed)
return &traceIDGenerator{
randSource: rand.New(rand.NewSource(rngSeed)),
}
}

func (gen *traceIDGenerator) NewSpanID(ctx context.Context, traceID trace.TraceID) trace.SpanID {
gen.Lock()
defer gen.Unlock()
sid := trace.SpanID{}
_, _ = gen.randSource.Read(sid[:])
return sid
}

func (gen *traceIDGenerator) NewIDs(ctx context.Context) (trace.TraceID, trace.SpanID) {
gen.Lock()
defer gen.Unlock()
tid := trace.TraceID{}
_, _ = gen.randSource.Read(tid[:])
sid := trace.SpanID{}
_, _ = gen.randSource.Read(sid[:])

// read trace id from header if provided
traceIDHeader := ctx.Value(OptlyTraceIDHeader)
if val, ok := traceIDHeader.(string); ok {
if val != "" {
headerTraceId, err := trace.TraceIDFromHex(val)
if err == nil {
tid = headerTraceId
} else {
log.Error().Err(err).Msg("failed to parse trace id from header, invalid trace id")
}
}
}

return tid, sid
}

type statusRecorder struct {
http.ResponseWriter
statusCode int
Expand All @@ -37,7 +90,9 @@ func (r *statusRecorder) WriteHeader(code int) {
func AddTracing(tracerName, spanName string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx, span := otel.Tracer(tracerName).Start(r.Context(), spanName)
pctx := context.WithValue(r.Context(), OptlyTraceIDHeader, r.Header.Get(OptlyTraceIDHeader))

ctx, span := otel.Tracer(tracerName).Start(pctx, spanName)
defer span.End()

span.SetAttributes(
Expand All @@ -46,6 +101,7 @@ func AddTracing(tracerName, spanName string) func(http.Handler) http.Handler {
semconv.HTTPURLKey.String(r.URL.String()),
semconv.HTTPHostKey.String(r.Host),
semconv.HTTPSchemeKey.String(r.URL.Scheme),
attribute.String(OptlySDKHeader, r.Header.Get(OptlySDKHeader)),
)

rec := &statusRecorder{
Expand Down
53 changes: 53 additions & 0 deletions pkg/middleware/trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@
package middleware

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/trace"
)

func TestAddTracing(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/text")
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})

req := httptest.NewRequest("GET", "/", nil)
Expand All @@ -38,4 +44,51 @@ func TestAddTracing(t *testing.T) {
if status := rr.Code; status != http.StatusOK {
t.Errorf("Expected status code %v, but got %v", http.StatusOK, status)
}

if body := rr.Body.String(); body != "OK" {
t.Errorf("Expected response body %v, but got %v", "OK", body)
}

if typeHeader := rr.Header().Get("Content-Type"); typeHeader != "application/text" {
t.Errorf("Expected Content-Type header %v, but got %v", "application/text", typeHeader)
}
}

func TestNewIDs(t *testing.T) {
gen := NewTraceIDGenerator()
n := 1000

for i := 0; i < n; i++ {
traceID, spanID := gen.NewIDs(context.Background())
assert.Truef(t, traceID.IsValid(), "trace id: %s", traceID.String())
assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String())
}
}

func TestNewSpanID(t *testing.T) {
gen := NewTraceIDGenerator()
testTraceID := [16]byte{123, 123}
n := 1000

for i := 0; i < n; i++ {
spanID := gen.NewSpanID(context.Background(), testTraceID)
assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String())
}
}

func TestNewSpanIDWithInvalidTraceID(t *testing.T) {
gen := NewTraceIDGenerator()
spanID := gen.NewSpanID(context.Background(), trace.TraceID{})
assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String())
}

func TestTraceIDWithGivenHeaderValue(t *testing.T) {
gen := NewTraceIDGenerator()

traceID := "9b8eac67e332c6f8baf1e013de6891bb"

ctx := context.WithValue(context.Background(), OptlyTraceIDHeader, traceID)
genTraceID, _ := gen.NewIDs(ctx)
assert.Truef(t, genTraceID.IsValid(), "trace id: %s", genTraceID.String())
assert.Equal(t, traceID, genTraceID.String())
}