diff --git a/internals/middlewares/cookies_crypto.go b/internals/middlewares/cookies_crypto.go index e367421..b5bf5a2 100644 --- a/internals/middlewares/cookies_crypto.go +++ b/internals/middlewares/cookies_crypto.go @@ -5,12 +5,18 @@ import ( "crypto/cipher" "crypto/rand" "encoding/base64" + "log" "net/http" "strings" ) type CookiesCryptoMiddleware struct { - Key string + key string + logger *log.Logger +} + +func NewCookiesCryptoMiddleware(key string, logger *log.Logger) CookiesCryptoMiddleware { + return CookiesCryptoMiddleware{key, logger} } func (m CookiesCryptoMiddleware) Serve(handler http.HandlerFunc) http.HandlerFunc { @@ -21,26 +27,26 @@ func (m CookiesCryptoMiddleware) Serve(handler http.HandlerFunc) http.HandlerFun } } -func (m CookiesCryptoMiddleware) encrypt(pt string) string { - aes, err := aes.NewCipher([]byte(m.Key)) +func (m CookiesCryptoMiddleware) encrypt(pt string) (string, error) { + aes, err := aes.NewCipher([]byte(m.key)) if err != nil { - panic(err) + return "", err } gcm, err := cipher.NewGCM(aes) if err != nil { - panic(err) + return "", err } nonce := make([]byte, gcm.NonceSize()) _, err = rand.Read(nonce) if err != nil { - panic(err) + return "", err } ct := gcm.Seal(nonce, nonce, []byte(pt), nil) - return base64.URLEncoding.EncodeToString(ct) + return base64.URLEncoding.EncodeToString(ct), nil } func (m CookiesCryptoMiddleware) encryptCookies(w http.ResponseWriter) { @@ -58,7 +64,14 @@ func (m CookiesCryptoMiddleware) encryptCookies(w http.ResponseWriter) { cn, v := strings.Split(c, "=")[0], strings.Split(c, "=")[1] - v = m.encrypt(strings.Trim(v, "\"")) + v, err := m.encrypt(strings.Trim(v, "\"")) + if err != nil { + m.logger.Panicf( + "ERRO: Unable to encrypt cookie \"%s\", skipping. Error: %s", + cn, err.Error(), + ) + continue + } c = cn + "=\"" + v + "\";" + attrs @@ -66,21 +79,21 @@ func (m CookiesCryptoMiddleware) encryptCookies(w http.ResponseWriter) { } } -func (m CookiesCryptoMiddleware) decrypt(ct string) string { +func (m CookiesCryptoMiddleware) decrypt(ct string) (string, error) { cb, err := base64.URLEncoding.DecodeString(ct) if err != nil { - panic(err) + return "", err } ct = string(cb) - aes, err := aes.NewCipher([]byte(m.Key)) + aes, err := aes.NewCipher([]byte(m.key)) if err != nil { - panic(err) + return "", err } gcm, err := cipher.NewGCM(aes) if err != nil { - panic(err) + return "", err } nonceSize := gcm.NonceSize() @@ -88,17 +101,25 @@ func (m CookiesCryptoMiddleware) decrypt(ct string) string { pt, err := gcm.Open(nil, []byte(nonce), []byte(ct), nil) if err != nil { - panic(err) + return "", err } - return string(pt) + return string(pt), nil } func (m CookiesCryptoMiddleware) decryptCookies(r *http.Request) { rcookies := r.Cookies() r.Header.Del("Cookie") for _, c := range rcookies { - c.Value = m.decrypt(c.Value) + cv, err := m.decrypt(c.Value) + if err != nil { + m.logger.Panicf( + "ERRO: Unable to decrypt cookie \"%s\", skipping: Error: %s", + c.Name, err.Error(), + ) + continue + } + c.Value = cv r.AddCookie(c) } } diff --git a/internals/middlewares/development.go b/internals/middlewares/development.go index 4ce5ae3..acf7875 100644 --- a/internals/middlewares/development.go +++ b/internals/middlewares/development.go @@ -6,12 +6,16 @@ import ( ) type DevelopmentMiddleware struct { - Logger *log.Logger + logger *log.Logger +} + +func NewDevelopmentMiddleware(logger *log.Logger) DevelopmentMiddleware { + return DevelopmentMiddleware{logger} } func (m DevelopmentMiddleware) Serve(handler http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - m.Logger.Printf("New request: %s", r.URL.Path) + m.logger.Printf("New request: %s", r.URL.Path) handler(w, r) diff --git a/main.go b/main.go index c68bad0..c5ca3b2 100644 --- a/main.go +++ b/main.go @@ -26,9 +26,9 @@ func main() { r := router.NewRouter(routes.ROUTES) if *dev { - r.AddMiddleware(middlewares.DevelopmentMiddleware{Logger: logger}) + r.AddMiddleware(middlewares.NewDevelopmentMiddleware(logger)) } - r.AddMiddleware(middlewares.CookiesCryptoMiddleware{os.Getenv("CRYPTO_COOKIES_KEY")}) + r.AddMiddleware(middlewares.NewCookiesCryptoMiddleware(os.Getenv("CRYPTO_COOKIES_KEY"), logger)) err := http.ListenAndServe(fmt.Sprintf(":%v", *port), r) if err != nil {