Source code for cyber_query_ai.rag
"""RAG system for the CyberQueryAI application."""
from __future__ import annotations
import json
from pathlib import Path
from langchain_community.document_loaders import TextLoader
from langchain_core.documents import Document
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_ollama import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pydantic import BaseModel
[docs]
class RAGSystem:
"""RAG (Retrieval-Augmented Generation) system for cybersecurity documentation."""
[docs]
def __init__(self, model: str, embedding_model: str, tools_json_filepath: Path) -> None:
"""Initialize the RAG system."""
self.model = model
self.embedding_model = embedding_model
self.tools_json_filepath = tools_json_filepath
self.embeddings = OllamaEmbeddings(model=embedding_model)
self.vector_store: InMemoryVectorStore | None = None
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
add_start_index=True,
)
[docs]
@classmethod
def create(cls, model: str, embedding_model: str, tools_json_filepath: Path) -> RAGSystem:
"""Create and initialize the RAG system."""
rag_system = cls(model=model, embedding_model=embedding_model, tools_json_filepath=tools_json_filepath)
rag_system.create_vector_store()
return rag_system
[docs]
def load_documents(self) -> list[Document]:
"""Load all text documents from the rag_data directory with JSON metadata."""
documents: list[Document] = []
if not self.tools_json_filepath.exists():
return documents
tool_suite = ToolSuite.from_json(str(self.tools_json_filepath))
# Load all .txt files from rag_data directory
for txt_file in self.tools_json_filepath.parent.glob("*.txt"):
loader = TextLoader(str(txt_file), encoding="utf-8")
docs = loader.load()
# Find metadata for this file
for tool in tool_suite.tools.values():
if tool.file == txt_file.name:
# Add metadata to each document
for doc in docs:
doc.metadata.update(tool.metadata_dict)
documents.extend(docs)
return documents
[docs]
def create_vector_store(self) -> None:
"""Create or return existing vector store."""
# Create vector store
self.vector_store = InMemoryVectorStore(self.embeddings)
# Load and split documents
if documents := self.load_documents():
if splits := self.text_splitter.split_documents(documents):
self.vector_store.add_documents(splits)
[docs]
def format_context(self, documents: list[Document]) -> str:
"""Format retrieved documents into a context string with rich metadata."""
if not documents:
return ""
context_parts = []
for doc in documents:
tool = doc.metadata.get("tool", "unknown")
source = doc.metadata.get("source", "unknown")
content = doc.page_content.strip()
# Build header with metadata
header_parts = [f"Tool: {tool}"]
if category := doc.metadata.get("category", ""):
header_parts.append(f"Category: {category}")
if subcategory := doc.metadata.get("subcategory", ""):
header_parts.append(f"Subcategory: {subcategory}")
if description := doc.metadata.get("description", ""):
header_parts.append(f"Description: {description}")
if tags := doc.metadata.get("tags", []):
header_parts.append(f"Tags: {', '.join(tags)}")
if use_cases := doc.metadata.get("use_cases", []):
header_parts.append(f"Use Cases: {', '.join(use_cases)}")
header = " | ".join(header_parts)
context_parts.append(f"[{header}]\nSource: {source}\n\n{content}")
return "\n\n" + "=" * 80 + "\n\n".join(context_parts)
[docs]
def get_context_for_template(self, query: str) -> str:
"""Get RAG context for a specific query."""
if not self.vector_store:
return ""
if relevant_docs := self.vector_store.similarity_search(query, k=3):
return self.format_context(relevant_docs)
return ""
[docs]
def generate_rag_content(self, query: str) -> str:
"""Generate RAG context."""
if rag_context := self.get_context_for_template(query):
return (
f"\nRELEVANT DOCUMENTATION:\n"
f"{rag_context.replace('{', '{{').replace('}', '}}')}\n\n"
f"Use the above documentation to provide more accurate and detailed responses. "
f"Reference specific tool options, syntax, and examples from the documentation when relevant.\n\n"
)
return ""