feat(cookies,router): new cookies marshaller and unmarshaller
This commit is contained in:
@@ -5,11 +5,41 @@ import (
|
||||
|
||||
"forge.capytal.company/capytalcode/project-comicverse/router/rerrors"
|
||||
"forge.capytal.company/capytalcode/project-comicverse/templates/layouts"
|
||||
"forge.capytal.company/capytalcode/project-comicverse/router/cookies"
|
||||
"errors"
|
||||
"log"
|
||||
)
|
||||
|
||||
type Dashboard struct{}
|
||||
|
||||
type DashboardCookie struct {
|
||||
Hello string `cookie:"dashboard-cookie"`
|
||||
Bool bool
|
||||
Test int
|
||||
}
|
||||
|
||||
func (p *Dashboard) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
hasCookie := true
|
||||
|
||||
var c DashboardCookie
|
||||
if err := cookies.UnmarshalRequest(r, &c); errors.Is(err, cookies.ErrNoCookie) {
|
||||
hasCookie = false
|
||||
c = DashboardCookie{Hello: "Hello world", Bool: true, Test: 69420}
|
||||
} else if err != nil {
|
||||
rerrors.InternalError(err).ServeHTTP(w, r)
|
||||
return
|
||||
} else {
|
||||
hasCookie = true
|
||||
}
|
||||
|
||||
log.Print(hasCookie, c)
|
||||
|
||||
if ck, err := cookies.Marshal(c); err != nil {
|
||||
rerrors.InternalError(err).ServeHTTP(w, r)
|
||||
} else {
|
||||
http.SetCookie(w, ck)
|
||||
}
|
||||
|
||||
if err := p.Component().Render(r.Context(), w); err != nil {
|
||||
rerrors.InternalError(err).ServeHTTP(w, r)
|
||||
return
|
||||
|
||||
242
router/cookies/cookies.go
Normal file
242
router/cookies/cookies.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package cookies
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDecodeBase64 = errors.New("Failed to decode base64 string from cookie value")
|
||||
ErrMarshal = errors.New("Failed to marhal JSON value for cookie value")
|
||||
ErrUnmarshal = errors.New("Failed to unmarshal JSON value from cookie value")
|
||||
ErrReflectPanic = errors.New("Reflect panic while trying to get tag from value")
|
||||
ErrMissingName = errors.New("Failed to get name of cookie")
|
||||
ErrNoCookie = http.ErrNoCookie
|
||||
)
|
||||
|
||||
var COOKIE_EXPIRE_VALID_FORMATS = []string{
|
||||
time.DateOnly, time.DateTime,
|
||||
time.RFC1123, time.RFC1123Z,
|
||||
}
|
||||
|
||||
func Marshal(v any) (*http.Cookie, error) {
|
||||
c, err := marshalValue(v)
|
||||
if err != nil {
|
||||
return c, err
|
||||
}
|
||||
|
||||
if err := setCookieProps(c, v); err != nil {
|
||||
return c, err
|
||||
}
|
||||
|
||||
return c, err
|
||||
}
|
||||
|
||||
func Unmarshal(c *http.Cookie, v any) error {
|
||||
value := c.Value
|
||||
b, err := base64.URLEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return errors.Join(ErrDecodeBase64, err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(b, v); err != nil {
|
||||
return errors.Join(ErrUnmarshal, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func UnmarshalRequest(r *http.Request, v any) error {
|
||||
name, err := getCookieName(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c, err := r.Cookie(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return Unmarshal(c, v)
|
||||
}
|
||||
|
||||
func marshalValue(v any) (*http.Cookie, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return &http.Cookie{}, errors.Join(ErrMarshal, err)
|
||||
}
|
||||
|
||||
s := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
return &http.Cookie{
|
||||
Value: s,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setCookieProps(c *http.Cookie, v any) error {
|
||||
tag, err := getCookieTag(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Name, err = getCookieName(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tvs := strings.Split(tag, ",")
|
||||
|
||||
if len(tvs) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tvs = tvs[1:]
|
||||
|
||||
for _, tv := range tvs {
|
||||
var k, v string
|
||||
if strings.Contains(tv, "=") {
|
||||
s := strings.Split(tv, "=")
|
||||
k = s[0]
|
||||
v = s[1]
|
||||
} else {
|
||||
k = tv
|
||||
v = ""
|
||||
}
|
||||
|
||||
switch k {
|
||||
case "SECURE":
|
||||
c.Name = "__Secure-" + c.Name
|
||||
c.Secure = true
|
||||
|
||||
case "HOST":
|
||||
c.Name = "__Host" + c.Name
|
||||
c.Secure = true
|
||||
c.Path = "/"
|
||||
|
||||
case "path":
|
||||
c.Path = v
|
||||
|
||||
case "domain":
|
||||
c.Domain = v
|
||||
|
||||
case "httponly":
|
||||
if v == "" {
|
||||
c.HttpOnly = true
|
||||
} else if v, err := strconv.ParseBool(v); err != nil {
|
||||
c.HttpOnly = false
|
||||
} else {
|
||||
c.HttpOnly = v
|
||||
}
|
||||
|
||||
case "samesite":
|
||||
if v == "" {
|
||||
c.SameSite = http.SameSiteDefaultMode
|
||||
} else if v == "strict" {
|
||||
c.SameSite = http.SameSiteStrictMode
|
||||
} else if v == "lax" {
|
||||
c.SameSite = http.SameSiteLaxMode
|
||||
} else {
|
||||
c.SameSite = http.SameSiteNoneMode
|
||||
}
|
||||
case "secure":
|
||||
if v == "" {
|
||||
c.Secure = true
|
||||
} else if v, err := strconv.ParseBool(v); err != nil {
|
||||
c.Secure = false
|
||||
} else {
|
||||
c.Secure = v
|
||||
}
|
||||
|
||||
case "max-age", "age":
|
||||
if v == "" {
|
||||
c.MaxAge = 0
|
||||
} else if v, err := strconv.Atoi(v); err != nil {
|
||||
c.MaxAge = 0
|
||||
} else {
|
||||
c.MaxAge = v
|
||||
}
|
||||
|
||||
case "expires":
|
||||
if v == "" {
|
||||
c.Expires = time.Now()
|
||||
} else if v, err := timeParseMultiple(v, COOKIE_EXPIRE_VALID_FORMATS...); err != nil {
|
||||
c.Expires = time.Now()
|
||||
} else {
|
||||
c.Expires = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getCookieName(v any) (name string, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
tag, err := getCookieTag(v)
|
||||
if err != nil {
|
||||
return name, err
|
||||
}
|
||||
|
||||
tvs := strings.Split(tag, ",")
|
||||
if len(tvs) == 0 {
|
||||
t := reflect.TypeOf(v)
|
||||
name = t.Name()
|
||||
} else {
|
||||
name = tvs[0]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return name, ErrMissingName
|
||||
}
|
||||
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func getCookieTag(v any) (t string, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
rt := reflect.TypeOf(v)
|
||||
|
||||
if rt.Kind() == reflect.Pointer {
|
||||
rt = rt.Elem()
|
||||
}
|
||||
|
||||
for i := 0; i < rt.NumField(); i++ {
|
||||
ft := rt.Field(i)
|
||||
if t := ft.Tag.Get("cookie"); t != "" {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func timeParseMultiple(v string, formats ...string) (time.Time, error) {
|
||||
errs := []error{}
|
||||
for _, f := range formats {
|
||||
t, err := time.Parse(v, f)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
} else {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
return time.Time{}, errs[len(errs)-1]
|
||||
}
|
||||
@@ -12,6 +12,12 @@ func MissingParameters(params []string) RouteError {
|
||||
})
|
||||
}
|
||||
|
||||
func MissingCookies(cookies []string) RouteError {
|
||||
return NewRouteError(http.StatusBadRequest, "Missing cookies", map[string]any{
|
||||
"missing_cookies": cookies,
|
||||
})
|
||||
}
|
||||
|
||||
func MethodNowAllowed(method string, allowedMethods []string) RouteError {
|
||||
return NewRouteError(http.StatusMethodNotAllowed, "Method not allowed", map[string]any{
|
||||
"method": method,
|
||||
|
||||
Reference in New Issue
Block a user