diff --git a/.changeset/chatty-melons-mix.md b/.changeset/chatty-melons-mix.md new file mode 100644 index 0000000000..ef4dc4ff48 --- /dev/null +++ b/.changeset/chatty-melons-mix.md @@ -0,0 +1,6 @@ +--- +"llamaindex": minor +"docs": minor +--- + +Implement context-aware agent functionality diff --git a/apps/docs/docs/examples/context_aware_agent.md b/apps/docs/docs/examples/context_aware_agent.md new file mode 100644 index 0000000000..2a5fe4f687 --- /dev/null +++ b/apps/docs/docs/examples/context_aware_agent.md @@ -0,0 +1,63 @@ +--- +sidebar_position: 14 +--- + +# Context-Aware Agent + +The Context-Aware Agent enhances the capabilities of standard LLM agents by incorporating relevant context from a retriever for each query. This allows the agent to provide more informed and specific responses based on the available information. + +## Usage + +Here's a simple example of how to use the Context-Aware Agent: + +```typescript +import { + Document, + VectorStoreIndex, + OpenAIContextAwareAgent, + OpenAI, +} from "llamaindex"; + +async function createContextAwareAgent() { + // Create and index some documents + const documents = [ + new Document({ + text: "LlamaIndex is a data framework for LLM applications.", + id_: "doc1", + }), + new Document({ + text: "The Eiffel Tower is located in Paris, France.", + id_: "doc2", + }), + ]; + + const index = await VectorStoreIndex.fromDocuments(documents); + const retriever = index.asRetriever({ similarityTopK: 1 }); + + // Create the Context-Aware Agent + const agent = new OpenAIContextAwareAgent({ + llm: new OpenAI({ model: "gpt-3.5-turbo" }), + contextRetriever: retriever, + }); + + // Use the agent to answer queries + const response = await agent.chat({ + message: "What is LlamaIndex used for?", + }); + + console.log("Agent Response:", response.response); +} + +createContextAwareAgent().catch(console.error); +``` + +In this example, the Context-Aware Agent uses the retriever to fetch relevant context for each query, allowing it to provide more accurate and informed responses based on the indexed documents. + +## Key Components + +- `contextRetriever`: A retriever (e.g., from a VectorStoreIndex) that fetches relevant documents or passages for each query. + +## Available Context-Aware Agents + +- `OpenAIContextAwareAgent`: A context-aware agent using OpenAI's models. +- `AnthropicContextAwareAgent`: A context-aware agent using Anthropic's models. diff --git a/packages/llamaindex/src/agent/anthropic.ts b/packages/llamaindex/src/agent/anthropic.ts index 8f17b360d3..67ae76d690 100644 --- a/packages/llamaindex/src/agent/anthropic.ts +++ b/packages/llamaindex/src/agent/anthropic.ts @@ -5,6 +5,10 @@ import type { EngineResponse, } from "../index.edge.js"; import { Anthropic } from "../llm/anthropic.js"; +import { + withContextAwareness, + type ContextAwareConfig, +} from "./contextAwareMixin.js"; import { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js"; export type AnthropicAgentParams = LLMAgentParams; @@ -36,3 +40,13 @@ export class AnthropicAgent extends LLMAgent { return super.chat(params); } } + +export class AnthropicContextAwareAgent extends (withContextAwareness( + AnthropicAgent, +) as new ( + params: AnthropicAgentParams & ContextAwareConfig, +) => AnthropicAgent) { + constructor(params: AnthropicAgentParams & ContextAwareConfig) { + super(params); + } +} diff --git a/packages/llamaindex/src/agent/contextAwareMixin.ts b/packages/llamaindex/src/agent/contextAwareMixin.ts new file mode 100644 index 0000000000..7569d73f00 --- /dev/null +++ b/packages/llamaindex/src/agent/contextAwareMixin.ts @@ -0,0 +1,80 @@ +import type { ChatMessage, LLM, MessageContent } from "@llamaindex/core/llms"; +import type { NodeWithScore } from "@llamaindex/core/schema"; +import { EngineResponse, MetadataMode } from "@llamaindex/core/schema"; +import type { + ChatEngineParamsNonStreaming, + ChatEngineParamsStreaming, +} from "../engines/chat/index.js"; +import type { BaseRetriever } from "../Retriever.js"; +import type { AgentRunner } from "./base.js"; + +type Constructor = new (...args: any[]) => T; + +export interface ContextAwareConfig { + contextRetriever: BaseRetriever; +} + +export interface ContextAwareAgentRunner extends AgentRunner { + contextRetriever: BaseRetriever; + retrievedContext: string | null; + retrieveContext(query: MessageContent): Promise; + injectContext(context: string): Promise; +} + +/** + * ContextAwareAgentRunner enhances the base AgentRunner with the ability to retrieve and inject relevant context + * for each query. This allows the agent to access and utilize appropriate information from a given index or retriever, + * providing more informed and context-specific responses to user queries. + */ +export function withContextAwareness>>( + BaseClass: T, +) { + return class extends BaseClass implements ContextAwareAgentRunner { + contextRetriever: BaseRetriever; + retrievedContext: string | null = null; + + constructor(...args: any[]) { + super(...args); + const config = args[args.length - 1] as ContextAwareConfig; + this.contextRetriever = config.contextRetriever; + } + + createStore(): object { + return {}; + } + + async retrieveContext(query: MessageContent): Promise { + const nodes = await this.contextRetriever.retrieve({ query }); + return nodes + .map((node: NodeWithScore) => node.node.getContent(MetadataMode.NONE)) + .join("\n"); + } + + async injectContext(context: string): Promise { + const chatHistory = (this as any).chatHistory as ChatMessage[]; + const systemMessage = chatHistory.find((msg) => msg.role === "system"); + if (systemMessage) { + systemMessage.content = `${context}\n\n${systemMessage.content}`; + } else { + chatHistory.unshift({ role: "system", content: context }); + } + } + + async chat(params: ChatEngineParamsNonStreaming): Promise; + async chat( + params: ChatEngineParamsStreaming, + ): Promise>; + async chat( + params: ChatEngineParamsNonStreaming | ChatEngineParamsStreaming, + ): Promise> { + const context = await this.retrieveContext(params.message); + await this.injectContext(context); + + if ("stream" in params && params.stream === true) { + return super.chat(params); + } else { + return super.chat(params as ChatEngineParamsNonStreaming); + } + } + }; +} diff --git a/packages/llamaindex/src/agent/index.ts b/packages/llamaindex/src/agent/index.ts index feda11bd40..6507b72925 100644 --- a/packages/llamaindex/src/agent/index.ts +++ b/packages/llamaindex/src/agent/index.ts @@ -4,10 +4,16 @@ export { type AnthropicAgentParams, } from "./anthropic.js"; export { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js"; +export { + withContextAwareness, + type ContextAwareAgentRunner, + type ContextAwareConfig, +} from "./contextAwareMixin.js"; export { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js"; export { OpenAIAgent, OpenAIAgentWorker, + OpenAIContextAwareAgent, type OpenAIAgentParams, } from "./openai.js"; export { diff --git a/packages/llamaindex/src/agent/openai.ts b/packages/llamaindex/src/agent/openai.ts index 8cd84bc675..3659ccad67 100644 --- a/packages/llamaindex/src/agent/openai.ts +++ b/packages/llamaindex/src/agent/openai.ts @@ -1,5 +1,9 @@ import { OpenAI } from "@llamaindex/openai"; import { Settings } from "../Settings.js"; +import { + withContextAwareness, + type ContextAwareConfig, +} from "./contextAwareMixin.js"; import { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js"; // This is likely not necessary anymore but leaving it here just incase it's in use elsewhere @@ -21,3 +25,11 @@ export class OpenAIAgent extends LLMAgent { }); } } + +export class OpenAIContextAwareAgent extends (withContextAwareness( + OpenAIAgent, +) as new (params: OpenAIAgentParams & ContextAwareConfig) => OpenAIAgent) { + constructor(params: OpenAIAgentParams & ContextAwareConfig) { + super(params); + } +}