Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict external data access #130

Merged
merged 3 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion server/ai/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"
_ "time/tzdata" // Needed to fill time.LoadLocation db

"github.com/mattermost/mattermost-plugin-ai/server/mmapi"
"github.com/mattermost/mattermost/server/public/model"
)

Expand All @@ -24,6 +25,7 @@ type Post struct {
}

type ConversationContext struct {
BotID string
jupenur marked this conversation as resolved.
Show resolved Hide resolved
Time string
ServerName string
CompanyName string
Expand All @@ -34,7 +36,7 @@ type ConversationContext struct {
PromptParameters map[string]string
}

func NewConversationContext(requestingUser *model.User, channel *model.Channel, post *model.Post) ConversationContext {
func NewConversationContext(botID string, requestingUser *model.User, channel *model.Channel, post *model.Post) ConversationContext {
// Get current time and date formatted nicely with the user's locale
now := time.Now()
nowString := now.Format(time.RFC1123)
Expand All @@ -51,9 +53,14 @@ func NewConversationContext(requestingUser *model.User, channel *model.Channel,
RequestingUser: requestingUser,
Channel: channel,
Post: post,
BotID: botID,
}
}

func (c *ConversationContext) IsDMWithBot() bool {
return mmapi.IsDMWith(c.BotID, c.Channel)
}

func (c ConversationContext) String() string {
var result strings.Builder
result.WriteString(fmt.Sprintf("Time: %v\nServerName: %v\nCompanyName: %v", c.Time, c.ServerName, c.CompanyName))
Expand Down
12 changes: 7 additions & 5 deletions server/ai/prompts.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import (
"github.com/pkg/errors"
)

type BuiltInToolsFunc func(isDM bool) []Tool

type Prompts struct {
templates *template.Template
getBuiltInTools func() []Tool
getBuiltInTools BuiltInToolsFunc
}

const PromptExtension = "tmpl"
Expand All @@ -37,7 +39,7 @@ const (
PromptFindOpenQuestionsSince = "find_open_questions_since"
)

func NewPrompts(input fs.FS, getBuiltInTools func() []Tool) (*Prompts, error) {
func NewPrompts(input fs.FS, getBuiltInTools BuiltInToolsFunc) (*Prompts, error) {
templates, err := template.ParseFS(input, "ai/prompts/*")
if err != nil {
return nil, errors.Wrap(err, "unable to parse prompt templates")
Expand All @@ -53,17 +55,17 @@ func withPromptExtension(filename string) string {
return filename + "." + PromptExtension
}

func (p *Prompts) getDefaultTools() ToolStore {
func (p *Prompts) getDefaultTools(isDMWithBot bool) ToolStore {
tools := NewToolStore()
tools.AddTools(p.getBuiltInTools())
tools.AddTools(p.getBuiltInTools(isDMWithBot))
return tools
}

func (p *Prompts) ChatCompletion(templateName string, context ConversationContext) (BotConversation, error) {
conversation := BotConversation{
Posts: []Post{},
Context: context,
Tools: p.getDefaultTools(),
Tools: p.getDefaultTools(context.IsDMWithBot()),
}

template := p.templates.Lookup(withPromptExtension(templateName))
Expand Down
2 changes: 1 addition & 1 deletion server/api_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (p *Plugin) handleSince(c *gin.Context) {

formattedThread := formatThread(threadData)

context := ai.NewConversationContext(user, channel, nil)
context := p.MakeConversationContext(user, channel, nil)
context.PromptParameters = map[string]string{
"Posts": formattedThread,
}
Expand Down
51 changes: 29 additions & 22 deletions server/built_in_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ func (p *Plugin) toolResolveLookupMattermostUser(context ai.ConversationContext,
}
result += fmt.Sprintf("\nTimezone: %s", model.GetPreferredTimezone(user.Timezone))
result += fmt.Sprintf("\nLast Activity: %s", model.GetTimeForMillis(userStatus.LastActivityAt).Format("2006-01-02 15:04:05 MST"))
result += fmt.Sprintf("\nStatus: %s", userStatus.Status)
// Exclude manual statuses because they could be prompt injections
if userStatus.Status != "" && !userStatus.Manual {
result += fmt.Sprintf("\nStatus: %s", userStatus.Status)
}

return result, nil
}
Expand Down Expand Up @@ -178,34 +181,38 @@ func (p *Plugin) toolGetGithubIssue(context ai.ConversationContext, argsGetter a
return formatIssue(&issue), nil
}

func (p *Plugin) getBuiltInTools() []ai.Tool {
// Unconditional tools
builtInTools := []ai.Tool{
{
Name: "LookupMattermostUser",
Description: "Lookup a Mattermost user by their username. Available information includes: username, full name, email, nickname, position, locale, timezone, last activity, and status.",
Schema: LookupMattermostUserArgs{},
Resolver: p.toolResolveLookupMattermostUser,
},
{
// getBuiltInTools returns the built-in tools that are available to all users.
// isDM is true if the response will be in a DM with the user. More tools are available in DMs because of security properties.
func (p *Plugin) getBuiltInTools(isDM bool) []ai.Tool {
builtInTools := []ai.Tool{}

if isDM {
builtInTools = append(builtInTools, ai.Tool{
Name: "GetChannelPosts",
Description: "Get the most recent posts from a Mattermost channel. Returns posts in the format 'username: message'",
Schema: GetChannelPosts{},
Resolver: p.toolResolveGetChannelPosts,
},
}
})

// Github plugin tools
status, err := p.pluginAPI.Plugin.GetPluginStatus("github")
if err != nil {
p.API.LogError("failed to get github plugin status", "error", err.Error())
} else if status != nil && status.State == model.PluginStateRunning {
builtInTools = append(builtInTools, ai.Tool{
Name: "GetGithubIssue",
Description: "Retrieve a single GitHub issue by owner, repo, and issue number.",
Schema: GetGithubIssueArgs{},
Resolver: p.toolGetGithubIssue,
Name: "LookupMattermostUser",
Description: "Lookup a Mattermost user by their username. Available information includes: username, full name, email, nickname, position, locale, timezone, last activity, and status.",
Schema: LookupMattermostUserArgs{},
Resolver: p.toolResolveLookupMattermostUser,
})

// Github plugin tools
status, err := p.pluginAPI.Plugin.GetPluginStatus("github")
if err != nil {
p.API.LogError("failed to get github plugin status", "error", err.Error())
} else if status != nil && status.State == model.PluginStateRunning {
builtInTools = append(builtInTools, ai.Tool{
Name: "GetGithubIssue",
Description: "Retrieve a single GitHub issue by owner, repo, and issue number.",
Schema: GetGithubIssueArgs{},
Resolver: p.toolGetGithubIssue,
})
}
}

return builtInTools
Expand Down
2 changes: 1 addition & 1 deletion server/conversation_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

func (p *Plugin) MakeConversationContext(user *model.User, channel *model.Channel, post *model.Post) ai.ConversationContext {
context := ai.NewConversationContext(user, channel, post)
context := ai.NewConversationContext(p.botid, user, channel, post)
if p.pluginAPI.Configuration.GetConfig().TeamSettings.SiteName != nil {
context.ServerName = *p.pluginAPI.Configuration.GetConfig().TeamSettings.SiteName
}
Expand Down
14 changes: 14 additions & 0 deletions server/mmapi/channels.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package mmapi

import (
"strings"

"github.com/mattermost/mattermost/server/public/model"
)

func IsDMWith(userID string, channel *model.Channel) bool {
return channel != nil &&
channel.Type == model.ChannelTypeDirect &&
userID != "" &&
strings.Contains(channel.Name, userID)
}
58 changes: 58 additions & 0 deletions server/mmapi/channels_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package mmapi

import (
"testing"

"github.com/mattermost/mattermost/server/public/model"
"github.com/stretchr/testify/require"
)

func TestIsDMWith(t *testing.T) {
for _, tc := range []struct {
name string
userID string
channel *model.Channel
want bool
}{
{
name: "nil channel",
userID: "thisisuserid",
channel: nil,
want: false,
},
{
name: "not direct channel",
userID: "thisisuserid",
channel: &model.Channel{Type: model.ChannelTypeGroup},
want: false,
},
{
name: "empty user",
userID: "",
channel: &model.Channel{Type: model.ChannelTypeDirect, Name: "thisisuserid__otheruserid"},
want: false,
},
{
name: "not DM with user",
userID: "thisisuserid",
channel: &model.Channel{Type: model.ChannelTypeDirect, Name: "someotheruser__otheruserid"},
want: false,
},
{
name: "DM with user",
userID: "thisisuserid",
channel: &model.Channel{Type: model.ChannelTypeDirect, Name: "thisisuserid__otheruserid"},
want: true,
},
{
name: "DM with user reversed",
userID: "thisisuserid",
channel: &model.Channel{Type: model.ChannelTypeDirect, Name: "otheruserid__thisisuserid"},
want: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.want, IsDMWith(tc.userID, tc.channel))
})
}
}
3 changes: 2 additions & 1 deletion server/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"strings"

"github.com/mattermost/mattermost-plugin-ai/server/mmapi"
"github.com/mattermost/mattermost/server/public/model"
"github.com/pkg/errors"
)
Expand Down Expand Up @@ -30,7 +31,7 @@ func (p *Plugin) checkUsageRestrictionsForChannel(channel *model.Channel) error

if !cfg.AllowPrivateChannels {
if channel.Type != model.ChannelTypeOpen {
if !(channel.Type == model.ChannelTypeDirect && strings.Contains(channel.Name, p.botid)) {
if !mmapi.IsDMWith(p.botid, channel) {
return errors.Wrap(ErrUsageRestriction, "can't work on private channels")
}
}
Expand Down
6 changes: 3 additions & 3 deletions server/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"embed"
"os/exec"
"strings"
"sync"

sq "github.com/Masterminds/squirrel"
Expand All @@ -14,6 +13,7 @@ import (
"github.com/mattermost/mattermost-plugin-ai/server/ai/asksage"
"github.com/mattermost/mattermost-plugin-ai/server/ai/openai"
"github.com/mattermost/mattermost-plugin-ai/server/enterprise"
"github.com/mattermost/mattermost-plugin-ai/server/mmapi"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/plugin"
"github.com/mattermost/mattermost/server/public/pluginapi"
Expand Down Expand Up @@ -225,8 +225,8 @@ func (p *Plugin) handleMessages(post *model.Post) error {
case userIsMentionedMarkdown(post.Message, BotUsername):
return p.handleMentions(post, postingUser, channel)

// Check if this is post in the DM channel with the bot
case channel.Type == model.ChannelTypeDirect && strings.Contains(channel.Name, p.botid):
// Check if this is post in the DM channel with the bot
case mmapi.IsDMWith(p.botid, channel):
return p.handleDMs(channel, postingUser, post)
}

Expand Down
Loading