diff --git a/provided/run/init.go b/provided/run/init.go index 62cdff6..0040af0 100644 --- a/provided/run/init.go +++ b/provided/run/init.go @@ -64,22 +64,25 @@ var logsBuf bytes.Buffer var serverInitEnd time.Time func newContext() *mockLambdaContext { + context := &mockLambdaContext{ - RequestID: fakeGUID(), - FnName: getEnv("AWS_LAMBDA_FUNCTION_NAME", "test"), - Version: getEnv("AWS_LAMBDA_FUNCTION_VERSION", "$LATEST"), - MemSize: getEnv("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "1536"), - Timeout: getEnv("AWS_LAMBDA_FUNCTION_TIMEOUT", "300"), - Region: getEnv("AWS_REGION", getEnv("AWS_DEFAULT_REGION", "us-east-1")), - AccountID: getEnv("AWS_ACCOUNT_ID", strconv.FormatInt(int64(rand.Int31()), 10)), - XAmznTraceID: getEnv("_X_AMZN_TRACE_ID", ""), - ClientContext: getEnv("AWS_LAMBDA_CLIENT_CONTEXT", ""), - CognitoIdentity: getEnv("AWS_LAMBDA_COGNITO_IDENTITY", ""), - Start: time.Now(), - Done: make(chan bool), + RequestID: fakeGUID(), + FnName: getEnv("AWS_LAMBDA_FUNCTION_NAME", "test"), + Version: getEnv("AWS_LAMBDA_FUNCTION_VERSION", "$LATEST"), + MemSize: getEnv("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "1536"), + Timeout: getEnv("AWS_LAMBDA_FUNCTION_TIMEOUT", "300"), + Region: getEnv("AWS_REGION", getEnv("AWS_DEFAULT_REGION", "us-east-1")), + AccountID: getEnv("AWS_ACCOUNT_ID", strconv.FormatInt(int64(rand.Int31()), 10)), + XAmznTraceID: getEnv("_X_AMZN_TRACE_ID", ""), + ClientContext: getEnv("AWS_LAMBDA_CLIENT_CONTEXT", ""), + CognitoIdentity: getEnv("AWS_LAMBDA_COGNITO_IDENTITY", ""), + GracefulTerminationDelay: getEnv("AWS_LAMBDA_GRACEFUL_TERMINATION_DELAY","30s"), + Start: time.Now(), + Done: make(chan bool), } context.ParseTimeout() context.ParseFunctionArn() + context.ParseGracefulTerminationDelay() return context } @@ -158,7 +161,7 @@ func main() { var runtimeServer *http.Server - runtimeRouter := createRuntimeRouter() + runtimeRouter := createRuntimeRouter(interrupt) runtimeServer = &http.Server{Handler: addAPIRoutes(runtimeRouter)} go runtimeServer.Serve(runtimeListener) @@ -172,6 +175,7 @@ func main() { setupFileWatchers() } setupSighupHandler() + setupSigStopHandler(curContext.GracefulTerminationDelayDuration) systemLog(fmt.Sprintf("Lambda API listening on port %s...", apiPort)) <-interrupt } else { @@ -215,6 +219,20 @@ func setupSighupHandler() { }() } +func setupSigStopHandler(delay time.Duration) { + sighupReceiver := make(chan os.Signal, 1) + signal.Notify(sighupReceiver, syscall.SIGTERM) + go func() { + for { + <-sighupReceiver + systemLog(fmt.Sprintf("SIGTERM received, waiting for end of execution...")) + time.Sleep(delay) + systemLog(fmt.Sprintf("SIGTERM received, waited but still processing, going to kill the process.")) + exit(0) + } + }() +} + func setupFileWatchers() { fileWatcher := make(chan notify.EventInfo, 1) if err := notify.Watch("/var/task/...", fileWatcher, notify.All); err != nil { @@ -457,7 +475,7 @@ func addAPIRoutes(r *chi.Mux) *chi.Mux { return r } -func createRuntimeRouter() *chi.Mux { +func createRuntimeRouter(interrupt chan os.Signal) *chi.Mux { r := chi.NewRouter() r.Route("/2018-06-01", func(r chi.Router) { @@ -465,6 +483,10 @@ func createRuntimeRouter() *chi.Mux { w.Write([]byte("pong")) }) + r.Post("/stop",func(w http.ResponseWriter, r *http.Request) { + close(interrupt) + }) + r.Route("/runtime", func(r chi.Router) { r. With(updateState("STATE_INIT_ERROR")). @@ -801,6 +823,8 @@ type mockLambdaContext struct { InvokedFunctionArn string ClientContext string CognitoIdentity string + GracefulTerminationDelay string + GracefulTerminationDelayDuration time.Duration Start time.Time InvokeWait time.Time InitEnd time.Time @@ -824,6 +848,14 @@ func (mc *mockLambdaContext) ParseTimeout() { mc.TimeoutDuration = timeoutDuration } +func (mc *mockLambdaContext) ParseGracefulTerminationDelay() { + gracefulTerminationDelayDuration, err := time.ParseDuration(mc.GracefulTerminationDelay + "s") + if err != nil { + panic(err) + } + mc.GracefulTerminationDelayDuration = gracefulTerminationDelayDuration +} + func (mc *mockLambdaContext) ParseFunctionArn() { mc.InvokedFunctionArn = getEnv("AWS_LAMBDA_FUNCTION_INVOKED_ARN", arn(mc.Region, mc.AccountID, mc.FnName)) }