boilerplate-go/core/appserver/middleware.go

196 lines
5.3 KiB
Go

package appserver
import (
"bytes"
"encoding/json"
"fmt"
"github.com/felixge/httpsnoop"
"github.com/google/uuid"
"gitlab.com/arkadooti.sarkar/go-boilerplate/core/appcontext"
"gitlab.com/arkadooti.sarkar/go-boilerplate/core/log"
"go.elastic.co/apm"
"go.elastic.co/apm/module/apmhttp"
"io/ioutil"
"net/http"
"runtime/debug"
"strings"
"time"
)
const TraceID = "traceid"
func SetTraceID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var traceID apm.TraceID
if values := r.Header[apmhttp.W3CTraceparentHeader]; len(values) == 1 && values[0] != "" {
if c, err := apmhttp.ParseTraceparentHeader(values[0]); err == nil {
traceID = c.Trace
}
}
if err := traceID.Validate(); err != nil {
uuidId := uuid.New()
var spanID apm.SpanID
var traceOptions apm.TraceOptions
copy(traceID[:], uuidId[:])
copy(spanID[:], traceID[8:])
traceContext := apm.TraceContext{
Trace: traceID,
Span: spanID,
Options: traceOptions.WithRecorded(true),
}
r.Header.Set(apmhttp.W3CTraceparentHeader, apmhttp.FormatTraceparentHeader(traceContext))
}
w.Header().Set(TraceID, traceID.String())
r.Header.Set(requestID, traceID.String())
next.ServeHTTP(w, r)
})
}
func enableCompression(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "br") && !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
next(w, r)
return
} else if !strings.Contains(r.Header.Get("Accept-Encoding"), "br") {
gzr := pool.Get().(*gzipResponseWriter)
gzr.statusCode = 0
gzr.headerWritten = false
gzr.ResponseWriter = w
gzr.w.Reset(w)
defer func() {
// gzr.w.Close will write a footer even if no data has been written.
// StatusNotModified and StatusNoContent expect an empty body so don't close it.
if gzr.statusCode != http.StatusNotModified && gzr.statusCode != http.StatusNoContent {
if err := gzr.w.Close(); err != nil {
ctx := appcontext.UpgradeCtx(r.Context())
log.GenericError(ctx, err, nil)
}
}
pool.Put(gzr)
}()
next(gzr, r)
return
}
br := poolbr.Get().(*brotliResponseWriter)
br.statusCode = 0
br.headerWritten = false
br.ResponseWriter = w
br.w.Reset(w)
defer func() {
// brotli.w.Close will write a footer even if no data has been written.x
// StatusNotModified and StatusNoContent expect an empty body so don't close it.
if br.statusCode != http.StatusNotModified && br.statusCode != http.StatusNoContent {
if err := br.w.Close(); err != nil {
ctx := appcontext.UpgradeCtx(r.Context())
log.GenericError(ctx, err, nil)
}
}
poolbr.Put(br)
}()
next(br, r)
}
}
func recovery(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
ctx := appcontext.UpgradeCtx(r.Context())
rec := recover()
if rec != nil {
span, _ := apm.StartSpan(ctx.Context, "recovery", "custom")
defer span.End()
trace := string(debug.Stack())
trace = strings.Replace(trace, "\n", " ", -1)
trace = strings.Replace(trace, "\t", " ", -1)
log.GenericError(ctx, fmt.Errorf("%v", rec),
log.FieldsMap{
"msg": "recovering from panic",
"stackTrace": trace,
})
jsonBody, _ := json.Marshal(map[string]string{
"error": "There was an internal server error",
})
e := apm.DefaultTracer.Recovered(rec)
e.SetSpan(span)
e.Send()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
w.Write(jsonBody)
}
}()
next.ServeHTTP(w, r)
}
}
func logRequest(handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
m := httpsnoop.CaptureMetrics(handler, w, r)
log.HTTPLog(constructHTTPLog(r, m, time.Since(start)))
}
}
func constructHTTPLog(r *http.Request, m httpsnoop.Metrics, duration time.Duration) string {
ctx := r.Context().Value(appcontext.AppCtx)
rawBody, _ := ioutil.ReadAll(r.Body)
if len(rawBody) > 0 {
r.Body = ioutil.NopCloser(bytes.NewBuffer(rawBody))
}
var jsonBody interface{}
// For Testing
json.Unmarshal(rawBody, &jsonBody)
bodyJsonByte, _ := json.Marshal(jsonBody)
if ctx != nil {
tCtx := ctx.(appcontext.AppExtContext)
return fmt.Sprintf("|%s|%s|%s|%s|%s|%d|%d|%s|%s|%s|",
tCtx.UserEmail,
"requestId="+tCtx.RequestID,
r.RemoteAddr,
r.Method,
r.URL,
m.Code,
m.Written,
r.UserAgent(),
duration,
"Body:"+string(bodyJsonByte),
)
}
return fmt.Sprintf("|%s|%s|%s|%d|%d|%s|%s|%s|",
r.RemoteAddr,
r.Method,
r.URL,
m.Code,
m.Written,
r.UserAgent(),
duration,
"Body:"+string(rawBody),
)
}
func createContext(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
header := r.Header
ctx := r.Context()
reqID := header.Get(requestID)
if reqID == "" {
reqID = strings.ReplaceAll(uuid.NewString(), "-", "")
}
email, app := header.Get(userEmail), header.Get(application)
locale := header.Get(locale)
tempCtx := appcontext.AppExtContext{
RequestID: reqID,
UserEmail: email,
Locale: locale,
Application: app,
}
ctx = appcontext.WithAppCtx(ctx, tempCtx)
next.ServeHTTP(w, r.WithContext(ctx))
}
}