From 33741b12c1593aca5e1c423dbbf00b86a6f48996 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Wed, 28 Aug 2024 23:17:35 -0700 Subject: [PATCH] chore: refactor logic for tool sharing --- pkg/engine/engine.go | 6 +- pkg/runner/input.go | 3 +- pkg/runner/output.go | 3 +- pkg/runner/runner.go | 6 +- pkg/tests/testdata/TestAgentOnly/call2.golden | 8 +- pkg/tests/testdata/TestAgentOnly/step1.golden | 8 +- .../testdata/TestAgents/call3-resp.golden | 2 +- pkg/tests/testdata/TestAgents/call3.golden | 8 +- pkg/tests/testdata/TestAgents/step1.golden | 12 +- .../testdata/TestExport/call1-resp.golden | 2 +- pkg/tests/testdata/TestExport/call1.golden | 8 +- pkg/tests/testdata/TestExport/call3.golden | 12 +- .../testdata/TestExportContext/call1.golden | 2 +- .../testdata/TestToolRefAll/call1.golden | 18 +- pkg/types/completion.go | 20 +- pkg/types/set.go | 11 + pkg/types/tool.go | 395 +++++++----------- 17 files changed, 211 insertions(+), 313 deletions(-) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 14b75e0a..d028d50b 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -204,7 +204,7 @@ func NewContext(ctx context.Context, prg *types.Program, input string) (Context, Input: input, } - agentGroup, err := callCtx.Tool.GetAgents(*prg) + agentGroup, err := callCtx.Tool.GetToolsByType(prg, types.ToolTypeAgent) if err != nil { return callCtx, err } @@ -225,7 +225,7 @@ func (c *Context) SubCallContext(ctx context.Context, input, toolID, callID stri callID = counter.Next() } - agentGroup, err := c.Tool.GetNextAgentGroup(*c.Program, c.AgentGroup, toolID) + agentGroup, err := c.Tool.GetNextAgentGroup(c.Program, c.AgentGroup, toolID) if err != nil { return Context{}, err } @@ -272,7 +272,7 @@ func populateMessageParams(ctx Context, completion *types.CompletionRequest, too } var err error - completion.Tools, err = tool.GetCompletionTools(*ctx.Program, ctx.AgentGroup...) + completion.Tools, err = tool.GetChatCompletionTools(*ctx.Program, ctx.AgentGroup...) if err != nil { return err } diff --git a/pkg/runner/input.go b/pkg/runner/input.go index 7d77330e..a211ec9d 100644 --- a/pkg/runner/input.go +++ b/pkg/runner/input.go @@ -5,10 +5,11 @@ import ( "fmt" "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/types" ) func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []string, input string) (string, error) { - inputToolRefs, err := callCtx.Tool.GetInputFilterTools(*callCtx.Program) + inputToolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeInput) if err != nil { return "", err } diff --git a/pkg/runner/output.go b/pkg/runner/output.go index d4cb4b9b..e5fe849d 100644 --- a/pkg/runner/output.go +++ b/pkg/runner/output.go @@ -6,10 +6,11 @@ import ( "fmt" "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/types" ) func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, state *State, retErr error) (*State, error) { - outputToolRefs, err := callCtx.Tool.GetOutputFilterTools(*callCtx.Program) + outputToolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeOutput) if err != nil { return nil, err } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 3035a1d1..c843b6b5 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -330,7 +330,7 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string) } func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (result []engine.InputContext, _ error) { - toolRefs, err := callCtx.Tool.GetContextTools(*callCtx.Program) + toolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeContext) if err != nil { return nil, err } @@ -387,7 +387,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en return nil, err } - credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup) + credTools, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeCredential) if err != nil { return nil, err } @@ -503,7 +503,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s progress, progressClose := streamProgress(&callCtx, monitor) defer progressClose() - credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup) + credTools, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeCredential) if err != nil { return nil, err } diff --git a/pkg/tests/testdata/TestAgentOnly/call2.golden b/pkg/tests/testdata/TestAgentOnly/call2.golden index 82f95523..7f6b155b 100644 --- a/pkg/tests/testdata/TestAgentOnly/call2.golden +++ b/pkg/tests/testdata/TestAgentOnly/call2.golden @@ -4,8 +4,8 @@ "tools": [ { "function": { - "toolID": "testdata/TestAgentOnly/test.gpt:agent1", - "name": "agent1", + "toolID": "testdata/TestAgentOnly/test.gpt:agent3", + "name": "agent3", "parameters": { "properties": { "defaultPromptParameter": { @@ -19,8 +19,8 @@ }, { "function": { - "toolID": "testdata/TestAgentOnly/test.gpt:agent3", - "name": "agent3", + "toolID": "testdata/TestAgentOnly/test.gpt:agent1", + "name": "agent1", "parameters": { "properties": { "defaultPromptParameter": { diff --git a/pkg/tests/testdata/TestAgentOnly/step1.golden b/pkg/tests/testdata/TestAgentOnly/step1.golden index 662dbf04..2cda2025 100644 --- a/pkg/tests/testdata/TestAgentOnly/step1.golden +++ b/pkg/tests/testdata/TestAgentOnly/step1.golden @@ -96,8 +96,8 @@ "tools": [ { "function": { - "toolID": "testdata/TestAgentOnly/test.gpt:agent1", - "name": "agent1", + "toolID": "testdata/TestAgentOnly/test.gpt:agent3", + "name": "agent3", "parameters": { "properties": { "defaultPromptParameter": { @@ -111,8 +111,8 @@ }, { "function": { - "toolID": "testdata/TestAgentOnly/test.gpt:agent3", - "name": "agent3", + "toolID": "testdata/TestAgentOnly/test.gpt:agent1", + "name": "agent1", "parameters": { "properties": { "defaultPromptParameter": { diff --git a/pkg/tests/testdata/TestAgents/call3-resp.golden b/pkg/tests/testdata/TestAgents/call3-resp.golden index e2a65c99..7568fc69 100644 --- a/pkg/tests/testdata/TestAgents/call3-resp.golden +++ b/pkg/tests/testdata/TestAgents/call3-resp.golden @@ -3,7 +3,7 @@ "content": [ { "toolCall": { - "index": 1, + "index": 0, "id": "call_3", "function": { "name": "agent3" diff --git a/pkg/tests/testdata/TestAgents/call3.golden b/pkg/tests/testdata/TestAgents/call3.golden index f9b45a1b..5b1638e0 100644 --- a/pkg/tests/testdata/TestAgents/call3.golden +++ b/pkg/tests/testdata/TestAgents/call3.golden @@ -4,8 +4,8 @@ "tools": [ { "function": { - "toolID": "testdata/TestAgents/test.gpt:agent1", - "name": "agent1", + "toolID": "testdata/TestAgents/test.gpt:agent3", + "name": "agent3", "parameters": { "properties": { "defaultPromptParameter": { @@ -19,8 +19,8 @@ }, { "function": { - "toolID": "testdata/TestAgents/test.gpt:agent3", - "name": "agent3", + "toolID": "testdata/TestAgents/test.gpt:agent1", + "name": "agent1", "parameters": { "properties": { "defaultPromptParameter": { diff --git a/pkg/tests/testdata/TestAgents/step1.golden b/pkg/tests/testdata/TestAgents/step1.golden index 3047e695..72e01114 100644 --- a/pkg/tests/testdata/TestAgents/step1.golden +++ b/pkg/tests/testdata/TestAgents/step1.golden @@ -178,8 +178,8 @@ "tools": [ { "function": { - "toolID": "testdata/TestAgents/test.gpt:agent1", - "name": "agent1", + "toolID": "testdata/TestAgents/test.gpt:agent3", + "name": "agent3", "parameters": { "properties": { "defaultPromptParameter": { @@ -193,8 +193,8 @@ }, { "function": { - "toolID": "testdata/TestAgents/test.gpt:agent3", - "name": "agent3", + "toolID": "testdata/TestAgents/test.gpt:agent1", + "name": "agent1", "parameters": { "properties": { "defaultPromptParameter": { @@ -222,7 +222,7 @@ "content": [ { "toolCall": { - "index": 1, + "index": 0, "id": "call_3", "function": { "name": "agent3" @@ -237,7 +237,7 @@ }, "pending": { "call_3": { - "index": 1, + "index": 0, "id": "call_3", "function": { "name": "agent3" diff --git a/pkg/tests/testdata/TestExport/call1-resp.golden b/pkg/tests/testdata/TestExport/call1-resp.golden index 8462d188..7fe59586 100644 --- a/pkg/tests/testdata/TestExport/call1-resp.golden +++ b/pkg/tests/testdata/TestExport/call1-resp.golden @@ -3,7 +3,7 @@ "content": [ { "toolCall": { - "index": 2, + "index": 1, "id": "call_1", "function": { "name": "transient" diff --git a/pkg/tests/testdata/TestExport/call1.golden b/pkg/tests/testdata/TestExport/call1.golden index 9f8b650d..b700ee55 100644 --- a/pkg/tests/testdata/TestExport/call1.golden +++ b/pkg/tests/testdata/TestExport/call1.golden @@ -18,8 +18,8 @@ }, { "function": { - "toolID": "testdata/TestExport/parent.gpt:parent-local", - "name": "parentLocal", + "toolID": "testdata/TestExport/sub/child.gpt:transient", + "name": "transient", "parameters": { "properties": { "defaultPromptParameter": { @@ -33,8 +33,8 @@ }, { "function": { - "toolID": "testdata/TestExport/sub/child.gpt:transient", - "name": "transient", + "toolID": "testdata/TestExport/parent.gpt:parent-local", + "name": "parentLocal", "parameters": { "properties": { "defaultPromptParameter": { diff --git a/pkg/tests/testdata/TestExport/call3.golden b/pkg/tests/testdata/TestExport/call3.golden index ccf7e980..d2abca0c 100644 --- a/pkg/tests/testdata/TestExport/call3.golden +++ b/pkg/tests/testdata/TestExport/call3.golden @@ -18,8 +18,8 @@ }, { "function": { - "toolID": "testdata/TestExport/parent.gpt:parent-local", - "name": "parentLocal", + "toolID": "testdata/TestExport/sub/child.gpt:transient", + "name": "transient", "parameters": { "properties": { "defaultPromptParameter": { @@ -33,8 +33,8 @@ }, { "function": { - "toolID": "testdata/TestExport/sub/child.gpt:transient", - "name": "transient", + "toolID": "testdata/TestExport/parent.gpt:parent-local", + "name": "parentLocal", "parameters": { "properties": { "defaultPromptParameter": { @@ -62,7 +62,7 @@ "content": [ { "toolCall": { - "index": 2, + "index": 1, "id": "call_1", "function": { "name": "transient" @@ -80,7 +80,7 @@ } ], "toolCall": { - "index": 2, + "index": 1, "id": "call_1", "function": { "name": "transient" diff --git a/pkg/tests/testdata/TestExportContext/call1.golden b/pkg/tests/testdata/TestExportContext/call1.golden index bec15478..0ee8f9fe 100644 --- a/pkg/tests/testdata/TestExportContext/call1.golden +++ b/pkg/tests/testdata/TestExportContext/call1.golden @@ -38,7 +38,7 @@ "role": "system", "content": [ { - "text": "this is from external context\nthis is from context\nThis is from tool" + "text": "this is from context\nthis is from external context\nThis is from tool" } ], "usage": {} diff --git a/pkg/tests/testdata/TestToolRefAll/call1.golden b/pkg/tests/testdata/TestToolRefAll/call1.golden index ef36e3fb..9289affa 100644 --- a/pkg/tests/testdata/TestToolRefAll/call1.golden +++ b/pkg/tests/testdata/TestToolRefAll/call1.golden @@ -18,12 +18,12 @@ }, { "function": { - "toolID": "testdata/TestToolRefAll/test.gpt:none", - "name": "none", + "toolID": "testdata/TestToolRefAll/test.gpt:agentAssistant", + "name": "agentAssistant", "parameters": { "properties": { - "noneArg": { - "description": "stuff", + "defaultPromptParameter": { + "description": "Prompt to send to the tool. This may be an instruction or question.", "type": "string" } }, @@ -33,12 +33,12 @@ }, { "function": { - "toolID": "testdata/TestToolRefAll/test.gpt:agentAssistant", - "name": "agent", + "toolID": "testdata/TestToolRefAll/test.gpt:none", + "name": "none", "parameters": { "properties": { - "defaultPromptParameter": { - "description": "Prompt to send to the tool. This may be an instruction or question.", + "noneArg": { + "description": "stuff", "type": "string" } }, @@ -52,7 +52,7 @@ "role": "system", "content": [ { - "text": "\nShared context\n\nContext Body\nMain tool" + "text": "\nContext Body\n\nShared context\nMain tool" } ], "usage": {} diff --git a/pkg/types/completion.go b/pkg/types/completion.go index dd70ad50..5b3899c3 100644 --- a/pkg/types/completion.go +++ b/pkg/types/completion.go @@ -9,15 +9,15 @@ import ( ) type CompletionRequest struct { - Model string `json:"model,omitempty"` - InternalSystemPrompt *bool `json:"internalSystemPrompt,omitempty"` - Tools []CompletionTool `json:"tools,omitempty"` - Messages []CompletionMessage `json:"messages,omitempty"` - MaxTokens int `json:"maxTokens,omitempty"` - Chat bool `json:"chat,omitempty"` - Temperature *float32 `json:"temperature,omitempty"` - JSONResponse bool `json:"jsonResponse,omitempty"` - Cache *bool `json:"cache,omitempty"` + Model string `json:"model,omitempty"` + InternalSystemPrompt *bool `json:"internalSystemPrompt,omitempty"` + Tools []ChatCompletionTool `json:"tools,omitempty"` + Messages []CompletionMessage `json:"messages,omitempty"` + MaxTokens int `json:"maxTokens,omitempty"` + Chat bool `json:"chat,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + JSONResponse bool `json:"jsonResponse,omitempty"` + Cache *bool `json:"cache,omitempty"` } func (r *CompletionRequest) GetCache() bool { @@ -27,7 +27,7 @@ func (r *CompletionRequest) GetCache() bool { return *r.Cache } -type CompletionTool struct { +type ChatCompletionTool struct { Function CompletionFunctionDefinition `json:"function,omitempty"` } diff --git a/pkg/types/set.go b/pkg/types/set.go index 230e112b..65b73d22 100644 --- a/pkg/types/set.go +++ b/pkg/types/set.go @@ -19,6 +19,17 @@ func (t *toolRefSet) List() (result []ToolReference, err error) { return result, t.err } +func (t *toolRefSet) Contains(value ToolReference) bool { + key := toolRefKey{ + name: value.Named, + toolID: value.ToolID, + arg: value.Arg, + } + + _, ok := t.set[key] + return ok +} + func (t *toolRefSet) HasTool(toolID string) bool { for _, ref := range t.set { if ref.ToolID == toolID { diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 0d2a5cc0..d9d59837 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -33,11 +33,15 @@ const ( ToolTypeAgent = ToolType("agent") ToolTypeOutput = ToolType("output") ToolTypeInput = ToolType("input") - ToolTypeAssistant = ToolType("assistant") ToolTypeTool = ToolType("tool") ToolTypeCredential = ToolType("credential") - ToolTypeProvider = ToolType("provider") ToolTypeDefault = ToolType("") + + // The following types logically exist but have no real code reference. These are kept + // here just so that we have a comprehensive list + + ToolTypeAssistant = ToolType("assistant") + ToolTypeProvider = ToolType("provider") ) type ErrToolNotFound struct { @@ -140,6 +144,28 @@ type Parameters struct { Type ToolType `json:"type,omitempty"` } +func (p Parameters) allExports() []string { + return slices.Concat( + p.ExportContext, + p.Export, + p.ExportCredentials, + p.ExportInputFilters, + p.ExportOutputFilters, + ) +} + +func (p Parameters) allReferences() []string { + return slices.Concat( + p.GlobalTools, + p.Tools, + p.Context, + p.Agents, + p.Credentials, + p.InputFilters, + p.OutputFilters, + ) +} + func (p Parameters) ToolRefNames() []string { return slices.Concat( p.Tools, @@ -335,39 +361,6 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str return originalName, alias, args, nil } -func (t Tool) GetAgents(prg Program) (result []ToolReference, _ error) { - toolRefs, err := t.GetToolRefsFromNames(t.Agents) - if err != nil { - return nil, err - } - - genericToolRefs, err := t.getCompletionToolRefs(prg, nil, ToolTypeAgent) - if err != nil { - return nil, err - } - - toolRefs = append(toolRefs, genericToolRefs...) - - // Agent Tool refs must be named - for i, toolRef := range toolRefs { - if toolRef.Named != "" { - continue - } - tool := prg.ToolSet[toolRef.ToolID] - name := tool.Name - if name == "" { - name = toolRef.Reference - } - normed := ToolNormalizer(name) - if trimmed := strings.TrimSuffix(strings.TrimSuffix(normed, "Agent"), "Assistant"); trimmed != "" { - normed = trimmed - } - toolRefs[i].Named = normed - } - - return toolRefs, nil -} - func (t Tool) GetToolRefsFromNames(names []string) (result []ToolReference, _ error) { for _, toolName := range names { toolRefs, ok := t.ToolMapping[toolName] @@ -507,293 +500,185 @@ func (t ToolDef) String() string { return buf.String() } -func (t Tool) getExportedContext(prg Program) ([]ToolReference, error) { - result := &toolRefSet{} - - exportRefs, err := t.GetToolRefsFromNames(t.ExportContext) - if err != nil { - return nil, err - } - - for _, exportRef := range exportRefs { - result.Add(exportRef) - - tool := prg.ToolSet[exportRef.ToolID] - result.AddAll(tool.getExportedContext(prg)) - } - - return result.List() -} - -func (t Tool) getExportedTools(prg Program) ([]ToolReference, error) { - result := &toolRefSet{} - - exportRefs, err := t.GetToolRefsFromNames(t.Export) - if err != nil { - return nil, err - } - - for _, exportRef := range exportRefs { - result.Add(exportRef) - result.AddAll(prg.ToolSet[exportRef.ToolID].getExportedTools(prg)) - } - - return result.List() -} - -// GetContextTools returns all tools that are in the context of the tool including all the -// contexts that are exported by the context tools. This will recurse all exports. -func (t Tool) GetContextTools(prg Program) ([]ToolReference, error) { - result := &toolRefSet{} - result.AddAll(t.getDirectContextToolRefs(prg)) - - contextRefs, err := t.getCompletionToolRefs(prg, nil, ToolTypeContext) - if err != nil { - return nil, err - } - - for _, contextRef := range contextRefs { - result.AddAll(prg.ToolSet[contextRef.ToolID].getExportedContext(prg)) - result.Add(contextRef) - } - - exportOnlyTools, err := t.getCompletionToolRefs(prg, nil, ToolTypeDefault, ToolTypeContext) - if err != nil { - return nil, err - } +func (t Tool) GetNextAgentGroup(prg *Program, agentGroup []ToolReference, toolID string) (result []ToolReference, _ error) { + newAgentGroup := toolRefSet{} + newAgentGroup.AddAll(t.GetToolsByType(prg, ToolTypeAgent)) - for _, contextRef := range exportOnlyTools { - result.AddAll(prg.ToolSet[contextRef.ToolID].getExportedContext(prg)) + if newAgentGroup.HasTool(toolID) { + // Join new agent group + return newAgentGroup.List() } - return result.List() + return agentGroup, nil } -// GetContextTools returns all tools that are in the context of the tool including all the -// contexts that are exported by the context tools. This will recurse all exports. -func (t Tool) getDirectContextToolRefs(prg Program) ([]ToolReference, error) { - result := &toolRefSet{} - - contextRefs, err := t.GetToolRefsFromNames(t.Context) +func (t Tool) getAgents(prg *Program) (result []ToolReference, _ error) { + toolRefs, err := t.GetToolRefsFromNames(t.Agents) if err != nil { return nil, err } - for _, contextRef := range contextRefs { - result.AddAll(prg.ToolSet[contextRef.ToolID].getExportedContext(prg)) - result.Add(contextRef) + // Agent Tool refs must be named + for i, toolRef := range toolRefs { + if toolRef.Named != "" { + continue + } + tool := prg.ToolSet[toolRef.ToolID] + name := tool.Name + if name == "" { + name = toolRef.Reference + } + normed := ToolNormalizer(name) + if trimmed := strings.TrimSuffix(strings.TrimSuffix(normed, "Agent"), "Assistant"); trimmed != "" { + normed = trimmed + } + toolRefs[i].Named = normed } - return result.List() + return toolRefs, nil } -func (t Tool) GetOutputFilterTools(program Program) ([]ToolReference, error) { - result := &toolRefSet{} - - outputFilterRefs, err := t.GetToolRefsFromNames(t.OutputFilters) - if err != nil { - return nil, err - } - - for _, outputFilterRef := range outputFilterRefs { - result.Add(outputFilterRef) - } - - result.AddAll(t.getCompletionToolRefs(program, nil, ToolTypeOutput)) - - contextRefs, err := t.getDirectContextToolRefs(program) - if err != nil { - return nil, err +func (t Tool) GetToolsByType(prg *Program, toolType ToolType) ([]ToolReference, error) { + if toolType == ToolTypeAgent { + // Agents are special, they can only be sourced from direct references and not the generic 'tool:' or shared by references + return t.getAgents(prg) } - for _, contextRef := range contextRefs { - contextTool := program.ToolSet[contextRef.ToolID] - result.AddAll(contextTool.GetToolRefsFromNames(contextTool.ExportOutputFilters)) - } - - return result.List() -} - -func (t Tool) GetInputFilterTools(program Program) ([]ToolReference, error) { - result := &toolRefSet{} + toolSet := &toolRefSet{} - inputFilterRefs, err := t.GetToolRefsFromNames(t.InputFilters) - if err != nil { - return nil, err - } + var ( + directRefs []string + toolsListFilterType = []ToolType{toolType} + ) - for _, inputFilterRef := range inputFilterRefs { - result.Add(inputFilterRef) + switch toolType { + case ToolTypeContext: + directRefs = t.Context + case ToolTypeOutput: + directRefs = t.OutputFilters + case ToolTypeInput: + directRefs = t.InputFilters + case ToolTypeTool: + toolsListFilterType = append(toolsListFilterType, ToolTypeDefault, ToolTypeAgent) + case ToolTypeCredential: + directRefs = t.Credentials + default: + return nil, fmt.Errorf("unknown tool type %v", toolType) } - result.AddAll(t.getCompletionToolRefs(program, nil, ToolTypeInput)) + toolSet.AddAll(t.GetToolRefsFromNames(directRefs)) - contextRefs, err := t.getDirectContextToolRefs(program) + toolRefs, err := t.GetToolRefsFromNames(t.Tools) if err != nil { return nil, err } - for _, contextRef := range contextRefs { - contextTool := program.ToolSet[contextRef.ToolID] - result.AddAll(contextTool.GetToolRefsFromNames(contextTool.ExportInputFilters)) - } - - return result.List() -} - -func (t Tool) GetNextAgentGroup(prg Program, agentGroup []ToolReference, toolID string) (result []ToolReference, _ error) { - newAgentGroup := toolRefSet{} - if err := t.addAgents(prg, &newAgentGroup); err != nil { - return nil, err - } - - if newAgentGroup.HasTool(toolID) { - // Join new agent group - return newAgentGroup.List() - } - - return agentGroup, nil -} - -func filterRefs(prg Program, refs []ToolReference, types ...ToolType) (result []ToolReference) { - for _, ref := range refs { - if slices.Contains(types, prg.ToolSet[ref.ToolID].Type) { - result = append(result, ref) + for _, toolRef := range toolRefs { + tool, ok := prg.ToolSet[toolRef.ToolID] + if !ok { + continue + } + if slices.Contains(toolsListFilterType, tool.Type) { + toolSet.Add(toolRef) } - } - return -} - -func (t Tool) GetCompletionTools(prg Program, agentGroup ...ToolReference) (result []CompletionTool, err error) { - toolSet := &toolRefSet{} - toolSet.AddAll(t.getCompletionToolRefs(prg, agentGroup, ToolTypeDefault, ToolTypeTool)) - - if err := t.addAgents(prg, toolSet); err != nil { - return nil, err } - refs, err := toolSet.List() + exportSources, err := t.getExportSources(prg) if err != nil { return nil, err } - return toolRefsToCompletionTools(refs, prg), nil -} - -func (t Tool) addAgents(prg Program, result *toolRefSet) error { - subToolRefs, err := t.GetAgents(prg) - if err != nil { - return err - } - - for _, subToolRef := range subToolRefs { - // don't add yourself - if subToolRef.ToolID != t.ID { - // Add the tool itself and no exports - result.Add(subToolRef) + for _, exportSource := range exportSources { + var ( + tool = prg.ToolSet[exportSource.ToolID] + exportRefs []string + ) + + switch toolType { + case ToolTypeContext: + exportRefs = tool.ExportContext + case ToolTypeOutput: + exportRefs = tool.ExportOutputFilters + case ToolTypeInput: + exportRefs = tool.ExportInputFilters + case ToolTypeTool: + exportRefs = tool.Export + case ToolTypeCredential: + exportRefs = tool.ExportCredentials + default: + return nil, fmt.Errorf("unknown tool type %v", toolType) } + toolSet.AddAll(tool.GetToolRefsFromNames(exportRefs)) } - return nil + return toolSet.List() } -func (t Tool) addReferencedTools(prg Program, result *toolRefSet) error { - subToolRefs, err := t.GetToolRefsFromNames(t.Parameters.Tools) +func (t Tool) addExportsRecursively(prg *Program, toolSet *toolRefSet) error { + toolRefs, err := t.GetToolRefsFromNames(t.allExports()) if err != nil { return err } - for _, subToolRef := range subToolRefs { - // Add the tool - result.Add(subToolRef) + for _, toolRef := range toolRefs { + if toolSet.Contains(toolRef) { + continue + } - // Get all tools exports - result.AddAll(prg.ToolSet[subToolRef.ToolID].getExportedTools(prg)) + toolSet.Add(toolRef) + if err := prg.ToolSet[toolRef.ToolID].addExportsRecursively(prg, toolSet); err != nil { + return err + } } return nil } -func (t Tool) addContextExportedTools(prg Program, result *toolRefSet) error { - contextTools, err := t.getDirectContextToolRefs(prg) +func (t Tool) getExportSources(prg *Program) ([]ToolReference, error) { + // We start first with all references from this tool. This gives us the + // initial set of export sources. + // Then all tools in the export sources in the set we look for exports of those tools recursively. + // So a share of a share of a share should be added. + + toolSet := toolRefSet{} + toolRefs, err := t.GetToolRefsFromNames(t.allReferences()) if err != nil { - return err + return nil, err } - for _, contextTool := range contextTools { - result.AddAll(prg.ToolSet[contextTool.ToolID].getExportedTools(prg)) + for _, toolRef := range toolRefs { + if err := prg.ToolSet[toolRef.ToolID].addExportsRecursively(prg, &toolSet); err != nil { + return nil, err + } + toolSet.Add(toolRef) } - return nil + return toolSet.List() } -func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference, types ...ToolType) ([]ToolReference, error) { - if len(types) == 0 { - types = []ToolType{ToolTypeDefault, ToolTypeTool} - } - - result := toolRefSet{} +func (t Tool) GetChatCompletionTools(prg Program, agentGroup ...ToolReference) (result []ChatCompletionTool, err error) { + toolSet := &toolRefSet{} + toolSet.AddAll(t.GetToolsByType(&prg, ToolTypeTool)) + toolSet.AddAll(t.GetToolsByType(&prg, ToolTypeAgent)) if t.Chat { for _, agent := range agentGroup { // don't add yourself if agent.ToolID != t.ID { - result.Add(agent) + toolSet.Add(agent) } } } - if err := t.addReferencedTools(prg, &result); err != nil { - return nil, err - } - - if err := t.addContextExportedTools(prg, &result); err != nil { - return nil, err - } - - refs, err := result.List() - return filterRefs(prg, refs, types...), err -} - -func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]ToolReference, error) { - result := toolRefSet{} - - result.AddAll(t.GetToolRefsFromNames(t.Credentials)) - - result.AddAll(t.getCompletionToolRefs(prg, nil, ToolTypeCredential)) - - toolRefs, err := result.List() - if err != nil { - return nil, err - } - for _, toolRef := range toolRefs { - referencedTool := prg.ToolSet[toolRef.ToolID] - result.AddAll(referencedTool.GetToolRefsFromNames(referencedTool.ExportCredentials)) - } - - toolRefs, err = t.getCompletionToolRefs(prg, agentGroup) - if err != nil { - return nil, err - } - for _, toolRef := range toolRefs { - referencedTool := prg.ToolSet[toolRef.ToolID] - result.AddAll(referencedTool.GetToolRefsFromNames(referencedTool.ExportCredentials)) - } - - contextToolRefs, err := t.GetContextTools(prg) + refs, err := toolSet.List() if err != nil { return nil, err } - for _, contextToolRef := range contextToolRefs { - contextTool := prg.ToolSet[contextToolRef.ToolID] - result.AddAll(contextTool.GetToolRefsFromNames(contextTool.ExportCredentials)) - } - - return result.List() + return toolRefsToCompletionTools(refs, prg), nil } -func toolRefsToCompletionTools(completionTools []ToolReference, prg Program) (result []CompletionTool) { +func toolRefsToCompletionTools(completionTools []ToolReference, prg Program) (result []ChatCompletionTool) { toolNames := map[string]struct{}{} for _, subToolRef := range completionTools { @@ -814,7 +699,7 @@ func toolRefsToCompletionTools(completionTools []ToolReference, prg Program) (re if subTool.Instructions == "" { log.Debugf("Skipping zero instruction tool %s (%s)", subToolName, subTool.ID) } else { - result = append(result, CompletionTool{ + result = append(result, ChatCompletionTool{ Function: CompletionFunctionDefinition{ ToolID: subTool.ID, Name: PickToolName(subToolName, toolNames),