diff --git a/groute/cookies/cookies.go b/groute/cookies/cookies.go deleted file mode 100644 index 4c263a3..0000000 --- a/groute/cookies/cookies.go +++ /dev/null @@ -1,298 +0,0 @@ -package cookies - -import ( - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "net/http" - "reflect" - "strconv" - "strings" - "time" - - "forge.capytal.company/loreddev/x/groute/router/rerrors" -) - -type Marshaler interface { - MarshalCookie() (*http.Cookie, error) -} - -type Unmarshaler interface { - UnmarshalCookie(*http.Cookie) error -} - -func Marshal(v any) (*http.Cookie, error) { - if m, ok := v.(Marshaler); ok { - return m.MarshalCookie() - } - - c, err := marshalValue(v) - if err != nil { - return c, err - } - - if err := setCookieProps(c, v); err != nil { - return c, err - } - - return c, err -} - -func MarshalToWriter(v any, w http.ResponseWriter) error { - if ck, err := Marshal(v); err != nil { - return err - } else { - http.SetCookie(w, ck) - } - return nil -} - -func Unmarshal(c *http.Cookie, v any) error { - if m, ok := v.(Unmarshaler); ok { - return m.UnmarshalCookie(c) - } - - value := c.Value - b, err := base64.URLEncoding.DecodeString(value) - if err != nil { - return errors.Join(ErrDecodeBase64, err) - } - - if err := json.Unmarshal(b, v); err != nil { - return errors.Join(ErrUnmarshal, err) - } - - return nil -} - -func UnmarshalRequest(r *http.Request, v any) error { - name, err := getCookieName(v) - if err != nil { - return err - } - - c, err := r.Cookie(name) - if errors.Is(err, http.ErrNoCookie) { - return ErrNoCookie{name} - } else if err != nil { - return err - } - - return Unmarshal(c, v) -} - -func UnmarshalIfRequest(r *http.Request, v any) (bool, error) { - if err := UnmarshalRequest(r, v); err != nil { - if _, ok := err.(ErrNoCookie); ok { - return false, nil - } else { - return true, err - } - } else { - return true, nil - } -} - -func RerrUnmarshalCookie(err error) rerrors.RouteError { - if e, ok := err.(ErrNoCookie); ok { - return rerrors.MissingCookies([]string{e.name}) - } else { - return rerrors.InternalError(err) - } -} - -func marshalValue(v any) (*http.Cookie, error) { - b, err := json.Marshal(v) - if err != nil { - return &http.Cookie{}, errors.Join(ErrMarshal, err) - } - - s := base64.URLEncoding.EncodeToString(b) - - return &http.Cookie{ - Value: s, - }, nil -} - -var COOKIE_EXPIRE_VALID_FORMATS = []string{ - time.DateOnly, time.DateTime, - time.RFC1123, time.RFC1123Z, -} - -func setCookieProps(c *http.Cookie, v any) error { - tag, err := getCookieTag(v) - if err != nil { - return err - } - - c.Name, err = getCookieName(v) - if err != nil { - return err - } - - tvs := strings.Split(tag, ",") - - if len(tvs) == 1 { - return nil - } - - tvs = tvs[1:] - - for _, tv := range tvs { - var k, v string - if strings.Contains(tv, "=") { - s := strings.Split(tv, "=") - k = s[0] - v = s[1] - } else { - k = tv - v = "" - } - - switch k { - case "SECURE": - c.Name = "__Secure-" + c.Name - c.Secure = true - - case "HOST": - c.Name = "__Host" + c.Name - c.Secure = true - c.Path = "/" - - case "path": - c.Path = v - - case "domain": - c.Domain = v - - case "httponly": - if v == "" { - c.HttpOnly = true - } else if v, err := strconv.ParseBool(v); err != nil { - c.HttpOnly = false - } else { - c.HttpOnly = v - } - - case "samesite": - if v == "" { - c.SameSite = http.SameSiteDefaultMode - } else if v == "strict" { - c.SameSite = http.SameSiteStrictMode - } else if v == "lax" { - c.SameSite = http.SameSiteLaxMode - } else { - c.SameSite = http.SameSiteNoneMode - } - case "secure": - if v == "" { - c.Secure = true - } else if v, err := strconv.ParseBool(v); err != nil { - c.Secure = false - } else { - c.Secure = v - } - - case "max-age", "age": - if v == "" { - c.MaxAge = 0 - } else if v, err := strconv.Atoi(v); err != nil { - c.MaxAge = 0 - } else { - c.MaxAge = v - } - - case "expires": - if v == "" { - c.Expires = time.Now() - } else if v, err := timeParseMultiple(v, COOKIE_EXPIRE_VALID_FORMATS...); err != nil { - c.Expires = time.Now() - } else { - c.Expires = v - } - } - } - - return nil -} - -func getCookieName(v any) (name string, err error) { - defer func() { - if r := recover(); r != nil { - err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r)) - } - }() - - tag, err := getCookieTag(v) - if err != nil { - return name, err - } - - tvs := strings.Split(tag, ",") - if len(tvs) == 0 { - t := reflect.TypeOf(v) - name = t.Name() - } else { - name = tvs[0] - } - - if name == "" { - return name, ErrMissingName - } - - return name, nil -} - -func getCookieTag(v any) (t string, err error) { - defer func() { - if r := recover(); r != nil { - err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r)) - } - }() - - rt := reflect.TypeOf(v) - - if rt.Kind() == reflect.Pointer { - rt = rt.Elem() - } - - for i := 0; i < rt.NumField(); i++ { - ft := rt.Field(i) - if t := ft.Tag.Get("cookie"); t != "" { - return t, nil - } - } - - return "", nil -} - -func timeParseMultiple(v string, formats ...string) (time.Time, error) { - errs := []error{} - for _, f := range formats { - t, err := time.Parse(v, f) - if err != nil { - errs = append(errs, err) - } else { - return t, nil - } - } - - return time.Time{}, errs[len(errs)-1] -} - -var ( - ErrDecodeBase64 = errors.New("Failed to decode base64 string from cookie value") - ErrMarshal = errors.New("Failed to marhal JSON value for cookie value") - ErrUnmarshal = errors.New("Failed to unmarshal JSON value from cookie value") - ErrReflectPanic = errors.New("Reflect panic while trying to get tag from value") - ErrMissingName = errors.New("Failed to get name of cookie") -) - -type ErrNoCookie struct { - name string -} - -func (e ErrNoCookie) Error() string { - return fmt.Sprintf("Cookie \"%s\" missing from request", e.name) -} diff --git a/groute/forms/forms.go b/groute/forms/forms.go deleted file mode 100644 index e93b8a6..0000000 --- a/groute/forms/forms.go +++ /dev/null @@ -1,215 +0,0 @@ -package forms - -import ( - "errors" - "fmt" - "log" - "net/http" - "reflect" - "strconv" - "strings" - - "forge.capytal.company/loreddev/x/groute/router/rerrors" -) - -type Unmarshaler interface { - UnmarshalForm(r *http.Request) error -} - -func Unmarshal(r *http.Request, v any) (err error) { - if u, ok := v.(Unmarshaler); ok { - return u.UnmarshalForm(r) - } - - defer func() { - if r := recover(); r != nil { - err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r)) - } - }() - - rv := reflect.ValueOf(v) - if rv.Kind() == reflect.Pointer { - rv = rv.Elem() - } - rt := rv.Type() - - for i := 0; i < rv.NumField(); i++ { - ft := rt.Field(i) - fv := rv.FieldByName(ft.Name) - - log.Print(ft.Name) - - if !fv.CanSet() { - continue - } - - // TODO: Support embedded fields - if ft.Anonymous { - continue - } - - var tv string - if t := ft.Tag.Get("form"); t != "" { - tv = t - } else if t = ft.Tag.Get("query"); t != "" { - tv = t - } else { - tv = ft.Name - } - - tvs := strings.Split(tv, ",") - - name := tvs[0] - required := false - defaultv := "" - - for _, v := range tvs { - if v == "required" { - required = true - } else if strings.HasPrefix(v, "default=") { - defaultv = strings.TrimPrefix(v, "default=") - } - } - - qv := r.FormValue(name) - if qv == "" { - if defaultv != "" { - qv = defaultv - } else if required { - return &ErrMissingRequiredValue{name} - } else { - continue - } - } - - if err := setFieldValue(fv, qv); errors.Is(err, &ErrInvalidValueType{}) { - e, _ := err.(*ErrInvalidValueType) - e.value = name - return e - } else if errors.Is(err, &ErrUnsuportedValueType{}) { - e, _ := err.(*ErrUnsuportedValueType) - e.value = name - return e - } else if err != nil { - return err - } - } - - return nil -} - -func RerrUnsmarshal(err error) rerrors.RouteError { - if e, ok := err.(*ErrMissingRequiredValue); ok { - return rerrors.MissingParameters([]string{e.value}) - } else if e, ok := err.(*ErrInvalidValueType); ok { - return rerrors.BadRequest(e.Error()) - } else { - return rerrors.InternalError(err) - } -} - -func setFieldValue(rv reflect.Value, v string) error { - switch rv.Kind() { - - case reflect.Pointer: - return setFieldValue(rv.Elem(), v) - - case reflect.String: - rv.SetString(v) - - case reflect.Bool: - if cv, err := strconv.ParseBool(v); err != nil { - return &ErrInvalidValueType{"bool", err, ""} - } else { - rv.SetBool(cv) - } - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if cv, err := strconv.Atoi(v); err != nil { - return &ErrInvalidValueType{"int", err, ""} - } else { - rv.SetInt(int64(cv)) - } - - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if cv, err := strconv.Atoi(v); err != nil { - return &ErrInvalidValueType{"uint", err, ""} - } else { - rv.SetUint(uint64(cv)) - } - - case reflect.Float32, reflect.Float64: - if cv, err := strconv.ParseFloat(v, 64); err != nil { - return &ErrInvalidValueType{"float64", err, ""} - } else { - rv.SetFloat(cv) - } - - case reflect.Complex64, reflect.Complex128: - if cv, err := strconv.ParseComplex(v, 128); err != nil { - return &ErrInvalidValueType{"complex128", err, ""} - } else { - rv.SetComplex(cv) - } - - // TODO: Support strucys - // TODO: Support slices - // TODO: Support maps - default: - return &ErrUnsuportedValueType{ - []string{ - "string", - "bool", - "int", "int8", "int16", "int32", "int64", - "uint", "uint8", "uint16", "uint32", "uint64", - "float32", "float64", - "complex64", "complex64", - }, - "", - } - - } - - return nil -} - -type ErrInvalidValueType struct { - expected string - err error - value string -} - -func (e ErrInvalidValueType) Error() string { - return fmt.Sprintf( - "Value \"%s\" is a invalid type, expected type \"%s\". Got err: %s", - e.value, - e.expected, - e.err.Error(), - ) -} - -type ErrUnsuportedValueType struct { - supported []string - value string -} - -func (e ErrUnsuportedValueType) Error() string { - return fmt.Sprintf( - "Value \"%s\" is a unsupported type, supported types are: \"%s\"", - e.value, - strings.Join(e.supported, ", "), - ) -} - -type ErrMissingRequiredValue struct { - value string -} - -func (e ErrMissingRequiredValue) Error() string { - return fmt.Sprintf("Required value \"%s\" missing from query", e.value) -} - -var ( - ErrParseForm = errors.New("Failed to parse form from body or query parameters") - ErrReflectPanic = errors.New("Reflect panic while trying to parse request") -) diff --git a/groute/middleware/cache.go b/groute/middleware/cache.go deleted file mode 100644 index 78fe291..0000000 --- a/groute/middleware/cache.go +++ /dev/null @@ -1,12 +0,0 @@ -package middleware - -import ( - "net/http" -) - -func CacheMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=604800, stale-while-revalidate=86400, public") - next.ServeHTTP(w, r) - }) -} diff --git a/groute/middleware/dev.go b/groute/middleware/dev.go deleted file mode 100644 index fd95a8d..0000000 --- a/groute/middleware/dev.go +++ /dev/null @@ -1,73 +0,0 @@ -package middleware - -import ( - "log/slog" - "math/rand" - "net/http" -) - -func DevMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "no-store") - next.ServeHTTP(w, r) - }) -} - -type loggerReponse struct { - http.ResponseWriter - status int -} - -func (lr *loggerReponse) WriteHeader(s int) { - lr.status = s - lr.ResponseWriter.WriteHeader(s) -} - -func NewLoggerMiddleware(l *slog.Logger) Middleware { - l = l.WithGroup("logger_middleware") - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - id := randHash(5) - - l.Info("NEW REQUEST", - slog.String("id", id), - slog.String("status", "xxx"), - slog.String("method", r.Method), - slog.String("path", r.URL.Path), - ) - - lw := &loggerReponse{w, http.StatusOK} - next.ServeHTTP(lw, r) - - if lw.status >= 400 { - l.Warn("ERR REQUEST", - slog.String("id", id), - slog.Int("status", lw.status), - slog.String("method", r.Method), - slog.String("path", r.URL.Path), - ) - return - } - - l.Info("END REQUEST", - slog.String("id", id), - slog.Int("status", lw.status), - slog.String("method", r.Method), - slog.String("path", r.URL.Path), - ) - }) - } -} - -const HASH_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - -// This is not the most performant function, as a TODO we could -// improve based on this Stackoberflow thread: -// https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go -func randHash(n int) string { - b := make([]byte, n) - for i := range b { - b[i] = HASH_CHARS[rand.Int63()%int64(len(HASH_CHARS))] - } - return string(b) -} diff --git a/groute/middleware/middleware.go b/groute/middleware/middleware.go deleted file mode 100644 index 517b791..0000000 --- a/groute/middleware/middleware.go +++ /dev/null @@ -1,108 +0,0 @@ -package middleware - -import ( - "errors" - "fmt" - "io" - "net/http" - "strconv" -) - -type Middleware = func(next http.Handler) http.Handler - -type MiddlewaredReponse struct { - w http.ResponseWriter - statuses []int - bodyWrites [][]byte -} - -func NewMiddlewaredResponse(w http.ResponseWriter) *MiddlewaredReponse { - return &MiddlewaredReponse{w, []int{500}, [][]byte{[]byte("")}} -} - -func (m *MiddlewaredReponse) WriteHeader(s int) { - m.Header().Set("Status", strconv.Itoa(s)) - m.statuses = append(m.statuses, s) -} - -func (m *MiddlewaredReponse) Header() http.Header { - return m.w.Header() -} - -func (m *MiddlewaredReponse) Write(b []byte) (int, error) { - m.bodyWrites = append(m.bodyWrites, b) - return len(b), nil -} - -func (m *MiddlewaredReponse) ReallyWriteHeader() (int, error) { - status := m.statuses[len(m.statuses)-1] - m.w.WriteHeader(status) - bytes := 0 - for _, b := range m.bodyWrites { - by, err := m.w.Write(b) - if err != nil { - return bytes, errors.Join( - fmt.Errorf( - "Failed to write to response in middleware."+ - "\nStatuses are %v"+ - "\nTried to write %v bytes"+ - "\nTried to write response:\n%s", - m.statuses, bytes, string(b), - ), - err, - ) - } - bytes += by - } - - return bytes, nil -} - -type multiResponseWriter struct { - response http.ResponseWriter - writers []io.Writer -} - -func MultiResponseWriter( - w http.ResponseWriter, - writers ...io.Writer, -) http.ResponseWriter { - if mw, ok := w.(*multiResponseWriter); ok { - mw.writers = append(mw.writers, writers...) - return mw - } - - allWriters := make([]io.Writer, 0, len(writers)) - for _, iow := range writers { - if mw, ok := iow.(*multiResponseWriter); ok { - allWriters = append(allWriters, mw.writers...) - } else { - allWriters = append(allWriters, iow) - } - } - - return &multiResponseWriter{w, allWriters} -} - -func (w *multiResponseWriter) WriteHeader(status int) { - w.Header().Set("Status", strconv.Itoa(status)) - w.response.WriteHeader(status) -} - -func (w *multiResponseWriter) Write(p []byte) (int, error) { - w.WriteHeader(http.StatusOK) - for _, w := range w.writers { - n, err := w.Write(p) - if err != nil { - return n, err - } - if n != len(p) { - return n, io.ErrShortWrite - } - } - return w.response.Write(p) -} - -func (w *multiResponseWriter) Header() http.Header { - return w.response.Header() -} diff --git a/groute/router/default.go b/groute/router/default.go deleted file mode 100644 index 81c2ca3..0000000 --- a/groute/router/default.go +++ /dev/null @@ -1,25 +0,0 @@ -package router - -import ( - "net/http" - - "forge.capytal.company/loreddev/x/groute/middleware" -) - -var DefaultRouter = NewRouter() - -func Handle(pattern string, handler http.Handler) { - DefaultRouter.Handle(pattern, handler) -} - -func HandleFunc(pattern string, handler http.HandlerFunc) { - DefaultRouter.HandleFunc(pattern, handler) -} - -func Use(m middleware.Middleware) { - DefaultRouter.Use(m) -} - -func ServeHTTP(w http.ResponseWriter, r *http.Request) { - DefaultRouter.ServeHTTP(w, r) -} diff --git a/groute/router/rerrors/400s.go b/groute/router/rerrors/400s.go deleted file mode 100644 index 6463484..0000000 --- a/groute/router/rerrors/400s.go +++ /dev/null @@ -1,43 +0,0 @@ -package rerrors - -import ( - "net/http" - "strconv" -) - -func BadRequest(reason ...string) RouteError { - info := map[string]any{} - - if len(reason) == 1 { - info["reason"] = reason[0] - } else if len(reason) > 1 { - for i, r := range reason { - info["reason_"+strconv.Itoa(i)] = r - } - } - - return NewRouteError(http.StatusBadRequest, "Bad Request", info) -} - -func NotFound() RouteError { - return NewRouteError(http.StatusNotFound, "Not Found", map[string]any{}) -} - -func MissingCookies(cookies []string) RouteError { - return NewRouteError(http.StatusBadRequest, "Missing cookies", map[string]any{ - "missing_cookies": cookies, - }) -} - -func MethodNowAllowed(method string, allowedMethods []string) RouteError { - return NewRouteError(http.StatusMethodNotAllowed, "Method not allowed", map[string]any{ - "method": method, - "allowed_methods": allowedMethods, - }) -} - -func MissingParameters(params []string) RouteError { - return NewRouteError(http.StatusBadRequest, "Missing parameters", map[string]any{ - "missing_parameters": params, - }) -} diff --git a/groute/router/rerrors/500s.go b/groute/router/rerrors/500s.go deleted file mode 100644 index a705a1e..0000000 --- a/groute/router/rerrors/500s.go +++ /dev/null @@ -1,14 +0,0 @@ -package rerrors - -import ( - "errors" - "net/http" -) - -func InternalError(errs ...error) RouteError { - err := errors.Join(errs...) - return NewRouteError(http.StatusInternalServerError, "Internal server error", map[string]any{ - "errors": err, - "errors_desc": err.Error(), - }) -} diff --git a/groute/router/rerrors/errors.go b/groute/router/rerrors/errors.go deleted file mode 100644 index ca67bd1..0000000 --- a/groute/router/rerrors/errors.go +++ /dev/null @@ -1,171 +0,0 @@ -package rerrors - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "log/slog" - "net/http" - "strings" - - "forge.capytal.company/loreddev/x/groute/middleware" - "github.com/a-h/templ" -) - -const ( - ERROR_MIDDLEWARE_HEADER = "XX-Error-Middleware" - ERROR_VALUE_HEADER = "X-Error-Value" -) - -type RouteError struct { - StatusCode int `json:"status_code"` - Err string `json:"error"` - Info map[string]any `json:"info"` - Endpoint string -} - -func NewRouteError(status int, error string, info ...map[string]any) RouteError { - rerr := RouteError{StatusCode: status, Err: error} - if len(info) > 0 { - rerr.Info = info[0] - } else { - rerr.Info = map[string]any{} - } - return rerr -} - -func (rerr RouteError) Error() string { - return fmt.Sprintf("route error %d %s: %v", rerr.StatusCode, rerr.Endpoint, rerr.Info) -} - -func (rerr RouteError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if rerr.StatusCode == 0 { - rerr.StatusCode = http.StatusNotImplemented - } - - if rerr.Err == "" { - rerr.Err = "MISSING ERROR DESCRIPTION" - } - - if rerr.Info == nil { - rerr.Info = map[string]any{} - } - - j, err := json.Marshal(rerr) - if err != nil { - j, _ = json.Marshal(RouteError{ - StatusCode: http.StatusInternalServerError, - Err: "Failed to marshal error message to JSON", - Info: map[string]any{ - "source_value": fmt.Sprintf("%#v", rerr), - "error": err.Error(), - }, - }) - } - - if r.Header.Get(ERROR_MIDDLEWARE_HEADER) == "enable" && prefersHtml(r.Header) { - q := r.URL.Query() - q.Set("error", base64.URLEncoding.EncodeToString(j)) - r.URL.RawQuery = q.Encode() - - http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect) - return - } - - w.Header().Set("Content-Type", "application/json") - - w.WriteHeader(rerr.StatusCode) - if _, err = w.Write(j); err != nil { - _, _ = w.Write([]byte("Failed to write error JSON string to body")) - } -} - -type ErrorMiddlewarePage func(err RouteError) templ.Component - -type ErrorDisplayer struct { - log *slog.Logger - page ErrorMiddlewarePage -} - -func (h ErrorDisplayer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - e, err := base64.URLEncoding.DecodeString(r.URL.Query().Get("error")) - if err != nil { - h.log.Error("Failed to decode \"error\" parameter from error redirect", - slog.String("method", r.Method), - slog.String("path", r.URL.Path), - slog.Int("status", 0), - slog.String("data", string(e)), - ) - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte( - fmt.Sprintf("Data %s\nError %s", string(e), err.Error()), - )) - return - } - - var rerr RouteError - if err := json.Unmarshal(e, &rerr); err != nil { - h.log.Error("Failed to decode \"error\" parameter from error redirect", - slog.String("method", r.Method), - slog.String("path", r.URL.Path), - slog.Int("status", 0), - slog.String("data", string(e)), - ) - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte( - fmt.Sprintf("Data %s\nError %s", string(e), err.Error()), - )) - return - } - - if rerr.Endpoint == "" { - q := r.URL.Query() - q.Del("error") - r.URL.RawQuery = q.Encode() - - rerr.Endpoint = r.URL.String() - } - - w.WriteHeader(rerr.StatusCode) - if err := h.page(rerr).Render(r.Context(), w); err != nil { - _, _ = w.Write(e) - } -} - -func NewErrorMiddleware( - p ErrorMiddlewarePage, - l *slog.Logger, - notfound ...ErrorMiddlewarePage, -) middleware.Middleware { - var nf ErrorMiddlewarePage - if len(notfound) > 0 { - nf = notfound[0] - } else { - nf = p - } - - l = l.WithGroup("error_middleware") - - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.Header.Set(ERROR_MIDDLEWARE_HEADER, "enable") - - if uerr := r.URL.Query().Get("error"); uerr != "" && prefersHtml(r.Header) { - ErrorDisplayer{l, nf}.ServeHTTP(w, r) - return - } - - next.ServeHTTP(w, r) - }) - } -} - -func prefersHtml(h http.Header) bool { - if h.Get("Accept") == "" { - return false - } - return (strings.Contains(h.Get("Accept"), "text/html") || - strings.Contains(h.Get("Accept"), "application/xhtml+xml") || - strings.Contains(h.Get("Accept"), "application/xml")) && - !strings.Contains(h.Get("Accept"), "application/json") -} diff --git a/groute/router/router.go b/groute/router/router.go deleted file mode 100644 index eb30b41..0000000 --- a/groute/router/router.go +++ /dev/null @@ -1,222 +0,0 @@ -package router - -import ( - "fmt" - "net/http" - "path" - "strings" - - "forge.capytal.company/loreddev/x/groute/middleware" -) - -type Router interface { - Handle(pattern string, handler http.Handler) - HandleFunc(pattern string, handler http.HandlerFunc) - - Use(middleware middleware.Middleware) - - http.Handler -} - -type RouterWithRoutes interface { - Router - Routes() []Route -} - -type RouterWithMiddlewares interface { - RouterWithRoutes - Middlewares() []middleware.Middleware -} - -type RouterWithMiddlewaresWrapper interface { - RouterWithMiddlewares - WrapMiddlewares(ms []middleware.Middleware, h http.Handler) http.Handler -} - -type Route struct { - Path string - Method string - Host string - Handler http.Handler -} - -func NewRouter(mux ...*http.ServeMux) Router { - var m *http.ServeMux - if len(mux) > 0 { - m = mux[0] - } else { - m = http.NewServeMux() - } - - return &defaultRouter{ - m, - []middleware.Middleware{}, - map[string]Route{}, - } -} - -type defaultRouter struct { - mux *http.ServeMux - middlewares []middleware.Middleware - routes map[string]Route -} - -func (r *defaultRouter) Handle(pattern string, h http.Handler) { - if sr, ok := h.(Router); ok { - r.handleRouter(pattern, sr) - } else { - r.handle(pattern, h) - } -} - -func (r *defaultRouter) HandleFunc(pattern string, hf http.HandlerFunc) { - r.handle(pattern, hf) -} - -func (r *defaultRouter) Use(m middleware.Middleware) { - r.middlewares = append(r.middlewares, m) -} - -func (r *defaultRouter) Routes() []Route { - rs := make([]Route, len(r.routes)) - i := 0 - for _, r := range r.routes { - rs[i] = r - i++ - } - return rs -} - -func (r *defaultRouter) Middlewares() []middleware.Middleware { - return r.middlewares -} - -func (r defaultRouter) WrapMiddlewares(ms []middleware.Middleware, h http.Handler) http.Handler { - hf := h - for _, m := range ms { - hf = m(hf) - } - return hf -} - -func (r *defaultRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { - r.mux.ServeHTTP(w, req) -} - -func (r defaultRouter) handle(pattern string, hf http.Handler) { - m, h, p := r.parsePattern(pattern) - rt := Route{ - Method: m, - Host: h, - Path: p, - Handler: hf, - } - r.handleRoute(rt) -} - -func (r defaultRouter) handleRouter(pattern string, rr Router) { - m, h, p := r.parsePattern(pattern) - - rs, ok := rr.(RouterWithRoutes) - if !ok { - r.handle(p, rr) - } - - routes := rs.Routes() - middlewares := []middleware.Middleware{} - if rw, ok := rs.(RouterWithMiddlewares); ok { - middlewares = rw.Middlewares() - } - - wrap := r.WrapMiddlewares - if rw, ok := rs.(RouterWithMiddlewaresWrapper); ok { - wrap = rw.WrapMiddlewares - } - - for _, route := range routes { - route.Handler = wrap(middlewares, route.Handler) - route.Path = path.Join(p, route.Path) - - if m != "" && route.Method != "" && m != route.Method { - panic( - fmt.Sprintf( - "Nested router's route has incompatible method than defined in path %q."+ - "Router's route method is %q, while path's is %q", - p, route.Method, m, - ), - ) - } - if h != "" && route.Host != "" && h != route.Host { - panic( - fmt.Sprintf( - "Nested router's route has incompatible host than defined in path %q."+ - "Router's route host is %q, while path's is %q", - p, route.Host, h, - ), - ) - } - - r.handleRoute(route) - } -} - -func (r defaultRouter) handleRoute(rt Route) { - if len(r.middlewares) > 0 { - rt.Handler = r.WrapMiddlewares(r.middlewares, rt.Handler) - } - - if rt.Path == "" || !strings.HasPrefix(rt.Path, "/") { - panic( - fmt.Sprintf( - "INVALID STATE: Path of route (%#v) does not start with a leading slash", - rt, - ), - ) - } - - p := rt.Path - if rt.Host != "" { - p = fmt.Sprintf("%s%s", rt.Host, p) - } - if rt.Method != "" { - p = fmt.Sprintf("%s %s", rt.Method, p) - } - - if !strings.HasSuffix(p, "/") { - p = fmt.Sprintf("%s/", p) - } - - if strings.HasSuffix(p, "...}/") { - p = strings.TrimSuffix(p, "/") - } - - r.routes[p] = rt - r.mux.Handle(p, rt.Handler) -} - -func (r *defaultRouter) parsePattern(pattern string) (method, host, p string) { - pattern = strings.TrimSpace(pattern) - - // ServerMux patterns are "[METHOD ][HOST]/[PATH]", so to parsing it, we must - // first split it between "[METHOD ][HOST]" and "[PATH]" - ps := strings.Split(pattern, "/") - - p = path.Join("/", strings.Join(ps[1:], "/")) - - // If "[METHOD ][HOST]" is empty, we just have the path and can send it back - if ps[0] == "" { - return "", "", p - } - - // Split string again, if method is not defined, this will end up being just []string{"[HOST]"} - // since there isn't a space before the host. If there is a method defined, this will end up as - // []string{"[METHOD]","[HOST]"}, with "[HOST]" being possibly a empty string. - mh := strings.Split(ps[0], " ") - - // If slice is of length 1, this means it is []string{"[HOST]"} - if len(mh) == 1 { - return "", host, p - } - - return mh[0], mh[1], p -}