diff --git a/server/plugin/api.go b/server/plugin/api.go index a425d3b92..104cca5bb 100644 --- a/server/plugin/api.go +++ b/server/plugin/api.go @@ -779,19 +779,31 @@ func (p *Plugin) getPrsDetails(c *UserContext, w http.ResponseWriter, r *http.Re prDetails := make([]*PRDetails, len(validPRs)) var wg sync.WaitGroup + var fetchErr error + var fetchErrMu sync.Mutex for i, pr := range validPRs { wg.Go(func() { - prDetail := p.fetchPRDetails(c, githubClient, pr.URL, pr.Number) + prDetail, err := p.fetchPRDetails(c, githubClient, pr.URL, pr.Number) prDetails[i] = prDetail + if err != nil { + fetchErrMu.Lock() + if fetchErr == nil { + fetchErr = err + } + fetchErrMu.Unlock() + } }) } wg.Wait() + if isGitHubAuthFailure(fetchErr) { + p.handleRevokedToken(c.GHInfo) + } p.writeJSON(w, prDetails) } -func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL string, prNumber int) *PRDetails { +func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL string, prNumber int) (*PRDetails, error) { var status string var mergeable bool // Initialize to a non-nil slice to simplify JSON handling semantics @@ -806,16 +818,26 @@ func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL str Number: prNumber, RequestedReviewers: requestedReviewers, Reviews: reviewsList, - } + }, nil } var wg sync.WaitGroup + var fetchErr error + var fetchErrMu sync.Mutex + recordFetchErr := func(err error, msg string) { + c.Log.WithError(err).Warnf("%s", msg) + fetchErrMu.Lock() + if fetchErr == nil { + fetchErr = err + } + fetchErrMu.Unlock() + } // Fetch reviews wg.Go(func() { fetchedReviews, err := fetchReviews(c, client, repoOwner, repoName, prNumber) if err != nil { - c.Log.WithError(err).Warnf("Failed to fetch reviews for PR details") + recordFetchErr(err, "Failed to fetch reviews for PR details") return } reviewsList = fetchedReviews @@ -825,7 +847,7 @@ func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL str wg.Go(func() { prInfo, _, err := client.PullRequests.Get(c.Ctx, repoOwner, repoName, prNumber) if err != nil { - c.Log.WithError(err).Warnf("Failed to fetch PR for PR details") + recordFetchErr(err, "Failed to fetch PR for PR details") return } @@ -836,7 +858,7 @@ func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL str } statuses, _, err := client.Repositories.GetCombinedStatus(c.Ctx, repoOwner, repoName, prInfo.GetHead().GetSHA(), nil) if err != nil { - c.Log.WithError(err).Warnf("Failed to fetch combined status") + recordFetchErr(err, "Failed to fetch combined status") return } status = *statuses.State @@ -850,7 +872,7 @@ func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL str Mergeable: mergeable, RequestedReviewers: requestedReviewers, Reviews: reviewsList, - } + }, fetchErr } func fetchReviews(c *UserContext, client *github.Client, repoOwner string, repoName string, number int) ([]*github.PullRequestReview, error) { @@ -1100,6 +1122,9 @@ func (p *Plugin) getLHSData(c *UserContext) (reviewResp []*graphql.GithubPRDetai graphQLClient := p.graphQLConnect(c.GHInfo) reviewResp, assignmentResp, openPRResp, err = graphQLClient.GetLHSData(c.Ctx) + if isGitHubAuthFailure(err) { + p.handleRevokedToken(c.GHInfo) + } if err != nil { return []*graphql.GithubPRDetails{}, []*github.Issue{}, []*graphql.GithubPRDetails{}, err } diff --git a/server/plugin/graphql/lhs_request.go b/server/plugin/graphql/lhs_request.go index 2bcbb93c5..270b4aabc 100644 --- a/server/plugin/graphql/lhs_request.go +++ b/server/plugin/graphql/lhs_request.go @@ -36,11 +36,15 @@ func (c *Client) GetLHSData(ctx context.Context) ([]*GithubPRDetails, []*github. var resultAssignee []*github.Issue var resultReview, resultOpenPR []*GithubPRDetails - var err error + var firstErr error for _, org := range orgsList { + var err error resultReview, resultAssignee, resultOpenPR, err = c.fetchLHSData(ctx, resultReview, resultAssignee, resultOpenPR, org, c.username) if err != nil { c.logger.Error("Error fetching LHS data for org", "org", org, "error", err.Error()) + if firstErr == nil { + firstErr = err + } } } @@ -48,7 +52,9 @@ func (c *Client) GetLHSData(ctx context.Context) ([]*GithubPRDetails, []*github. return c.fetchLHSData(ctx, resultReview, resultAssignee, resultOpenPR, "", c.username) } - return resultReview, resultAssignee, resultOpenPR, nil + // Return partial results alongside the error so callers can detect auth failures + // while still rendering whatever orgs succeeded. + return resultReview, resultAssignee, resultOpenPR, firstErr } func (c *Client) fetchLHSData( diff --git a/server/plugin/plugin.go b/server/plugin/plugin.go index 96a723da5..7f27b3bed 100644 --- a/server/plugin/plugin.go +++ b/server/plugin/plugin.go @@ -1349,13 +1349,64 @@ func (p *Plugin) useGitHubClient(info *GitHubUserInfo, toRun func(info *GitHubUs p.client.Log.Warn("Error occurred while using the Github client", "error", err.Error()) } - if err != nil && strings.Contains(err.Error(), invalidTokenError) { + if isGitHubAuthFailure(err) { p.handleRevokedToken(info) } return err } +// isGitHubAuthFailure reports whether err indicates the stored OAuth token is no +// longer usable: a 401, or a 403 from SAML SSO enforcement. +// +// We use two detection paths because the plugin talks to GitHub through both the +// REST client (go-github) and the GraphQL client (githubv4), and they surface +// errors differently. +func isGitHubAuthFailure(err error) bool { + if err == nil { + return false + } + + // REST API: go-github returns a typed *github.ErrorResponse with the HTTP status. + var ghErr *github.ErrorResponse + if errors.As(err, &ghErr) && ghErr.Response != nil { + switch ghErr.Response.StatusCode { + case http.StatusUnauthorized: + return true + case http.StatusForbidden: + // Not every 403 is an auth failure; only SAML SSO revocation requires reconnect. + return isSAMLError(ghErr) + } + } + + // GraphQL API: githubv4 returns untyped errors with status and message embedded + // in err.Error(), so match on known substrings instead. + errMsg := err.Error() + if strings.Contains(errMsg, invalidTokenError) { + return true + } + if strings.Contains(errMsg, "non-200 OK status code: 401") { + return true + } + // Match SAML content, not bare 403 — unrelated permission errors also return 403. + return strings.Contains(errMsg, "SAML enforcement") || strings.Contains(errMsg, "saml_failure") +} + +func isSAMLError(ghErr *github.ErrorResponse) bool { + if ghErr.Response.Header.Get("X-GitHub-SSO") != "" { + return true + } + if strings.Contains(ghErr.Message, "SAML") { + return true + } + for _, e := range ghErr.Errors { + if strings.Contains(e.Message, "SAML") { + return true + } + } + return false +} + func (p *Plugin) handleRevokedToken(info *GitHubUserInfo) { p.disconnectGitHubAccount(info.UserID) p.CreateBotDMPost(info.UserID, "Your Github account was disconnected due to an invalid or revoked authorization token. Reconnect your account using the `/github connect` command.", "custom_git_revoked_token") diff --git a/server/plugin/plugin_test.go b/server/plugin/plugin_test.go index 46c6b24f4..e55fdd35f 100644 --- a/server/plugin/plugin_test.go +++ b/server/plugin/plugin_test.go @@ -5,11 +5,13 @@ package plugin import ( "encoding/json" + "net/http" "strings" "testing" "unicode/utf8" "github.com/golang/mock/gomock" + "github.com/google/go-github/v54/github" "github.com/pkg/errors" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -43,6 +45,7 @@ func setupRotationTest(t *testing.T) (*Plugin, *plugintest.API, *mocks.MockKvSto api.On("KVSetWithOptions", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Maybe() api.On("LogError", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe() + api.On("LogWarn", mock.Anything, mock.Anything, mock.Anything).Maybe() api.On("LogAuditRec", mock.Anything).Maybe() return p, api, mockKvStore, ctrl @@ -406,3 +409,155 @@ func TestTruncatePostMessage(t *testing.T) { require.True(t, strings.HasSuffix(out, "_… message truncated_")) }) } + +func TestIsGitHubAuthFailure(t *testing.T) { + t.Run("nil", func(t *testing.T) { + require.False(t, isGitHubAuthFailure(nil)) + }) + + t.Run("401 bad credentials string", func(t *testing.T) { + require.True(t, isGitHubAuthFailure(errors.New(invalidTokenError))) + }) + + t.Run("401 from ErrorResponse", func(t *testing.T) { + err := &github.ErrorResponse{Response: &http.Response{StatusCode: http.StatusUnauthorized}} + require.True(t, isGitHubAuthFailure(err)) + }) + + t.Run("401 from graphql client message", func(t *testing.T) { + err := errors.New("non-200 OK status code: 401 Unauthorized") + require.True(t, isGitHubAuthFailure(err)) + }) + + t.Run("403 SAML from ErrorResponse", func(t *testing.T) { + err := &github.ErrorResponse{ + Message: "Resource protected by organization SAML enforcement. You must grant your OAuth token access to this organization.", + Response: &http.Response{ + StatusCode: http.StatusForbidden, + Header: http.Header{"X-Github-Sso": []string{"required; url=https://github.com/orgs/foo/sso"}}, + }, + } + require.True(t, isGitHubAuthFailure(err)) + }) + + t.Run("403 SAML from graphql error string", func(t *testing.T) { + err := errors.New("error in executing query: GraphQL: Resource protected by organization SAML enforcement. You must grant your OAuth token access to this organization.") + require.True(t, isGitHubAuthFailure(err)) + }) + + t.Run("403 unrelated", func(t *testing.T) { + err := &github.ErrorResponse{ + Message: "Forbidden", + Response: &http.Response{StatusCode: http.StatusForbidden}, + } + require.False(t, isGitHubAuthFailure(err)) + }) + + t.Run("generic error", func(t *testing.T) { + require.False(t, isGitHubAuthFailure(errors.New("connection reset"))) + }) +} + +func connectedGitHubUserInfo(t *testing.T) *GitHubUserInfo { + t.Helper() + encryptedToken, err := encrypt([]byte(testNewKey), MockAccessToken) + require.NoError(t, err) + return &GitHubUserInfo{ + UserID: "user1", + GitHubUsername: "ghuser1", + Token: &oauth2.Token{AccessToken: encryptedToken}, + Settings: &UserSettings{}, + } +} + +func expectRevokedTokenNotification(api *plugintest.API, mockKvStore *mocks.MockKvStore, userInfo *GitHubUserInfo) { + mockKvStore.EXPECT().Get(userInfo.UserID+githubTokenKey, gomock.Any()).DoAndReturn( + func(_ string, out any) error { + userInfoBytes, err := json.Marshal(userInfo) + if err != nil { + return err + } + return json.Unmarshal(userInfoBytes, out) + }, + ) + mockKvStore.EXPECT().Delete(userInfo.UserID + githubTokenKey).Return(nil) + mockKvStore.EXPECT().Delete(userInfo.GitHubUsername + githubUsernameKey).Return(nil) + mockKvStore.EXPECT().Delete(userInfo.UserID + githubPrivateRepoKey).Return(nil) + api.On("GetUser", userInfo.UserID).Return(&model.User{ + Id: userInfo.UserID, + Props: model.StringMap{"git_user": userInfo.GitHubUsername}, + }, nil) + api.On("UpdateUser", mock.Anything).Return(&model.User{Id: userInfo.UserID, Props: model.StringMap{}}, nil) + api.On("PublishWebSocketEvent", wsEventDisconnect, map[string]any(nil), + &model.WebsocketBroadcast{UserId: userInfo.UserID}).Return() + api.On("GetDirectChannel", userInfo.UserID, MockBotID).Return(&model.Channel{Id: "dmchannel"}, nil) + api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool { + return post.UserId == MockBotID && + post.ChannelId == "dmchannel" && + post.Type == "custom_git_revoked_token" + })).Return(&model.Post{}, nil).Once() +} + +func TestUseGitHubClient_AuthFailureNotifiesUser(t *testing.T) { + samlGraphQLErr := errors.New("error in executing query: GraphQL: Resource protected by organization SAML enforcement. You must grant your OAuth token access to this organization.") + + tests := []struct { + name string + err error + notify bool + }{ + { + name: "401 bad credentials", + err: errors.New(invalidTokenError), + notify: true, + }, + { + name: "403 SAML REST", + err: &github.ErrorResponse{ + Message: "Resource protected by organization SAML enforcement. You must grant your OAuth token access to this organization.", + Response: &http.Response{ + StatusCode: http.StatusForbidden, + Header: http.Header{"X-Github-Sso": []string{"required; url=https://github.com/orgs/foo/sso"}}, + Request: &http.Request{}, + }, + }, + notify: true, + }, + { + name: "403 SAML graphql", + err: samlGraphQLErr, + notify: true, + }, + { + name: "403 unrelated", + err: &github.ErrorResponse{ + Message: "Forbidden", + Response: &http.Response{StatusCode: http.StatusForbidden, Request: &http.Request{}}, + }, + notify: false, + }, + { + name: "generic error", + err: errors.New("connection reset"), + notify: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p, api, mockKvStore, ctrl := setupRotationTest(t) + defer ctrl.Finish() + + userInfo := connectedGitHubUserInfo(t) + if tc.notify { + expectRevokedTokenNotification(api, mockKvStore, userInfo) + } + + err := p.useGitHubClient(userInfo, func(_ *GitHubUserInfo, _ *oauth2.Token) error { + return tc.err + }) + require.Equal(t, tc.err, err) + api.AssertExpectations(t) + }) + } +}