From 14377f3ff3ea63c544566ba9567019b7a428e836 Mon Sep 17 00:00:00 2001 From: Rohan Saxena Date: Sat, 20 Jun 2026 09:27:36 +0530 Subject: [PATCH] backend: portforward: Fix context key to match rest of backend --- backend/cmd/headlamp.go | 31 +++++- backend/pkg/portforward/handler.go | 52 +++------- backend/pkg/portforward/handler_test.go | 10 +- backend/pkg/portforward/handler_unit_test.go | 104 +++++++++---------- backend/pkg/portforward/internal_test.go | 29 +++--- 5 files changed, 111 insertions(+), 115 deletions(-) diff --git a/backend/cmd/headlamp.go b/backend/cmd/headlamp.go index 431fda644da..73c25caab7d 100644 --- a/backend/cmd/headlamp.go +++ b/backend/cmd/headlamp.go @@ -653,24 +653,49 @@ func createHeadlampHandler(ctx context.Context, config *HeadlampConfig) http.Han // Setup port forwarding handlers. r.HandleFunc("/clusters/{clusterName}/portforward", func(w http.ResponseWriter, r *http.Request) { + contextKey, err := config.getContextKeyForRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + portforward.StartPortForward( config.KubeConfigStore, config.Cache, config.shouldUseUnsafeServiceAccountToken(), + contextKey, w, r, ) }).Methods("POST") r.HandleFunc("/clusters/{clusterName}/portforward", func(w http.ResponseWriter, r *http.Request) { - portforward.StopOrDeletePortForward(config.Cache, w, r) + contextKey, err := config.getContextKeyForRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + portforward.StopOrDeletePortForward(config.Cache, contextKey, w, r) }).Methods("DELETE") r.HandleFunc("/clusters/{clusterName}/portforward/list", func(w http.ResponseWriter, r *http.Request) { - portforward.GetPortForwards(config.Cache, w, r) + contextKey, err := config.getContextKeyForRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + portforward.GetPortForwards(config.Cache, contextKey, w, r) }) r.HandleFunc("/clusters/{clusterName}/portforward", func(w http.ResponseWriter, r *http.Request) { - portforward.GetPortForwardByID(config.Cache, w, r) + contextKey, err := config.getContextKeyForRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + portforward.GetPortForwardByID(config.Cache, contextKey, w, r) }).Methods("GET") // Expose user info so the frontend can show the current user in the top bar using the per-cluster auth cookie. diff --git a/backend/pkg/portforward/handler.go b/backend/pkg/portforward/handler.go index bd3f8b0a2d0..6cb4bf29db1 100644 --- a/backend/pkg/portforward/handler.go +++ b/backend/pkg/portforward/handler.go @@ -136,6 +136,7 @@ func getFreePort() (int, error) { //nolint:funlen func StartPortForward(kubeConfigStore kubeconfig.ContextStore, cache cache.Cache[interface{}], unsafeUseServiceAccountToken bool, + contextKey string, w http.ResponseWriter, r *http.Request, ) { var p portForwardRequest @@ -170,17 +171,11 @@ func StartPortForward(kubeConfigStore kubeconfig.ContextStore, cache cache.Cache p.Port = strconv.Itoa(freePort) } - userID := r.Header.Get("X-HEADLAMP-USER-ID") requestClusterName := mux.Vars(r)["clusterName"] - clusterName := requestClusterName - if userID != "" { - clusterName += userID - } - - kContext, err := kubeConfigStore.GetContext(clusterName) + kContext, err := kubeConfigStore.GetContext(contextKey) if err != nil { - logger.Log(logger.LevelError, map[string]string{"cluster": clusterName}, + logger.Log(logger.LevelError, map[string]string{"cluster": contextKey}, err, "getting kubeconfig context") http.Error(w, err.Error(), http.StatusInternalServerError) @@ -192,7 +187,7 @@ func StartPortForward(kubeConfigStore kubeconfig.ContextStore, cache cache.Cache token, _ = auth.GetTokenFromCookie(r, requestClusterName) } - err = startPortForward(kContext, cache, p, token, clusterName) + err = startPortForward(kContext, cache, p, token, contextKey) if err != nil { logger.Log(logger.LevelError, nil, err, "starting portforward") http.Error(w, err.Error(), http.StatusInternalServerError) @@ -671,7 +666,9 @@ func (r *stopOrDeletePortForwardRequest) Validate() error { } // StopOrDeletePortForward handles stop or delete port forward request. -func StopOrDeletePortForward(cache cache.Cache[interface{}], w http.ResponseWriter, r *http.Request) { +func StopOrDeletePortForward(cache cache.Cache[interface{}], contextKey string, + w http.ResponseWriter, r *http.Request, +) { var p stopOrDeletePortForwardRequest err := json.NewDecoder(r.Body).Decode(&p) @@ -689,14 +686,7 @@ func StopOrDeletePortForward(cache cache.Cache[interface{}], w http.ResponseWrit return } - userID := r.Header.Get("X-HEADLAMP-USER-ID") - clusterName := mux.Vars(r)["clusterName"] - - if userID != "" { - clusterName += userID - } - - err = stopOrDeletePortForward(cache, clusterName, p.ID, p.StopOrDelete) + err = stopOrDeletePortForward(cache, contextKey, p.ID, p.StopOrDelete) if err == nil { if _, err := w.Write([]byte("stopped")); err != nil { logger.Log(logger.LevelError, nil, err, "writing response") @@ -710,7 +700,9 @@ func StopOrDeletePortForward(cache cache.Cache[interface{}], w http.ResponseWrit } // GetPortForwards handles get port forwards request. -func GetPortForwards(cache cache.Cache[interface{}], w http.ResponseWriter, r *http.Request) { +func GetPortForwards(cache cache.Cache[interface{}], contextKey string, + w http.ResponseWriter, r *http.Request, +) { cluster := mux.Vars(r)["clusterName"] if cluster == "" { logger.Log(logger.LevelError, nil, errors.New("cluster is required"), "getting portforwards") @@ -719,14 +711,7 @@ func GetPortForwards(cache cache.Cache[interface{}], w http.ResponseWriter, r *h return } - userID := r.Header.Get("X-HEADLAMP-USER-ID") - clusterName := cluster - - if userID != "" { - clusterName = cluster + userID - } - - ports := getPortForwardList(cache, clusterName) + ports := getPortForwardList(cache, contextKey) w.Header().Set("Content-Type", "application/json") @@ -739,7 +724,9 @@ func GetPortForwards(cache cache.Cache[interface{}], w http.ResponseWriter, r *h } // GetPortForwardByID handles get port forward by id request. -func GetPortForwardByID(cache cache.Cache[interface{}], w http.ResponseWriter, r *http.Request) { +func GetPortForwardByID(cache cache.Cache[interface{}], contextKey string, + w http.ResponseWriter, r *http.Request, +) { cluster := mux.Vars(r)["clusterName"] if cluster == "" { logger.Log(logger.LevelError, nil, errors.New("cluster is required"), "getting portforward by id") @@ -756,14 +743,7 @@ func GetPortForwardByID(cache cache.Cache[interface{}], w http.ResponseWriter, r return } - userID := r.Header.Get("X-HEADLAMP-USER-ID") - clusterName := cluster - - if userID != "" { - clusterName = cluster + userID - } - - p, err := getPortForwardByID(cache, clusterName, id) + p, err := getPortForwardByID(cache, contextKey, id) if err != nil { logger.Log(logger.LevelError, nil, err, "getting portforward by id") http.Error(w, "no portforward running with id "+id, http.StatusNotFound) diff --git a/backend/pkg/portforward/handler_test.go b/backend/pkg/portforward/handler_test.go index 47b9a03fbd0..d47a56fdd18 100644 --- a/backend/pkg/portforward/handler_test.go +++ b/backend/pkg/portforward/handler_test.go @@ -133,7 +133,7 @@ func TestStartPortForward(t *testing.T) { req.Body = io.NopCloser(bytes.NewReader(jsonReq)) req.Header.Set("Content-Type", "application/json") - portforward.StartPortForward(kubeConfigStore, ch, false, resp, req) + portforward.StartPortForward(kubeConfigStore, ch, false, minikubeName, resp, req) res := resp.Result() @@ -204,7 +204,7 @@ func TestStartPortForward(t *testing.T) { stopReq.Header.Set("Content-Type", "application/json") stopReq = mux.SetURLVars(stopReq, map[string]string{"clusterName": minikubeName}) - portforward.StopOrDeletePortForward(ch, stopResp, stopReq) + portforward.StopOrDeletePortForward(ch, minikubeName, stopResp, stopReq) stopRes := stopResp.Result() @@ -230,7 +230,7 @@ func TestStartPortForward(t *testing.T) { listReq.URL = &url.URL{} listReq = mux.SetURLVars(listReq, map[string]string{"clusterName": minikubeName}) - portforward.GetPortForwards(ch, listResp, listReq) + portforward.GetPortForwards(ch, minikubeName, listResp, listReq) listRes := listResp.Result() @@ -280,7 +280,7 @@ func TestStartPortForward(t *testing.T) { getReq.URL.RawQuery = "id=" + id getReq = mux.SetURLVars(getReq, map[string]string{"clusterName": minikubeName}) - portforward.GetPortForwardByID(ch, getResp, getReq) + portforward.GetPortForwardByID(ch, minikubeName, getResp, getReq) getRes := getResp.Result() @@ -315,7 +315,7 @@ func TestStartPortForward(t *testing.T) { deleteReq.Header.Set("Content-Type", "application/json") deleteReq = mux.SetURLVars(deleteReq, map[string]string{"clusterName": minikubeName}) - portforward.StopOrDeletePortForward(ch, deleteResp, deleteReq) + portforward.StopOrDeletePortForward(ch, minikubeName, deleteResp, deleteReq) deleteRes := deleteResp.Result() diff --git a/backend/pkg/portforward/handler_unit_test.go b/backend/pkg/portforward/handler_unit_test.go index 9800a4d56fb..438724e5c1e 100644 --- a/backend/pkg/portforward/handler_unit_test.go +++ b/backend/pkg/portforward/handler_unit_test.go @@ -52,7 +52,7 @@ func TestGetPortForwards_MissingCluster(t *testing.T) { // No clusterName in mux vars — simulates a request without the route parameter. r := newRequestWithVars(http.MethodGet, "/portforward/list", nil, map[string]string{}) - portforward.GetPortForwards(ch, w, r) + portforward.GetPortForwards(ch, "", w, r) res := w.Result() @@ -74,7 +74,7 @@ func TestGetPortForwards_EmptyList(t *testing.T) { "clusterName": "test-cluster", }) - portforward.GetPortForwards(ch, w, r) + portforward.GetPortForwards(ch, "test-cluster", w, r) res := w.Result() @@ -99,7 +99,7 @@ func TestGetPortForwards_ContentTypeHeader(t *testing.T) { "clusterName": "any-cluster", }) - portforward.GetPortForwards(ch, w, r) + portforward.GetPortForwards(ch, "any-cluster", w, r) res := w.Result() @@ -119,7 +119,7 @@ func TestGetPortForwardByID_MissingCluster(t *testing.T) { r := newRequestWithVars(http.MethodGet, "/portforward?id=abc", nil, map[string]string{}) r.URL = &url.URL{RawQuery: "id=abc"} - portforward.GetPortForwardByID(ch, w, r) + portforward.GetPortForwardByID(ch, "", w, r) res := w.Result() @@ -142,7 +142,7 @@ func TestGetPortForwardByID_MissingID(t *testing.T) { }) r.URL = &url.URL{} - portforward.GetPortForwardByID(ch, w, r) + portforward.GetPortForwardByID(ch, "test-cluster", w, r) res := w.Result() @@ -165,7 +165,7 @@ func TestGetPortForwardByID_NotFound(t *testing.T) { }) r.URL = &url.URL{RawQuery: "id=nonexistent"} - portforward.GetPortForwardByID(ch, w, r) + portforward.GetPortForwardByID(ch, "test-cluster", w, r) res := w.Result() @@ -191,7 +191,7 @@ func TestStopOrDeletePortForward_InvalidJSON(t *testing.T) { "clusterName": "test-cluster", }) - portforward.StopOrDeletePortForward(ch, w, r) + portforward.StopOrDeletePortForward(ch, "test-cluster", w, r) res := w.Result() @@ -200,66 +200,60 @@ func TestStopOrDeletePortForward_InvalidJSON(t *testing.T) { assert.Equal(t, http.StatusBadRequest, res.StatusCode) } -func TestStopOrDeletePortForward_MissingID(t *testing.T) { +func TestStopOrDeletePortForward_ErrorCases(t *testing.T) { t.Parallel() - ch := cache.New[interface{}]() - w := httptest.NewRecorder() - - payload := map[string]interface{}{ - "id": "", - "stopOrDelete": true, + tests := []struct { + name string + id string + wantStatusCode int + wantBody string + }{ + { + name: "missing id", + id: "", + wantStatusCode: http.StatusBadRequest, + wantBody: "id is required", + }, + { + name: "id not found in cache", + id: "does-not-exist", + wantStatusCode: http.StatusInternalServerError, + wantBody: "failed to delete port forward", + }, } - jsonPayload, err := json.Marshal(payload) - require.NoError(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() - body := bytes.NewReader(jsonPayload) - r := newRequestWithVars(http.MethodDelete, "/portforward", body, map[string]string{ - "clusterName": "test-cluster", - }) + ch := cache.New[interface{}]() + w := httptest.NewRecorder() - portforward.StopOrDeletePortForward(ch, w, r) + payload := map[string]interface{}{ + "id": tt.id, + "stopOrDelete": true, + } - res := w.Result() + jsonPayload, err := json.Marshal(payload) + require.NoError(t, err) - defer func() { _ = res.Body.Close() }() + body := bytes.NewReader(jsonPayload) + r := newRequestWithVars(http.MethodDelete, "/portforward", body, map[string]string{ + "clusterName": "test-cluster", + }) - assert.Equal(t, http.StatusBadRequest, res.StatusCode) + portforward.StopOrDeletePortForward(ch, "test-cluster", w, r) - respBody, err := io.ReadAll(res.Body) - require.NoError(t, err) - assert.Contains(t, string(respBody), "id is required") -} + res := w.Result() -func TestStopOrDeletePortForward_NotFoundInCache(t *testing.T) { - t.Parallel() + defer func() { _ = res.Body.Close() }() - ch := cache.New[interface{}]() - w := httptest.NewRecorder() + assert.Equal(t, tt.wantStatusCode, res.StatusCode) - payload := map[string]interface{}{ - "id": "does-not-exist", - "stopOrDelete": true, + respBody, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.Contains(t, string(respBody), tt.wantBody) + }) } - - jsonPayload, err := json.Marshal(payload) - require.NoError(t, err) - - body := bytes.NewReader(jsonPayload) - r := newRequestWithVars(http.MethodDelete, "/portforward", body, map[string]string{ - "clusterName": "test-cluster", - }) - - portforward.StopOrDeletePortForward(ch, w, r) - - res := w.Result() - - defer func() { _ = res.Body.Close() }() - - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - - respBody, err := io.ReadAll(res.Body) - require.NoError(t, err) - assert.Contains(t, string(respBody), "failed to delete port forward") } diff --git a/backend/pkg/portforward/internal_test.go b/backend/pkg/portforward/internal_test.go index 7a1bb17ca22..75653ce7aa7 100644 --- a/backend/pkg/portforward/internal_test.go +++ b/backend/pkg/portforward/internal_test.go @@ -529,7 +529,7 @@ func TestGetPortForwardByID_UserIDKeyIsolation(t *testing.T) { } // TestGetPortForwardsHandler_UserIDKeyIsolation uses the exported HTTP handler -// to verify that the X-HEADLAMP-USER-ID header causes a different cache lookup. +// to verify that a different context key causes a different cache lookup. func TestGetPortForwardsHandler_UserIDKeyIsolation(t *testing.T) { c := cache.New[interface{}]() @@ -537,12 +537,12 @@ func TestGetPortForwardsHandler_UserIDKeyIsolation(t *testing.T) { pf := portForward{ID: "pf-3", Cluster: "cluster", Pod: "nginx", Namespace: "default", Status: RUNNING} portforwardstore(c, pf) - // Request WITHOUT user ID header — should return the seeded entry. + // Request with the base cluster context key — should return the seeded entry. w := httptest.NewRecorder() r := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/portforward/list", nil) r = mux.SetURLVars(r, map[string]string{"clusterName": "cluster"}) - GetPortForwards(c, w, r) + GetPortForwards(c, "cluster", w, r) res := w.Result() @@ -554,13 +554,12 @@ func TestGetPortForwardsHandler_UserIDKeyIsolation(t *testing.T) { require.NoError(t, err) assert.Contains(t, string(body), "pf-3") - // Request WITH user ID header — should return empty list. + // Request with a user-specific context key — should return empty list. w2 := httptest.NewRecorder() r2 := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/portforward/list", nil) r2 = mux.SetURLVars(r2, map[string]string{"clusterName": "cluster"}) - r2.Header.Set("X-HEADLAMP-USER-ID", "user999") - GetPortForwards(c, w2, r2) + GetPortForwards(c, "clusteruser999", w2, r2) res2 := w2.Result() @@ -574,7 +573,7 @@ func TestGetPortForwardsHandler_UserIDKeyIsolation(t *testing.T) { } // TestGetPortForwardByIDHandler_UserIDKeyIsolation uses the exported HTTP handler -// to verify that the X-HEADLAMP-USER-ID header causes a different cache lookup. +// to verify that a different context key causes a different cache lookup. func TestGetPortForwardByIDHandler_UserIDKeyIsolation(t *testing.T) { c := cache.New[interface{}]() @@ -582,13 +581,13 @@ func TestGetPortForwardByIDHandler_UserIDKeyIsolation(t *testing.T) { pf := portForward{ID: "pf-4", Cluster: "cluster", Pod: "redis", Namespace: "cache", Status: RUNNING} portforwardstore(c, pf) - // Request WITHOUT user ID header — should find the entry. + // Request with the base cluster context key — should find the entry. w := httptest.NewRecorder() r := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/portforward?id=pf-4", nil) r = mux.SetURLVars(r, map[string]string{"clusterName": "cluster"}) r.URL = &url.URL{RawQuery: "id=pf-4"} - GetPortForwardByID(c, w, r) + GetPortForwardByID(c, "cluster", w, r) res := w.Result() @@ -596,14 +595,13 @@ func TestGetPortForwardByIDHandler_UserIDKeyIsolation(t *testing.T) { assert.Equal(t, http.StatusOK, res.StatusCode) - // Request WITH user ID header — should NOT find it. + // Request with a user-specific context key — should NOT find it. w2 := httptest.NewRecorder() r2 := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/portforward?id=pf-4", nil) r2 = mux.SetURLVars(r2, map[string]string{"clusterName": "cluster"}) r2.URL = &url.URL{RawQuery: "id=pf-4"} - r2.Header.Set("X-HEADLAMP-USER-ID", "user999") - GetPortForwardByID(c, w2, r2) + GetPortForwardByID(c, "clusteruser999", w2, r2) res2 := w2.Result() @@ -614,7 +612,7 @@ func TestGetPortForwardByIDHandler_UserIDKeyIsolation(t *testing.T) { } // TestStopOrDeletePortForwardHandler_UserIDKeyIsolation verifies that -// StopOrDeletePortForward uses cluster+userID as the cache key. +// StopOrDeletePortForward uses the provided context key as the cache key. func TestStopOrDeletePortForwardHandler_UserIDKeyIsolation(t *testing.T) { c := cache.New[interface{}]() @@ -623,16 +621,15 @@ func TestStopOrDeletePortForwardHandler_UserIDKeyIsolation(t *testing.T) { pf := portForward{ID: "pf-5", Cluster: "cluster", Pod: "app", Namespace: "ns", Status: RUNNING, closeChan: ch} portforwardstore(c, pf) - // Try to stop with a user ID header — should fail because the key is different. + // Try to stop with a user-specific context key — should fail because the key is different. payload, err := json.Marshal(map[string]interface{}{"id": "pf-5", "stopOrDelete": true}) require.NoError(t, err) w := httptest.NewRecorder() r := httptest.NewRequestWithContext(context.Background(), http.MethodDelete, "/portforward", bytes.NewReader(payload)) r = mux.SetURLVars(r, map[string]string{"clusterName": "cluster"}) - r.Header.Set("X-HEADLAMP-USER-ID", "user999") - StopOrDeletePortForward(c, w, r) + StopOrDeletePortForward(c, "clusteruser999", w, r) res := w.Result()