diff --git a/smalltrip/problem/multiplexer.go b/smalltrip/problem/multiplexer.go new file mode 100644 index 0000000..76d973f --- /dev/null +++ b/smalltrip/problem/multiplexer.go @@ -0,0 +1,148 @@ +// Copyright 2025-present Gustavo "Guz" L. de Mello +// Copyright 2025-present The Lored.dev Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package problem + +import ( + "net/http" + "strings" + + "forge.capytal.company/loreddev/x/smalltrip/multiplexer" +) + +func Multiplexer(m multiplexer.Multiplexer, opts ...MultiplexerOption) multiplexer.Multiplexer { + mux := &mux{Multiplexer: m, methodList: DefaultMethods} + + for _, opt := range opts { + opt(mux) + } + + return mux +} + +type mux struct { + notFound http.Handler + methodNotAllowed http.Handler + methodList []string + + multiplexer.Multiplexer +} + +func (m *mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + i := &interceptor{ + notFound: m.notFound, + methodNotAllowed: m.methodNotAllowed, + methodList: m.methodList, + + mux: m, + + w: w, + r: r, + } + m.Multiplexer.ServeHTTP(i, r) +} + +type interceptor struct { + notFound http.Handler + methodNotAllowed http.Handler + methodList []string + + mux multiplexer.Multiplexer + + statusCode int + intercept bool + + w http.ResponseWriter + r *http.Request +} + +var _ http.ResponseWriter = (*interceptor)(nil) + +func (e *interceptor) Header() http.Header { + return e.w.Header() +} + +func (e *interceptor) WriteHeader(statusCode int) { + if statusCode > 399 && strings.Contains(e.w.Header().Get("Content-Type"), "text/plain") { + e.w.Header().Del("Content-Type") + e.intercept = true + e.statusCode = statusCode + return + } + e.w.WriteHeader(statusCode) +} + +func (e *interceptor) Write(data []byte) (int, error) { + if e.intercept && e.statusCode == http.StatusMethodNotAllowed { + method := e.r.Method + _, current := e.mux.Handler(e.r) + + var allowed []string + for _, m := range e.methodList { + e.r.Method = m + if _, p := e.mux.Handler(e.r); p != current { + allowed = append(allowed, m) + } + } + + e.r.Method = method + + if e.methodNotAllowed != nil { + e.w.Header().Set("Allow", strings.Join(allowed, ", ")) + e.methodNotAllowed.ServeHTTP(e.w, e.r) + } else { + NewMethodNotAllowed(allowed).ServeHTTP(e.w, e.r) + } + return len(data), nil + + } else if e.intercept && e.statusCode == http.StatusNotFound && e.notFound != nil { + e.notFound.ServeHTTP(e.w, e.r) + return len(data), nil + + } else if e.intercept { + NewDetailed(e.statusCode, strings.TrimSpace(string(data))).ServeHTTP(e.w, e.r) + return len(data), nil + } + + return e.w.Write(data) +} + +var DefaultMethods = multiplexer.DefaultMethods + +type MultiplexerOption func(*mux) + +func WithNotFound(h http.Handler) MultiplexerOption { + return func(m *mux) { + m.notFound = h + } +} + +func WithMethodNotAllowed(h http.Handler) MultiplexerOption { + return func(m *mux) { + m.methodNotAllowed = h + } +} + +func WithMethod(method string) MultiplexerOption { + return func(m *mux) { + m.methodList = append(m.methodList, method) + } +} + +func WithMethodList(list []string) MultiplexerOption { + return func(m *mux) { + m.methodList = list + } +}