diff --git a/middleware/logger.go b/middleware/logger.go index 9fd523a..fba48bf 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -23,167 +23,44 @@ import ( "net/http" ) -const ( - defaultMsgNew = "NEW REQUEST" - defaultMsg200 = "END REQUEST" - defaultMsg400 = "INV REQUEST" - defaultMsg500 = "ERR REQUEST" -) - -func Logger(logger *slog.Logger, options ...LoggerOption) Middleware { - state := &loggerState{ - levelNew: slog.LevelDebug, - msgNew: "", - argsNew: LoggerArgsDefault, - - level200: -1, - msg200: "", - args200: LoggerArgsDefault, - - level400: -1, - msg400: "", - args400: LoggerArgsDefault, - - level500: -1, - msg500: "", - args500: LoggerArgsDefault, - - hashFunction: randHash, - - logger: logger, - } - - for _, option := range options { - option(state) - } - - if state.level200 == -1 { - state.level200 = state.levelNew + 4 - } - if state.level400 == -1 { - state.level400 = state.level200 + 4 - } - if state.level500 == -1 { - state.level500 = state.level500 + 4 - } - - if state.msgNew == "" { - state.msgNew = defaultMsgNew - } - if state.msg200 == "" { - if state.msgNew != "" && state.msgNew != defaultMsgNew { - state.msg200 = state.msgNew - } else { - state.msg200 = defaultMsg200 - } - } - if state.msg400 == "" { - if state.msg200 != "" && state.msg200 != defaultMsg200 { - state.msg400 = state.msg200 - } else { - state.msg400 = defaultMsg400 - } - } - if state.msg500 == "" { - if state.msg400 != "" && state.msg400 != defaultMsg400 { - state.msg500 = state.msg400 - } else { - state.msg500 = defaultMsg500 - } - } - +func Logger(logger *slog.Logger) Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - id := randHash(5) - - state.logger = logger.With(slog.String("id", id)) - lw := &loggerResponseWriter{w, 0} - state.lNew(state.argsNew(lw, r)...) + addr := loggerGetAddr(r) + if net.ParseIP(addr) == nil { + addr = fmt.Sprintf("INVALID %s", addr) + } + + log := logger.With( + slog.String("id", randHash(5)), + slog.String("method", fmt.Sprintf("%4s", r.Method)), + slog.String("addr", addr), + slog.String("path", r.URL.Path), + ) + + log.Debug("NEW REQUEST", slog.String("status", "000")) next.ServeHTTP(lw, r) + log = log.With(slog.String("status", fmt.Sprintf("%3d", lw.statusCode))) + switch { - case lw.StatusCode() >= 500: - state.l500(state.args500(lw, r)...) - case lw.StatusCode() >= 400: - state.l400(state.args400(lw, r)...) + case lw.statusCode >= 500: + log.Warn("ERR REQUEST") + case lw.statusCode >= 400: + log.Info("INV REQUEST") + case lw.statusCode >= 200: + log.Debug("END REQUEST") default: - state.l200(state.args200(lw, r)...) + log.Debug("MSC REQUEST") } }) } } -type LoggerOption func(*loggerState) - -func LoggerWithLevel(l slog.Level) LoggerOption { - return func(ls *loggerState) { ls.levelNew = l } -} - -func LoggerWithMsg(msg string) LoggerOption { - return func(ls *loggerState) { ls.msgNew = msg } -} - -func LoggerWithArgs(args LoggerArgs) LoggerOption { - return func(ls *loggerState) { ls.argsNew = args } -} - -func LoggerWith200Level(l slog.Level) LoggerOption { - return func(ls *loggerState) { ls.level200 = l } -} - -func LoggerWith200Msg(msg string) LoggerOption { - return func(ls *loggerState) { ls.msg200 = msg } -} - -func LoggerWith200Args(args LoggerArgs) LoggerOption { - return func(ls *loggerState) { ls.args200 = args } -} - -func LoggerWith400Level(l slog.Level) LoggerOption { - return func(ls *loggerState) { ls.level400 = l } -} - -func LoggerWith400Msg(msg string) LoggerOption { - return func(ls *loggerState) { ls.msg400 = msg } -} - -func LoggerWith400Args(args LoggerArgs) LoggerOption { - return func(ls *loggerState) { ls.args400 = args } -} - -func LoggerWith500Level(l slog.Level) LoggerOption { - return func(ls *loggerState) { ls.level500 = l } -} - -func LoggerWith500Msg(msg string) LoggerOption { - return func(ls *loggerState) { ls.msg500 = msg } -} - -func LoggerWith500Args(args LoggerArgs) LoggerOption { - return func(ls *loggerState) { ls.args500 = args } -} - -type LoggerArgs func(LoggerResponseWriter, *http.Request) []any - -func LoggerArgsDefault(lw LoggerResponseWriter, r *http.Request) []any { - addr := LoggerGetAddr(r) - - if net.ParseIP(addr) == nil { - addr = fmt.Sprintf("INVALID ADDR %s", addr) - } - - return []any{ - slog.String("status", fmt.Sprintf("%3d", lw.StatusCode())), - slog.String("method", fmt.Sprintf("%3s", r.Method)), - slog.String("addr", addr), - slog.String("path", r.URL.Path), - } -} - -func LoggerGetAddr(r *http.Request) string { +func loggerGetAddr(r *http.Request) string { if i := r.Header.Get("CF-Connecting-IP"); i != "" { return i } @@ -196,11 +73,6 @@ func LoggerGetAddr(r *http.Request) string { return r.RemoteAddr } -type LoggerResponseWriter interface { - http.ResponseWriter - StatusCode() int -} - type loggerResponseWriter struct { http.ResponseWriter statusCode int @@ -211,72 +83,7 @@ func (w *loggerResponseWriter) WriteHeader(s int) { w.ResponseWriter.WriteHeader(s) } -func (w *loggerResponseWriter) StatusCode() int { - return w.statusCode -} - -type loggerState struct { - levelNew slog.Level - msgNew string - argsNew LoggerArgs - - level200 slog.Level - msg200 string - args200 LoggerArgs - - level400 slog.Level - msg400 string - args400 LoggerArgs - - level500 slog.Level - msg500 string - args500 LoggerArgs - - hashFunction func(n int) string - - logger *slog.Logger -} - -func (l *loggerState) lNew(args ...any) { - l.logLevel(l.levelNew, l.msgNew, args...) -} - -func (l *loggerState) l200(args ...any) { - l.logLevel(l.level200, l.msg200, args...) -} - -func (l *loggerState) l400(args ...any) { - l.logLevel(l.level400, l.msg400, args...) -} - -func (l *loggerState) l500(args ...any) { - l.logLevel(l.level500, l.msg500, args...) -} - -func (l *loggerState) logLevel(level slog.Level, msg string, args ...any) { - switch true { - case level >= slog.LevelError: - l.logger.Error(msg, args...) - case level >= slog.LevelWarn: - l.logger.Warn(msg, args...) - case level >= slog.LevelInfo: - l.logger.Info(msg, args...) - default: - l.logger.Debug(msg, args...) - } -} - -func getBiggestLength(s ...string) int { - var l int - for _, s := range s { - if len(s) > l { - l = len(s) - } - } - return l -} - -const HASH_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +const hashChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" // This is not the most performant function, as a TODO we could // improve based on this Stackoberflow thread: @@ -284,7 +91,7 @@ const HASH_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567 func randHash(n int) string { b := make([]byte, n) for i := range b { - b[i] = HASH_CHARS[rand.Int63()%int64(len(HASH_CHARS))] + b[i] = hashChars[rand.Int63()%int64(len(hashChars))] } return string(b) }