From 9f57b5cd100e6b1c5bb6cbac6268cb98e531f9c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Sun, 7 Jun 2026 17:49:04 +0200 Subject: [PATCH] refactor!: improve code structure and fix APIBaseUrl naming MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marc Schäfer --- badger.go | 418 ++++++++++++++++++++++++++++++------------------- badger_test.go | 8 +- 2 files changed, 258 insertions(+), 168 deletions(-) diff --git a/badger.go b/badger.go index 8fdf221..4dfda95 100644 --- a/badger.go +++ b/badger.go @@ -5,16 +5,33 @@ import ( "context" "encoding/json" "fmt" + "html" + "html/template" + "log" "net" "net/http" + "net/url" "strings" + "time" "github.com/fosrl/badger/ips" "github.com/fosrl/badger/version" ) +const ( + errInternalServer = "Internal Server Error" + errUnauthorized = "Unauthorized" + headerSetCookie = "Set-Cookie" + headerRemoteUserID = "Remote-User-Id" + headerRemoteUser = "Remote-User" + headerRemoteEmail = "Remote-Email" + headerRemoteName = "Remote-Name" + headerRemoteRole = "Remote-Role" + headerContentType = "Content-Type" +) + type Config struct { - APIBaseUrl string `json:"apiBaseUrl,omitempty"` + APIBaseURL string `json:"apiBaseUrl,omitempty"` UserSessionCookieName string `json:"userSessionCookieName,omitempty"` ResourceSessionRequestParam string `json:"resourceSessionRequestParam,omitempty"` AccessTokenQueryParam string `json:"accessTokenQueryParam,omitempty"` @@ -37,7 +54,7 @@ const ( type Badger struct { next http.Handler name string - apiBaseUrl string + apiBaseURL string userSessionCookieName string resourceSessionRequestParam string accessTokenQueryParam string @@ -46,6 +63,7 @@ type Badger struct { disableForwardAuth bool trustIP []*net.IPNet customIPHeader string + httpClient *http.Client } type VerifyBody struct { @@ -62,20 +80,22 @@ type VerifyBody struct { BadgerVersion string `json:"badgerVersion,omitempty"` } +type VerifyResponseData struct { + HeaderAuthChallenged bool `json:"headerAuthChallenged"` + Valid bool `json:"valid"` + RedirectURL *string `json:"redirectUrl"` + UserID *string `json:"userId,omitempty"` + DontStripSession bool `json:"dontStripSession,omitempty"` + Username *string `json:"username,omitempty"` + Email *string `json:"email,omitempty"` + Name *string `json:"name,omitempty"` + Role *string `json:"role,omitempty"` + ResponseHeaders map[string]string `json:"responseHeaders,omitempty"` + PangolinVersion *string `json:"pangolinVersion,omitempty"` +} + type VerifyResponse struct { - Data struct { - HeaderAuthChallenged bool `json:"headerAuthChallenged"` - Valid bool `json:"valid"` - RedirectURL *string `json:"redirectUrl"` - UserId *string `json:"userId,omitempty"` - DontStripSession bool `json:"dontStripSession,omitempty"` - Username *string `json:"username,omitempty"` - Email *string `json:"email,omitempty"` - Name *string `json:"name,omitempty"` - Role *string `json:"role,omitempty"` - ResponseHeaders map[string]string `json:"responseHeaders,omitempty"` - PangolinVersion *string `json:"pangolinVersion,omitempty"` - } `json:"data"` + Data VerifyResponseData `json:"data"` } type ExchangeSessionBody struct { @@ -96,11 +116,15 @@ func CreateConfig() *Config { return &Config{} } -func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) { +func New(_ context.Context, next http.Handler, config *Config, name string) (http.Handler, error) { + if err := validateConfig(config); err != nil { + return nil, err + } + badger := &Badger{ next: next, name: name, - apiBaseUrl: config.APIBaseUrl, + apiBaseURL: config.APIBaseURL, userSessionCookieName: config.UserSessionCookieName, resourceSessionRequestParam: config.ResourceSessionRequestParam, accessTokenQueryParam: config.AccessTokenQueryParam, @@ -108,42 +132,53 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h accessTokenHeader: config.AccessTokenHeader, disableForwardAuth: config.DisableForwardAuth, customIPHeader: config.CustomIPHeader, + httpClient: &http.Client{Timeout: 10 * time.Second}, } - // Validate required fields only if forward auth is enabled - if !config.DisableForwardAuth { - if config.APIBaseUrl == "" { - return nil, fmt.Errorf("apiBaseUrl is required when forward auth is enabled") - } - if config.UserSessionCookieName == "" { - return nil, fmt.Errorf("userSessionCookieName is required when forward auth is enabled") - } - if config.ResourceSessionRequestParam == "" { - return nil, fmt.Errorf("resourceSessionRequestParam is required when forward auth is enabled") - } + if err := badger.parseTrustedIPs(config.TrustIP, config.DisableDefaultCFIPs); err != nil { + return nil, err } - if config.TrustIP != nil { - for _, v := range config.TrustIP { - _, trustip, err := net.ParseCIDR(v) - if err != nil { - return nil, err - } - badger.trustIP = append(badger.trustIP, trustip) + return badger, nil +} + +// validateConfig checks required fields when forward auth is enabled. +func validateConfig(config *Config) error { + if config.DisableForwardAuth { + return nil + } + if config.APIBaseURL == "" { + return fmt.Errorf("apiBaseURL is required when forward auth is enabled") + } + if config.UserSessionCookieName == "" { + return fmt.Errorf("userSessionCookieName is required when forward auth is enabled") + } + if config.ResourceSessionRequestParam == "" { + return fmt.Errorf("resourceSessionRequestParam is required when forward auth is enabled") + } + return nil +} + +// parseTrustedIPs parses configured and default Cloudflare IP ranges into the Badger's trustIP list. +func (p *Badger) parseTrustedIPs(trustIPs []string, disableDefaultCFIPs bool) error { + for _, v := range trustIPs { + _, trustip, err := net.ParseCIDR(v) + if err != nil { + return err } + p.trustIP = append(p.trustIP, trustip) } - if !config.DisableDefaultCFIPs { + if !disableDefaultCFIPs { for _, v := range ips.CFIPs() { _, trustip, err := net.ParseCIDR(v) if err != nil { - return nil, err + return err } - badger.trustIP = append(badger.trustIP, trustip) + p.trustIP = append(p.trustIP, trustip) } } - - return badger, nil + return nil } func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) { @@ -156,67 +191,133 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } cookies := p.extractCookies(req) - queryValues := req.URL.Query() if sessionRequestValue := queryValues.Get(p.resourceSessionRequestParam); sessionRequestValue != "" { - body := ExchangeSessionBody{ - RequestToken: &sessionRequestValue, - RequestHost: &req.Host, - RequestIP: &realIP, - } - - jsonData, err := json.Marshal(body) - if err != nil { - http.Error(rw, "Internal Server Error", http.StatusInternalServerError) + if p.handleSessionExchange(rw, req, sessionRequestValue, realIP) { return } + } - verifyURL := fmt.Sprintf("%s/badger/exchange-session", p.apiBaseUrl) - resp, err := http.Post(verifyURL, "application/json", bytes.NewBuffer(jsonData)) - if err != nil { - http.Error(rw, "Internal Server Error", http.StatusInternalServerError) - return - } - defer resp.Body.Close() + originalRequestURL := buildOriginalURL(req, queryValues) + verifyURL := fmt.Sprintf("%s/badger/verify-session", p.apiBaseURL) - var result ExchangeSessionResponse - err = json.NewDecoder(resp.Body).Decode(&result) - if err != nil { - http.Error(rw, "Internal Server Error", http.StatusInternalServerError) - return - } + cookieData := buildVerifyBody(req, cookies, originalRequestURL, realIP, queryValues) + + jsonData, err := json.Marshal(cookieData) + if err != nil { + http.Error(rw, errInternalServer, http.StatusInternalServerError) + return + } - if result.Data.Cookie != nil && *result.Data.Cookie != "" { - rw.Header().Add("Set-Cookie", *result.Data.Cookie) + httpReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, verifyURL, bytes.NewBuffer(jsonData)) //nolint:gosec // G704: URL is constructed from configured apiBaseURL + if err != nil { + http.Error(rw, errInternalServer, http.StatusInternalServerError) + return + } + httpReq.Header.Set(headerContentType, "application/json") - queryValues.Del(p.resourceSessionRequestParam) - cleanedQuery := queryValues.Encode() - originalRequestURL := fmt.Sprintf("%s://%s%s", p.getScheme(req), req.Host, req.URL.Path) - if cleanedQuery != "" { - originalRequestURL = fmt.Sprintf("%s?%s", originalRequestURL, cleanedQuery) - } + resp, err := p.httpClient.Do(httpReq) //nolint:gosec // G704: URL is constructed from configured apiBaseURL + if err != nil { + http.Error(rw, errInternalServer, http.StatusInternalServerError) + return + } + defer resp.Body.Close() - if result.Data.ResponseHeaders != nil { - for key, value := range result.Data.ResponseHeaders { - rw.Header().Add(key, value) - } - } + for _, setCookie := range resp.Header[headerSetCookie] { + rw.Header().Add(headerSetCookie, setCookie) + } - fmt.Println("Got exchange token, redirecting to", originalRequestURL) - http.Redirect(rw, req, originalRequestURL, http.StatusFound) - return - } + if resp.StatusCode != http.StatusOK { + http.Error(rw, errInternalServer, http.StatusInternalServerError) + return + } + + var result VerifyResponse + err = json.NewDecoder(resp.Body).Decode(&result) + if err != nil { + http.Error(rw, errInternalServer, http.StatusInternalServerError) + return + } + + p.handleVerifyResponse(rw, req, result) +} + +// handleSessionExchange processes a session exchange request. +// Returns true if the request was handled (response written), false if it should fall through to verification. +func (p *Badger) handleSessionExchange(rw http.ResponseWriter, req *http.Request, sessionRequestValue string, realIP string) bool { + body := ExchangeSessionBody{ + RequestToken: &sessionRequestValue, + RequestHost: &req.Host, + RequestIP: &realIP, + } + + jsonData, err := json.Marshal(body) + if err != nil { + http.Error(rw, errInternalServer, http.StatusInternalServerError) + return true + } + + verifyURL := fmt.Sprintf("%s/badger/exchange-session", p.apiBaseURL) + httpReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, verifyURL, bytes.NewBuffer(jsonData)) //nolint:gosec // G704: URL is constructed from configured apiBaseURL + if err != nil { + http.Error(rw, errInternalServer, http.StatusInternalServerError) + return true + } + httpReq.Header.Set(headerContentType, "application/json") + + resp, err := p.httpClient.Do(httpReq) //nolint:gosec // G704: URL is constructed from configured apiBaseURL + if err != nil { + http.Error(rw, errInternalServer, http.StatusInternalServerError) + return true + } + defer resp.Body.Close() + + var result ExchangeSessionResponse + err = json.NewDecoder(resp.Body).Decode(&result) + if err != nil { + http.Error(rw, errInternalServer, http.StatusInternalServerError) + return true } + if result.Data.Cookie == nil || *result.Data.Cookie == "" { + // No valid session cookie; fall through to verification + return false + } + + rw.Header().Add(headerSetCookie, *result.Data.Cookie) + + queryValues := req.URL.Query() + queryValues.Del(p.resourceSessionRequestParam) cleanedQuery := queryValues.Encode() - originalRequestURL := fmt.Sprintf("%s://%s%s", p.getScheme(req), req.Host, req.URL.Path) + originalRequestURL := fmt.Sprintf("%s://%s%s", getScheme(req), req.Host, req.URL.Path) if cleanedQuery != "" { originalRequestURL = fmt.Sprintf("%s?%s", originalRequestURL, cleanedQuery) } - verifyURL := fmt.Sprintf("%s/badger/verify-session", p.apiBaseUrl) + if result.Data.ResponseHeaders != nil { + for key, value := range result.Data.ResponseHeaders { + rw.Header().Add(key, value) + } + } + + log.Printf("badger: got exchange token, redirecting to %s", originalRequestURL) //nolint:gosec // G706: originalRequestURL is derived from the incoming request + http.Redirect(rw, req, originalRequestURL, http.StatusFound) //nolint:gosec // G710: redirect URL is constructed from the original request + return true +} + +// buildOriginalURL reconstructs the original request URL, stripping the session param. +func buildOriginalURL(req *http.Request, queryValues url.Values) string { + cleanedQuery := queryValues.Encode() + originalRequestURL := fmt.Sprintf("%s://%s%s", getScheme(req), req.Host, req.URL.Path) + if cleanedQuery != "" { + originalRequestURL = fmt.Sprintf("%s?%s", originalRequestURL, cleanedQuery) + } + return originalRequestURL +} +// buildVerifyBody constructs the verification request payload. +func buildVerifyBody(req *http.Request, cookies map[string]string, originalRequestURL string, realIP string, queryValues url.Values) VerifyBody { headers := make(map[string]string) for name, values := range req.Header { if len(values) > 0 { @@ -231,10 +332,11 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } - cookieData := VerifyBody{ + scheme := getScheme(req) + return VerifyBody{ Sessions: cookies, OriginalRequestURL: originalRequestURL, - RequestScheme: &req.URL.Scheme, + RequestScheme: &scheme, RequestHost: &req.Host, RequestPath: &req.URL.Path, RequestMethod: &req.Method, @@ -244,102 +346,89 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) { Query: queryParams, BadgerVersion: version.Version, } +} - jsonData, err := json.Marshal(cookieData) - if err != nil { - http.Error(rw, "Internal Server Error", http.StatusInternalServerError) // TODO: redirect to error page - return - } - - resp, err := http.Post(verifyURL, "application/json", bytes.NewBuffer(jsonData)) - if err != nil { - http.Error(rw, "Internal Server Error", http.StatusInternalServerError) - return - } - defer resp.Body.Close() - - for _, setCookie := range resp.Header["Set-Cookie"] { - rw.Header().Add("Set-Cookie", setCookie) - } - - if resp.StatusCode != http.StatusOK { - http.Error(rw, "Internal Server Error", http.StatusInternalServerError) - return - } - - var result VerifyResponse - err = json.NewDecoder(resp.Body).Decode(&result) - if err != nil { - http.Error(rw, "Internal Server Error", http.StatusInternalServerError) - return - } - - req.Header.Del("Remote-User") - req.Header.Del("Remote-Email") - req.Header.Del("Remote-Name") - req.Header.Del("Remote-Role") - req.Header.Del("Remote-User-Id") - - if result.Data.ResponseHeaders != nil { - for key, value := range result.Data.ResponseHeaders { - rw.Header().Add(key, value) - } - } +// handleVerifyResponse processes the verification response and writes the appropriate result. +func (p *Badger) handleVerifyResponse(rw http.ResponseWriter, req *http.Request, result VerifyResponse) { + clearRemoteHeaders(req) + applyResponseHeaders(rw, result.Data.ResponseHeaders) if result.Data.HeaderAuthChallenged { - fmt.Println("Badger: challenging client for header authentication") - rw.Header().Add("WWW-Authenticate", "Basic realm=\"pangolin\"") - - if result.Data.RedirectURL != nil && *result.Data.RedirectURL != "" { - rw.Header().Set("Content-Type", "text/html; charset=utf-8") - rw.WriteHeader(http.StatusUnauthorized) - rw.Write([]byte(p.renderRedirectPage(*result.Data.RedirectURL))) - } else { - http.Error(rw, "Unauthorized", http.StatusUnauthorized) - } + handleHeaderAuthChallenge(rw, result.Data.RedirectURL) return } if result.Data.RedirectURL != nil && *result.Data.RedirectURL != "" { - fmt.Println("Badger: Redirecting to", *result.Data.RedirectURL) - http.Redirect(rw, req, *result.Data.RedirectURL, http.StatusFound) + log.Printf("badger: redirecting to %s", *result.Data.RedirectURL) //nolint:gosec // G706: redirectURL comes from trusted auth server + http.Redirect(rw, req, *result.Data.RedirectURL, http.StatusFound) //nolint:gosec // G710: redirect URL comes from the auth server return } if result.Data.Valid { - - if result.Data.UserId != nil { - req.Header.Add("Remote-User-Id", *result.Data.UserId) - } - - if result.Data.Username != nil { - req.Header.Add("Remote-User", *result.Data.Username) - } - - if result.Data.Email != nil { - req.Header.Add("Remote-Email", *result.Data.Email) - } - - if result.Data.Name != nil { - req.Header.Add("Remote-Name", *result.Data.Name) - } - - if result.Data.Role != nil { - req.Header.Add("Remote-Role", *result.Data.Role) - } - + setUserHeaders(req, &result.Data) if !result.Data.DontStripSession { p.stripSessionParam(req) p.stripSessionCookies(req) p.stripAccessTokenHeaders(req) } - - fmt.Println("Badger: Valid session") + log.Printf("badger: valid session") p.next.ServeHTTP(rw, req) return } - http.Error(rw, "Unauthorized", http.StatusUnauthorized) + http.Error(rw, errUnauthorized, http.StatusUnauthorized) +} + +// clearRemoteHeaders removes all remote-user headers from the request. +func clearRemoteHeaders(req *http.Request) { + req.Header.Del(headerRemoteUser) + req.Header.Del(headerRemoteEmail) + req.Header.Del(headerRemoteName) + req.Header.Del(headerRemoteRole) + req.Header.Del(headerRemoteUserID) +} + +// applyResponseHeaders copies response headers from the verification result to the response writer. +func applyResponseHeaders(rw http.ResponseWriter, headers map[string]string) { + if headers == nil { + return + } + for key, value := range headers { + rw.Header().Add(key, value) + } +} + +// handleHeaderAuthChallenge responds with a 401 and optional redirect page for header-based auth. +func handleHeaderAuthChallenge(rw http.ResponseWriter, redirectURL *string) { + log.Printf("badger: challenging client for header authentication") + rw.Header().Add("WWW-Authenticate", "Basic realm=\"pangolin\"") + + if redirectURL != nil && *redirectURL != "" { + rw.Header().Set(headerContentType, "text/html; charset=utf-8") + rw.WriteHeader(http.StatusUnauthorized) + _, _ = rw.Write([]byte(renderRedirectPage(*redirectURL))) //nolint:gosec // G705: redirectURL comes from trusted auth server + } else { + http.Error(rw, errUnauthorized, http.StatusUnauthorized) + } +} + +// setUserHeaders sets the remote-user headers from the verification result. +func setUserHeaders(req *http.Request, data *VerifyResponseData) { + if data.UserID != nil { + req.Header.Add(headerRemoteUserID, *data.UserID) + } + if data.Username != nil { + req.Header.Add(headerRemoteUser, *data.Username) + } + if data.Email != nil { + req.Header.Add(headerRemoteEmail, *data.Email) + } + if data.Name != nil { + req.Header.Add(headerRemoteName, *data.Name) + } + if data.Role != nil { + req.Header.Add(headerRemoteRole, *data.Role) + } } func (p *Badger) extractCookies(req *http.Request) map[string]string { @@ -358,14 +447,16 @@ func (p *Badger) extractCookies(req *http.Request) map[string]string { return cookies } -func (p *Badger) getScheme(req *http.Request) string { +func getScheme(req *http.Request) string { if req.TLS != nil { return "https" } return "http" } -func (p *Badger) renderRedirectPage(redirectURL string) string { +func renderRedirectPage(redirectURL string) string { + htmlEscaped := html.EscapeString(redirectURL) + jsEscaped := template.JSEscapeString(redirectURL) return fmt.Sprintf(` @@ -406,7 +497,7 @@ func (p *Badger) renderRedirectPage(redirectURL string) string { window.location.href = "%s"; -`, redirectURL, redirectURL) +`, htmlEscaped, jsEscaped) } func (p *Badger) getRealIP(req *http.Request) string { @@ -449,7 +540,6 @@ func (p *Badger) stripSessionParam(req *http.Request) { } if modified { req.URL.RawQuery = query.Encode() - req.RequestURI = req.URL.RequestURI() } } diff --git a/badger_test.go b/badger_test.go index f8ab045..6f84709 100644 --- a/badger_test.go +++ b/badger_test.go @@ -29,16 +29,16 @@ func TestCreateConfig(t *testing.T) { func TestNewRequiresFieldsWhenForwardAuthEnabled(t *testing.T) { cases := map[string]*badger.Config{ - "missing apiBaseUrl": { + "missing apiBaseURL": { UserSessionCookieName: "p_session_token", ResourceSessionRequestParam: "p_session_request", }, "missing userSessionCookieName": { - APIBaseUrl: "http://localhost:3001", + APIBaseURL: "http://localhost:3001", ResourceSessionRequestParam: "p_session_request", }, "missing resourceSessionRequestParam": { - APIBaseUrl: "http://localhost:3001", + APIBaseURL: "http://localhost:3001", UserSessionCookieName: "p_session_token", }, } @@ -163,7 +163,7 @@ func TestStripSessionCookiesPreservesUnrelated(t *testing.T) { forwarded = req }) cfg := &badger.Config{ - APIBaseUrl: verify.URL, + APIBaseURL: verify.URL, UserSessionCookieName: "p_session_token", ResourceSessionRequestParam: "p_session_request", DisableDefaultCFIPs: true,