From c0854dea2c1e35b051a73e998e099dc4319f36a5 Mon Sep 17 00:00:00 2001 From: "Gustavo L de Mello (Guz)" Date: Wed, 18 Dec 2024 15:33:29 -0300 Subject: [PATCH] feat: move capytalcode/project-comicverse/libs to ./groute --- go.mod | 2 + go.sum | 4 + groute/cookies/cookies.go | 298 ++++++++++++++++++++++++++++++++ groute/forms/forms.go | 215 +++++++++++++++++++++++ groute/middleware/cache.go | 12 ++ groute/middleware/dev.go | 73 ++++++++ groute/middleware/middleware.go | 108 ++++++++++++ groute/router/default.go | 25 +++ groute/router/rerrors/400s.go | 43 +++++ groute/router/rerrors/500s.go | 14 ++ groute/router/rerrors/errors.go | 167 ++++++++++++++++++ groute/router/router.go | 231 +++++++++++++++++++++++++ 12 files changed, 1192 insertions(+) create mode 100644 go.sum create mode 100644 groute/cookies/cookies.go create mode 100644 groute/forms/forms.go create mode 100644 groute/middleware/cache.go create mode 100644 groute/middleware/dev.go create mode 100644 groute/middleware/middleware.go create mode 100644 groute/router/default.go create mode 100644 groute/router/rerrors/400s.go create mode 100644 groute/router/rerrors/500s.go create mode 100644 groute/router/rerrors/errors.go create mode 100644 groute/router/router.go diff --git a/go.mod b/go.mod index 4a0280b..ad6de2f 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module forge.capytal.company/loreddev/x go 1.23.3 + +require github.com/a-h/templ v0.2.793 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..dab7b07 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/a-h/templ v0.2.793 h1:Io+/ocnfGWYO4VHdR0zBbf39PQlnzVCVVD+wEEs6/qY= +github.com/a-h/templ v0.2.793/go.mod h1:lq48JXoUvuQrU0VThrK31yFwdRjTCnIE5bcPCM9IP1w= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= diff --git a/groute/cookies/cookies.go b/groute/cookies/cookies.go new file mode 100644 index 0000000..4c263a3 --- /dev/null +++ b/groute/cookies/cookies.go @@ -0,0 +1,298 @@ +package cookies + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "reflect" + "strconv" + "strings" + "time" + + "forge.capytal.company/loreddev/x/groute/router/rerrors" +) + +type Marshaler interface { + MarshalCookie() (*http.Cookie, error) +} + +type Unmarshaler interface { + UnmarshalCookie(*http.Cookie) error +} + +func Marshal(v any) (*http.Cookie, error) { + if m, ok := v.(Marshaler); ok { + return m.MarshalCookie() + } + + c, err := marshalValue(v) + if err != nil { + return c, err + } + + if err := setCookieProps(c, v); err != nil { + return c, err + } + + return c, err +} + +func MarshalToWriter(v any, w http.ResponseWriter) error { + if ck, err := Marshal(v); err != nil { + return err + } else { + http.SetCookie(w, ck) + } + return nil +} + +func Unmarshal(c *http.Cookie, v any) error { + if m, ok := v.(Unmarshaler); ok { + return m.UnmarshalCookie(c) + } + + value := c.Value + b, err := base64.URLEncoding.DecodeString(value) + if err != nil { + return errors.Join(ErrDecodeBase64, err) + } + + if err := json.Unmarshal(b, v); err != nil { + return errors.Join(ErrUnmarshal, err) + } + + return nil +} + +func UnmarshalRequest(r *http.Request, v any) error { + name, err := getCookieName(v) + if err != nil { + return err + } + + c, err := r.Cookie(name) + if errors.Is(err, http.ErrNoCookie) { + return ErrNoCookie{name} + } else if err != nil { + return err + } + + return Unmarshal(c, v) +} + +func UnmarshalIfRequest(r *http.Request, v any) (bool, error) { + if err := UnmarshalRequest(r, v); err != nil { + if _, ok := err.(ErrNoCookie); ok { + return false, nil + } else { + return true, err + } + } else { + return true, nil + } +} + +func RerrUnmarshalCookie(err error) rerrors.RouteError { + if e, ok := err.(ErrNoCookie); ok { + return rerrors.MissingCookies([]string{e.name}) + } else { + return rerrors.InternalError(err) + } +} + +func marshalValue(v any) (*http.Cookie, error) { + b, err := json.Marshal(v) + if err != nil { + return &http.Cookie{}, errors.Join(ErrMarshal, err) + } + + s := base64.URLEncoding.EncodeToString(b) + + return &http.Cookie{ + Value: s, + }, nil +} + +var COOKIE_EXPIRE_VALID_FORMATS = []string{ + time.DateOnly, time.DateTime, + time.RFC1123, time.RFC1123Z, +} + +func setCookieProps(c *http.Cookie, v any) error { + tag, err := getCookieTag(v) + if err != nil { + return err + } + + c.Name, err = getCookieName(v) + if err != nil { + return err + } + + tvs := strings.Split(tag, ",") + + if len(tvs) == 1 { + return nil + } + + tvs = tvs[1:] + + for _, tv := range tvs { + var k, v string + if strings.Contains(tv, "=") { + s := strings.Split(tv, "=") + k = s[0] + v = s[1] + } else { + k = tv + v = "" + } + + switch k { + case "SECURE": + c.Name = "__Secure-" + c.Name + c.Secure = true + + case "HOST": + c.Name = "__Host" + c.Name + c.Secure = true + c.Path = "/" + + case "path": + c.Path = v + + case "domain": + c.Domain = v + + case "httponly": + if v == "" { + c.HttpOnly = true + } else if v, err := strconv.ParseBool(v); err != nil { + c.HttpOnly = false + } else { + c.HttpOnly = v + } + + case "samesite": + if v == "" { + c.SameSite = http.SameSiteDefaultMode + } else if v == "strict" { + c.SameSite = http.SameSiteStrictMode + } else if v == "lax" { + c.SameSite = http.SameSiteLaxMode + } else { + c.SameSite = http.SameSiteNoneMode + } + case "secure": + if v == "" { + c.Secure = true + } else if v, err := strconv.ParseBool(v); err != nil { + c.Secure = false + } else { + c.Secure = v + } + + case "max-age", "age": + if v == "" { + c.MaxAge = 0 + } else if v, err := strconv.Atoi(v); err != nil { + c.MaxAge = 0 + } else { + c.MaxAge = v + } + + case "expires": + if v == "" { + c.Expires = time.Now() + } else if v, err := timeParseMultiple(v, COOKIE_EXPIRE_VALID_FORMATS...); err != nil { + c.Expires = time.Now() + } else { + c.Expires = v + } + } + } + + return nil +} + +func getCookieName(v any) (name string, err error) { + defer func() { + if r := recover(); r != nil { + err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r)) + } + }() + + tag, err := getCookieTag(v) + if err != nil { + return name, err + } + + tvs := strings.Split(tag, ",") + if len(tvs) == 0 { + t := reflect.TypeOf(v) + name = t.Name() + } else { + name = tvs[0] + } + + if name == "" { + return name, ErrMissingName + } + + return name, nil +} + +func getCookieTag(v any) (t string, err error) { + defer func() { + if r := recover(); r != nil { + err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r)) + } + }() + + rt := reflect.TypeOf(v) + + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + } + + for i := 0; i < rt.NumField(); i++ { + ft := rt.Field(i) + if t := ft.Tag.Get("cookie"); t != "" { + return t, nil + } + } + + return "", nil +} + +func timeParseMultiple(v string, formats ...string) (time.Time, error) { + errs := []error{} + for _, f := range formats { + t, err := time.Parse(v, f) + if err != nil { + errs = append(errs, err) + } else { + return t, nil + } + } + + return time.Time{}, errs[len(errs)-1] +} + +var ( + ErrDecodeBase64 = errors.New("Failed to decode base64 string from cookie value") + ErrMarshal = errors.New("Failed to marhal JSON value for cookie value") + ErrUnmarshal = errors.New("Failed to unmarshal JSON value from cookie value") + ErrReflectPanic = errors.New("Reflect panic while trying to get tag from value") + ErrMissingName = errors.New("Failed to get name of cookie") +) + +type ErrNoCookie struct { + name string +} + +func (e ErrNoCookie) Error() string { + return fmt.Sprintf("Cookie \"%s\" missing from request", e.name) +} diff --git a/groute/forms/forms.go b/groute/forms/forms.go new file mode 100644 index 0000000..e93b8a6 --- /dev/null +++ b/groute/forms/forms.go @@ -0,0 +1,215 @@ +package forms + +import ( + "errors" + "fmt" + "log" + "net/http" + "reflect" + "strconv" + "strings" + + "forge.capytal.company/loreddev/x/groute/router/rerrors" +) + +type Unmarshaler interface { + UnmarshalForm(r *http.Request) error +} + +func Unmarshal(r *http.Request, v any) (err error) { + if u, ok := v.(Unmarshaler); ok { + return u.UnmarshalForm(r) + } + + defer func() { + if r := recover(); r != nil { + err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r)) + } + }() + + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + rt := rv.Type() + + for i := 0; i < rv.NumField(); i++ { + ft := rt.Field(i) + fv := rv.FieldByName(ft.Name) + + log.Print(ft.Name) + + if !fv.CanSet() { + continue + } + + // TODO: Support embedded fields + if ft.Anonymous { + continue + } + + var tv string + if t := ft.Tag.Get("form"); t != "" { + tv = t + } else if t = ft.Tag.Get("query"); t != "" { + tv = t + } else { + tv = ft.Name + } + + tvs := strings.Split(tv, ",") + + name := tvs[0] + required := false + defaultv := "" + + for _, v := range tvs { + if v == "required" { + required = true + } else if strings.HasPrefix(v, "default=") { + defaultv = strings.TrimPrefix(v, "default=") + } + } + + qv := r.FormValue(name) + if qv == "" { + if defaultv != "" { + qv = defaultv + } else if required { + return &ErrMissingRequiredValue{name} + } else { + continue + } + } + + if err := setFieldValue(fv, qv); errors.Is(err, &ErrInvalidValueType{}) { + e, _ := err.(*ErrInvalidValueType) + e.value = name + return e + } else if errors.Is(err, &ErrUnsuportedValueType{}) { + e, _ := err.(*ErrUnsuportedValueType) + e.value = name + return e + } else if err != nil { + return err + } + } + + return nil +} + +func RerrUnsmarshal(err error) rerrors.RouteError { + if e, ok := err.(*ErrMissingRequiredValue); ok { + return rerrors.MissingParameters([]string{e.value}) + } else if e, ok := err.(*ErrInvalidValueType); ok { + return rerrors.BadRequest(e.Error()) + } else { + return rerrors.InternalError(err) + } +} + +func setFieldValue(rv reflect.Value, v string) error { + switch rv.Kind() { + + case reflect.Pointer: + return setFieldValue(rv.Elem(), v) + + case reflect.String: + rv.SetString(v) + + case reflect.Bool: + if cv, err := strconv.ParseBool(v); err != nil { + return &ErrInvalidValueType{"bool", err, ""} + } else { + rv.SetBool(cv) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if cv, err := strconv.Atoi(v); err != nil { + return &ErrInvalidValueType{"int", err, ""} + } else { + rv.SetInt(int64(cv)) + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if cv, err := strconv.Atoi(v); err != nil { + return &ErrInvalidValueType{"uint", err, ""} + } else { + rv.SetUint(uint64(cv)) + } + + case reflect.Float32, reflect.Float64: + if cv, err := strconv.ParseFloat(v, 64); err != nil { + return &ErrInvalidValueType{"float64", err, ""} + } else { + rv.SetFloat(cv) + } + + case reflect.Complex64, reflect.Complex128: + if cv, err := strconv.ParseComplex(v, 128); err != nil { + return &ErrInvalidValueType{"complex128", err, ""} + } else { + rv.SetComplex(cv) + } + + // TODO: Support strucys + // TODO: Support slices + // TODO: Support maps + default: + return &ErrUnsuportedValueType{ + []string{ + "string", + "bool", + "int", "int8", "int16", "int32", "int64", + "uint", "uint8", "uint16", "uint32", "uint64", + "float32", "float64", + "complex64", "complex64", + }, + "", + } + + } + + return nil +} + +type ErrInvalidValueType struct { + expected string + err error + value string +} + +func (e ErrInvalidValueType) Error() string { + return fmt.Sprintf( + "Value \"%s\" is a invalid type, expected type \"%s\". Got err: %s", + e.value, + e.expected, + e.err.Error(), + ) +} + +type ErrUnsuportedValueType struct { + supported []string + value string +} + +func (e ErrUnsuportedValueType) Error() string { + return fmt.Sprintf( + "Value \"%s\" is a unsupported type, supported types are: \"%s\"", + e.value, + strings.Join(e.supported, ", "), + ) +} + +type ErrMissingRequiredValue struct { + value string +} + +func (e ErrMissingRequiredValue) Error() string { + return fmt.Sprintf("Required value \"%s\" missing from query", e.value) +} + +var ( + ErrParseForm = errors.New("Failed to parse form from body or query parameters") + ErrReflectPanic = errors.New("Reflect panic while trying to parse request") +) diff --git a/groute/middleware/cache.go b/groute/middleware/cache.go new file mode 100644 index 0000000..78fe291 --- /dev/null +++ b/groute/middleware/cache.go @@ -0,0 +1,12 @@ +package middleware + +import ( + "net/http" +) + +func CacheMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=604800, stale-while-revalidate=86400, public") + next.ServeHTTP(w, r) + }) +} diff --git a/groute/middleware/dev.go b/groute/middleware/dev.go new file mode 100644 index 0000000..fd95a8d --- /dev/null +++ b/groute/middleware/dev.go @@ -0,0 +1,73 @@ +package middleware + +import ( + "log/slog" + "math/rand" + "net/http" +) + +func DevMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "no-store") + next.ServeHTTP(w, r) + }) +} + +type loggerReponse struct { + http.ResponseWriter + status int +} + +func (lr *loggerReponse) WriteHeader(s int) { + lr.status = s + lr.ResponseWriter.WriteHeader(s) +} + +func NewLoggerMiddleware(l *slog.Logger) Middleware { + l = l.WithGroup("logger_middleware") + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := randHash(5) + + l.Info("NEW REQUEST", + slog.String("id", id), + slog.String("status", "xxx"), + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + ) + + lw := &loggerReponse{w, http.StatusOK} + next.ServeHTTP(lw, r) + + if lw.status >= 400 { + l.Warn("ERR REQUEST", + slog.String("id", id), + slog.Int("status", lw.status), + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + ) + return + } + + l.Info("END REQUEST", + slog.String("id", id), + slog.Int("status", lw.status), + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + ) + }) + } +} + +const HASH_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +// This is not the most performant function, as a TODO we could +// improve based on this Stackoberflow thread: +// https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go +func randHash(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = HASH_CHARS[rand.Int63()%int64(len(HASH_CHARS))] + } + return string(b) +} diff --git a/groute/middleware/middleware.go b/groute/middleware/middleware.go new file mode 100644 index 0000000..517b791 --- /dev/null +++ b/groute/middleware/middleware.go @@ -0,0 +1,108 @@ +package middleware + +import ( + "errors" + "fmt" + "io" + "net/http" + "strconv" +) + +type Middleware = func(next http.Handler) http.Handler + +type MiddlewaredReponse struct { + w http.ResponseWriter + statuses []int + bodyWrites [][]byte +} + +func NewMiddlewaredResponse(w http.ResponseWriter) *MiddlewaredReponse { + return &MiddlewaredReponse{w, []int{500}, [][]byte{[]byte("")}} +} + +func (m *MiddlewaredReponse) WriteHeader(s int) { + m.Header().Set("Status", strconv.Itoa(s)) + m.statuses = append(m.statuses, s) +} + +func (m *MiddlewaredReponse) Header() http.Header { + return m.w.Header() +} + +func (m *MiddlewaredReponse) Write(b []byte) (int, error) { + m.bodyWrites = append(m.bodyWrites, b) + return len(b), nil +} + +func (m *MiddlewaredReponse) ReallyWriteHeader() (int, error) { + status := m.statuses[len(m.statuses)-1] + m.w.WriteHeader(status) + bytes := 0 + for _, b := range m.bodyWrites { + by, err := m.w.Write(b) + if err != nil { + return bytes, errors.Join( + fmt.Errorf( + "Failed to write to response in middleware."+ + "\nStatuses are %v"+ + "\nTried to write %v bytes"+ + "\nTried to write response:\n%s", + m.statuses, bytes, string(b), + ), + err, + ) + } + bytes += by + } + + return bytes, nil +} + +type multiResponseWriter struct { + response http.ResponseWriter + writers []io.Writer +} + +func MultiResponseWriter( + w http.ResponseWriter, + writers ...io.Writer, +) http.ResponseWriter { + if mw, ok := w.(*multiResponseWriter); ok { + mw.writers = append(mw.writers, writers...) + return mw + } + + allWriters := make([]io.Writer, 0, len(writers)) + for _, iow := range writers { + if mw, ok := iow.(*multiResponseWriter); ok { + allWriters = append(allWriters, mw.writers...) + } else { + allWriters = append(allWriters, iow) + } + } + + return &multiResponseWriter{w, allWriters} +} + +func (w *multiResponseWriter) WriteHeader(status int) { + w.Header().Set("Status", strconv.Itoa(status)) + w.response.WriteHeader(status) +} + +func (w *multiResponseWriter) Write(p []byte) (int, error) { + w.WriteHeader(http.StatusOK) + for _, w := range w.writers { + n, err := w.Write(p) + if err != nil { + return n, err + } + if n != len(p) { + return n, io.ErrShortWrite + } + } + return w.response.Write(p) +} + +func (w *multiResponseWriter) Header() http.Header { + return w.response.Header() +} diff --git a/groute/router/default.go b/groute/router/default.go new file mode 100644 index 0000000..81c2ca3 --- /dev/null +++ b/groute/router/default.go @@ -0,0 +1,25 @@ +package router + +import ( + "net/http" + + "forge.capytal.company/loreddev/x/groute/middleware" +) + +var DefaultRouter = NewRouter() + +func Handle(pattern string, handler http.Handler) { + DefaultRouter.Handle(pattern, handler) +} + +func HandleFunc(pattern string, handler http.HandlerFunc) { + DefaultRouter.HandleFunc(pattern, handler) +} + +func Use(m middleware.Middleware) { + DefaultRouter.Use(m) +} + +func ServeHTTP(w http.ResponseWriter, r *http.Request) { + DefaultRouter.ServeHTTP(w, r) +} diff --git a/groute/router/rerrors/400s.go b/groute/router/rerrors/400s.go new file mode 100644 index 0000000..6463484 --- /dev/null +++ b/groute/router/rerrors/400s.go @@ -0,0 +1,43 @@ +package rerrors + +import ( + "net/http" + "strconv" +) + +func BadRequest(reason ...string) RouteError { + info := map[string]any{} + + if len(reason) == 1 { + info["reason"] = reason[0] + } else if len(reason) > 1 { + for i, r := range reason { + info["reason_"+strconv.Itoa(i)] = r + } + } + + return NewRouteError(http.StatusBadRequest, "Bad Request", info) +} + +func NotFound() RouteError { + return NewRouteError(http.StatusNotFound, "Not Found", map[string]any{}) +} + +func MissingCookies(cookies []string) RouteError { + return NewRouteError(http.StatusBadRequest, "Missing cookies", map[string]any{ + "missing_cookies": cookies, + }) +} + +func MethodNowAllowed(method string, allowedMethods []string) RouteError { + return NewRouteError(http.StatusMethodNotAllowed, "Method not allowed", map[string]any{ + "method": method, + "allowed_methods": allowedMethods, + }) +} + +func MissingParameters(params []string) RouteError { + return NewRouteError(http.StatusBadRequest, "Missing parameters", map[string]any{ + "missing_parameters": params, + }) +} diff --git a/groute/router/rerrors/500s.go b/groute/router/rerrors/500s.go new file mode 100644 index 0000000..a705a1e --- /dev/null +++ b/groute/router/rerrors/500s.go @@ -0,0 +1,14 @@ +package rerrors + +import ( + "errors" + "net/http" +) + +func InternalError(errs ...error) RouteError { + err := errors.Join(errs...) + return NewRouteError(http.StatusInternalServerError, "Internal server error", map[string]any{ + "errors": err, + "errors_desc": err.Error(), + }) +} diff --git a/groute/router/rerrors/errors.go b/groute/router/rerrors/errors.go new file mode 100644 index 0000000..deb6c34 --- /dev/null +++ b/groute/router/rerrors/errors.go @@ -0,0 +1,167 @@ +package rerrors + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "strings" + + "forge.capytal.company/loreddev/x/groute/middleware" + "github.com/a-h/templ" +) + +const ( + ERROR_MIDDLEWARE_HEADER = "XX-Error-Middleware" + ERROR_VALUE_HEADER = "X-Error-Value" +) + +type RouteError struct { + StatusCode int `json:"status_code"` + Error string `json:"error"` + Info map[string]any `json:"info"` + Endpoint string +} + +func NewRouteError(status int, error string, info ...map[string]any) RouteError { + rerr := RouteError{StatusCode: status, Error: error} + if len(info) > 0 { + rerr.Info = info[0] + } else { + rerr.Info = map[string]any{} + } + return rerr +} + +func (rerr RouteError) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if rerr.StatusCode == 0 { + rerr.StatusCode = http.StatusNotImplemented + } + + if rerr.Error == "" { + rerr.Error = "MISSING ERROR DESCRIPTION" + } + + if rerr.Info == nil { + rerr.Info = map[string]any{} + } + + j, err := json.Marshal(rerr) + if err != nil { + j, _ = json.Marshal(RouteError{ + StatusCode: http.StatusInternalServerError, + Error: "Failed to marshal error message to JSON", + Info: map[string]any{ + "source_value": fmt.Sprintf("%#v", rerr), + "error": err.Error(), + }, + }) + } + + if r.Header.Get(ERROR_MIDDLEWARE_HEADER) == "enable" && prefersHtml(r.Header) { + q := r.URL.Query() + q.Set("error", base64.URLEncoding.EncodeToString(j)) + r.URL.RawQuery = q.Encode() + + http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect) + return + } + + w.Header().Set("Content-Type", "application/json") + + w.WriteHeader(rerr.StatusCode) + if _, err = w.Write(j); err != nil { + _, _ = w.Write([]byte("Failed to write error JSON string to body")) + } +} + +type ErrorMiddlewarePage func(err RouteError) templ.Component + +type ErrorDisplayer struct { + log *slog.Logger + page ErrorMiddlewarePage +} + +func (h ErrorDisplayer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + e, err := base64.URLEncoding.DecodeString(r.URL.Query().Get("error")) + if err != nil { + h.log.Error("Failed to decode \"error\" parameter from error redirect", + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.Int("status", 0), + slog.String("data", string(e)), + ) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte( + fmt.Sprintf("Data %s\nError %s", string(e), err.Error()), + )) + return + } + + var rerr RouteError + if err := json.Unmarshal(e, &rerr); err != nil { + h.log.Error("Failed to decode \"error\" parameter from error redirect", + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.Int("status", 0), + slog.String("data", string(e)), + ) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte( + fmt.Sprintf("Data %s\nError %s", string(e), err.Error()), + )) + return + } + + if rerr.Endpoint == "" { + q := r.URL.Query() + q.Del("error") + r.URL.RawQuery = q.Encode() + + rerr.Endpoint = r.URL.String() + } + + w.WriteHeader(rerr.StatusCode) + if err := h.page(rerr).Render(r.Context(), w); err != nil { + _, _ = w.Write(e) + } +} + +func NewErrorMiddleware( + p ErrorMiddlewarePage, + l *slog.Logger, + notfound ...ErrorMiddlewarePage, +) middleware.Middleware { + var nf ErrorMiddlewarePage + if len(notfound) > 0 { + nf = notfound[0] + } else { + nf = p + } + + l = l.WithGroup("error_middleware") + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Header.Set(ERROR_MIDDLEWARE_HEADER, "enable") + + if uerr := r.URL.Query().Get("error"); uerr != "" && prefersHtml(r.Header) { + ErrorDisplayer{l, nf}.ServeHTTP(w, r) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +func prefersHtml(h http.Header) bool { + if h.Get("Accept") == "" { + return false + } + return (strings.Contains(h.Get("Accept"), "text/html") || + strings.Contains(h.Get("Accept"), "application/xhtml+xml") || + strings.Contains(h.Get("Accept"), "application/xml")) && + !strings.Contains(h.Get("Accept"), "application/json") +} diff --git a/groute/router/router.go b/groute/router/router.go new file mode 100644 index 0000000..c76d526 --- /dev/null +++ b/groute/router/router.go @@ -0,0 +1,231 @@ +package router + +import ( + "fmt" + "net/http" + "path" + "strings" + + "forge.capytal.company/loreddev/x/groute/middleware" +) + +type Router interface { + Handle(pattern string, handler http.Handler) + HandleFunc(pattern string, handler http.HandlerFunc) + + Use(middleware middleware.Middleware) + + http.Handler +} + +type RouterWithRoutes interface { + Router + Routes() []Route +} + +type RouterWithMiddlewares interface { + RouterWithRoutes + Middlewares() []middleware.Middleware +} + +type RouterWithMiddlewaresWrapper interface { + RouterWithMiddlewares + WrapMiddlewares(ms []middleware.Middleware, h http.Handler) http.Handler +} + +type Route struct { + Path string + Method string + Host string + Handler http.Handler +} + +func NewRouter(mux ...*http.ServeMux) Router { + var m *http.ServeMux + if len(mux) > 0 { + m = mux[0] + } else { + m = http.NewServeMux() + } + + return &defaultRouter{ + m, + []middleware.Middleware{}, + map[string]Route{}, + } +} + +type defaultRouter struct { + mux *http.ServeMux + middlewares []middleware.Middleware + routes map[string]Route +} + +func (r *defaultRouter) Handle(pattern string, h http.Handler) { + if sr, ok := h.(Router); ok { + r.handleRouter(pattern, sr) + } else { + r.handle(pattern, h) + } +} + +func (r *defaultRouter) HandleFunc(pattern string, hf http.HandlerFunc) { + r.handle(pattern, hf) +} + +func (r *defaultRouter) Use(m middleware.Middleware) { + r.middlewares = append(r.middlewares, m) +} + +func (r *defaultRouter) Routes() []Route { + rs := make([]Route, len(r.routes)) + i := 0 + for _, r := range r.routes { + rs[i] = r + i++ + } + return rs +} + +func (r *defaultRouter) Middlewares() []middleware.Middleware { + return r.middlewares +} + +func (r defaultRouter) WrapMiddlewares(ms []middleware.Middleware, h http.Handler) http.Handler { + hf := h + for _, m := range ms { + hf = m(hf) + } + return hf +} + +func (r *defaultRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { + r.mux.ServeHTTP(w, req) +} + +func (r defaultRouter) handle(pattern string, hf http.Handler) { + m, h, p := r.parsePattern(pattern) + rt := Route{ + Method: m, + Host: h, + Path: p, + Handler: hf, + } + r.handleRoute(rt) +} + +func (r defaultRouter) handleRouter(pattern string, rr Router) { + m, h, p := r.parsePattern(pattern) + + rs, ok := rr.(RouterWithRoutes) + if !ok { + r.handle(p, rr) + } + + routes := rs.Routes() + middlewares := []middleware.Middleware{} + if rw, ok := rs.(RouterWithMiddlewares); ok { + middlewares = rw.Middlewares() + } + + wrap := r.WrapMiddlewares + if rw, ok := rs.(RouterWithMiddlewaresWrapper); ok { + wrap = rw.WrapMiddlewares + } + + for _, route := range routes { + route.Handler = wrap(middlewares, route.Handler) + route.Path = path.Join(p, route.Path) + + if m != "" && route.Method != "" && m != route.Method { + panic( + fmt.Sprintf( + "Nested router's route has incompatible method than defined in path %q."+ + "Router's route method is %q, while path's is %q", + p, route.Method, m, + ), + ) + } + if h != "" && route.Host != "" && h != route.Host { + panic( + fmt.Sprintf( + "Nested router's route has incompatible host than defined in path %q."+ + "Router's route host is %q, while path's is %q", + p, route.Host, h, + ), + ) + } + + r.handleRoute(route) + } +} + +func (r defaultRouter) handleRoute(rt Route) { + if len(r.middlewares) > 0 { + rt.Handler = r.WrapMiddlewares(r.middlewares, rt.Handler) + } + + if rt.Path == "" || !strings.HasPrefix(rt.Path, "/") { + panic( + fmt.Sprintf( + "INVALID STATE: Path of route (%#v) does not start with a leading slash", + rt, + ), + ) + } + + p := rt.Path + if rt.Host != "" { + p = fmt.Sprintf("%s%s", rt.Host, p) + } + if rt.Method != "" { + p = fmt.Sprintf("%s %s", rt.Method, p) + } + + if !strings.HasSuffix(p, "/") { + p = fmt.Sprintf("%s/", p) + } + + r.routes[p] = rt + r.mux.Handle(p, rt.Handler) +} + +func (r *defaultRouter) parsePattern(pattern string) (method, host, p string) { + pattern = strings.TrimSpace(pattern) + + // ServerMux patterns are "[METHOD ][HOST]/[PATH]", so to parsing it, we must + // first split it between "[METHOD ][HOST]" and "[PATH]" + ps := strings.Split(pattern, "/") + + p = path.Join("/", strings.Join(ps[1:], "/")) + + // path.Join adds a trailing slash, if the original pattern doesn't has one, the parsed + // path shouldn't also + if !strings.HasSuffix(pattern, "/") { + p = strings.TrimSuffix(p, "/") + } + + // Since path.Join adds a trailing slash, it can break the {pattern...} syntax. + // So we check if it has the suffix "...}/" to see if it ends in "/{pattern...}/" + if strings.HasSuffix(p, "...}/") { + // If it does, we remove the any possible trailing slash + p = strings.TrimSuffix(p, "/") + } + + // If "[METHOD ][HOST]" is empty, we just have the path and can send it back + if ps[0] == "" { + return "", "", p + } + + // Split string again, if method is not defined, this will end up being just []string{"[HOST]"} + // since there isn't a space before the host. If there is a method defined, this will end up as + // []string{"[METHOD]","[HOST]"}, with "[HOST]" being possibly a empty string. + mh := strings.Split(ps[0], " ") + + // If slice is of length 1, this means it is []string{"[HOST]"} + if len(mh) == 1 { + return "", host, p + } + + return mh[0], mh[1], p +}