Skip to content

Commit

Permalink
Merge pull request #831 from ibuildthecloud/share-context-bug
Browse files Browse the repository at this point in the history
chore: refactor logic for tool sharing
  • Loading branch information
ibuildthecloud authored Aug 29, 2024
2 parents fbb8f5d + 33741b1 commit bfe96cf
Show file tree
Hide file tree
Showing 17 changed files with 211 additions and 313 deletions.
6 changes: 3 additions & 3 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func NewContext(ctx context.Context, prg *types.Program, input string) (Context,
Input: input,
}

agentGroup, err := callCtx.Tool.GetAgents(*prg)
agentGroup, err := callCtx.Tool.GetToolsByType(prg, types.ToolTypeAgent)
if err != nil {
return callCtx, err
}
Expand All @@ -225,7 +225,7 @@ func (c *Context) SubCallContext(ctx context.Context, input, toolID, callID stri
callID = counter.Next()
}

agentGroup, err := c.Tool.GetNextAgentGroup(*c.Program, c.AgentGroup, toolID)
agentGroup, err := c.Tool.GetNextAgentGroup(c.Program, c.AgentGroup, toolID)
if err != nil {
return Context{}, err
}
Expand Down Expand Up @@ -272,7 +272,7 @@ func populateMessageParams(ctx Context, completion *types.CompletionRequest, too
}

var err error
completion.Tools, err = tool.GetCompletionTools(*ctx.Program, ctx.AgentGroup...)
completion.Tools, err = tool.GetChatCompletionTools(*ctx.Program, ctx.AgentGroup...)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/runner/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
"fmt"

"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/types"
)

func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []string, input string) (string, error) {
inputToolRefs, err := callCtx.Tool.GetInputFilterTools(*callCtx.Program)
inputToolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeInput)
if err != nil {
return "", err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/runner/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import (
"fmt"

"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/types"
)

func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, state *State, retErr error) (*State, error) {
outputToolRefs, err := callCtx.Tool.GetOutputFilterTools(*callCtx.Program)
outputToolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeOutput)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
}

func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (result []engine.InputContext, _ error) {
toolRefs, err := callCtx.Tool.GetContextTools(*callCtx.Program)
toolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -387,7 +387,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
return nil, err
}

credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup)
credTools, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeCredential)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -503,7 +503,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
progress, progressClose := streamProgress(&callCtx, monitor)
defer progressClose()

credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup)
credTools, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeCredential)
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/tests/testdata/TestAgentOnly/call2.golden
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"tools": [
{
"function": {
"toolID": "testdata/TestAgentOnly/test.gpt:agent1",
"name": "agent1",
"toolID": "testdata/TestAgentOnly/test.gpt:agent3",
"name": "agent3",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand All @@ -19,8 +19,8 @@
},
{
"function": {
"toolID": "testdata/TestAgentOnly/test.gpt:agent3",
"name": "agent3",
"toolID": "testdata/TestAgentOnly/test.gpt:agent1",
"name": "agent1",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand Down
8 changes: 4 additions & 4 deletions pkg/tests/testdata/TestAgentOnly/step1.golden
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@
"tools": [
{
"function": {
"toolID": "testdata/TestAgentOnly/test.gpt:agent1",
"name": "agent1",
"toolID": "testdata/TestAgentOnly/test.gpt:agent3",
"name": "agent3",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand All @@ -111,8 +111,8 @@
},
{
"function": {
"toolID": "testdata/TestAgentOnly/test.gpt:agent3",
"name": "agent3",
"toolID": "testdata/TestAgentOnly/test.gpt:agent1",
"name": "agent1",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand Down
2 changes: 1 addition & 1 deletion pkg/tests/testdata/TestAgents/call3-resp.golden
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"content": [
{
"toolCall": {
"index": 1,
"index": 0,
"id": "call_3",
"function": {
"name": "agent3"
Expand Down
8 changes: 4 additions & 4 deletions pkg/tests/testdata/TestAgents/call3.golden
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"tools": [
{
"function": {
"toolID": "testdata/TestAgents/test.gpt:agent1",
"name": "agent1",
"toolID": "testdata/TestAgents/test.gpt:agent3",
"name": "agent3",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand All @@ -19,8 +19,8 @@
},
{
"function": {
"toolID": "testdata/TestAgents/test.gpt:agent3",
"name": "agent3",
"toolID": "testdata/TestAgents/test.gpt:agent1",
"name": "agent1",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand Down
12 changes: 6 additions & 6 deletions pkg/tests/testdata/TestAgents/step1.golden
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@
"tools": [
{
"function": {
"toolID": "testdata/TestAgents/test.gpt:agent1",
"name": "agent1",
"toolID": "testdata/TestAgents/test.gpt:agent3",
"name": "agent3",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand All @@ -193,8 +193,8 @@
},
{
"function": {
"toolID": "testdata/TestAgents/test.gpt:agent3",
"name": "agent3",
"toolID": "testdata/TestAgents/test.gpt:agent1",
"name": "agent1",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand Down Expand Up @@ -222,7 +222,7 @@
"content": [
{
"toolCall": {
"index": 1,
"index": 0,
"id": "call_3",
"function": {
"name": "agent3"
Expand All @@ -237,7 +237,7 @@
},
"pending": {
"call_3": {
"index": 1,
"index": 0,
"id": "call_3",
"function": {
"name": "agent3"
Expand Down
2 changes: 1 addition & 1 deletion pkg/tests/testdata/TestExport/call1-resp.golden
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"content": [
{
"toolCall": {
"index": 2,
"index": 1,
"id": "call_1",
"function": {
"name": "transient"
Expand Down
8 changes: 4 additions & 4 deletions pkg/tests/testdata/TestExport/call1.golden
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
},
{
"function": {
"toolID": "testdata/TestExport/parent.gpt:parent-local",
"name": "parentLocal",
"toolID": "testdata/TestExport/sub/child.gpt:transient",
"name": "transient",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand All @@ -33,8 +33,8 @@
},
{
"function": {
"toolID": "testdata/TestExport/sub/child.gpt:transient",
"name": "transient",
"toolID": "testdata/TestExport/parent.gpt:parent-local",
"name": "parentLocal",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand Down
12 changes: 6 additions & 6 deletions pkg/tests/testdata/TestExport/call3.golden
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
},
{
"function": {
"toolID": "testdata/TestExport/parent.gpt:parent-local",
"name": "parentLocal",
"toolID": "testdata/TestExport/sub/child.gpt:transient",
"name": "transient",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand All @@ -33,8 +33,8 @@
},
{
"function": {
"toolID": "testdata/TestExport/sub/child.gpt:transient",
"name": "transient",
"toolID": "testdata/TestExport/parent.gpt:parent-local",
"name": "parentLocal",
"parameters": {
"properties": {
"defaultPromptParameter": {
Expand Down Expand Up @@ -62,7 +62,7 @@
"content": [
{
"toolCall": {
"index": 2,
"index": 1,
"id": "call_1",
"function": {
"name": "transient"
Expand All @@ -80,7 +80,7 @@
}
],
"toolCall": {
"index": 2,
"index": 1,
"id": "call_1",
"function": {
"name": "transient"
Expand Down
2 changes: 1 addition & 1 deletion pkg/tests/testdata/TestExportContext/call1.golden
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"role": "system",
"content": [
{
"text": "this is from external context\nthis is from context\nThis is from tool"
"text": "this is from context\nthis is from external context\nThis is from tool"
}
],
"usage": {}
Expand Down
18 changes: 9 additions & 9 deletions pkg/tests/testdata/TestToolRefAll/call1.golden
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
},
{
"function": {
"toolID": "testdata/TestToolRefAll/test.gpt:none",
"name": "none",
"toolID": "testdata/TestToolRefAll/test.gpt:agentAssistant",
"name": "agentAssistant",
"parameters": {
"properties": {
"noneArg": {
"description": "stuff",
"defaultPromptParameter": {
"description": "Prompt to send to the tool. This may be an instruction or question.",
"type": "string"
}
},
Expand All @@ -33,12 +33,12 @@
},
{
"function": {
"toolID": "testdata/TestToolRefAll/test.gpt:agentAssistant",
"name": "agent",
"toolID": "testdata/TestToolRefAll/test.gpt:none",
"name": "none",
"parameters": {
"properties": {
"defaultPromptParameter": {
"description": "Prompt to send to the tool. This may be an instruction or question.",
"noneArg": {
"description": "stuff",
"type": "string"
}
},
Expand All @@ -52,7 +52,7 @@
"role": "system",
"content": [
{
"text": "\nShared context\n\nContext Body\nMain tool"
"text": "\nContext Body\n\nShared context\nMain tool"
}
],
"usage": {}
Expand Down
20 changes: 10 additions & 10 deletions pkg/types/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (
)

type CompletionRequest struct {
Model string `json:"model,omitempty"`
InternalSystemPrompt *bool `json:"internalSystemPrompt,omitempty"`
Tools []CompletionTool `json:"tools,omitempty"`
Messages []CompletionMessage `json:"messages,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
Chat bool `json:"chat,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
JSONResponse bool `json:"jsonResponse,omitempty"`
Cache *bool `json:"cache,omitempty"`
Model string `json:"model,omitempty"`
InternalSystemPrompt *bool `json:"internalSystemPrompt,omitempty"`
Tools []ChatCompletionTool `json:"tools,omitempty"`
Messages []CompletionMessage `json:"messages,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
Chat bool `json:"chat,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
JSONResponse bool `json:"jsonResponse,omitempty"`
Cache *bool `json:"cache,omitempty"`
}

func (r *CompletionRequest) GetCache() bool {
Expand All @@ -27,7 +27,7 @@ func (r *CompletionRequest) GetCache() bool {
return *r.Cache
}

type CompletionTool struct {
type ChatCompletionTool struct {
Function CompletionFunctionDefinition `json:"function,omitempty"`
}

Expand Down
11 changes: 11 additions & 0 deletions pkg/types/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ func (t *toolRefSet) List() (result []ToolReference, err error) {
return result, t.err
}

func (t *toolRefSet) Contains(value ToolReference) bool {
key := toolRefKey{
name: value.Named,
toolID: value.ToolID,
arg: value.Arg,
}

_, ok := t.set[key]
return ok
}

func (t *toolRefSet) HasTool(toolID string) bool {
for _, ref := range t.set {
if ref.ToolID == toolID {
Expand Down
Loading

0 comments on commit bfe96cf

Please sign in to comment.