chore(groute): remove unused groute package
This commit is contained in:
@@ -1,298 +0,0 @@
|
||||
package cookies
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forge.capytal.company/loreddev/x/groute/router/rerrors"
|
||||
)
|
||||
|
||||
type Marshaler interface {
|
||||
MarshalCookie() (*http.Cookie, error)
|
||||
}
|
||||
|
||||
type Unmarshaler interface {
|
||||
UnmarshalCookie(*http.Cookie) error
|
||||
}
|
||||
|
||||
func Marshal(v any) (*http.Cookie, error) {
|
||||
if m, ok := v.(Marshaler); ok {
|
||||
return m.MarshalCookie()
|
||||
}
|
||||
|
||||
c, err := marshalValue(v)
|
||||
if err != nil {
|
||||
return c, err
|
||||
}
|
||||
|
||||
if err := setCookieProps(c, v); err != nil {
|
||||
return c, err
|
||||
}
|
||||
|
||||
return c, err
|
||||
}
|
||||
|
||||
func MarshalToWriter(v any, w http.ResponseWriter) error {
|
||||
if ck, err := Marshal(v); err != nil {
|
||||
return err
|
||||
} else {
|
||||
http.SetCookie(w, ck)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Unmarshal(c *http.Cookie, v any) error {
|
||||
if m, ok := v.(Unmarshaler); ok {
|
||||
return m.UnmarshalCookie(c)
|
||||
}
|
||||
|
||||
value := c.Value
|
||||
b, err := base64.URLEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return errors.Join(ErrDecodeBase64, err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(b, v); err != nil {
|
||||
return errors.Join(ErrUnmarshal, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func UnmarshalRequest(r *http.Request, v any) error {
|
||||
name, err := getCookieName(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c, err := r.Cookie(name)
|
||||
if errors.Is(err, http.ErrNoCookie) {
|
||||
return ErrNoCookie{name}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return Unmarshal(c, v)
|
||||
}
|
||||
|
||||
func UnmarshalIfRequest(r *http.Request, v any) (bool, error) {
|
||||
if err := UnmarshalRequest(r, v); err != nil {
|
||||
if _, ok := err.(ErrNoCookie); ok {
|
||||
return false, nil
|
||||
} else {
|
||||
return true, err
|
||||
}
|
||||
} else {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
func RerrUnmarshalCookie(err error) rerrors.RouteError {
|
||||
if e, ok := err.(ErrNoCookie); ok {
|
||||
return rerrors.MissingCookies([]string{e.name})
|
||||
} else {
|
||||
return rerrors.InternalError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func marshalValue(v any) (*http.Cookie, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return &http.Cookie{}, errors.Join(ErrMarshal, err)
|
||||
}
|
||||
|
||||
s := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
return &http.Cookie{
|
||||
Value: s,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var COOKIE_EXPIRE_VALID_FORMATS = []string{
|
||||
time.DateOnly, time.DateTime,
|
||||
time.RFC1123, time.RFC1123Z,
|
||||
}
|
||||
|
||||
func setCookieProps(c *http.Cookie, v any) error {
|
||||
tag, err := getCookieTag(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Name, err = getCookieName(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tvs := strings.Split(tag, ",")
|
||||
|
||||
if len(tvs) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tvs = tvs[1:]
|
||||
|
||||
for _, tv := range tvs {
|
||||
var k, v string
|
||||
if strings.Contains(tv, "=") {
|
||||
s := strings.Split(tv, "=")
|
||||
k = s[0]
|
||||
v = s[1]
|
||||
} else {
|
||||
k = tv
|
||||
v = ""
|
||||
}
|
||||
|
||||
switch k {
|
||||
case "SECURE":
|
||||
c.Name = "__Secure-" + c.Name
|
||||
c.Secure = true
|
||||
|
||||
case "HOST":
|
||||
c.Name = "__Host" + c.Name
|
||||
c.Secure = true
|
||||
c.Path = "/"
|
||||
|
||||
case "path":
|
||||
c.Path = v
|
||||
|
||||
case "domain":
|
||||
c.Domain = v
|
||||
|
||||
case "httponly":
|
||||
if v == "" {
|
||||
c.HttpOnly = true
|
||||
} else if v, err := strconv.ParseBool(v); err != nil {
|
||||
c.HttpOnly = false
|
||||
} else {
|
||||
c.HttpOnly = v
|
||||
}
|
||||
|
||||
case "samesite":
|
||||
if v == "" {
|
||||
c.SameSite = http.SameSiteDefaultMode
|
||||
} else if v == "strict" {
|
||||
c.SameSite = http.SameSiteStrictMode
|
||||
} else if v == "lax" {
|
||||
c.SameSite = http.SameSiteLaxMode
|
||||
} else {
|
||||
c.SameSite = http.SameSiteNoneMode
|
||||
}
|
||||
case "secure":
|
||||
if v == "" {
|
||||
c.Secure = true
|
||||
} else if v, err := strconv.ParseBool(v); err != nil {
|
||||
c.Secure = false
|
||||
} else {
|
||||
c.Secure = v
|
||||
}
|
||||
|
||||
case "max-age", "age":
|
||||
if v == "" {
|
||||
c.MaxAge = 0
|
||||
} else if v, err := strconv.Atoi(v); err != nil {
|
||||
c.MaxAge = 0
|
||||
} else {
|
||||
c.MaxAge = v
|
||||
}
|
||||
|
||||
case "expires":
|
||||
if v == "" {
|
||||
c.Expires = time.Now()
|
||||
} else if v, err := timeParseMultiple(v, COOKIE_EXPIRE_VALID_FORMATS...); err != nil {
|
||||
c.Expires = time.Now()
|
||||
} else {
|
||||
c.Expires = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getCookieName(v any) (name string, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
tag, err := getCookieTag(v)
|
||||
if err != nil {
|
||||
return name, err
|
||||
}
|
||||
|
||||
tvs := strings.Split(tag, ",")
|
||||
if len(tvs) == 0 {
|
||||
t := reflect.TypeOf(v)
|
||||
name = t.Name()
|
||||
} else {
|
||||
name = tvs[0]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return name, ErrMissingName
|
||||
}
|
||||
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func getCookieTag(v any) (t string, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
rt := reflect.TypeOf(v)
|
||||
|
||||
if rt.Kind() == reflect.Pointer {
|
||||
rt = rt.Elem()
|
||||
}
|
||||
|
||||
for i := 0; i < rt.NumField(); i++ {
|
||||
ft := rt.Field(i)
|
||||
if t := ft.Tag.Get("cookie"); t != "" {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func timeParseMultiple(v string, formats ...string) (time.Time, error) {
|
||||
errs := []error{}
|
||||
for _, f := range formats {
|
||||
t, err := time.Parse(v, f)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
} else {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
return time.Time{}, errs[len(errs)-1]
|
||||
}
|
||||
|
||||
var (
|
||||
ErrDecodeBase64 = errors.New("Failed to decode base64 string from cookie value")
|
||||
ErrMarshal = errors.New("Failed to marhal JSON value for cookie value")
|
||||
ErrUnmarshal = errors.New("Failed to unmarshal JSON value from cookie value")
|
||||
ErrReflectPanic = errors.New("Reflect panic while trying to get tag from value")
|
||||
ErrMissingName = errors.New("Failed to get name of cookie")
|
||||
)
|
||||
|
||||
type ErrNoCookie struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (e ErrNoCookie) Error() string {
|
||||
return fmt.Sprintf("Cookie \"%s\" missing from request", e.name)
|
||||
}
|
||||
@@ -1,215 +0,0 @@
|
||||
package forms
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"forge.capytal.company/loreddev/x/groute/router/rerrors"
|
||||
)
|
||||
|
||||
type Unmarshaler interface {
|
||||
UnmarshalForm(r *http.Request) error
|
||||
}
|
||||
|
||||
func Unmarshal(r *http.Request, v any) (err error) {
|
||||
if u, ok := v.(Unmarshaler); ok {
|
||||
return u.UnmarshalForm(r)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Pointer {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
rt := rv.Type()
|
||||
|
||||
for i := 0; i < rv.NumField(); i++ {
|
||||
ft := rt.Field(i)
|
||||
fv := rv.FieldByName(ft.Name)
|
||||
|
||||
log.Print(ft.Name)
|
||||
|
||||
if !fv.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: Support embedded fields
|
||||
if ft.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
var tv string
|
||||
if t := ft.Tag.Get("form"); t != "" {
|
||||
tv = t
|
||||
} else if t = ft.Tag.Get("query"); t != "" {
|
||||
tv = t
|
||||
} else {
|
||||
tv = ft.Name
|
||||
}
|
||||
|
||||
tvs := strings.Split(tv, ",")
|
||||
|
||||
name := tvs[0]
|
||||
required := false
|
||||
defaultv := ""
|
||||
|
||||
for _, v := range tvs {
|
||||
if v == "required" {
|
||||
required = true
|
||||
} else if strings.HasPrefix(v, "default=") {
|
||||
defaultv = strings.TrimPrefix(v, "default=")
|
||||
}
|
||||
}
|
||||
|
||||
qv := r.FormValue(name)
|
||||
if qv == "" {
|
||||
if defaultv != "" {
|
||||
qv = defaultv
|
||||
} else if required {
|
||||
return &ErrMissingRequiredValue{name}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := setFieldValue(fv, qv); errors.Is(err, &ErrInvalidValueType{}) {
|
||||
e, _ := err.(*ErrInvalidValueType)
|
||||
e.value = name
|
||||
return e
|
||||
} else if errors.Is(err, &ErrUnsuportedValueType{}) {
|
||||
e, _ := err.(*ErrUnsuportedValueType)
|
||||
e.value = name
|
||||
return e
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RerrUnsmarshal(err error) rerrors.RouteError {
|
||||
if e, ok := err.(*ErrMissingRequiredValue); ok {
|
||||
return rerrors.MissingParameters([]string{e.value})
|
||||
} else if e, ok := err.(*ErrInvalidValueType); ok {
|
||||
return rerrors.BadRequest(e.Error())
|
||||
} else {
|
||||
return rerrors.InternalError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func setFieldValue(rv reflect.Value, v string) error {
|
||||
switch rv.Kind() {
|
||||
|
||||
case reflect.Pointer:
|
||||
return setFieldValue(rv.Elem(), v)
|
||||
|
||||
case reflect.String:
|
||||
rv.SetString(v)
|
||||
|
||||
case reflect.Bool:
|
||||
if cv, err := strconv.ParseBool(v); err != nil {
|
||||
return &ErrInvalidValueType{"bool", err, ""}
|
||||
} else {
|
||||
rv.SetBool(cv)
|
||||
}
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
if cv, err := strconv.Atoi(v); err != nil {
|
||||
return &ErrInvalidValueType{"int", err, ""}
|
||||
} else {
|
||||
rv.SetInt(int64(cv))
|
||||
}
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
if cv, err := strconv.Atoi(v); err != nil {
|
||||
return &ErrInvalidValueType{"uint", err, ""}
|
||||
} else {
|
||||
rv.SetUint(uint64(cv))
|
||||
}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
if cv, err := strconv.ParseFloat(v, 64); err != nil {
|
||||
return &ErrInvalidValueType{"float64", err, ""}
|
||||
} else {
|
||||
rv.SetFloat(cv)
|
||||
}
|
||||
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
if cv, err := strconv.ParseComplex(v, 128); err != nil {
|
||||
return &ErrInvalidValueType{"complex128", err, ""}
|
||||
} else {
|
||||
rv.SetComplex(cv)
|
||||
}
|
||||
|
||||
// TODO: Support strucys
|
||||
// TODO: Support slices
|
||||
// TODO: Support maps
|
||||
default:
|
||||
return &ErrUnsuportedValueType{
|
||||
[]string{
|
||||
"string",
|
||||
"bool",
|
||||
"int", "int8", "int16", "int32", "int64",
|
||||
"uint", "uint8", "uint16", "uint32", "uint64",
|
||||
"float32", "float64",
|
||||
"complex64", "complex64",
|
||||
},
|
||||
"",
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ErrInvalidValueType struct {
|
||||
expected string
|
||||
err error
|
||||
value string
|
||||
}
|
||||
|
||||
func (e ErrInvalidValueType) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"Value \"%s\" is a invalid type, expected type \"%s\". Got err: %s",
|
||||
e.value,
|
||||
e.expected,
|
||||
e.err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
type ErrUnsuportedValueType struct {
|
||||
supported []string
|
||||
value string
|
||||
}
|
||||
|
||||
func (e ErrUnsuportedValueType) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"Value \"%s\" is a unsupported type, supported types are: \"%s\"",
|
||||
e.value,
|
||||
strings.Join(e.supported, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
type ErrMissingRequiredValue struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (e ErrMissingRequiredValue) Error() string {
|
||||
return fmt.Sprintf("Required value \"%s\" missing from query", e.value)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrParseForm = errors.New("Failed to parse form from body or query parameters")
|
||||
ErrReflectPanic = errors.New("Reflect panic while trying to parse request")
|
||||
)
|
||||
@@ -1,12 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func CacheMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Cache-Control", "max-age=604800, stale-while-revalidate=86400, public")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func DevMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
type loggerReponse struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (lr *loggerReponse) WriteHeader(s int) {
|
||||
lr.status = s
|
||||
lr.ResponseWriter.WriteHeader(s)
|
||||
}
|
||||
|
||||
func NewLoggerMiddleware(l *slog.Logger) Middleware {
|
||||
l = l.WithGroup("logger_middleware")
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id := randHash(5)
|
||||
|
||||
l.Info("NEW REQUEST",
|
||||
slog.String("id", id),
|
||||
slog.String("status", "xxx"),
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
)
|
||||
|
||||
lw := &loggerReponse{w, http.StatusOK}
|
||||
next.ServeHTTP(lw, r)
|
||||
|
||||
if lw.status >= 400 {
|
||||
l.Warn("ERR REQUEST",
|
||||
slog.String("id", id),
|
||||
slog.Int("status", lw.status),
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
l.Info("END REQUEST",
|
||||
slog.String("id", id),
|
||||
slog.Int("status", lw.status),
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const HASH_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
// This is not the most performant function, as a TODO we could
|
||||
// improve based on this Stackoberflow thread:
|
||||
// https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go
|
||||
func randHash(n int) string {
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = HASH_CHARS[rand.Int63()%int64(len(HASH_CHARS))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Middleware = func(next http.Handler) http.Handler
|
||||
|
||||
type MiddlewaredReponse struct {
|
||||
w http.ResponseWriter
|
||||
statuses []int
|
||||
bodyWrites [][]byte
|
||||
}
|
||||
|
||||
func NewMiddlewaredResponse(w http.ResponseWriter) *MiddlewaredReponse {
|
||||
return &MiddlewaredReponse{w, []int{500}, [][]byte{[]byte("")}}
|
||||
}
|
||||
|
||||
func (m *MiddlewaredReponse) WriteHeader(s int) {
|
||||
m.Header().Set("Status", strconv.Itoa(s))
|
||||
m.statuses = append(m.statuses, s)
|
||||
}
|
||||
|
||||
func (m *MiddlewaredReponse) Header() http.Header {
|
||||
return m.w.Header()
|
||||
}
|
||||
|
||||
func (m *MiddlewaredReponse) Write(b []byte) (int, error) {
|
||||
m.bodyWrites = append(m.bodyWrites, b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (m *MiddlewaredReponse) ReallyWriteHeader() (int, error) {
|
||||
status := m.statuses[len(m.statuses)-1]
|
||||
m.w.WriteHeader(status)
|
||||
bytes := 0
|
||||
for _, b := range m.bodyWrites {
|
||||
by, err := m.w.Write(b)
|
||||
if err != nil {
|
||||
return bytes, errors.Join(
|
||||
fmt.Errorf(
|
||||
"Failed to write to response in middleware."+
|
||||
"\nStatuses are %v"+
|
||||
"\nTried to write %v bytes"+
|
||||
"\nTried to write response:\n%s",
|
||||
m.statuses, bytes, string(b),
|
||||
),
|
||||
err,
|
||||
)
|
||||
}
|
||||
bytes += by
|
||||
}
|
||||
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
type multiResponseWriter struct {
|
||||
response http.ResponseWriter
|
||||
writers []io.Writer
|
||||
}
|
||||
|
||||
func MultiResponseWriter(
|
||||
w http.ResponseWriter,
|
||||
writers ...io.Writer,
|
||||
) http.ResponseWriter {
|
||||
if mw, ok := w.(*multiResponseWriter); ok {
|
||||
mw.writers = append(mw.writers, writers...)
|
||||
return mw
|
||||
}
|
||||
|
||||
allWriters := make([]io.Writer, 0, len(writers))
|
||||
for _, iow := range writers {
|
||||
if mw, ok := iow.(*multiResponseWriter); ok {
|
||||
allWriters = append(allWriters, mw.writers...)
|
||||
} else {
|
||||
allWriters = append(allWriters, iow)
|
||||
}
|
||||
}
|
||||
|
||||
return &multiResponseWriter{w, allWriters}
|
||||
}
|
||||
|
||||
func (w *multiResponseWriter) WriteHeader(status int) {
|
||||
w.Header().Set("Status", strconv.Itoa(status))
|
||||
w.response.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (w *multiResponseWriter) Write(p []byte) (int, error) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
for _, w := range w.writers {
|
||||
n, err := w.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if n != len(p) {
|
||||
return n, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
return w.response.Write(p)
|
||||
}
|
||||
|
||||
func (w *multiResponseWriter) Header() http.Header {
|
||||
return w.response.Header()
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"forge.capytal.company/loreddev/x/groute/middleware"
|
||||
)
|
||||
|
||||
var DefaultRouter = NewRouter()
|
||||
|
||||
func Handle(pattern string, handler http.Handler) {
|
||||
DefaultRouter.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func HandleFunc(pattern string, handler http.HandlerFunc) {
|
||||
DefaultRouter.HandleFunc(pattern, handler)
|
||||
}
|
||||
|
||||
func Use(m middleware.Middleware) {
|
||||
DefaultRouter.Use(m)
|
||||
}
|
||||
|
||||
func ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
DefaultRouter.ServeHTTP(w, r)
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package rerrors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func BadRequest(reason ...string) RouteError {
|
||||
info := map[string]any{}
|
||||
|
||||
if len(reason) == 1 {
|
||||
info["reason"] = reason[0]
|
||||
} else if len(reason) > 1 {
|
||||
for i, r := range reason {
|
||||
info["reason_"+strconv.Itoa(i)] = r
|
||||
}
|
||||
}
|
||||
|
||||
return NewRouteError(http.StatusBadRequest, "Bad Request", info)
|
||||
}
|
||||
|
||||
func NotFound() RouteError {
|
||||
return NewRouteError(http.StatusNotFound, "Not Found", map[string]any{})
|
||||
}
|
||||
|
||||
func MissingCookies(cookies []string) RouteError {
|
||||
return NewRouteError(http.StatusBadRequest, "Missing cookies", map[string]any{
|
||||
"missing_cookies": cookies,
|
||||
})
|
||||
}
|
||||
|
||||
func MethodNowAllowed(method string, allowedMethods []string) RouteError {
|
||||
return NewRouteError(http.StatusMethodNotAllowed, "Method not allowed", map[string]any{
|
||||
"method": method,
|
||||
"allowed_methods": allowedMethods,
|
||||
})
|
||||
}
|
||||
|
||||
func MissingParameters(params []string) RouteError {
|
||||
return NewRouteError(http.StatusBadRequest, "Missing parameters", map[string]any{
|
||||
"missing_parameters": params,
|
||||
})
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package rerrors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func InternalError(errs ...error) RouteError {
|
||||
err := errors.Join(errs...)
|
||||
return NewRouteError(http.StatusInternalServerError, "Internal server error", map[string]any{
|
||||
"errors": err,
|
||||
"errors_desc": err.Error(),
|
||||
})
|
||||
}
|
||||
@@ -1,171 +0,0 @@
|
||||
package rerrors
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"forge.capytal.company/loreddev/x/groute/middleware"
|
||||
"github.com/a-h/templ"
|
||||
)
|
||||
|
||||
const (
|
||||
ERROR_MIDDLEWARE_HEADER = "XX-Error-Middleware"
|
||||
ERROR_VALUE_HEADER = "X-Error-Value"
|
||||
)
|
||||
|
||||
type RouteError struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Err string `json:"error"`
|
||||
Info map[string]any `json:"info"`
|
||||
Endpoint string
|
||||
}
|
||||
|
||||
func NewRouteError(status int, error string, info ...map[string]any) RouteError {
|
||||
rerr := RouteError{StatusCode: status, Err: error}
|
||||
if len(info) > 0 {
|
||||
rerr.Info = info[0]
|
||||
} else {
|
||||
rerr.Info = map[string]any{}
|
||||
}
|
||||
return rerr
|
||||
}
|
||||
|
||||
func (rerr RouteError) Error() string {
|
||||
return fmt.Sprintf("route error %d %s: %v", rerr.StatusCode, rerr.Endpoint, rerr.Info)
|
||||
}
|
||||
|
||||
func (rerr RouteError) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if rerr.StatusCode == 0 {
|
||||
rerr.StatusCode = http.StatusNotImplemented
|
||||
}
|
||||
|
||||
if rerr.Err == "" {
|
||||
rerr.Err = "MISSING ERROR DESCRIPTION"
|
||||
}
|
||||
|
||||
if rerr.Info == nil {
|
||||
rerr.Info = map[string]any{}
|
||||
}
|
||||
|
||||
j, err := json.Marshal(rerr)
|
||||
if err != nil {
|
||||
j, _ = json.Marshal(RouteError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Err: "Failed to marshal error message to JSON",
|
||||
Info: map[string]any{
|
||||
"source_value": fmt.Sprintf("%#v", rerr),
|
||||
"error": err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.Header.Get(ERROR_MIDDLEWARE_HEADER) == "enable" && prefersHtml(r.Header) {
|
||||
q := r.URL.Query()
|
||||
q.Set("error", base64.URLEncoding.EncodeToString(j))
|
||||
r.URL.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
w.WriteHeader(rerr.StatusCode)
|
||||
if _, err = w.Write(j); err != nil {
|
||||
_, _ = w.Write([]byte("Failed to write error JSON string to body"))
|
||||
}
|
||||
}
|
||||
|
||||
type ErrorMiddlewarePage func(err RouteError) templ.Component
|
||||
|
||||
type ErrorDisplayer struct {
|
||||
log *slog.Logger
|
||||
page ErrorMiddlewarePage
|
||||
}
|
||||
|
||||
func (h ErrorDisplayer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
e, err := base64.URLEncoding.DecodeString(r.URL.Query().Get("error"))
|
||||
if err != nil {
|
||||
h.log.Error("Failed to decode \"error\" parameter from error redirect",
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
slog.Int("status", 0),
|
||||
slog.String("data", string(e)),
|
||||
)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(
|
||||
fmt.Sprintf("Data %s\nError %s", string(e), err.Error()),
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var rerr RouteError
|
||||
if err := json.Unmarshal(e, &rerr); err != nil {
|
||||
h.log.Error("Failed to decode \"error\" parameter from error redirect",
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
slog.Int("status", 0),
|
||||
slog.String("data", string(e)),
|
||||
)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(
|
||||
fmt.Sprintf("Data %s\nError %s", string(e), err.Error()),
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
if rerr.Endpoint == "" {
|
||||
q := r.URL.Query()
|
||||
q.Del("error")
|
||||
r.URL.RawQuery = q.Encode()
|
||||
|
||||
rerr.Endpoint = r.URL.String()
|
||||
}
|
||||
|
||||
w.WriteHeader(rerr.StatusCode)
|
||||
if err := h.page(rerr).Render(r.Context(), w); err != nil {
|
||||
_, _ = w.Write(e)
|
||||
}
|
||||
}
|
||||
|
||||
func NewErrorMiddleware(
|
||||
p ErrorMiddlewarePage,
|
||||
l *slog.Logger,
|
||||
notfound ...ErrorMiddlewarePage,
|
||||
) middleware.Middleware {
|
||||
var nf ErrorMiddlewarePage
|
||||
if len(notfound) > 0 {
|
||||
nf = notfound[0]
|
||||
} else {
|
||||
nf = p
|
||||
}
|
||||
|
||||
l = l.WithGroup("error_middleware")
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Header.Set(ERROR_MIDDLEWARE_HEADER, "enable")
|
||||
|
||||
if uerr := r.URL.Query().Get("error"); uerr != "" && prefersHtml(r.Header) {
|
||||
ErrorDisplayer{l, nf}.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func prefersHtml(h http.Header) bool {
|
||||
if h.Get("Accept") == "" {
|
||||
return false
|
||||
}
|
||||
return (strings.Contains(h.Get("Accept"), "text/html") ||
|
||||
strings.Contains(h.Get("Accept"), "application/xhtml+xml") ||
|
||||
strings.Contains(h.Get("Accept"), "application/xml")) &&
|
||||
!strings.Contains(h.Get("Accept"), "application/json")
|
||||
}
|
||||
@@ -1,222 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"forge.capytal.company/loreddev/x/groute/middleware"
|
||||
)
|
||||
|
||||
type Router interface {
|
||||
Handle(pattern string, handler http.Handler)
|
||||
HandleFunc(pattern string, handler http.HandlerFunc)
|
||||
|
||||
Use(middleware middleware.Middleware)
|
||||
|
||||
http.Handler
|
||||
}
|
||||
|
||||
type RouterWithRoutes interface {
|
||||
Router
|
||||
Routes() []Route
|
||||
}
|
||||
|
||||
type RouterWithMiddlewares interface {
|
||||
RouterWithRoutes
|
||||
Middlewares() []middleware.Middleware
|
||||
}
|
||||
|
||||
type RouterWithMiddlewaresWrapper interface {
|
||||
RouterWithMiddlewares
|
||||
WrapMiddlewares(ms []middleware.Middleware, h http.Handler) http.Handler
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
Path string
|
||||
Method string
|
||||
Host string
|
||||
Handler http.Handler
|
||||
}
|
||||
|
||||
func NewRouter(mux ...*http.ServeMux) Router {
|
||||
var m *http.ServeMux
|
||||
if len(mux) > 0 {
|
||||
m = mux[0]
|
||||
} else {
|
||||
m = http.NewServeMux()
|
||||
}
|
||||
|
||||
return &defaultRouter{
|
||||
m,
|
||||
[]middleware.Middleware{},
|
||||
map[string]Route{},
|
||||
}
|
||||
}
|
||||
|
||||
type defaultRouter struct {
|
||||
mux *http.ServeMux
|
||||
middlewares []middleware.Middleware
|
||||
routes map[string]Route
|
||||
}
|
||||
|
||||
func (r *defaultRouter) Handle(pattern string, h http.Handler) {
|
||||
if sr, ok := h.(Router); ok {
|
||||
r.handleRouter(pattern, sr)
|
||||
} else {
|
||||
r.handle(pattern, h)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *defaultRouter) HandleFunc(pattern string, hf http.HandlerFunc) {
|
||||
r.handle(pattern, hf)
|
||||
}
|
||||
|
||||
func (r *defaultRouter) Use(m middleware.Middleware) {
|
||||
r.middlewares = append(r.middlewares, m)
|
||||
}
|
||||
|
||||
func (r *defaultRouter) Routes() []Route {
|
||||
rs := make([]Route, len(r.routes))
|
||||
i := 0
|
||||
for _, r := range r.routes {
|
||||
rs[i] = r
|
||||
i++
|
||||
}
|
||||
return rs
|
||||
}
|
||||
|
||||
func (r *defaultRouter) Middlewares() []middleware.Middleware {
|
||||
return r.middlewares
|
||||
}
|
||||
|
||||
func (r defaultRouter) WrapMiddlewares(ms []middleware.Middleware, h http.Handler) http.Handler {
|
||||
hf := h
|
||||
for _, m := range ms {
|
||||
hf = m(hf)
|
||||
}
|
||||
return hf
|
||||
}
|
||||
|
||||
func (r *defaultRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
r.mux.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
func (r defaultRouter) handle(pattern string, hf http.Handler) {
|
||||
m, h, p := r.parsePattern(pattern)
|
||||
rt := Route{
|
||||
Method: m,
|
||||
Host: h,
|
||||
Path: p,
|
||||
Handler: hf,
|
||||
}
|
||||
r.handleRoute(rt)
|
||||
}
|
||||
|
||||
func (r defaultRouter) handleRouter(pattern string, rr Router) {
|
||||
m, h, p := r.parsePattern(pattern)
|
||||
|
||||
rs, ok := rr.(RouterWithRoutes)
|
||||
if !ok {
|
||||
r.handle(p, rr)
|
||||
}
|
||||
|
||||
routes := rs.Routes()
|
||||
middlewares := []middleware.Middleware{}
|
||||
if rw, ok := rs.(RouterWithMiddlewares); ok {
|
||||
middlewares = rw.Middlewares()
|
||||
}
|
||||
|
||||
wrap := r.WrapMiddlewares
|
||||
if rw, ok := rs.(RouterWithMiddlewaresWrapper); ok {
|
||||
wrap = rw.WrapMiddlewares
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
route.Handler = wrap(middlewares, route.Handler)
|
||||
route.Path = path.Join(p, route.Path)
|
||||
|
||||
if m != "" && route.Method != "" && m != route.Method {
|
||||
panic(
|
||||
fmt.Sprintf(
|
||||
"Nested router's route has incompatible method than defined in path %q."+
|
||||
"Router's route method is %q, while path's is %q",
|
||||
p, route.Method, m,
|
||||
),
|
||||
)
|
||||
}
|
||||
if h != "" && route.Host != "" && h != route.Host {
|
||||
panic(
|
||||
fmt.Sprintf(
|
||||
"Nested router's route has incompatible host than defined in path %q."+
|
||||
"Router's route host is %q, while path's is %q",
|
||||
p, route.Host, h,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
r.handleRoute(route)
|
||||
}
|
||||
}
|
||||
|
||||
func (r defaultRouter) handleRoute(rt Route) {
|
||||
if len(r.middlewares) > 0 {
|
||||
rt.Handler = r.WrapMiddlewares(r.middlewares, rt.Handler)
|
||||
}
|
||||
|
||||
if rt.Path == "" || !strings.HasPrefix(rt.Path, "/") {
|
||||
panic(
|
||||
fmt.Sprintf(
|
||||
"INVALID STATE: Path of route (%#v) does not start with a leading slash",
|
||||
rt,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
p := rt.Path
|
||||
if rt.Host != "" {
|
||||
p = fmt.Sprintf("%s%s", rt.Host, p)
|
||||
}
|
||||
if rt.Method != "" {
|
||||
p = fmt.Sprintf("%s %s", rt.Method, p)
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(p, "/") {
|
||||
p = fmt.Sprintf("%s/", p)
|
||||
}
|
||||
|
||||
if strings.HasSuffix(p, "...}/") {
|
||||
p = strings.TrimSuffix(p, "/")
|
||||
}
|
||||
|
||||
r.routes[p] = rt
|
||||
r.mux.Handle(p, rt.Handler)
|
||||
}
|
||||
|
||||
func (r *defaultRouter) parsePattern(pattern string) (method, host, p string) {
|
||||
pattern = strings.TrimSpace(pattern)
|
||||
|
||||
// ServerMux patterns are "[METHOD ][HOST]/[PATH]", so to parsing it, we must
|
||||
// first split it between "[METHOD ][HOST]" and "[PATH]"
|
||||
ps := strings.Split(pattern, "/")
|
||||
|
||||
p = path.Join("/", strings.Join(ps[1:], "/"))
|
||||
|
||||
// If "[METHOD ][HOST]" is empty, we just have the path and can send it back
|
||||
if ps[0] == "" {
|
||||
return "", "", p
|
||||
}
|
||||
|
||||
// Split string again, if method is not defined, this will end up being just []string{"[HOST]"}
|
||||
// since there isn't a space before the host. If there is a method defined, this will end up as
|
||||
// []string{"[METHOD]","[HOST]"}, with "[HOST]" being possibly a empty string.
|
||||
mh := strings.Split(ps[0], " ")
|
||||
|
||||
// If slice is of length 1, this means it is []string{"[HOST]"}
|
||||
if len(mh) == 1 {
|
||||
return "", host, p
|
||||
}
|
||||
|
||||
return mh[0], mh[1], p
|
||||
}
|
||||
Reference in New Issue
Block a user