Skip to content
This repository was archived by the owner on Jan 15, 2023. It is now read-only.

feat: graceful termination of the init process #334

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions provided/run/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,25 @@ var logsBuf bytes.Buffer
var serverInitEnd time.Time

func newContext() *mockLambdaContext {

context := &mockLambdaContext{
RequestID: fakeGUID(),
FnName: getEnv("AWS_LAMBDA_FUNCTION_NAME", "test"),
Version: getEnv("AWS_LAMBDA_FUNCTION_VERSION", "$LATEST"),
MemSize: getEnv("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "1536"),
Timeout: getEnv("AWS_LAMBDA_FUNCTION_TIMEOUT", "300"),
Region: getEnv("AWS_REGION", getEnv("AWS_DEFAULT_REGION", "us-east-1")),
AccountID: getEnv("AWS_ACCOUNT_ID", strconv.FormatInt(int64(rand.Int31()), 10)),
XAmznTraceID: getEnv("_X_AMZN_TRACE_ID", ""),
ClientContext: getEnv("AWS_LAMBDA_CLIENT_CONTEXT", ""),
CognitoIdentity: getEnv("AWS_LAMBDA_COGNITO_IDENTITY", ""),
Start: time.Now(),
Done: make(chan bool),
RequestID: fakeGUID(),
FnName: getEnv("AWS_LAMBDA_FUNCTION_NAME", "test"),
Version: getEnv("AWS_LAMBDA_FUNCTION_VERSION", "$LATEST"),
MemSize: getEnv("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "1536"),
Timeout: getEnv("AWS_LAMBDA_FUNCTION_TIMEOUT", "300"),
Region: getEnv("AWS_REGION", getEnv("AWS_DEFAULT_REGION", "us-east-1")),
AccountID: getEnv("AWS_ACCOUNT_ID", strconv.FormatInt(int64(rand.Int31()), 10)),
XAmznTraceID: getEnv("_X_AMZN_TRACE_ID", ""),
ClientContext: getEnv("AWS_LAMBDA_CLIENT_CONTEXT", ""),
CognitoIdentity: getEnv("AWS_LAMBDA_COGNITO_IDENTITY", ""),
GracefulTerminationDelay: getEnv("AWS_LAMBDA_GRACEFUL_TERMINATION_DELAY","30s"),
Start: time.Now(),
Done: make(chan bool),
}
context.ParseTimeout()
context.ParseFunctionArn()
context.ParseGracefulTerminationDelay()
return context
}

Expand Down Expand Up @@ -158,7 +161,7 @@ func main() {

var runtimeServer *http.Server

runtimeRouter := createRuntimeRouter()
runtimeRouter := createRuntimeRouter(interrupt)
runtimeServer = &http.Server{Handler: addAPIRoutes(runtimeRouter)}

go runtimeServer.Serve(runtimeListener)
Expand All @@ -172,6 +175,7 @@ func main() {
setupFileWatchers()
}
setupSighupHandler()
setupSigStopHandler(curContext.GracefulTerminationDelayDuration)
systemLog(fmt.Sprintf("Lambda API listening on port %s...", apiPort))
<-interrupt
} else {
Expand Down Expand Up @@ -215,6 +219,20 @@ func setupSighupHandler() {
}()
}

func setupSigStopHandler(delay time.Duration) {
sighupReceiver := make(chan os.Signal, 1)
signal.Notify(sighupReceiver, syscall.SIGTERM)
go func() {
for {
<-sighupReceiver
systemLog(fmt.Sprintf("SIGTERM received, waiting for end of execution..."))
time.Sleep(delay)
systemLog(fmt.Sprintf("SIGTERM received, waited but still processing, going to kill the process."))
exit(0)
}
}()
}

func setupFileWatchers() {
fileWatcher := make(chan notify.EventInfo, 1)
if err := notify.Watch("/var/task/...", fileWatcher, notify.All); err != nil {
Expand Down Expand Up @@ -457,14 +475,18 @@ func addAPIRoutes(r *chi.Mux) *chi.Mux {
return r
}

func createRuntimeRouter() *chi.Mux {
func createRuntimeRouter(interrupt chan os.Signal) *chi.Mux {
r := chi.NewRouter()

r.Route("/2018-06-01", func(r chi.Router) {
r.Get("/ping", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("pong"))
})

r.Post("/stop",func(w http.ResponseWriter, r *http.Request) {
close(interrupt)
})

r.Route("/runtime", func(r chi.Router) {
r.
With(updateState("STATE_INIT_ERROR")).
Expand Down Expand Up @@ -801,6 +823,8 @@ type mockLambdaContext struct {
InvokedFunctionArn string
ClientContext string
CognitoIdentity string
GracefulTerminationDelay string
GracefulTerminationDelayDuration time.Duration
Start time.Time
InvokeWait time.Time
InitEnd time.Time
Expand All @@ -824,6 +848,14 @@ func (mc *mockLambdaContext) ParseTimeout() {
mc.TimeoutDuration = timeoutDuration
}

func (mc *mockLambdaContext) ParseGracefulTerminationDelay() {
gracefulTerminationDelayDuration, err := time.ParseDuration(mc.GracefulTerminationDelay + "s")
if err != nil {
panic(err)
}
mc.GracefulTerminationDelayDuration = gracefulTerminationDelayDuration
}

func (mc *mockLambdaContext) ParseFunctionArn() {
mc.InvokedFunctionArn = getEnv("AWS_LAMBDA_FUNCTION_INVOKED_ARN", arn(mc.Region, mc.AccountID, mc.FnName))
}
Expand Down