Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions backend/cmd/headlamp.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,24 +653,49 @@ func createHeadlampHandler(ctx context.Context, config *HeadlampConfig) http.Han

// Setup port forwarding handlers.
r.HandleFunc("/clusters/{clusterName}/portforward", func(w http.ResponseWriter, r *http.Request) {
contextKey, err := config.getContextKeyForRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

portforward.StartPortForward(
config.KubeConfigStore,
config.Cache,
config.shouldUseUnsafeServiceAccountToken(),
contextKey,
w,
r,
)
}).Methods("POST")

r.HandleFunc("/clusters/{clusterName}/portforward", func(w http.ResponseWriter, r *http.Request) {
portforward.StopOrDeletePortForward(config.Cache, w, r)
contextKey, err := config.getContextKeyForRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

portforward.StopOrDeletePortForward(config.Cache, contextKey, w, r)
}).Methods("DELETE")

r.HandleFunc("/clusters/{clusterName}/portforward/list", func(w http.ResponseWriter, r *http.Request) {
portforward.GetPortForwards(config.Cache, w, r)
contextKey, err := config.getContextKeyForRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

portforward.GetPortForwards(config.Cache, contextKey, w, r)
})
r.HandleFunc("/clusters/{clusterName}/portforward", func(w http.ResponseWriter, r *http.Request) {
portforward.GetPortForwardByID(config.Cache, w, r)
contextKey, err := config.getContextKeyForRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

portforward.GetPortForwardByID(config.Cache, contextKey, w, r)
}).Methods("GET")

// Expose user info so the frontend can show the current user in the top bar using the per-cluster auth cookie.
Expand Down
52 changes: 16 additions & 36 deletions backend/pkg/portforward/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ func getFreePort() (int, error) {
//nolint:funlen
func StartPortForward(kubeConfigStore kubeconfig.ContextStore, cache cache.Cache[interface{}],
unsafeUseServiceAccountToken bool,
contextKey string,
w http.ResponseWriter, r *http.Request,
) {
var p portForwardRequest
Expand Down Expand Up @@ -170,17 +171,11 @@ func StartPortForward(kubeConfigStore kubeconfig.ContextStore, cache cache.Cache
p.Port = strconv.Itoa(freePort)
}

userID := r.Header.Get("X-HEADLAMP-USER-ID")
requestClusterName := mux.Vars(r)["clusterName"]
clusterName := requestClusterName

if userID != "" {
clusterName += userID
}

kContext, err := kubeConfigStore.GetContext(clusterName)
kContext, err := kubeConfigStore.GetContext(contextKey)
if err != nil {
logger.Log(logger.LevelError, map[string]string{"cluster": clusterName},
logger.Log(logger.LevelError, map[string]string{"cluster": contextKey},
err, "getting kubeconfig context")
http.Error(w, err.Error(), http.StatusInternalServerError)

Expand All @@ -192,7 +187,7 @@ func StartPortForward(kubeConfigStore kubeconfig.ContextStore, cache cache.Cache
token, _ = auth.GetTokenFromCookie(r, requestClusterName)
}

err = startPortForward(kContext, cache, p, token, clusterName)
err = startPortForward(kContext, cache, p, token, contextKey)
if err != nil {
logger.Log(logger.LevelError, nil, err, "starting portforward")
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -671,7 +666,9 @@ func (r *stopOrDeletePortForwardRequest) Validate() error {
}

// StopOrDeletePortForward handles stop or delete port forward request.
func StopOrDeletePortForward(cache cache.Cache[interface{}], w http.ResponseWriter, r *http.Request) {
func StopOrDeletePortForward(cache cache.Cache[interface{}], contextKey string,
w http.ResponseWriter, r *http.Request,
) {
var p stopOrDeletePortForwardRequest

err := json.NewDecoder(r.Body).Decode(&p)
Expand All @@ -689,14 +686,7 @@ func StopOrDeletePortForward(cache cache.Cache[interface{}], w http.ResponseWrit
return
}

userID := r.Header.Get("X-HEADLAMP-USER-ID")
clusterName := mux.Vars(r)["clusterName"]

if userID != "" {
clusterName += userID
}

err = stopOrDeletePortForward(cache, clusterName, p.ID, p.StopOrDelete)
err = stopOrDeletePortForward(cache, contextKey, p.ID, p.StopOrDelete)
if err == nil {
if _, err := w.Write([]byte("stopped")); err != nil {
logger.Log(logger.LevelError, nil, err, "writing response")
Expand All @@ -710,7 +700,9 @@ func StopOrDeletePortForward(cache cache.Cache[interface{}], w http.ResponseWrit
}

// GetPortForwards handles get port forwards request.
func GetPortForwards(cache cache.Cache[interface{}], w http.ResponseWriter, r *http.Request) {
func GetPortForwards(cache cache.Cache[interface{}], contextKey string,
w http.ResponseWriter, r *http.Request,
) {
cluster := mux.Vars(r)["clusterName"]
if cluster == "" {
logger.Log(logger.LevelError, nil, errors.New("cluster is required"), "getting portforwards")
Expand All @@ -719,14 +711,7 @@ func GetPortForwards(cache cache.Cache[interface{}], w http.ResponseWriter, r *h
return
}

userID := r.Header.Get("X-HEADLAMP-USER-ID")
clusterName := cluster

if userID != "" {
clusterName = cluster + userID
}

ports := getPortForwardList(cache, clusterName)
ports := getPortForwardList(cache, contextKey)

w.Header().Set("Content-Type", "application/json")

Expand All @@ -739,7 +724,9 @@ func GetPortForwards(cache cache.Cache[interface{}], w http.ResponseWriter, r *h
}

// GetPortForwardByID handles get port forward by id request.
func GetPortForwardByID(cache cache.Cache[interface{}], w http.ResponseWriter, r *http.Request) {
func GetPortForwardByID(cache cache.Cache[interface{}], contextKey string,
w http.ResponseWriter, r *http.Request,
) {
cluster := mux.Vars(r)["clusterName"]
if cluster == "" {
logger.Log(logger.LevelError, nil, errors.New("cluster is required"), "getting portforward by id")
Expand All @@ -756,14 +743,7 @@ func GetPortForwardByID(cache cache.Cache[interface{}], w http.ResponseWriter, r
return
}

userID := r.Header.Get("X-HEADLAMP-USER-ID")
clusterName := cluster

if userID != "" {
clusterName = cluster + userID
}

p, err := getPortForwardByID(cache, clusterName, id)
p, err := getPortForwardByID(cache, contextKey, id)
if err != nil {
logger.Log(logger.LevelError, nil, err, "getting portforward by id")
http.Error(w, "no portforward running with id "+id, http.StatusNotFound)
Comment on lines +746 to 749
Expand Down
10 changes: 5 additions & 5 deletions backend/pkg/portforward/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func TestStartPortForward(t *testing.T) {
req.Body = io.NopCloser(bytes.NewReader(jsonReq))
req.Header.Set("Content-Type", "application/json")

portforward.StartPortForward(kubeConfigStore, ch, false, resp, req)
portforward.StartPortForward(kubeConfigStore, ch, false, minikubeName, resp, req)

res := resp.Result()

Expand Down Expand Up @@ -204,7 +204,7 @@ func TestStartPortForward(t *testing.T) {
stopReq.Header.Set("Content-Type", "application/json")
stopReq = mux.SetURLVars(stopReq, map[string]string{"clusterName": minikubeName})

portforward.StopOrDeletePortForward(ch, stopResp, stopReq)
portforward.StopOrDeletePortForward(ch, minikubeName, stopResp, stopReq)

stopRes := stopResp.Result()

Expand All @@ -230,7 +230,7 @@ func TestStartPortForward(t *testing.T) {
listReq.URL = &url.URL{}
listReq = mux.SetURLVars(listReq, map[string]string{"clusterName": minikubeName})

portforward.GetPortForwards(ch, listResp, listReq)
portforward.GetPortForwards(ch, minikubeName, listResp, listReq)

listRes := listResp.Result()

Expand Down Expand Up @@ -280,7 +280,7 @@ func TestStartPortForward(t *testing.T) {
getReq.URL.RawQuery = "id=" + id
getReq = mux.SetURLVars(getReq, map[string]string{"clusterName": minikubeName})

portforward.GetPortForwardByID(ch, getResp, getReq)
portforward.GetPortForwardByID(ch, minikubeName, getResp, getReq)

getRes := getResp.Result()

Expand Down Expand Up @@ -315,7 +315,7 @@ func TestStartPortForward(t *testing.T) {
deleteReq.Header.Set("Content-Type", "application/json")
deleteReq = mux.SetURLVars(deleteReq, map[string]string{"clusterName": minikubeName})

portforward.StopOrDeletePortForward(ch, deleteResp, deleteReq)
portforward.StopOrDeletePortForward(ch, minikubeName, deleteResp, deleteReq)

deleteRes := deleteResp.Result()

Expand Down
104 changes: 49 additions & 55 deletions backend/pkg/portforward/handler_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestGetPortForwards_MissingCluster(t *testing.T) {
// No clusterName in mux vars — simulates a request without the route parameter.
r := newRequestWithVars(http.MethodGet, "/portforward/list", nil, map[string]string{})

portforward.GetPortForwards(ch, w, r)
portforward.GetPortForwards(ch, "", w, r)

res := w.Result()

Expand All @@ -74,7 +74,7 @@ func TestGetPortForwards_EmptyList(t *testing.T) {
"clusterName": "test-cluster",
})

portforward.GetPortForwards(ch, w, r)
portforward.GetPortForwards(ch, "test-cluster", w, r)

res := w.Result()

Expand All @@ -99,7 +99,7 @@ func TestGetPortForwards_ContentTypeHeader(t *testing.T) {
"clusterName": "any-cluster",
})

portforward.GetPortForwards(ch, w, r)
portforward.GetPortForwards(ch, "any-cluster", w, r)

res := w.Result()

Expand All @@ -119,7 +119,7 @@ func TestGetPortForwardByID_MissingCluster(t *testing.T) {
r := newRequestWithVars(http.MethodGet, "/portforward?id=abc", nil, map[string]string{})
r.URL = &url.URL{RawQuery: "id=abc"}

portforward.GetPortForwardByID(ch, w, r)
portforward.GetPortForwardByID(ch, "", w, r)

res := w.Result()

Expand All @@ -142,7 +142,7 @@ func TestGetPortForwardByID_MissingID(t *testing.T) {
})
r.URL = &url.URL{}

portforward.GetPortForwardByID(ch, w, r)
portforward.GetPortForwardByID(ch, "test-cluster", w, r)

res := w.Result()

Expand All @@ -165,7 +165,7 @@ func TestGetPortForwardByID_NotFound(t *testing.T) {
})
r.URL = &url.URL{RawQuery: "id=nonexistent"}

portforward.GetPortForwardByID(ch, w, r)
portforward.GetPortForwardByID(ch, "test-cluster", w, r)

res := w.Result()

Expand All @@ -191,7 +191,7 @@ func TestStopOrDeletePortForward_InvalidJSON(t *testing.T) {
"clusterName": "test-cluster",
})

portforward.StopOrDeletePortForward(ch, w, r)
portforward.StopOrDeletePortForward(ch, "test-cluster", w, r)

res := w.Result()

Expand All @@ -200,66 +200,60 @@ func TestStopOrDeletePortForward_InvalidJSON(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, res.StatusCode)
}

func TestStopOrDeletePortForward_MissingID(t *testing.T) {
func TestStopOrDeletePortForward_ErrorCases(t *testing.T) {
t.Parallel()

ch := cache.New[interface{}]()
w := httptest.NewRecorder()

payload := map[string]interface{}{
"id": "",
"stopOrDelete": true,
tests := []struct {
name string
id string
wantStatusCode int
wantBody string
}{
{
name: "missing id",
id: "",
wantStatusCode: http.StatusBadRequest,
wantBody: "id is required",
},
{
name: "id not found in cache",
id: "does-not-exist",
wantStatusCode: http.StatusInternalServerError,
wantBody: "failed to delete port forward",
},
}

jsonPayload, err := json.Marshal(payload)
require.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

Comment on lines +226 to 229
body := bytes.NewReader(jsonPayload)
r := newRequestWithVars(http.MethodDelete, "/portforward", body, map[string]string{
"clusterName": "test-cluster",
})
ch := cache.New[interface{}]()
w := httptest.NewRecorder()

portforward.StopOrDeletePortForward(ch, w, r)
payload := map[string]interface{}{
"id": tt.id,
"stopOrDelete": true,
}

res := w.Result()
jsonPayload, err := json.Marshal(payload)
require.NoError(t, err)

defer func() { _ = res.Body.Close() }()
body := bytes.NewReader(jsonPayload)
r := newRequestWithVars(http.MethodDelete, "/portforward", body, map[string]string{
"clusterName": "test-cluster",
})

assert.Equal(t, http.StatusBadRequest, res.StatusCode)
portforward.StopOrDeletePortForward(ch, "test-cluster", w, r)

respBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Contains(t, string(respBody), "id is required")
}
res := w.Result()

func TestStopOrDeletePortForward_NotFoundInCache(t *testing.T) {
t.Parallel()
defer func() { _ = res.Body.Close() }()

ch := cache.New[interface{}]()
w := httptest.NewRecorder()
assert.Equal(t, tt.wantStatusCode, res.StatusCode)

payload := map[string]interface{}{
"id": "does-not-exist",
"stopOrDelete": true,
respBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Contains(t, string(respBody), tt.wantBody)
})
}

jsonPayload, err := json.Marshal(payload)
require.NoError(t, err)

body := bytes.NewReader(jsonPayload)
r := newRequestWithVars(http.MethodDelete, "/portforward", body, map[string]string{
"clusterName": "test-cluster",
})

portforward.StopOrDeletePortForward(ch, w, r)

res := w.Result()

defer func() { _ = res.Body.Close() }()

assert.Equal(t, http.StatusInternalServerError, res.StatusCode)

respBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Contains(t, string(respBody), "failed to delete port forward")
}
Loading
Loading