Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions server/plugin/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -779,19 +779,31 @@ func (p *Plugin) getPrsDetails(c *UserContext, w http.ResponseWriter, r *http.Re

prDetails := make([]*PRDetails, len(validPRs))
var wg sync.WaitGroup
var fetchErr error
var fetchErrMu sync.Mutex
for i, pr := range validPRs {
wg.Go(func() {
prDetail := p.fetchPRDetails(c, githubClient, pr.URL, pr.Number)
prDetail, err := p.fetchPRDetails(c, githubClient, pr.URL, pr.Number)
prDetails[i] = prDetail
if err != nil {
fetchErrMu.Lock()
if fetchErr == nil {
fetchErr = err
}
fetchErrMu.Unlock()
}
})
}

wg.Wait()

if isGitHubAuthFailure(fetchErr) {
p.handleRevokedToken(c.GHInfo)
}
p.writeJSON(w, prDetails)
}

func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL string, prNumber int) *PRDetails {
func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL string, prNumber int) (*PRDetails, error) {
var status string
var mergeable bool
// Initialize to a non-nil slice to simplify JSON handling semantics
Expand All @@ -806,16 +818,26 @@ func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL str
Number: prNumber,
RequestedReviewers: requestedReviewers,
Reviews: reviewsList,
}
}, nil
}

var wg sync.WaitGroup
var fetchErr error
var fetchErrMu sync.Mutex
recordFetchErr := func(err error, msg string) {
c.Log.WithError(err).Warnf("%s", msg)
fetchErrMu.Lock()
if fetchErr == nil {
fetchErr = err
}
fetchErrMu.Unlock()
}

// Fetch reviews
wg.Go(func() {
fetchedReviews, err := fetchReviews(c, client, repoOwner, repoName, prNumber)
if err != nil {
c.Log.WithError(err).Warnf("Failed to fetch reviews for PR details")
recordFetchErr(err, "Failed to fetch reviews for PR details")
return
}
reviewsList = fetchedReviews
Expand All @@ -825,7 +847,7 @@ func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL str
wg.Go(func() {
prInfo, _, err := client.PullRequests.Get(c.Ctx, repoOwner, repoName, prNumber)
if err != nil {
c.Log.WithError(err).Warnf("Failed to fetch PR for PR details")
recordFetchErr(err, "Failed to fetch PR for PR details")
return
}

Expand All @@ -836,7 +858,7 @@ func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL str
}
statuses, _, err := client.Repositories.GetCombinedStatus(c.Ctx, repoOwner, repoName, prInfo.GetHead().GetSHA(), nil)
if err != nil {
c.Log.WithError(err).Warnf("Failed to fetch combined status")
recordFetchErr(err, "Failed to fetch combined status")
return
}
status = *statuses.State
Expand All @@ -850,7 +872,7 @@ func (p *Plugin) fetchPRDetails(c *UserContext, client *github.Client, prURL str
Mergeable: mergeable,
RequestedReviewers: requestedReviewers,
Reviews: reviewsList,
}
}, fetchErr
}

func fetchReviews(c *UserContext, client *github.Client, repoOwner string, repoName string, number int) ([]*github.PullRequestReview, error) {
Expand Down Expand Up @@ -1100,6 +1122,9 @@ func (p *Plugin) getLHSData(c *UserContext) (reviewResp []*graphql.GithubPRDetai
graphQLClient := p.graphQLConnect(c.GHInfo)

reviewResp, assignmentResp, openPRResp, err = graphQLClient.GetLHSData(c.Ctx)
if isGitHubAuthFailure(err) {
p.handleRevokedToken(c.GHInfo)
}
if err != nil {
return []*graphql.GithubPRDetails{}, []*github.Issue{}, []*graphql.GithubPRDetails{}, err
}
Expand Down
10 changes: 8 additions & 2 deletions server/plugin/graphql/lhs_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,25 @@ func (c *Client) GetLHSData(ctx context.Context) ([]*GithubPRDetails, []*github.
var resultAssignee []*github.Issue
var resultReview, resultOpenPR []*GithubPRDetails

var err error
var firstErr error
for _, org := range orgsList {
var err error
resultReview, resultAssignee, resultOpenPR, err = c.fetchLHSData(ctx, resultReview, resultAssignee, resultOpenPR, org, c.username)
if err != nil {
c.logger.Error("Error fetching LHS data for org", "org", org, "error", err.Error())
if firstErr == nil {
firstErr = err
}
}
}

if len(orgsList) == 0 {
return c.fetchLHSData(ctx, resultReview, resultAssignee, resultOpenPR, "", c.username)
}

return resultReview, resultAssignee, resultOpenPR, nil
// Return partial results alongside the error so callers can detect auth failures
// while still rendering whatever orgs succeeded.
return resultReview, resultAssignee, resultOpenPR, firstErr
}

func (c *Client) fetchLHSData(
Expand Down
53 changes: 52 additions & 1 deletion server/plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -1349,13 +1349,64 @@ func (p *Plugin) useGitHubClient(info *GitHubUserInfo, toRun func(info *GitHubUs
p.client.Log.Warn("Error occurred while using the Github client", "error", err.Error())
}

if err != nil && strings.Contains(err.Error(), invalidTokenError) {
if isGitHubAuthFailure(err) {
p.handleRevokedToken(info)
}

return err
}

// isGitHubAuthFailure reports whether err indicates the stored OAuth token is no
// longer usable: a 401, or a 403 from SAML SSO enforcement.
//
// We use two detection paths because the plugin talks to GitHub through both the
// REST client (go-github) and the GraphQL client (githubv4), and they surface
// errors differently.
func isGitHubAuthFailure(err error) bool {
if err == nil {
return false
}

// REST API: go-github returns a typed *github.ErrorResponse with the HTTP status.
var ghErr *github.ErrorResponse
if errors.As(err, &ghErr) && ghErr.Response != nil {
switch ghErr.Response.StatusCode {
case http.StatusUnauthorized:
return true
case http.StatusForbidden:
// Not every 403 is an auth failure; only SAML SSO revocation requires reconnect.
return isSAMLError(ghErr)
}
}

// GraphQL API: githubv4 returns untyped errors with status and message embedded
// in err.Error(), so match on known substrings instead.
errMsg := err.Error()
if strings.Contains(errMsg, invalidTokenError) {
return true
}
if strings.Contains(errMsg, "non-200 OK status code: 401") {
return true
}
// Match SAML content, not bare 403 — unrelated permission errors also return 403.
return strings.Contains(errMsg, "SAML enforcement") || strings.Contains(errMsg, "saml_failure")
}

func isSAMLError(ghErr *github.ErrorResponse) bool {
if ghErr.Response.Header.Get("X-GitHub-SSO") != "" {
return true
}
if strings.Contains(ghErr.Message, "SAML") {
return true
}
for _, e := range ghErr.Errors {
if strings.Contains(e.Message, "SAML") {
return true
}
}
return false
}

func (p *Plugin) handleRevokedToken(info *GitHubUserInfo) {
p.disconnectGitHubAccount(info.UserID)
p.CreateBotDMPost(info.UserID, "Your Github account was disconnected due to an invalid or revoked authorization token. Reconnect your account using the `/github connect` command.", "custom_git_revoked_token")
Expand Down
155 changes: 155 additions & 0 deletions server/plugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ package plugin

import (
"encoding/json"
"net/http"
"strings"
"testing"
"unicode/utf8"

"github.com/golang/mock/gomock"
"github.com/google/go-github/v54/github"
"github.com/pkg/errors"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -43,6 +45,7 @@ func setupRotationTest(t *testing.T) (*Plugin, *plugintest.API, *mocks.MockKvSto

api.On("KVSetWithOptions", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Maybe()
api.On("LogError", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe()
api.On("LogWarn", mock.Anything, mock.Anything, mock.Anything).Maybe()
api.On("LogAuditRec", mock.Anything).Maybe()

return p, api, mockKvStore, ctrl
Expand Down Expand Up @@ -406,3 +409,155 @@ func TestTruncatePostMessage(t *testing.T) {
require.True(t, strings.HasSuffix(out, "_… message truncated_"))
})
}

func TestIsGitHubAuthFailure(t *testing.T) {
t.Run("nil", func(t *testing.T) {
require.False(t, isGitHubAuthFailure(nil))
})

t.Run("401 bad credentials string", func(t *testing.T) {
require.True(t, isGitHubAuthFailure(errors.New(invalidTokenError)))
})

t.Run("401 from ErrorResponse", func(t *testing.T) {
err := &github.ErrorResponse{Response: &http.Response{StatusCode: http.StatusUnauthorized}}
require.True(t, isGitHubAuthFailure(err))
})

t.Run("401 from graphql client message", func(t *testing.T) {
err := errors.New("non-200 OK status code: 401 Unauthorized")
require.True(t, isGitHubAuthFailure(err))
})

t.Run("403 SAML from ErrorResponse", func(t *testing.T) {
err := &github.ErrorResponse{
Message: "Resource protected by organization SAML enforcement. You must grant your OAuth token access to this organization.",
Response: &http.Response{
StatusCode: http.StatusForbidden,
Header: http.Header{"X-Github-Sso": []string{"required; url=https://github.com/orgs/foo/sso"}},
},
}
require.True(t, isGitHubAuthFailure(err))
})

t.Run("403 SAML from graphql error string", func(t *testing.T) {
err := errors.New("error in executing query: GraphQL: Resource protected by organization SAML enforcement. You must grant your OAuth token access to this organization.")
require.True(t, isGitHubAuthFailure(err))
})

t.Run("403 unrelated", func(t *testing.T) {
err := &github.ErrorResponse{
Message: "Forbidden",
Response: &http.Response{StatusCode: http.StatusForbidden},
}
require.False(t, isGitHubAuthFailure(err))
})

t.Run("generic error", func(t *testing.T) {
require.False(t, isGitHubAuthFailure(errors.New("connection reset")))
})
}

func connectedGitHubUserInfo(t *testing.T) *GitHubUserInfo {
t.Helper()
encryptedToken, err := encrypt([]byte(testNewKey), MockAccessToken)
require.NoError(t, err)
return &GitHubUserInfo{
UserID: "user1",
GitHubUsername: "ghuser1",
Token: &oauth2.Token{AccessToken: encryptedToken},
Settings: &UserSettings{},
}
}

func expectRevokedTokenNotification(api *plugintest.API, mockKvStore *mocks.MockKvStore, userInfo *GitHubUserInfo) {
mockKvStore.EXPECT().Get(userInfo.UserID+githubTokenKey, gomock.Any()).DoAndReturn(
func(_ string, out any) error {
userInfoBytes, err := json.Marshal(userInfo)
if err != nil {
return err
}
return json.Unmarshal(userInfoBytes, out)
},
)
mockKvStore.EXPECT().Delete(userInfo.UserID + githubTokenKey).Return(nil)
mockKvStore.EXPECT().Delete(userInfo.GitHubUsername + githubUsernameKey).Return(nil)
mockKvStore.EXPECT().Delete(userInfo.UserID + githubPrivateRepoKey).Return(nil)
api.On("GetUser", userInfo.UserID).Return(&model.User{
Id: userInfo.UserID,
Props: model.StringMap{"git_user": userInfo.GitHubUsername},
}, nil)
api.On("UpdateUser", mock.Anything).Return(&model.User{Id: userInfo.UserID, Props: model.StringMap{}}, nil)
api.On("PublishWebSocketEvent", wsEventDisconnect, map[string]any(nil),
&model.WebsocketBroadcast{UserId: userInfo.UserID}).Return()
api.On("GetDirectChannel", userInfo.UserID, MockBotID).Return(&model.Channel{Id: "dmchannel"}, nil)
api.On("CreatePost", mock.MatchedBy(func(post *model.Post) bool {
return post.UserId == MockBotID &&
post.ChannelId == "dmchannel" &&
post.Type == "custom_git_revoked_token"
})).Return(&model.Post{}, nil).Once()
}

func TestUseGitHubClient_AuthFailureNotifiesUser(t *testing.T) {
samlGraphQLErr := errors.New("error in executing query: GraphQL: Resource protected by organization SAML enforcement. You must grant your OAuth token access to this organization.")

tests := []struct {
name string
err error
notify bool
}{
{
name: "401 bad credentials",
err: errors.New(invalidTokenError),
notify: true,
},
{
name: "403 SAML REST",
err: &github.ErrorResponse{
Message: "Resource protected by organization SAML enforcement. You must grant your OAuth token access to this organization.",
Response: &http.Response{
StatusCode: http.StatusForbidden,
Header: http.Header{"X-Github-Sso": []string{"required; url=https://github.com/orgs/foo/sso"}},
Request: &http.Request{},
},
},
notify: true,
},
{
name: "403 SAML graphql",
err: samlGraphQLErr,
notify: true,
},
{
name: "403 unrelated",
err: &github.ErrorResponse{
Message: "Forbidden",
Response: &http.Response{StatusCode: http.StatusForbidden, Request: &http.Request{}},
},
notify: false,
},
{
name: "generic error",
err: errors.New("connection reset"),
notify: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
p, api, mockKvStore, ctrl := setupRotationTest(t)
defer ctrl.Finish()

userInfo := connectedGitHubUserInfo(t)
if tc.notify {
expectRevokedTokenNotification(api, mockKvStore, userInfo)
}

err := p.useGitHubClient(userInfo, func(_ *GitHubUserInfo, _ *oauth2.Token) error {
return tc.err
})
require.Equal(t, tc.err, err)
api.AssertExpectations(t)
})
}
}
Loading