Source code for cyber_query_ai.server

"""Server for the CyberQueryAI application."""

import json
import logging

from fastapi import HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from python_template_server.models import BaseResponse, ResponseCode
from python_template_server.template_server import TemplateServer

from cyber_query_ai.chatbot import Chatbot
from cyber_query_ai.helpers import (
    clean_json_response,
    get_rag_tools_path,
    sanitize_text,
)
from cyber_query_ai.models import (
    CyberQueryAIConfig,
    GetApiConfigResponse,
    PostChatRequest,
    PostChatResponse,
    PostCodeExplanationResponse,
    PostCodeGenerationResponse,
    PostExploitSearchResponse,
    PostPromptRequest,
)

logger = logging.getLogger(__name__)

CHAT_FIELDS = PostChatResponse.model_fields.keys() - BaseResponse.model_fields.keys()
CODE_GENERATE_FIELDS = PostCodeGenerationResponse.model_fields.keys() - BaseResponse.model_fields.keys()
CODE_EXPLAIN_FIELDS = PostCodeExplanationResponse.model_fields.keys() - BaseResponse.model_fields.keys()
EXPLOIT_SEARCH_FIELDS = PostExploitSearchResponse.model_fields.keys() - BaseResponse.model_fields.keys()


[docs] class CyberQueryAIServer(TemplateServer): """AI chatbot server application inheriting from TemplateServer."""
[docs] def __init__(self, config: CyberQueryAIConfig | None = None) -> None: """Initialise the CyberQueryAIServer by delegating to the template server. :param CyberQueryAIConfig config: CyberQueryAI server configuration """ self.config: CyberQueryAIConfig super().__init__( package_name="cyber-query-ai", config=config, ) self.chatbot = Chatbot( model=self.config.model.model, embedding_model=self.config.model.embedding_model, tools_json_filepath=get_rag_tools_path(), ) logger.info( "Initialized Chatbot with LLMs: %s & %s", self.config.model.model, self.config.model.embedding_model )
[docs] @staticmethod def parse_response(response_str: str) -> dict: """Parse the LLM response string into a dictionary. :param str response_str: LLM response string :return dict: Parsed response dictionary :raises json.JSONDecodeError: If the response cannot be parsed as JSON """ cleaned_response = clean_json_response(response_str) return json.loads(cleaned_response) # type: ignore[no-any-return]
[docs] @staticmethod def validate_keys(required_keys: set[str], response_dict: dict) -> None: """Validate that all required keys are present in the response dictionary. :param set[str] required_keys: Set of required keys :param dict response_dict: Response dictionary to validate :raises KeyError: If any required keys are missing """ if missing_keys := list(required_keys - response_dict.keys()): msg = f"Missing required keys in LLM response: {missing_keys}" raise KeyError(msg)
[docs] def validate_config(self, config_data: dict) -> CyberQueryAIConfig: """Validate and parse the configuration data into a CyberQueryAIConfig. :param dict config_data: Raw configuration data :return CyberQueryAIConfig: Validated CyberQueryAI server configuration """ return CyberQueryAIConfig.model_validate(config_data) # type: ignore[no-any-return]
[docs] def setup_routes(self) -> None: """Set up API routes.""" self.add_route( endpoint="/config", handler_function=self.get_api_config, response_model=GetApiConfigResponse, methods=["GET"], limited=False, authentication_required=False, ) self.add_route( endpoint="/model/chat", handler_function=self.post_chat, response_model=PostChatResponse, methods=["POST"], limited=True, authentication_required=True, ) self.add_route( endpoint="/code/generate", handler_function=self.post_generate_code, response_model=PostCodeGenerationResponse, methods=["POST"], limited=True, authentication_required=True, ) self.add_route( endpoint="/code/explain", handler_function=self.post_explain_code, response_model=PostCodeExplanationResponse, methods=["POST"], limited=True, authentication_required=True, ) self.add_route( endpoint="/exploit/search", handler_function=self.post_exploit_search, response_model=PostExploitSearchResponse, methods=["POST"], limited=True, authentication_required=True, )
[docs] async def get_api_config(self, request: Request) -> GetApiConfigResponse: """Get the API configuration including model configuration and version.""" logger.info("Received request for API configuration.") return GetApiConfigResponse( message="Successfully retrieved chatbot configuration.", model=self.config.model, version=self.package_metadata["Version"], )
[docs] async def post_chat(self, request: Request) -> PostChatResponse: """Chat with the AI assistant using conversation history.""" chat_request = PostChatRequest.model_validate(await request.json()) logger.info("Received chat request: %s", chat_request.message) history_text = "" for msg in chat_request.history: history_text += f"{msg.role}: {msg.content}\n" formatted_prompt = sanitize_text(self.chatbot.prompt_chat(chat_request.message, history_text)) try: model_response = await run_in_threadpool(self.chatbot.llm.invoke, formatted_prompt) parsed = self.parse_response(str(model_response.content)) self.validate_keys(CHAT_FIELDS, parsed) logger.info("Successfully generated chat response.") return PostChatResponse( message="Successfully generated chat response.", model_message=parsed["model_message"], ) except json.JSONDecodeError as e: error_msg = f"Invalid JSON response from LLM: {model_response.content}" logger.exception(error_msg) raise HTTPException( status_code=ResponseCode.INTERNAL_SERVER_ERROR, detail=error_msg, ) from e except KeyError as e: error_msg = "LLM response missing required keys." logger.exception(error_msg) raise HTTPException( status_code=ResponseCode.INTERNAL_SERVER_ERROR, detail=error_msg, ) from e except Exception as e: error_msg = "An unexpected error occurred during chat." logger.exception(error_msg) raise HTTPException( status_code=ResponseCode.INTERNAL_SERVER_ERROR, detail=error_msg, ) from e
[docs] async def post_generate_code(self, request: Request) -> PostCodeGenerationResponse: """Generate cybersecurity code based on user prompt.""" prompt_request = PostPromptRequest.model_validate(await request.json()) logger.info("Received code generation request: %s", prompt_request.prompt) formatted_prompt = sanitize_text(self.chatbot.prompt_code_generation(prompt_request.prompt)) try: model_response = await run_in_threadpool(self.chatbot.llm.invoke, formatted_prompt) parsed = self.parse_response(str(model_response.content)) self.validate_keys(CODE_GENERATE_FIELDS, parsed) logger.info("Successfully generated code.") return PostCodeGenerationResponse( message="Successfully generated code.", generated_code=parsed["generated_code"], explanation=parsed["explanation"], language=parsed["language"], ) except json.JSONDecodeError as e: error_msg = f"Invalid JSON response from LLM: {model_response.content}" logger.exception(error_msg) raise HTTPException( status_code=ResponseCode.INTERNAL_SERVER_ERROR, detail=error_msg, ) from e except KeyError as e: error_msg = "LLM response missing required keys." logger.exception(error_msg) raise HTTPException( status_code=ResponseCode.INTERNAL_SERVER_ERROR, detail=error_msg, ) from e except Exception as e: error_msg = "An unexpected error occurred during code generation." logger.exception(error_msg) raise HTTPException( status_code=ResponseCode.INTERNAL_SERVER_ERROR, detail=error_msg, ) from e
[docs] async def post_explain_code(self, request: Request) -> PostCodeExplanationResponse: """Explain code step-by-step.""" prompt_request = PostPromptRequest.model_validate(await request.json()) logger.info("Received code explanation request: %s", prompt_request.prompt) formatted_prompt = sanitize_text(self.chatbot.prompt_code_explanation(prompt_request.prompt)) try: model_response = await run_in_threadpool(self.chatbot.llm.invoke, formatted_prompt) parsed = self.parse_response(str(model_response.content)) self.validate_keys(CODE_EXPLAIN_FIELDS, parsed) logger.info("Successfully explained code.") return PostCodeExplanationResponse( code=ResponseCode.OK, message="Successfully explained code.", explanation=parsed["explanation"], ) except json.JSONDecodeError as e: error_msg = f"Invalid JSON response from LLM: {model_response.content}" logger.exception(error_msg) raise HTTPException( status_code=ResponseCode.INTERNAL_SERVER_ERROR, detail=error_msg, ) from e except KeyError as e: error_msg = "LLM response missing required keys." logger.exception(error_msg) raise HTTPException( status_code=ResponseCode.INTERNAL_SERVER_ERROR, detail=error_msg, ) from e except Exception as e: error_msg = "An unexpected error occurred during code explanation." logger.exception(error_msg) raise HTTPException( status_code=ResponseCode.INTERNAL_SERVER_ERROR, detail=error_msg, ) from e