.
├── cmd
│ └── main.go
├── internal
│ ├── agents
│ │ ├── agent_manager.go
│ │ ├── agent_utils.go
│ │ ├── agent_query.go
│ │ └── agent_initializer.go
│ ├── config
│ │ └── config.go
│ ├── handlers
│ │ ├── agent_handler.go
│ │ └── memory_handler.go
│ ├── logger
│ │ └── logger.go
│ ├── middleware
│ │ └── middleware.go
│ └── router
│ └── router.go
├── go.mod
├── go.sum
└── .env
go mod init github.com/blog/conversational-agent
go get github.com/labstack/echo/v4
go get github.com/tmc/langchaingo
go get github.com/tmc/langchaingo/llms/openai
go get github.com/tmc/langchaingo/vectorstores/weaviate
go get github.com/tmc/langchaingo/memory
go get github.com/tmc/langchaingo/chains
go get github.com/tmc/langchaingo/embeddings
go get github.com/tmc/langchaingo/prompts
go get github.com/spf13/viper
go get github.com/rs/zerolog
OPENAI_API_KEY="open_ai_key"
OPENAI_MODEL="gpt-4o-mini"
WEAVIATE_HOST="weaviate_hostname.gcp.weaviate.cloud"
WEAVIATE_API_KEY="weaviate_api_key"
WEAVIATE_INDEX_NAME=AgentMemory
DEBUG=true
// internal/config/config.go
package config
import (
"log"
"github.com/spf13/viper"
)
type Config struct {
OpenAIAPIKey string `mapstructure:"OPENAI_API_KEY"`
OpenAIModel string `mapstructure:"OPENAI_MODEL"`
WeaviateHost string `mapstructure:"WEAVIATE_HOST"`
WeaviateAPIKey string `mapstructure:"WEAVIATE_API_KEY"`
WeaviateIndexName string `mapstructure:"WEAVIATE_INDEX_NAME"`
Debug string `mapstructure:"DEBUG"`
}
// LoadConfig loads environment variables into the Config struct
func LoadConfig() (*Config, error) {
viper.SetConfigName(".env")
viper.SetConfigType("env")
viper.AddConfigPath(".")
viper.AutomaticEnv()
// Load the config file
if err := viper.ReadInConfig(); err != nil {
log.Printf("Warning: No .env file found (%v), loading from environment variables only.", err)
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, err
}
return &config, nil
}
// internal/config/config.go
package config
import (
"log"
"github.com/spf13/viper"
)
type Config struct {
OpenAIAPIKey string `mapstructure:"OPENAI_API_KEY"`
OpenAIModel string `mapstructure:"OPENAI_MODEL"`
WeaviateHost string `mapstructure:"WEAVIATE_HOST"`
WeaviateAPIKey string `mapstructure:"WEAVIATE_API_KEY"`
WeaviateIndexName string `mapstructure:"WEAVIATE_INDEX_NAME"`
Debug string `mapstructure:"DEBUG"`
}
// LoadConfig loads environment variables into the Config struct
func LoadConfig() (*Config, error) {
viper.SetConfigName(".env")
viper.SetConfigType("env")
viper.AddConfigPath(".")
viper.AutomaticEnv()
// Load the config file
if err := viper.ReadInConfig(); err != nil {
log.Printf("Warning: No .env file found (%v), loading from environment variables only.", err)
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, err
}
return &config, nil
}
package middleware
import (
"github.com/Ingenimax/starops-infra-agent/internal/logger"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)
// LoggingMiddleware logs HTTP requests and responses
func LoggingMiddleware() echo.MiddlewareFunc {
log := logger.GetLogger()
return middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
LogURI: true,
LogMethod: true,
LogStatus: true,
LogLatency: true,
LogError: true,
LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
log.Info().
Str("method", v.Method).
Str("uri", v.URI).
Int("status", v.Status).
Dur("latency", v.Latency).
Msg("HTTP request processed")
return nil
},
})
}
// /internal/agents/agent_initializer.go
package agents
import (
"context"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/memory"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/vectorstores/weaviate"
)
// InitializeLLM sets up the OpenAI LLM.
func InitializeLLM(openAIApiKey, openAIModel string) (*openai.LLM, error) {
return openai.New(
openai.WithToken(openAIApiKey),
openai.WithModel(openAIModel),
)
}
// InitializeMemory sets up conversation buffer memory.
func InitializeMemory() *memory.ConversationBuffer {
return memory.NewConversationBuffer(
memory.WithMemoryKey("history"),
memory.WithReturnMessages(false), // Return buffer as a string
memory.WithHumanPrefix("You"),
memory.WithAIPrefix("StarOps"),
)
}
func InitializeChain(llm llms.Model, memory *memory.ConversationBuffer) (*chains.LLMChain, error) {
// Define the prompt template with a single 'context' field
prompt := prompts.NewPromptTemplate(`
{{.context}}
AI Response:`,
[]string{"context"}, // Explicitly define 'context' as the input key
)
// Create the LLMChain with the prompt and memory
chain := chains.NewLLMChain(llm, prompt)
chain.Memory = memory
return chain, nil
}
// InitializeEmbedder sets up the OpenAI embedder.
func InitializeEmbedder(llm *openai.LLM) (*embeddings.EmbedderImpl, error) {
// Wrap LLM's CreateEmbedding method in EmbedderClientFunc
client := embeddings.EmbedderClientFunc(func(ctx context.Context, texts []string) ([][]float32, error) {
return llm.CreateEmbedding(ctx, texts)
})
return embeddings.NewEmbedder(client)
}
// InitializeVectorStore sets up the Weaviate vector store.
func InitializeVectorStore(
weaviateHost, weaviateApiKey, weaviateIndex string,
embedder *embeddings.EmbedderImpl,
) (weaviate.Store, error) {
return weaviate.New(
weaviate.WithHost(weaviateHost),
weaviate.WithScheme("https"),
weaviate.WithAPIKey(weaviateApiKey),
weaviate.WithIndexName(weaviateIndex),
weaviate.WithTextKey("text"),
weaviate.WithEmbedder(embedder),
)
}
// /internal/agents/agent_manager.go
package agents
import (
"fmt"
"sync"
"github.com/blog/conversational-agent/internal/logger"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/memory"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores/weaviate"
)
type AgentManager struct {
LLM *openai.LLM
VectorStore weaviate.Store
AgentMemory map[string]*memory.ConversationBuffer
ConversationalChain *chains.ConversationalRetrievalQA
WeaviateIndex string
LLMChain *chains.LLMChain
messageBuffer []schema.Document
bufferMutex sync.Mutex
maxBufferMessages int
memoryMutex sync.Mutex
}
func NewAgentManager(
openAIApiKey, openAIModel, weaviateHost, weaviateApiKey, weaviateIndex string,
maxBufferMessages int,
) (*AgentManager, error) {
log := logger.GetLogger()
log.Info().Msgf("Initializing OpenAI LLM with Model: %s", openAIModel)
llm, err := InitializeLLM(openAIApiKey, openAIModel)
if err != nil {
return nil, fmt.Errorf("failed to initialize OpenAI LLM: %w", err)
}
log.Info().Msg("OpenAI LLM initialized successfully")
log.Info().Msg("Initializing OpenAI Embedder...")
embedder, err := InitializeEmbedder(llm)
if err != nil {
return nil, fmt.Errorf("failed to initialize OpenAI embedder: %w", err)
}
log.Info().Msg("OpenAI Embedder initialized successfully")
log.Info().Msg("Initializing Weaviate vector store...")
vectorStore, err := InitializeVectorStore(weaviateHost, weaviateApiKey, weaviateIndex, embedder)
if err != nil {
return nil, fmt.Errorf("failed to initialize Weaviate vector store: %w", err)
}
log.Info().Msg("Weaviate vector store initialized successfully")
log.Info().Msg("Initializing ConversationBuffer memory...")
agentMemory := InitializeMemory()
log.Info().Msg("ConversationBuffer memory initialized successfully")
log.Info().Msg("Initializing Chain...")
chain, err := InitializeChain(llm, agentMemory)
if err != nil {
return nil, fmt.Errorf("failed to initialize chain: %w", err)
}
log.Info().Msg("Chain initialized successfully")
return &AgentManager{
LLM: llm,
VectorStore: vectorStore,
AgentMemory: make(map[string]*memory.ConversationBuffer),
LLMChain: chain,
WeaviateIndex: weaviateIndex,
messageBuffer: []schema.Document{},
maxBufferMessages: maxBufferMessages,
}, nil
}
// /internal/agents/agent_query.go
package agents
import (
"bytes"
"context"
"fmt"
"github.com/blog/conversational-agent/internal/logger"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/vectorstores"
)
// Query performs a similarity search and LLM chain call.
func (am *AgentManager) Query(
ctx context.Context,
userID, orgID, threadID, input string,
chunkCallback func([]byte),
) (string, error) {
log := logger.GetLogger()
// Retrieve memory and prepare for search
threadMemory := am.GetThreadMemory(threadID)
log.Debug().Msg("Performing similarity search in vector store...")
// Use org_id as namespace
namespace := orgID
log.Debug().Msgf("Using namespace: %s for similarity search", namespace)
// Perform similarity search in org namespace
log.Debug().Msgf("Performing similarity search in namespace: %s", namespace)
orgDocs, err := am.VectorStore.SimilaritySearch(ctx, input, 5, vectorstores.WithNameSpace(namespace))
if err != nil && err.Error() != "empty response" {
log.Error().Err(err).Msg("Failed to perform org-specific similarity search.")
return "", fmt.Errorf("org similarity search failed: %w", err)
}
// Search in default namespace
log.Debug().Msgf("Performing similarity search in default namespace")
defaultDocs, err := am.VectorStore.SimilaritySearch(ctx, input, 5)
if err != nil && err.Error() != "empty response" {
log.Error().Err(err).Msg("Failed to perform default similarity search.")
return "", fmt.Errorf("default similarity search failed: %w", err)
}
// Combine results
log.Debug().Msgf("Combining results")
similarDocs := append(orgDocs, defaultDocs...)
if len(similarDocs) == 0 {
log.Warn().Msg("No relevant documents found in either namespace. Proceeding with empty context.")
similarDocs = nil
}
// Log retrieved documents
log.Debug().Msgf("Retrieved documents for thread %s: %+v", threadID, similarDocs)
// Combine documents for context
var docContext bytes.Buffer
if len(similarDocs) > 0 {
for _, doc := range similarDocs {
docContext.WriteString(fmt.Sprintf("Document: %s\nMetadata: %+v\n", doc.PageContent, doc.Metadata))
}
} else {
docContext.WriteString("No relevant documents found.\n")
}
// Prepare LLM input context
history, _ := threadMemory.ChatHistory.Messages(ctx)
chainInputs := map[string]any{
"context": fmt.Sprintf(
"History:\n%s\n\nRelevant Documents:\n%s\n\nUser Input:\n%s",
formatMessages(history), docContext.String(), input,
),
}
log.Debug().Msgf("LLM context:\n%s", chainInputs["context"])
// Call LLM chain
chainOutputs, err := chains.Call(ctx, am.LLMChain, chainInputs, chains.WithStreamingFunc(
func(ctx context.Context, chunk []byte) error {
if chunkCallback != nil {
chunkCallback(chunk)
}
return nil
},
))
if err != nil {
log.Error().Err(err).Msg("Failed to execute LLMChain.")
return "", fmt.Errorf("failed to execute LLMChain: %w", err)
}
// Parse response and store in memory
fullResponse := chainOutputs["text"].(string)
log.Debug().Msgf("Response: %s", fullResponse)
threadMemory.SaveContext(ctx,
map[string]any{"input": input},
map[string]any{"response": fullResponse},
)
// Pass userID and orgID to addToBuffer
am.addToBuffer(threadID, input, fullResponse, userID, orgID)
return fullResponse, nil
}
// /internal/agents/agent_utils.go
package agents
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"github.com/blog/conversational-agent/internal/logger"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/memory"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores"
"github.com/weaviate/weaviate-go-client/v4/weaviate/filters"
)
// GetThreadMemory retrieves or initializes a ConversationBuffer for the given thread ID.
func (am *AgentManager) GetThreadMemory(threadID string) *memory.ConversationBuffer {
am.memoryMutex.Lock()
defer am.memoryMutex.Unlock()
// Check if the memory for the thread already exists
if mem, exists := am.AgentMemory[threadID]; exists {
return mem
}
// Initialize a new memory buffer
newMemory := memory.NewConversationBuffer(
memory.WithMemoryKey(fmt.Sprintf("thread:%s", threadID)),
memory.WithReturnMessages(false),
)
am.AgentMemory[threadID] = newMemory
return newMemory
}
// RetrieveMemory retrieves the chat history for a specific thread
func (am *AgentManager) RetrieveMemory(ctx context.Context, threadID string) ([]map[string]string, error) {
log := logger.GetLogger()
log.Debug().Msgf("Starting RetrieveMemory function for thread ID: %s", threadID)
// Get the specific memory for the thread
threadMemory := am.GetThreadMemory(threadID)
if threadMemory == nil {
log.Debug().Msgf("AgentMemory for thread ID '%s' is not initialized.", threadID)
return nil, fmt.Errorf("memory is not initialized for thread ID: %s", threadID)
}
// Retrieve messages from the chat history
messages, err := threadMemory.ChatHistory.Messages(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to retrieve messages from chat history.")
return nil, fmt.Errorf("failed to retrieve messages for thread_id %s: %w", threadID, err)
}
// Log retrieved messages
log.Debug().Msgf("Retrieved messages for thread ID %s: %+v", threadID, messages)
// Format the messages for output
formattedMessages := formatMessages(messages)
return formattedMessages, nil
}
// formatMessages formats the messages for output
func formatMessages(messages []llms.ChatMessage) []map[string]string {
var formattedMessages []map[string]string
for _, msg := range messages {
role := "user"
if msg.GetType() == llms.ChatMessageTypeAI {
role = "ai"
}
formattedMessages = append(formattedMessages, map[string]string{
"role": role,
"content": msg.GetContent(),
})
}
return formattedMessages
}
// addToBuffer adds the input and response to the message buffer
func (am *AgentManager) addToBuffer(threadID, input, response, userID, orgID string) {
am.bufferMutex.Lock()
defer am.bufferMutex.Unlock()
// Chunk the input and response for fine-grained storage
chunks := ChunkContent(input+"\n"+response, 300) // Adjust chunk size
for _, chunk := range chunks {
doc := schema.Document{
PageContent: chunk,
Metadata: map[string]any{
"thread_id": threadID,
"user_id": userID,
"org_id": orgID,
"timestamp": time.Now().Format(time.RFC3339),
"source": "conversation",
},
}
// Ensure uniqueness before adding
am.messageBuffer = append(am.messageBuffer, doc)
}
// Flush the buffer if it exceeds max size
if len(am.messageBuffer) >= am.maxBufferMessages {
go am.flushBufferToWeaviate(context.Background())
}
}
// flushBufferToWeaviate flushes the message buffer to the vector store
func (am *AgentManager) flushBufferToWeaviate(ctx context.Context) {
am.bufferMutex.Lock()
buffer := am.messageBuffer
am.messageBuffer = nil
am.bufferMutex.Unlock()
if len(buffer) == 0 {
return
}
// Deduplicate buffer entries before inserting into Weaviate
uniqueDocs := deduplicateDocuments(buffer)
log := logger.GetLogger()
log.Debug().Msgf("Batch inserting %d unique messages into Weaviate...", len(uniqueDocs))
// Assuming each document has org_id in its metadata
if len(uniqueDocs) > 0 {
orgID, ok := uniqueDocs[0].Metadata["org_id"].(string)
if !ok {
log.Error().Msg("Failed to retrieve org_id from document metadata.")
return
}
namespace := orgID
log.Debug().Msgf("Using namespace: %s for flushing buffer to Weaviate",
namespace)
_, err := am.VectorStore.AddDocuments(ctx, uniqueDocs, vectorstores.WithNameSpace(namespace))
if err != nil {
log.Error().Err(err).Msg("Failed to batch insert messages into Weaviate.")
return
}
}
log.Debug().Msg("Batch insertion to Weaviate completed successfully.")
}
// deduplicateDocuments removes duplicate documents from a slice based on PageContent.
func deduplicateDocuments(docs []schema.Document) []schema.Document {
seen := make(map[string]bool)
uniqueDocs := []schema.Document{}
for _, doc := range docs {
if _, exists := seen[doc.PageContent]; !exists {
seen[doc.PageContent] = true
uniqueDocs = append(uniqueDocs, doc)
}
}
return uniqueDocs
}
// SyncMemory syncs the message buffer to the vector store
func (am *AgentManager) SyncMemory(ctx context.Context) error {
log := logger.GetLogger()
am.bufferMutex.Lock()
defer am.bufferMutex.Unlock()
if len(am.messageBuffer) == 0 {
log.Debug().Msg("No messages to sync from buffer.")
return nil
}
log.Debug().Msgf("Syncing %d messages from buffer to Weaviate...", len(am.messageBuffer))
_, err := am.VectorStore.AddDocuments(ctx, am.messageBuffer)
if err != nil {
log.Error().Err(err).Msg("Failed to sync messages to Weaviate.")
return err
}
am.messageBuffer = nil
log.Debug().Msg("Memory synced successfully to Weaviate.")
return nil
}
// AddDocuments adds documents to the vector store
func (am *AgentManager) AddDocuments(ctx context.Context, docs []schema.Document, userID, orgID string) error {
log := logger.GetLogger()
// Add user_id to document metadata
for i := range docs {
docs[i].Metadata["user_id"] = userID
}
// Use org_id as namespace
namespace := orgID
log.Debug().Msgf("Using namespace: %s for adding documents", namespace)
// Retry mechanism for adding documents
maxRetries := 3
var err error
for attempt := 1; attempt <= maxRetries; attempt++ {
_, err = am.VectorStore.AddDocuments(ctx, docs, vectorstores.WithNameSpace(namespace))
if err != nil {
log.Warn().Err(err).Msgf("Attempt %d to add documents to vector store failed.", attempt)
if attempt < maxRetries {
continue
}
} else {
log.Info().Msgf("Documents added to vector store successfully on attempt %d.", attempt)
return nil
}
}
// Log final failure after retries
log.Error().Err(err).Msg("Failed to add documents to vector store after multiple attempts.")
return fmt.Errorf("failed to add documents to vector store: %w", err)
}
// AddDataset adds a dataset to the vector store
func (am *AgentManager) AddDataset(ctx context.Context, threadID, filePath string) error {
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
// Parse the dataset
var documents []schema.Document
if err := json.NewDecoder(file).Decode(&documents); err != nil {
return fmt.Errorf("failed to parse JSON dataset: %w", err)
}
// Retrieve the thread-specific memory
threadMemory := am.GetThreadMemory(threadID)
if threadMemory == nil {
return fmt.Errorf("failed to retrieve memory for thread ID: %s", threadID)
}
// Add each document to the memory
for _, doc := range documents {
err := threadMemory.ChatHistory.AddAIMessage(ctx, doc.PageContent)
if err != nil {
return fmt.Errorf("failed to add document to memory for thread ID %s: %w", threadID, err)
}
}
return nil
}
// Chunk content into smaller parts based on word count
func ChunkContent(content string, maxWords int) []string {
words := strings.Fields(content)
var chunks []string
for i := 0; i < len(words); i += maxWords {
end := i + maxWords
if end > len(words) {
end = len(words)
}
chunks = append(chunks, strings.Join(words[i:end], " "))
}
return chunks
}
// QueryMemoryByUserAndOrgID queries the memory for a specific user and organization
func (am *AgentManager) QueryMemoryByUserAndOrgID(ctx context.Context, userID, orgID string) ([]schema.Document, error) {
log := logger.GetLogger()
log.Debug().Msgf("Querying memory for user_id: %s and org_id: %s", userID, orgID)
// Build metadata filters using filters.Where
whereBuilder := filters.Where().
WithOperator(filters.And).
WithOperands([]*filters.WhereBuilder{
filters.Where().WithPath([]string{"user_id"}).WithOperator(filters.Equal).WithValueString(userID),
filters.Where().WithPath([]string{"org_id"}).WithOperator(filters.Equal).WithValueString(orgID),
})
// Execute MetadataSearch with the constructed filter
docs, err := am.VectorStore.MetadataSearch(ctx, 10, vectorstores.WithFilters(whereBuilder))
if err != nil {
log.Error().Err(err).Msg("Failed to query memory by user_id and org_id.")
return nil, fmt.Errorf("failed to query memory for user_id %s and org_id %s: %w", userID, orgID, err)
}
log.Debug().Msgf("Retrieved %d documents for user_id: %s and org_id: %s", len(docs), userID, orgID)
return docs, nil
}
// ConvertContentToString converts content to a string
func ConvertContentToString(content interface{}) (string, error) {
switch v := content.(type) {
case string:
return v, nil
case map[string]interface{}, []interface{}:
// Serialize JSON content to a string
jsonContent, err := json.Marshal(v)
if err != nil {
return "", fmt.Errorf("failed to serialize JSON content: %w", err)
}
return string(jsonContent), nil
default:
return "", fmt.Errorf("unsupported content type: %T", content)
}
}
// Convert metadata values
func ConvertMetadata(metadata map[string]string) map[string]any {
converted := make(map[string]any)
for k, v := range metadata {
converted[k] = v
}
return converted
}
// internal/handlers/query_handler.go
package handlers
import (
"fmt"
"net/http"
"github.com/blog/conversational-agent/internal/agents"
"github.com/labstack/echo/v4"
)
type AgentHandler struct {
AgentManager *agents.AgentManager
}
// QueryHandler handles the query request from the client.
func (h *AgentHandler) QueryHandler(c echo.Context) error {
type RequestBody struct {
Query string `json:"query"`
Stream bool `json:"stream"`
}
var req RequestBody
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request payload"})
}
// Retrieve query parameters
userID := c.Param("user_id")
orgID := c.Param("org_id")
threadID := c.Param("thread_id")
// Validate required parameters
if threadID == "" || userID == "" || orgID == "" {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "'thread_id', 'user_id', and 'org_id' are required query parameters",
})
}
if req.Stream {
// Streamed response
c.Response().Header().Set(echo.HeaderContentType, "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().WriteHeader(http.StatusOK)
_, err := h.AgentManager.Query(
c.Request().Context(),
userID,
orgID,
threadID,
req.Query,
func(chunk []byte) {
c.Response().Write([]byte(fmt.Sprintf("data: %s\n\n", chunk)))
c.Response().Flush()
},
)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
c.Response().Write([]byte("data: [DONE]\n\n"))
c.Response().Flush()
return nil
}
// Non-streamed response
response, err := h.AgentManager.Query(c.Request().Context(), userID, orgID, threadID, req.Query, nil)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"response": response})
}
// internal/handlers/memory_handler.go
package handlers
import (
"context"
"encoding/json"
"net/http"
"os"
"github.com/blog/conversational-agent/internal/agents"
"github.com/blog/conversational-agent/internal/logger"
"github.com/labstack/echo/v4"
"github.com/tmc/langchaingo/schema"
)
type ImportDatasetRequest struct {
FilePath string `json:"file_path"`
}
// GetMemoryHandler retrieves the chat history for a specific thread.
func (h *AgentHandler) GetMemoryHandler(c echo.Context) error {
ctx := c.Request().Context()
// Get threadID from the request (e.g., query parameter or header)
threadID := c.Param("thread_id")
if threadID == "" {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Missing thread_id"})
}
log := logger.GetLogger()
// Retrieve memory from AgentManager
memory, err := h.AgentManager.RetrieveMemory(ctx, threadID)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
if len(memory) == 0 {
log.Debug().Msgf("Handler: No memory found for thread ID: %s", threadID)
return c.JSON(http.StatusOK, map[string]string{"memory": "Memory is empty"})
}
return c.JSON(http.StatusOK, map[string]any{"memory": memory})
}
// AddDocumentHandler processes and stores documents with chunking and metadata.
func (h *AgentHandler) AddDocumentHandler(c echo.Context) error {
type AddDocumentRequest struct {
Content interface{} `json:"page_content"` // Accepts string or JSON
Metadata map[string]string `json:"metadata"`
}
var req AddDocumentRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request payload"})
}
// Convert Content to a string
content, err := agents.ConvertContentToString(req.Content)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
// Create a document to add
doc := schema.Document{
PageContent: content,
Metadata: agents.ConvertMetadata(req.Metadata),
}
// Add document to Weaviate vector store
ctx := c.Request().Context()
docID, err := h.AgentManager.VectorStore.AddDocuments(ctx, []schema.Document{doc})
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to add document"})
}
return c.JSON(http.StatusOK, map[string]string{"doc_id": docID[0]})
}
// ImportMemoryHandler imports a dataset into the vector store.
func (h *AgentHandler) ImportMemoryHandler(c echo.Context) error {
type ImportRequest struct {
FilePath string `json:"file_path"`
UserID string `json:"user_id"`
OrgID string `json:"org_id"`
}
var req ImportRequest
if err := c.Bind(&req); err != nil || req.FilePath == "" || req.UserID == "" || req.OrgID == "" {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "Invalid request payload. 'file_path', 'user_id', and 'org_id' are required.",
})
}
// Open the JSON file
file, err := os.Open(req.FilePath)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to open the file."})
}
defer file.Close()
// Parse the JSON file
var dataset []map[string]interface{}
if err := json.NewDecoder(file).Decode(&dataset); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid JSON file format."})
}
// Convert dataset into schema.Document format
docs := []schema.Document{}
for _, item := range dataset {
if summary, ok := item["summary"].(string); ok {
doc := schema.Document{
PageContent: summary,
Metadata: item,
}
docs = append(docs, doc)
}
}
// Add documents to vector store with userID and orgID
ctx := context.Background()
if err := h.AgentManager.AddDocuments(ctx, docs, req.UserID, req.OrgID); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "Failed to add documents to vector store."})
}
return c.JSON(http.StatusOK, map[string]string{"message": "Dataset imported successfully."})
}
// /internal/router/router.go
package router
import (
"github.com/blog/conversational-agent/internal/handlers"
"github.com/labstack/echo/v4"
)
// RegisterRoutes registers all application routes
func RegisterRoutes(e *echo.Echo, agentHandler *handlers.AgentHandler) {
e.POST("/v1/agent/query/:org_id/:user_id/:thread_id", agentHandler.QueryHandler)
e.GET("/v1/agent/memory/thread/:thread_id", agentHandler.GetMemoryHandler)
e.POST("/v1/agent/memory/update", agentHandler.AddDocumentHandler)
e.POST("/v1/agent/memory/import/:org_id/:user_id", agentHandler.ImportMemoryHandler)
}
// /cmd/main.go
package main
import (
"github.com/blog/conversational-agent/internal/agents"
"github.com/blog/conversational-agent/internal/config"
"github.com/blog/conversational-agent/internal/handlers"
"github.com/blog/conversational-agent/internal/logger"
"github.com/blog/conversational-agent/internal/middleware"
"github.com/blog/conversational-agent/internal/router"
"github.com/labstack/echo/v4"
)
func main() {
// Initialize Echo server
e := echo.New()
log := logger.GetLogger()
// Load configuration
cfg, err := config.LoadConfig()
if err != nil {
log.Fatal().Err(err).Msg("Failed to load configuration")
}
// Load middleware
e.Use(middleware.LoggingMiddleware())
// Create AgentManager
agentManager, err := agents.NewAgentManager(
cfg.OpenAIAPIKey,
cfg.OpenAIModel,
cfg.WeaviateHost,
cfg.WeaviateAPIKey,
cfg.WeaviateIndexName,
5, // maxBufferMessages
)
if err != nil {
log.Fatal().Err(err).Msg("Failed to initialize agent manager")
}
// Initialize handlers
agentHandler := &handlers.AgentHandler{
AgentManager: agentManager,
}
// Register routes
router.RegisterRoutes(e, agentHandler)
// Start the server
log.Info().Msg("Server is starting on port 8080")
if err := e.Start(":8080"); err != nil {
log.Fatal().Err(err).Msg("Server failed to start")
}
}
# build
go build -o agent cmd/main.go
#run
./agent
# endpoint /v1/agent/query/:org_id/:user_id/:thread_id
# non-streaming
curl -X POST http://localhost:8080/v1/agent/query/myorg/myuser/thread1 \
-H "Content-Type: application/json" \
-d '{
"query": "What is the capital of California?",
"stream": false
}'
# ask a follow up question
curl -X POST http://localhost:8080/v1/agent/query/myorg/myuser/thread1 \
-H "Content-Type: application/json" \
-d '{
"query": "What is the population size?",
"stream": false
}'
# with streaming
curl -X POST http://localhost:8080/v1/agent/query/myorg/myuser/thread2 \
-H "Content-Type: application/json" \
-d '{
"query": "Tell me a short story",
"stream": true
}'