Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/context aware agent #1235

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/chatty-melons-mix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"llamaindex": minor
"docs": minor
---

Implement context-aware agent functionality
63 changes: 63 additions & 0 deletions apps/docs/docs/examples/context_aware_agent.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
---
sidebar_position: 14
---

# Context-Aware Agent

The Context-Aware Agent enhances the capabilities of standard LLM agents by incorporating relevant context from a retriever for each query. This allows the agent to provide more informed and specific responses based on the available information.

## Usage

Here's a simple example of how to use the Context-Aware Agent:

```typescript
import {
Document,
VectorStoreIndex,
OpenAIContextAwareAgent,
OpenAI,
} from "llamaindex";

async function createContextAwareAgent() {
// Create and index some documents
const documents = [
new Document({
text: "LlamaIndex is a data framework for LLM applications.",
id_: "doc1",
}),
new Document({
text: "The Eiffel Tower is located in Paris, France.",
id_: "doc2",
}),
];

const index = await VectorStoreIndex.fromDocuments(documents);
const retriever = index.asRetriever({ similarityTopK: 1 });

// Create the Context-Aware Agent
const agent = new OpenAIContextAwareAgent({
llm: new OpenAI({ model: "gpt-3.5-turbo" }),
contextRetriever: retriever,
});

// Use the agent to answer queries
const response = await agent.chat({
message: "What is LlamaIndex used for?",
});

console.log("Agent Response:", response.response);
}

createContextAwareAgent().catch(console.error);
```

In this example, the Context-Aware Agent uses the retriever to fetch relevant context for each query, allowing it to provide more accurate and informed responses based on the indexed documents.

## Key Components

- `contextRetriever`: A retriever (e.g., from a VectorStoreIndex) that fetches relevant documents or passages for each query.

## Available Context-Aware Agents

- `OpenAIContextAwareAgent`: A context-aware agent using OpenAI's models.
- `AnthropicContextAwareAgent`: A context-aware agent using Anthropic's models.
14 changes: 14 additions & 0 deletions packages/llamaindex/src/agent/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import type {
EngineResponse,
} from "../index.edge.js";
import { Anthropic } from "../llm/anthropic.js";
import {
withContextAwareness,
type ContextAwareConfig,
} from "./contextAwareMixin.js";
import { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js";

export type AnthropicAgentParams = LLMAgentParams;
Expand Down Expand Up @@ -36,3 +40,13 @@ export class AnthropicAgent extends LLMAgent {
return super.chat(params);
}
}

export class AnthropicContextAwareAgent extends (withContextAwareness(
AnthropicAgent,
) as new (
params: AnthropicAgentParams & ContextAwareConfig,
) => AnthropicAgent) {
constructor(params: AnthropicAgentParams & ContextAwareConfig) {
super(params);
}
}
80 changes: 80 additions & 0 deletions packages/llamaindex/src/agent/contextAwareMixin.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import type { ChatMessage, LLM, MessageContent } from "@llamaindex/core/llms";
import type { NodeWithScore } from "@llamaindex/core/schema";
import { EngineResponse, MetadataMode } from "@llamaindex/core/schema";
import type {
ChatEngineParamsNonStreaming,
ChatEngineParamsStreaming,
} from "../engines/chat/index.js";
import type { BaseRetriever } from "../Retriever.js";
import type { AgentRunner } from "./base.js";

type Constructor<T = {}> = new (...args: any[]) => T;

export interface ContextAwareConfig {
contextRetriever: BaseRetriever;
}

export interface ContextAwareAgentRunner extends AgentRunner<LLM> {
contextRetriever: BaseRetriever;
retrievedContext: string | null;
retrieveContext(query: MessageContent): Promise<string>;
injectContext(context: string): Promise<void>;
}

/**
* ContextAwareAgentRunner enhances the base AgentRunner with the ability to retrieve and inject relevant context
* for each query. This allows the agent to access and utilize appropriate information from a given index or retriever,
* providing more informed and context-specific responses to user queries.
*/
export function withContextAwareness<T extends Constructor<AgentRunner<LLM>>>(
BaseClass: T,
) {
return class extends BaseClass implements ContextAwareAgentRunner {
contextRetriever: BaseRetriever;
retrievedContext: string | null = null;

constructor(...args: any[]) {
super(...args);
const config = args[args.length - 1] as ContextAwareConfig;
this.contextRetriever = config.contextRetriever;
}

createStore(): object {
return {};
}

async retrieveContext(query: MessageContent): Promise<string> {
const nodes = await this.contextRetriever.retrieve({ query });
return nodes
.map((node: NodeWithScore) => node.node.getContent(MetadataMode.NONE))
.join("\n");
}

async injectContext(context: string): Promise<void> {
const chatHistory = (this as any).chatHistory as ChatMessage[];
const systemMessage = chatHistory.find((msg) => msg.role === "system");
if (systemMessage) {
systemMessage.content = `${context}\n\n${systemMessage.content}`;
} else {
chatHistory.unshift({ role: "system", content: context });
}
}

async chat(params: ChatEngineParamsNonStreaming): Promise<EngineResponse>;
async chat(
params: ChatEngineParamsStreaming,
): Promise<ReadableStream<EngineResponse>>;
async chat(
params: ChatEngineParamsNonStreaming | ChatEngineParamsStreaming,
): Promise<EngineResponse | ReadableStream<EngineResponse>> {
const context = await this.retrieveContext(params.message);
await this.injectContext(context);

if ("stream" in params && params.stream === true) {
return super.chat(params);
} else {
return super.chat(params as ChatEngineParamsNonStreaming);
}
}
};
}
6 changes: 6 additions & 0 deletions packages/llamaindex/src/agent/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@ export {
type AnthropicAgentParams,
} from "./anthropic.js";
export { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js";
export {
withContextAwareness,
type ContextAwareAgentRunner,
type ContextAwareConfig,
} from "./contextAwareMixin.js";
export { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js";
export {
OpenAIAgent,
OpenAIAgentWorker,
OpenAIContextAwareAgent,
type OpenAIAgentParams,
} from "./openai.js";
export {
Expand Down
12 changes: 12 additions & 0 deletions packages/llamaindex/src/agent/openai.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { OpenAI } from "@llamaindex/openai";
import { Settings } from "../Settings.js";
import {
withContextAwareness,
type ContextAwareConfig,
} from "./contextAwareMixin.js";
import { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js";

// This is likely not necessary anymore but leaving it here just incase it's in use elsewhere
Expand All @@ -21,3 +25,11 @@ export class OpenAIAgent extends LLMAgent {
});
}
}

export class OpenAIContextAwareAgent extends (withContextAwareness(
OpenAIAgent,
) as new (params: OpenAIAgentParams & ContextAwareConfig) => OpenAIAgent) {
constructor(params: OpenAIAgentParams & ContextAwareConfig) {
super(params);
}
}