From 049b70f23d0783c51a300da4cf87fa86e626a052 Mon Sep 17 00:00:00 2001 From: "Gustavo \"Guz\" L de Mello" Date: Mon, 24 Feb 2025 08:04:28 -0300 Subject: [PATCH] feat(smalltrip): cache middleware --- middleware/cache.go | 181 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 middleware/cache.go diff --git a/middleware/cache.go b/middleware/cache.go new file mode 100644 index 0000000..b7aef61 --- /dev/null +++ b/middleware/cache.go @@ -0,0 +1,181 @@ +package middleware + +import ( + "fmt" + "net/http" + "strings" + "time" +) + +func Cache(options ...CacheOption) Middleware { + d := defaultCacheDirectives + + for _, option := range options { + option(&d) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", d.String()) + next.ServeHTTP(w, r) + }) + } +} + + +type CacheOption func(*directives) + +func CacheMaxAge(t time.Duration) CacheOption { + return func(d *directives) { d.maxAge = &t } +} + +func CacheSMaxAge(t time.Duration) CacheOption { + return func(d *directives) { d.sMaxage = &t } +} + +func CacheNoCache(b ...bool) CacheOption { + bool := optionalTrue(b) + return func(d *directives) { d.noCache = &bool } +} + +func CacheNoStore(b ...bool) CacheOption { + bool := optionalTrue(b) + return func(d *directives) { d.noStore = &bool } +} + +func CacheNoTransform(b ...bool) CacheOption { + bool := optionalTrue(b) + return func(d *directives) { d.noTransform = &bool } +} + +func CacheMustRevalidate(b ...bool) CacheOption { + bool := optionalTrue(b) + return func(d *directives) { d.mustRevalidate = &bool } +} + +func CacheProxyRevalidate(b ...bool) CacheOption { + bool := optionalTrue(b) + return func(d *directives) { d.proxyRevalidate = &bool } +} + +func CacheMustUnderstand(b ...bool) CacheOption { + bool := optionalTrue(b) + return func(d *directives) { d.mustUnderstand = &bool } +} + +func CachePrivate(b ...bool) CacheOption { + bool := optionalTrue(b) + return func(d *directives) { d.private = &bool } +} + +func CachePublic(b ...bool) CacheOption { + bool := optionalTrue(b) + return func(d *directives) { d.public = &bool } +} + +func CacheImmutable(b ...bool) CacheOption { + bool := optionalTrue(b) + return func(d *directives) { d.immutable = &bool } +} + +func CacheStaleWhileRevalidate(t time.Duration) CacheOption { + return func(d *directives) { d.staleWhileRevalidate = &t } +} + +func CacheStaleIfError(t time.Duration) CacheOption { + return func(d *directives) { d.staleIfError = &t } +} + +func optionalTrue(b []bool) bool { + bool := true + if len(b) > 0 { + bool = b[1] + } + return bool +} + +var ( + defaultCacheDirectives = directives{ + maxAge: &day, + sMaxage: &day, + + mustRevalidate: &tru, + private: &tru, + + staleWhileRevalidate: &twoDays, + staleIfError: &twoDays, + } + tru, fals = true, false + day = time.Duration(time.Hour * 24) + twoDays = time.Duration(time.Hour * 48) +) + +type directives struct { + maxAge *time.Duration + sMaxage *time.Duration + + noCache *bool + noStore *bool + noTransform *bool + + mustRevalidate *bool + proxyRevalidate *bool + mustUnderstand *bool + + private *bool + public *bool + immutable *bool + + staleWhileRevalidate *time.Duration + staleIfError *time.Duration +} + +func (d directives) String() string { + ds := []string{} + + if d.maxAge != nil { + ds = append(ds, fmt.Sprintf("max-age=%d", d.maxAge.Seconds())) + } + if d.sMaxage != nil { + ds = append(ds, fmt.Sprintf("s-maxage=%d", d.sMaxage.Seconds())) + } + + if d.noCache != nil && *d.noCache { + ds = append(ds, "no-cache") + } + if d.noStore != nil && *d.noStore { + ds = append(ds, "no-store") + } + if d.noTransform != nil && *d.noTransform { + ds = append(ds, "no-transform") + } + + if d.mustRevalidate != nil && *d.mustRevalidate { + ds = append(ds, "must-revalidate") + } + if d.proxyRevalidate != nil && *d.proxyRevalidate { + ds = append(ds, "proxy-revalidate") + } + if d.mustUnderstand != nil && *d.mustRevalidate { + ds = append(ds, "must-understand") + } + + if d.private != nil && *d.private { + ds = append(ds, "private") + } + if d.public != nil && *d.public { + ds = append(ds, "public") + } + if d.immutable != nil && *d.immutable { + ds = append(ds, "immutable") + } + + if d.staleWhileRevalidate != nil { + ds = append(ds, fmt.Sprintf("stale-while-revalidate=%d", d.staleWhileRevalidate.Seconds())) + } + if d.staleIfError != nil { + ds = append(ds, fmt.Sprintf("stale-if-error=%d", d.staleIfError.Seconds())) + } + + return strings.Join(ds, ", ") +}