diff --git a/router/middleware/middleware.go b/router/middleware/middleware.go index 576bacd..b195135 100644 --- a/router/middleware/middleware.go +++ b/router/middleware/middleware.go @@ -4,11 +4,10 @@ import ( "errors" "fmt" "net/http" + "strconv" ) -type Middleware interface { - Serve(r http.HandlerFunc) http.HandlerFunc -} +type Middleware func(next http.Handler) http.Handler type MiddlewaredReponse struct { w http.ResponseWriter @@ -21,6 +20,7 @@ func NewMiddlewaredResponse(w http.ResponseWriter) *MiddlewaredReponse { } func (m *MiddlewaredReponse) WriteHeader(s int) { + m.Header().Set("Status", strconv.Itoa(s)) m.statuses = append(m.statuses, s) } diff --git a/router/router.go b/router/router.go index ef6542f..72f6585 100644 --- a/router/router.go +++ b/router/router.go @@ -62,15 +62,15 @@ func (rt *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { rt.serveHttp(w, r) } -func (r *Router) wrapMiddlewares(ms []middleware.Middleware, h http.HandlerFunc) http.HandlerFunc { - wh := h.ServeHTTP +func (r *Router) wrapMiddlewares(ms []middleware.Middleware, h http.Handler) http.HandlerFunc { + wh := h for _, m := range ms { - wh = m.Serve(wh) + wh = m(wh) } return func(w http.ResponseWriter, r *http.Request) { mw := middleware.NewMiddlewaredResponse(w) - wh(mw, r) + wh.ServeHTTP(mw, r) if _, err := mw.ReallyWriteHeader(); err != nil { _, _ = w.Write( []byte(